Skip to content

Commit ef6ae49

Browse files
authored
feat: unify all Daft type to Python type conversions (#4972)
## Changes Made We used to have five separate code paths for converting Daft types to Python: 1. Literal -> Python 2. `Series.to_pylist` 3. `Series.__iter__` 4. the SQL planner's `lit_to_py_any` 5. Series casting logic This PR makes all of them go through the literal to Python code path, making the conversion logic consistent across our library. Also added extension type to literal. TODO: - [x] add conversion mapping to docs ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [x] Documented in API Docs (if applicable) - [x] Documented in User Guide (if applicable) - [x] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [x] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 23549cd commit ef6ae49

File tree

24 files changed

+408
-472
lines changed

24 files changed

+408
-472
lines changed

Cargo.lock

Lines changed: 13 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ common-error = {path = "src/common/error", default-features = false}
235235
common-file-formats = {path = "src/common/file-formats"}
236236
common-image = {path = "src/common/image"}
237237
common-metrics = {path = "src/common/metrics"}
238+
common-ndarray = {path = "src/common/ndarray"}
238239
common-runtime = {path = "src/common/runtime", default-features = false}
239240
daft-algebra = {path = "src/daft-algebra"}
240241
daft-context = {path = "src/daft-context"}
@@ -273,9 +274,11 @@ jaq-core = "2.2.0"
273274
jaq-json = {version = "1.1.2", features = ["serde_json"]}
274275
jaq-std = "2.1.1"
275276
mur3 = "0.1.0"
277+
ndarray = "0.16.1"
276278
num-derive = "0.4.2"
277279
num-format = "0.4.4"
278280
num-traits = "0.2"
281+
numpy = "0.25.0"
279282
opentelemetry = {version = "0.29", features = ["trace", "metrics"]}
280283
opentelemetry-otlp = {version = "0.29", features = ["grpc-tonic"]}
281284
opentelemetry_sdk = "0.29"
@@ -348,7 +351,7 @@ features = ['async']
348351
path = "src/parquet2"
349352

350353
[workspace.dependencies.pyo3]
351-
features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap", "chrono"]
354+
features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap", "chrono", "chrono-tz"]
352355
version = "0.25.1"
353356

354357
[workspace.dependencies.pyo3-async-runtimes]

daft/daft/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,8 @@ class PySeries:
13271327
def from_pylist(name: str, pylist: list[Any], dtype: PyDataType) -> PySeries: ...
13281328
def to_pylist(self) -> list[Any]: ...
13291329
def to_arrow(self) -> pa.Array: ...
1330+
def __iter__(self) -> PySeriesIterator: ...
1331+
def __getitem__(self, index: int) -> Any: ...
13301332
def __abs__(self) -> PySeries: ...
13311333
def __add__(self, other: PySeries) -> PySeries: ...
13321334
def __sub__(self, other: PySeries) -> PySeries: ...
@@ -1402,6 +1404,10 @@ class PySeries:
14021404
@staticmethod
14031405
def _debug_bincode_deserialize(b: bytes) -> PySeries: ...
14041406

1407+
class PySeriesIterator:
1408+
def __next__(self) -> Any: ...
1409+
def __iter__(self) -> PySeriesIterator: ...
1410+
14051411
class PyShowOptions:
14061412
pass
14071413

daft/series.py

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,18 @@
55

66
import daft.daft as native
77
from daft.arrow_utils import ensure_array, ensure_chunked_array
8-
from daft.daft import CountMode, ImageFormat, ImageMode, PyRecordBatch, PySeries
8+
from daft.daft import CountMode, ImageFormat, ImageMode, PyRecordBatch, PySeries, PySeriesIterator
99
from daft.datatype import DataType, TimeUnit, _ensure_registered_super_ext_type
1010
from daft.dependencies import np, pa, pd
1111
from daft.schema import Field
1212
from daft.utils import pyarrow_supports_fixed_shape_tensor
1313

1414
if TYPE_CHECKING:
1515
import builtins
16-
from collections.abc import Iterator
1716

1817
from daft.daft import PyDataType
1918

2019

21-
class SeriesIterable:
22-
"""Iterable wrapper for Series that efficiently handles different data types."""
23-
24-
def __init__(self, series: Series):
25-
self.series = series
26-
27-
def __iter__(self) -> Iterator[Any]:
28-
dt = self.series.datatype()
29-
if dt == DataType.python():
30-
31-
def yield_pylist() -> Iterator[Any]:
32-
yield from self.series._series.to_pylist()
33-
34-
return yield_pylist()
35-
elif dt._should_cast_to_python():
36-
37-
def yield_pylist() -> Iterator[Any]:
38-
yield from self.series._series.cast(DataType.python()._dtype).to_pylist()
39-
40-
return yield_pylist()
41-
else:
42-
43-
def arrow_to_py() -> Iterator[Any]:
44-
# We directly call .to_arrow() on the internal PySeries object since the case
45-
# above has already captured the fixed shape tensor case.
46-
arrow_data = self.series._series.to_arrow()
47-
for item in arrow_data:
48-
yield None if item is None else item.as_py()
49-
50-
return arrow_to_py()
51-
52-
5320
class Series:
5421
"""A Daft Series is an array of data of a single type, and is usually a column in a DataFrame."""
5522

@@ -58,10 +25,9 @@ class Series:
5825
def __init__(self) -> None:
5926
raise NotImplementedError("We do not support creating a Series via __init__ ")
6027

61-
def __iter__(self) -> Iterator[Any]:
28+
def __iter__(self) -> PySeriesIterator:
6229
"""Return an iterator over the elements of the Series."""
63-
iterable = SeriesIterable(self)
64-
return iterable.__iter__()
30+
return self._series.__iter__()
6531

6632
@staticmethod
6733
def _from_pyseries(pyseries: PySeries) -> Series:
@@ -227,19 +193,6 @@ def from_pandas(cls, data: pd.Series[Any], name: str = "pd_series", dtype: DataT
227193
def cast(self, dtype: DataType) -> Series:
228194
return Series._from_pyseries(self._series.cast(dtype._dtype))
229195

230-
def _cast_to_python(self) -> Series:
231-
"""Convert this Series into a Series of Python objects.
232-
233-
Call Series.to_pylist() and create a new Series from the raw Pylist directly.
234-
235-
This logic is needed by the Rust implementation of cast(),
236-
but is written here (i.e. not in Rust) for conciseness.
237-
238-
Do not call this method directly in Python; call cast() instead.
239-
"""
240-
pylist = self.to_pylist()
241-
return Series.from_pylist(pylist, self.name(), dtype=DataType.python())
242-
243196
def _pycast_to_pynative(self, typefn: type, dtype: PyDataType) -> Series:
244197
"""Apply Python-level casting to this Series.
245198
@@ -290,12 +243,7 @@ def to_arrow(self) -> pa.Array:
290243

291244
def to_pylist(self) -> list[Any]:
292245
"""Convert this Series to a Python list."""
293-
if self.datatype().is_python():
294-
return self._series.to_pylist()
295-
elif self.datatype()._should_cast_to_python():
296-
return self._series.cast(DataType.python()._dtype).to_pylist()
297-
else:
298-
return self._series.to_arrow().to_pylist()
246+
return self._series.to_pylist()
299247

300248
def filter(self, mask: Series) -> Series:
301249
if not isinstance(mask, Series):
@@ -339,9 +287,12 @@ def murmur3_32(self) -> Series:
339287
def __repr__(self) -> str:
340288
return repr(self._series)
341289

290+
def __getitem__(self, index: int) -> Any:
291+
return self._series[index]
292+
342293
def __bool__(self) -> bool:
343294
raise ValueError(
344-
"Series don't have a truth value." "If you reached this error using `and` / `or`, use `&` / `|` instead."
295+
"Series don't have a truth value. If you reached this error using `and` / `or`, use `&` / `|` instead."
345296
)
346297

347298
def __len__(self) -> int:

docs/SUMMARY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@
5656
* [Expressions](api/expressions.md)
5757
* [Functions](api/functions/)
5858
* [User-Defined Functions](api/udf.md)
59+
* [Data Types](api/datatypes.md)
5960
* [Window](api/window.md)
6061
* [Sessions](api/sessions.md)
6162
* [Catalogs & Tables](api/catalogs_tables.md)
6263
* [Schema](api/schema.md)
63-
* [Data Types](api/datatypes.md)
6464
* [Aggregations](api/aggregations.md)
6565
* [Series](api/series.md)
6666
* [Spark Connect](api/spark_connect.md)

docs/api/datatypes.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
11
# DataTypes
22

3+
## Type Conversions
4+
5+
### Daft to Python
6+
7+
<!-- Note: the conversions here should match the behavior of the Rust `impl IntoPyObject for Literal`: `src/daft-core/src-lit/python.rs` -->
8+
9+
This table shows the mapping from Daft DataTypes to Python types, as done in places such as [`Series.to_pylist`][daft.series.Series.to_pylist], [`Expression.cast`][daft.expressions.Expression.cast] to Python type, and arguments passed into functions decorated with `@daft.func`.
10+
11+
| Daft DataType | Python Type |
12+
|----------------------------------------------------------------------|-------------------------------------------------------------------------------------|
13+
| Null | `None` |
14+
| Boolean | `bool` |
15+
| Utf8 | `str` |
16+
| Binary, FixedSizeBinary | `bytes` |
17+
| Int8, Uint8, Int16, UInt16, Int32, UInt32, Int64, UInt64 | `int` |
18+
| Timestamp | `datetime.datetime` |
19+
| Date | `datetime.date` |
20+
| Time | `datetime.time` |
21+
| Duration | `datetime.timedelta` |
22+
| Interval | not supported |
23+
| Float32, Float64 | `float` |
24+
| Decimal | `decimal.Decimal` |
25+
| List[T], FixedSizeList[T] | `list[T]` |
26+
| Struct \{<field_1\>: T1, <field_2\>: T2, ...\} | `dict[str, T1 | T2 | ...]` |
27+
| Map[K, V] | `dict[K, V]` |
28+
| Tensor[T], FixedShapeTensor[T] | `numpy.typing.NDArray[T]` |
29+
| SparseTensor[T], FixedShapeSparseTensor[T] | `{`<br>`"values": T,`<br>`"indices": list[int],`<br>`"shape": list[int]`<br>`}` |
30+
| Embedding[T] | `numpy.typing.NDArray[T]` |
31+
| Image | `numpy.typing.NDArray[numpy.uint8 | numpy.uint16 | numpy.float32]` |
32+
| Python | `Any` |
33+
| Extension[T] | `T` |
34+
35+
### Python to Daft
36+
TODO
37+
38+
## daft.DataType
39+
340
Daft provides simple DataTypes that are ubiquituous in many DataFrames such as numbers, strings and dates - all the way up to more complex types like tensors and images. Learn more about [DataTypes](../core_concepts.md#datatypes) in Daft User Guide.
441

542
::: daft.datatype.DataType

src/common/image/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
arrow2 = {workspace = true}
33
bincode = {workspace = true}
44
common-error = {path = "../error", default-features = false}
5+
common-ndarray = {workspace = true}
56
daft-schema = {path = "../../daft-schema", default-features = false}
67
image = {workspace = true, features = [
78
"gif",
@@ -13,6 +14,7 @@ image = {workspace = true, features = [
1314
"bmp",
1415
"hdr"
1516
]}
17+
ndarray = {workspace = true}
1618
serde = {workspace = true}
1719

1820
[lints]

src/common/image/src/image.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,47 @@
11
use core::slice;
22
use std::{hash::Hash, ops::Deref};
33

4-
use image::{DynamicImage, ImageBuffer};
4+
use common_ndarray::NdArray;
5+
use image::{flat::SampleLayout, DynamicImage, ImageBuffer, Pixel};
6+
use ndarray::{Array3, ShapeBuilder};
57
use serde::{Deserialize, Deserializer, Serialize, Serializer};
68

79
/// Wrapper around image::DynamicImage to implement certain traits
810
#[derive(Debug, Clone, PartialEq)]
911
pub struct Image(pub DynamicImage);
1012

13+
impl Image {
14+
pub fn into_ndarray(self) -> Box<dyn NdArray> {
15+
fn into_ndarray3<P: Pixel>(buf: ImageBuffer<P, Vec<P::Subpixel>>) -> Array3<P::Subpixel> {
16+
let SampleLayout {
17+
channels,
18+
channel_stride,
19+
height,
20+
height_stride,
21+
width,
22+
width_stride,
23+
} = buf.sample_layout();
24+
let shape = (height as usize, width as usize, channels as usize);
25+
let strides = (height_stride, width_stride, channel_stride);
26+
Array3::from_shape_vec(shape.strides(strides), buf.into_raw()).unwrap()
27+
}
28+
29+
match self.0 {
30+
DynamicImage::ImageLuma8(buf) => Box::new(into_ndarray3(buf).into_dyn()),
31+
DynamicImage::ImageLumaA8(buf) => Box::new(into_ndarray3(buf).into_dyn()),
32+
DynamicImage::ImageRgb8(buf) => Box::new(into_ndarray3(buf).into_dyn()),
33+
DynamicImage::ImageRgba8(buf) => Box::new(into_ndarray3(buf).into_dyn()),
34+
DynamicImage::ImageLuma16(buf) => Box::new(into_ndarray3(buf).into_dyn()),
35+
DynamicImage::ImageLumaA16(buf) => Box::new(into_ndarray3(buf).into_dyn()),
36+
DynamicImage::ImageRgb16(buf) => Box::new(into_ndarray3(buf).into_dyn()),
37+
DynamicImage::ImageRgba16(buf) => Box::new(into_ndarray3(buf).into_dyn()),
38+
DynamicImage::ImageRgb32F(buf) => Box::new(into_ndarray3(buf).into_dyn()),
39+
DynamicImage::ImageRgba32F(buf) => Box::new(into_ndarray3(buf).into_dyn()),
40+
_ => unimplemented!("unsupported DynamicImage variant"),
41+
}
42+
}
43+
}
44+
1145
impl Deref for Image {
1246
type Target = DynamicImage;
1347

@@ -51,7 +85,7 @@ impl Hash for Image {
5185
};
5286
buffer_slice.hash(state);
5387
}
54-
_ => todo!(),
88+
_ => unimplemented!("unsupported DynamicImage variant"),
5589
}
5690
}
5791
}

src/common/ndarray/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[dependencies]
2+
ndarray = {workspace = true}
3+
numpy = {workspace = true, optional = true}
4+
pyo3 = {workspace = true, optional = true}
5+
6+
[features]
7+
python = ["dep:pyo3", "dep:numpy"]
8+
9+
[lints]
10+
workspace = true
11+
12+
[package]
13+
edition = {workspace = true}
14+
name = "common-ndarray"
15+
version = {workspace = true}

src/common/ndarray/src/lib.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use ndarray::ArrayD;
2+
#[cfg(feature = "python")]
3+
use pyo3::{Bound, PyAny, Python};
4+
5+
/// Trait to allow dynamic dispatch over ndarray arrays of any type.
6+
pub trait NdArray {
7+
#[cfg(feature = "python")]
8+
fn into_py(self: Box<Self>, py: Python) -> Bound<PyAny>;
9+
}
10+
11+
#[cfg(not(feature = "python"))]
12+
impl<A> NdArray for ArrayD<A> {}
13+
14+
#[cfg(feature = "python")]
15+
impl<A> NdArray for ArrayD<A>
16+
where
17+
A: numpy::Element,
18+
{
19+
fn into_py(self: Box<Self>, py: Python) -> Bound<PyAny> {
20+
use numpy::IntoPyArray;
21+
22+
self.into_pyarray(py).into_any()
23+
}
24+
}

0 commit comments

Comments
 (0)