Skip to content

Commit 6a8767c

Browse files
authored
feat: unnest param on @daft.func (#5132)
## Changes Made Added the `unnest` argument on `@daft.func`, allowing for the automatic unnesting of multiple return values. Example: ```py >>> import daft >>> from daft import DataType >>> >>> @daft.func(return_dtype=DataType.struct({"int": DataType.int64(), "str": DataType.string()}), unnest=True) ... def my_multi_return(val: int): ... return {"int": val * 2, "str": str(val) * 2} >>> df = daft.from_pydict({"x": [1, 2, 3]}) >>> df.select(my_multi_return(df["x"])).collect() ╭───────┬──────╮ │ int ┆ str │ │ --- ┆ --- │ │ Int64 ┆ Utf8 │ ╞═══════╪══════╡ │ 2 ┆ 11 │ ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ │ 4 ┆ 22 │ ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ │ 6 ┆ 33 │ ╰───────┴──────╯ (Showing first 3 of 3 rows) ``` ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [x] Documented in API Docs (if applicable) - [x] Documented in User Guide (if applicable) - [x] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [x] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 798dce8 commit 6a8767c

File tree

5 files changed

+130
-11
lines changed

5 files changed

+130
-11
lines changed

daft/udf/__init__.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class _PartialUdf:
2727
"""Helper class to provide typing overloads for using `daft.func` as a decorator."""
2828

2929
return_dtype: DataTypeLike | None
30+
unnest: bool
3031

3132
@overload
3233
def __call__(self, fn: Callable[P, Iterator[T]]) -> GeneratorUdf[P, T]: ... # type: ignore[overload-overlap]
@@ -35,9 +36,9 @@ def __call__(self, fn: Callable[P, T]) -> RowWiseUdf[P, T]: ...
3536

3637
def __call__(self, fn: Callable[P, Any]) -> GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
3738
if isgeneratorfunction(fn):
38-
return GeneratorUdf(fn, return_dtype=self.return_dtype)
39+
return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
3940
else:
40-
return RowWiseUdf(fn, return_dtype=self.return_dtype)
41+
return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
4142

4243

4344
class _DaftFuncDecorator:
@@ -54,6 +55,7 @@ class _DaftFuncDecorator:
5455
5556
Args:
5657
return_dtype: The data type that this function should return or yield. If not specified, it is derived from the function's return type hint.
58+
unnest: Whether to unnest/flatten out return type fields into columns. Return dtype must be `DataType.struct` when this is set to true. Defaults to false.
5759
5860
Examples:
5961
Basic Example
@@ -184,21 +186,46 @@ class _DaftFuncDecorator:
184186
╰───────┴─────────╯
185187
<BLANKLINE>
186188
(Showing first 7 of 7 rows)
189+
190+
Unnesting multiple return fields
191+
192+
>>> import daft
193+
>>> from daft import DataType
194+
>>> @daft.func(return_dtype=DataType.struct({"int": DataType.int64(), "str": DataType.string()}), unnest=True)
195+
... def my_multi_return(val: int):
196+
... return {"int": val * 2, "str": str(val) * 2}
197+
>>> df = daft.from_pydict({"x": [1, 2, 3]})
198+
>>> df.select(my_multi_return(df["x"])).collect()
199+
╭───────┬──────╮
200+
│ int ┆ str │
201+
│ --- ┆ --- │
202+
│ Int64 ┆ Utf8 │
203+
╞═══════╪══════╡
204+
│ 2 ┆ 11 │
205+
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
206+
│ 4 ┆ 22 │
207+
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
208+
│ 6 ┆ 33 │
209+
╰───────┴──────╯
210+
<BLANKLINE>
211+
(Showing first 3 of 3 rows)
187212
"""
188213

189214
@overload
190-
def __new__(cls, *, return_dtype: DataTypeLike | None = None) -> _PartialUdf: ... # type: ignore[misc]
215+
def __new__(cls, *, return_dtype: DataTypeLike | None = None, unnest: bool = False) -> _PartialUdf: ... # type: ignore[misc]
191216
@overload
192217
def __new__( # type: ignore[misc]
193-
cls, fn: Callable[P, Iterator[T]], *, return_dtype: DataTypeLike | None = None
218+
cls, fn: Callable[P, Iterator[T]], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
194219
) -> GeneratorUdf[P, T]: ...
195220
@overload
196-
def __new__(cls, fn: Callable[P, T], *, return_dtype: DataTypeLike | None = None) -> RowWiseUdf[P, T]: ... # type: ignore[misc]
221+
def __new__( # type: ignore[misc]
222+
cls, fn: Callable[P, T], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
223+
) -> RowWiseUdf[P, T]: ...
197224

198225
def __new__( # type: ignore[misc]
199-
cls, fn: Callable[P, Any] | None = None, *, return_dtype: DataTypeLike | None = None
226+
cls, fn: Callable[P, Any] | None = None, *, return_dtype: DataTypeLike | None = None, unnest: bool = False
200227
) -> _PartialUdf | GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
201-
partial_udf = _PartialUdf(return_dtype=return_dtype)
228+
partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest)
202229
return partial_udf if fn is None else partial_udf(fn)
203230

204231

daft/udf/generator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ class GeneratorUdf(Generic[P, T]):
3232
If no values are yielded for an input, a null value is inserted.
3333
"""
3434

35-
def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None):
35+
def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool):
3636
self._inner = fn
3737
self.name = get_unique_function_name(fn)
38+
self.unnest = unnest
3839

3940
# attempt to extract return type from an Iterator or Generator type hint
4041
if return_dtype is None:
@@ -56,6 +57,11 @@ def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | No
5657
return_dtype = args[0]
5758
self.return_dtype = DataType._infer_type(return_dtype)
5859

60+
if self.unnest and not self.return_dtype.is_struct():
61+
raise ValueError(
62+
f"Expected Daft function `return_dtype` to be `DataType.struct` when `unnest=True`, instead found: {self.return_dtype}"
63+
)
64+
5965
@overload
6066
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Iterator[T]: ...
6167
@overload
@@ -78,6 +84,11 @@ def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]:
7884

7985
return_dtype_rowwise = DataType.list(self.return_dtype)
8086

81-
return Expression._from_pyexpr(
87+
expr = Expression._from_pyexpr(
8288
row_wise_udf(self.name, inner_rowwise, return_dtype_rowwise._dtype, (args, kwargs), expr_args)
8389
).explode()
90+
91+
if self.unnest:
92+
expr = expr.unnest()
93+
94+
return expr

daft/udf/row_wise.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ class RowWiseUdf(Generic[P, T]):
3333
Row-wise functions are called with data from one row at a time, and map that to a single output value for that row.
3434
"""
3535

36-
def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None):
36+
def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool):
3737
self._inner = fn
3838
self.name = get_unique_function_name(fn)
39+
self.unnest = unnest
3940

4041
if return_dtype is None:
4142
type_hints = get_type_hints(fn)
@@ -47,6 +48,11 @@ def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None):
4748
return_dtype = type_hints["return"]
4849
self.return_dtype = DataType._infer_type(return_dtype)
4950

51+
if self.unnest and not self.return_dtype.is_struct():
52+
raise ValueError(
53+
f"Expected Daft function `return_dtype` to be `DataType.struct` when `unnest=True`, instead found: {self.return_dtype}"
54+
)
55+
5056
@overload
5157
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
5258
@overload
@@ -62,10 +68,15 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
6268
if len(expr_args) == 0:
6369
return self._inner(*args, **kwargs)
6470

65-
return Expression._from_pyexpr(
71+
expr = Expression._from_pyexpr(
6672
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, (args, kwargs), expr_args)
6773
)
6874

75+
if self.unnest:
76+
expr = expr.unnest()
77+
78+
return expr
79+
6980

7081
def __call_async_batch(
7182
fn: Callable[..., Awaitable[Any]],

tests/udf/test_generator_udf.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import collections.abc
44
import typing
55

6+
import pytest
7+
68
import daft
79

810

@@ -85,3 +87,44 @@ def my_gen_func(input: int) -> collections.abc.Generator[str, None, None]:
8587
df = df.select(my_gen_func(df["input"]).alias("output"))
8688

8789
assert df.schema() == daft.Schema.from_pydict({"output": daft.DataType.string()})
90+
91+
92+
def test_generator_udf_unnest():
93+
@daft.func(
94+
return_dtype=daft.DataType.struct({"id": daft.DataType.int64(), "value": daft.DataType.string()}), unnest=True
95+
)
96+
def create_records(count: int, base_value: str):
97+
for i in range(count):
98+
yield {"id": i, "value": f"{base_value}_{i}"}
99+
100+
df = daft.from_pydict({"count": [2, 3, 1], "base": ["a", "b", "c"]})
101+
result = df.select(create_records(df["count"], df["base"])).to_pydict()
102+
103+
expected = {"id": [0, 1, 0, 1, 2, 0], "value": ["a_0", "a_1", "b_0", "b_1", "b_2", "c_0"]}
104+
assert result == expected
105+
106+
107+
def test_generator_udf_unnest_empty_generator():
108+
@daft.func(
109+
return_dtype=daft.DataType.struct({"x": daft.DataType.int64(), "y": daft.DataType.string()}), unnest=True
110+
)
111+
def empty_gen(n: int):
112+
if n > 0:
113+
yield {"x": n, "y": str(n)}
114+
115+
df = daft.from_pydict({"n": [0, 1, 2]})
116+
result = df.select(empty_gen(df["n"])).to_pydict()
117+
118+
expected = {"x": [None, 1, 2], "y": [None, "1", "2"]}
119+
assert result == expected
120+
121+
122+
def test_generator_udf_unnest_error_non_struct():
123+
with pytest.raises(
124+
ValueError, match="Expected Daft function `return_dtype` to be `DataType.struct` when `unnest=True`"
125+
):
126+
127+
@daft.func(return_dtype=daft.DataType.string(), unnest=True)
128+
def invalid_unnest_generator(n: int):
129+
for i in range(n):
130+
yield str(i)

tests/udf/test_row_wise_udf.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,30 @@ async def my_async_stringify_and_sum(a: int, b: int) -> str:
123123
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
124124
async_df = df.select(my_async_stringify_and_sum(col("x"), col("y")))
125125
assert async_df.to_pydict() == {"x": ["5", "7", "9"]}
126+
127+
128+
def test_row_wise_udf_unnest():
129+
@daft.func(
130+
return_dtype=daft.DataType.struct(
131+
{"id": daft.DataType.int64(), "name": daft.DataType.string(), "score": daft.DataType.float64()}
132+
),
133+
unnest=True,
134+
)
135+
def create_record(value: int):
136+
return {"id": value, "name": f"item_{value}", "score": value * 1.5}
137+
138+
df = daft.from_pydict({"value": [1, 2, 3]})
139+
result = df.select(create_record(col("value"))).to_pydict()
140+
141+
expected = {"id": [1, 2, 3], "name": ["item_1", "item_2", "item_3"], "score": [1.5, 3.0, 4.5]}
142+
assert result == expected
143+
144+
145+
def test_row_wise_udf_unnest_error_non_struct():
146+
with pytest.raises(
147+
ValueError, match="Expected Daft function `return_dtype` to be `DataType.struct` when `unnest=True`"
148+
):
149+
150+
@daft.func(return_dtype=daft.DataType.int64(), unnest=True)
151+
def invalid_unnest(a: int):
152+
return a

0 commit comments

Comments
 (0)