Skip to content

Commit 0d4ba08

Browse files
feat(optimizer): Add Lance count() pushdown optimization (#4969)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 950922d commit 0d4ba08

File tree

22 files changed

+627
-15
lines changed

22 files changed

+627
-15
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

daft/daft/__init__.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,18 +868,28 @@ class PyPushdowns:
868868
filters: PyExpr | None
869869
partition_filters: PyExpr | None
870870
limit: int | None
871+
aggregation: PyExpr | None
871872

872873
def __init__(
873874
self,
874875
columns: list[str] | None = None,
875876
filters: PyExpr | None = None,
876877
partition_filters: PyExpr | None = None,
877878
limit: int | None = None,
879+
aggregation: PyExpr | None = None,
878880
) -> None: ...
879881
def filter_required_column_names(self) -> list[str]:
880882
"""List of field names that are required by the filter predicate."""
881883
...
882884

885+
def aggregation_required_column_names(self) -> list[str]:
886+
"""List of field names that are required by the aggregation predicate."""
887+
...
888+
889+
def aggregation_count_mode(self) -> CountMode:
890+
"""Count mode of the aggregation predicate."""
891+
...
892+
883893
PyArrowParquetType = tuple[pa.Field, dict[str, str], pa.Array, int]
884894

885895
def read_parquet(

daft/io/lance/lance_scan.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# ruff: noqa: I002
22
# isort: dont-add-import: from __future__ import annotations
33

4+
import logging
45
from collections.abc import Iterator
5-
from typing import TYPE_CHECKING, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Optional, Union
67

7-
from daft.daft import PyExpr, PyPartitionField, PyPushdowns, PyRecordBatch, ScanTask
8+
from daft.daft import CountMode, PyExpr, PyPartitionField, PyPushdowns, PyRecordBatch, ScanTask
9+
from daft.dependencies import pa
810
from daft.io.scan import ScanOperator
911
from daft.logical.schema import Schema
1012
from daft.recordbatch import RecordBatch
@@ -13,15 +15,16 @@
1315

1416
if TYPE_CHECKING:
1517
import lance
16-
import pyarrow
18+
19+
logger = logging.getLogger(__name__)
1720

1821

1922
# TODO support fts and fast_search
2023
def _lancedb_table_factory_function(
2124
ds: "lance.LanceDataset",
2225
fragment_ids: Optional[list[int]] = None,
2326
required_columns: Optional[list[str]] = None,
24-
filter: Optional["pyarrow.compute.Expression"] = None,
27+
filter: Optional["pa.compute.Expression"] = None,
2528
limit: Optional[int] = None,
2629
) -> Iterator[PyRecordBatch]:
2730
fragments = [ds.get_fragment(id) for id in (fragment_ids or [])]
@@ -31,6 +34,30 @@ def _lancedb_table_factory_function(
3134
return (RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch for rb in scanner.to_batches())
3235

3336

37+
def _lancedb_count_result_function(
38+
ds: "lance.LanceDataset",
39+
required_column: str,
40+
filters: Optional[list[Any]] = None,
41+
) -> Iterator[PyRecordBatch]:
42+
"""Use LanceDB's API to count rows and return a record batch with the count result."""
43+
count = 0
44+
if filters is None:
45+
logger.debug("Using metadata for counting all rows (no filters)")
46+
count = ds.count_rows()
47+
else:
48+
# TODO: If filters are provided, we need to apply them after counting
49+
logger.debug("Counting rows with filters applied")
50+
scanner = ds.scanner(filter=filters)
51+
for batch in scanner.to_batches():
52+
count += batch.num_rows
53+
54+
arrow_schema = pa.schema([pa.field(required_column, pa.uint64())])
55+
arrow_array = pa.array([count], type=pa.uint64())
56+
arrow_batch = pa.RecordBatch.from_arrays([arrow_array], [required_column])
57+
result_batch = RecordBatch.from_arrow_record_batches([arrow_batch], arrow_schema)._recordbatch
58+
return (result_batch for _ in [1])
59+
60+
3461
class LanceDBScanOperator(ScanOperator, SupportsPushdownFilters):
3562
def __init__(self, ds: "lance.LanceDataset"):
3663
self._ds = ds
@@ -57,6 +84,14 @@ def can_absorb_limit(self) -> bool:
5784
def can_absorb_select(self) -> bool:
5885
return False
5986

87+
def supports_count_pushdown(self) -> bool:
88+
"""Returns whether this scan operator supports count pushdown."""
89+
return True
90+
91+
def supported_count_modes(self) -> list[CountMode]:
92+
"""Returns the count modes supported by this scan operator."""
93+
return [CountMode.All]
94+
6095
def multiline_display(self) -> list[str]:
6196
return [
6297
self.display_name(),
@@ -95,6 +130,45 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
95130
else pushdowns.columns + filter_required_column_names
96131
)
97132
)
133+
134+
# Check if there is a count aggregation pushdown
135+
if (
136+
pushdowns.aggregation is not None
137+
and pushdowns.aggregation_count_mode() is not None
138+
and pushdowns.aggregation_required_column_names()
139+
):
140+
count_mode = pushdowns.aggregation_count_mode()
141+
fields = pushdowns.aggregation_required_column_names()
142+
143+
if count_mode not in self.supported_count_modes():
144+
logger.warning(
145+
"Count mode %s is not supported for pushdown, falling back to original logic",
146+
count_mode,
147+
)
148+
yield from self._create_regular_scan_tasks(pushdowns, required_columns)
149+
150+
# TODO: If there are pushed filters, convert them to Arrow expressions
151+
filters = None
152+
153+
new_schema = Schema.from_pyarrow_schema(pa.schema([pa.field(fields[0], pa.uint64())]))
154+
yield ScanTask.python_factory_func_scan_task(
155+
module=_lancedb_count_result_function.__module__,
156+
func_name=_lancedb_count_result_function.__name__,
157+
func_args=(self._ds, fields[0], filters),
158+
schema=new_schema._schema,
159+
num_rows=1,
160+
size_bytes=None,
161+
pushdowns=pushdowns,
162+
stats=None,
163+
)
164+
else:
165+
# Regular scan without count pushdown
166+
yield from self._create_regular_scan_tasks(pushdowns, required_columns)
167+
168+
def _create_regular_scan_tasks(
169+
self, pushdowns: PyPushdowns, required_columns: Optional[list[str]]
170+
) -> Iterator[ScanTask]:
171+
"""Create regular scan tasks without count pushdown."""
98172
# TODO: figure out how to translate Pushdowns into LanceDB filters
99173
filters = None
100174
fragments = self._ds.get_fragments()

daft/io/scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,7 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
7979
def as_pushdown_filter(self) -> SupportsPushdownFilters | None:
8080
"""Returns this scan operator as a SupportsPushdownFilters if it supports pushdown filters."""
8181
raise NotImplementedError()
82+
83+
def supports_count_pushdown(self) -> bool:
84+
"""Returns true if this scan can accept count pushdowns."""
85+
return False

src/common/scan-info/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ common-display = {path = "../display", default-features = false}
44
common-error = {path = "../error", default-features = false}
55
common-file-formats = {path = "../file-formats", default-features = false}
66
daft-algebra = {path = "../../daft-algebra", default-features = false}
7+
daft-core = {path = "../../daft-core", default-features = false}
78
daft-dsl = {path = "../../daft-dsl", default-features = false}
89
daft-schema = {path = "../../daft-schema", default-features = false}
910
fnv = "1.0.7"

src/common/scan-info/src/pushdowns.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ pub struct Pushdowns {
2828
/// The `filters` field is kept for backward compatibility;
2929
/// it represents all current filters.
3030
pub pushed_filters: Option<Vec<ExprRef>>,
31+
32+
// /// Optional aggregation pushdown.
33+
/// This is used to indicate that the scan operator can perform an aggregation.
34+
/// This is useful for scans that can perform aggregations like `count`
35+
pub aggregation: Option<ExprRef>,
3136
}
3237

3338
impl Default for Pushdowns {
3439
fn default() -> Self {
35-
Self::new(None, None, None, None, None)
40+
Self::new(None, None, None, None, None, None)
3641
}
3742
}
3843

@@ -44,6 +49,7 @@ impl Pushdowns {
4449
columns: Option<Arc<Vec<String>>>,
4550
limit: Option<usize>,
4651
sharder: Option<Sharder>,
52+
aggregation: Option<ExprRef>,
4753
) -> Self {
4854
Self {
4955
filters,
@@ -52,6 +58,7 @@ impl Pushdowns {
5258
limit,
5359
sharder,
5460
pushed_filters: None,
61+
aggregation,
5562
}
5663
}
5764

@@ -72,6 +79,7 @@ impl Pushdowns {
7279
limit,
7380
sharder: self.sharder.clone(),
7481
pushed_filters: self.pushed_filters.clone(),
82+
aggregation: self.aggregation.clone(),
7583
}
7684
}
7785

@@ -84,6 +92,7 @@ impl Pushdowns {
8492
limit: self.limit,
8593
sharder: self.sharder.clone(),
8694
pushed_filters: self.pushed_filters.clone(),
95+
aggregation: self.aggregation.clone(),
8796
}
8897
}
8998

@@ -96,6 +105,7 @@ impl Pushdowns {
96105
limit: self.limit,
97106
sharder: self.sharder.clone(),
98107
pushed_filters: self.pushed_filters.clone(),
108+
aggregation: self.aggregation.clone(),
99109
}
100110
}
101111

@@ -108,6 +118,7 @@ impl Pushdowns {
108118
limit: self.limit,
109119
sharder: self.sharder.clone(),
110120
pushed_filters: self.pushed_filters.clone(),
121+
aggregation: self.aggregation.clone(),
111122
}
112123
}
113124

@@ -120,6 +131,7 @@ impl Pushdowns {
120131
limit: self.limit,
121132
sharder,
122133
pushed_filters: self.pushed_filters.clone(),
134+
aggregation: self.aggregation.clone(),
123135
}
124136
}
125137

@@ -132,6 +144,20 @@ impl Pushdowns {
132144
limit: self.limit,
133145
sharder: self.sharder.clone(),
134146
pushed_filters,
147+
aggregation: self.aggregation.clone(),
148+
}
149+
}
150+
151+
#[must_use]
152+
pub fn with_aggregation(&self, aggregation: Option<ExprRef>) -> Self {
153+
Self {
154+
filters: self.filters.clone(),
155+
partition_filters: self.partition_filters.clone(),
156+
columns: self.columns.clone(),
157+
limit: self.limit,
158+
sharder: self.sharder.clone(),
159+
pushed_filters: self.pushed_filters.clone(),
160+
aggregation,
135161
}
136162
}
137163

@@ -153,6 +179,9 @@ impl Pushdowns {
153179
if let Some(sharder) = &self.sharder {
154180
res.push(format!("Sharder = {sharder}"));
155181
}
182+
if let Some(aggregation) = &self.aggregation {
183+
res.push(format!("Aggregation pushdown = {aggregation}"));
184+
}
156185
res
157186
}
158187

@@ -187,6 +216,9 @@ impl DisplayAs for Pushdowns {
187216
if let Some(sharder) = &self.sharder {
188217
sub_items.push(format!("sharder: {sharder}"));
189218
}
219+
if let Some(aggregation) = &self.aggregation {
220+
sub_items.push(format!("aggregation: {aggregation}"));
221+
}
190222
s.push_str(&sub_items.join(", "));
191223
s.push('}');
192224
s

src/common/scan-info/src/python.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use pyo3::prelude::*;
33
pub mod pylib {
44
use std::sync::Arc;
55

6-
use daft_dsl::python::PyExpr;
6+
use daft_core::count_mode::CountMode;
7+
use daft_dsl::{python::PyExpr, AggExpr, Expr};
78
use daft_schema::python::field::PyField;
89
use pyo3::{exceptions::PyAttributeError, prelude::*, pyclass};
910
use serde::{Deserialize, Serialize};
@@ -164,19 +165,22 @@ pub mod pylib {
164165
partition_filters = None,
165166
columns = None,
166167
limit = None,
168+
aggregation = None,
167169
))]
168170
pub fn new(
169171
filters: Option<PyExpr>,
170172
partition_filters: Option<PyExpr>,
171173
columns: Option<Vec<String>>,
172174
limit: Option<usize>,
175+
aggregation: Option<PyExpr>,
173176
) -> Self {
174177
let pushdowns = Pushdowns::new(
175178
filters.map(|f| f.expr),
176179
partition_filters.map(|f| f.expr),
177180
columns.map(Arc::new),
178181
limit,
179182
None,
183+
aggregation.map(|f| f.expr),
180184
);
181185
Self(Arc::new(pushdowns))
182186
}
@@ -212,12 +216,38 @@ pub mod pylib {
212216
self.0.columns.as_deref().cloned()
213217
}
214218

219+
#[getter]
220+
#[must_use]
221+
pub fn aggregation(&self) -> Option<PyExpr> {
222+
self.0
223+
.aggregation
224+
.as_ref()
225+
.map(|e| PyExpr { expr: e.clone() })
226+
}
227+
215228
pub fn filter_required_column_names(&self) -> Option<Vec<String>> {
216229
self.0
217230
.filters
218231
.as_ref()
219232
.map(daft_dsl::optimization::get_required_columns)
220233
}
234+
235+
pub fn aggregation_required_column_names(&self) -> Option<Vec<String>> {
236+
self.0
237+
.aggregation
238+
.as_ref()
239+
.map(daft_dsl::optimization::get_required_columns)
240+
}
241+
242+
pub fn aggregation_count_mode(&self) -> Option<CountMode> {
243+
match self.0.aggregation.as_ref() {
244+
Some(expr) => match expr.as_ref() {
245+
Expr::Agg(AggExpr::Count(_, count_mode)) => Some(*count_mode),
246+
_ => None,
247+
},
248+
None => None,
249+
}
250+
}
221251
}
222252
}
223253

src/common/scan-info/src/scan_operator.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ pub trait ScanOperator: Send + Sync + Debug {
3333
fn can_absorb_shard(&self) -> bool;
3434
fn multiline_display(&self) -> Vec<String>;
3535

36+
fn supports_count_pushdown(&self) -> bool {
37+
false
38+
}
39+
40+
fn supported_count_modes(&self) -> Vec<daft_core::count_mode::CountMode> {
41+
Vec::new()
42+
}
43+
3644
/// If cfg provided, `to_scan_tasks` should apply the appropriate transformations
3745
/// (merging, splitting) to the outputted scan tasks
3846
fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult<Vec<ScanTaskLikeRef>>;

src/common/scan-info/src/test/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub struct DummyScanOperator {
2929
pub schema: SchemaRef,
3030
pub num_scan_tasks: u32,
3131
pub num_rows_per_task: Option<usize>,
32+
pub supports_count_pushdown_flag: bool,
3233
}
3334

3435
#[typetag::serde]
@@ -145,6 +146,10 @@ impl ScanOperator for DummyScanOperator {
145146
vec!["DummyScanOperator".to_string()]
146147
}
147148

149+
fn supports_count_pushdown(&self) -> bool {
150+
self.supports_count_pushdown_flag
151+
}
152+
148153
fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult<Vec<ScanTaskLikeRef>> {
149154
Ok((0..self.num_scan_tasks)
150155
.map(|i| {

0 commit comments

Comments
 (0)