Skip to content

Commit 5c3cea5

Browse files
perf: Implement count pushdown for parquet (#5038)
## Changes Made #4969 implemented count pushdowns for lance. We can do the same for parquet and read only parquet metadata. For example, I have a bucket in `s3://desmond-test/big.parquet/` with the following setup: - 161 parquet files - Each parquet file has `256000` rows of strings with 1024 random characters, giving a total size of 1.2GB per file - The footer len itself is only `2247` bytes Before, doing a count takes ~28s: ``` In [2]: %time daft.read_parquet("s3://desmond-test/big.parquet").count().show() ╭──────────╮ │ count │ │ --- │ │ UInt64 │ ╞══════════╡ │ 41216000 │ ╰──────────╯ (Showing first 1 of 1 rows) CPU times: user 1min 4s, sys: 1min 46s, total: 2min 50s Wall time: 28.2 s ``` After it takes ~4s: ``` In [2]: %time daft.read_parquet("s3://desmond-test/big.parquet").count().show() ╭──────────╮ │ count │ │ --- │ │ UInt64 │ ╞══════════╡ │ 41216000 │ ╰──────────╯ (Showing first 1 of 1 rows) CPU times: user 301 ms, sys: 94.2 ms, total: 395 ms Wall time: 3.89 s ```
1 parent 8f1fee8 commit 5c3cea5

File tree

8 files changed

+146
-38
lines changed

8 files changed

+146
-38
lines changed

src/daft-local-execution/src/sources/scan_task.rs

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use common_runtime::{combine_stream, get_compute_pool_num_threads, get_io_runtim
1313
use common_scan_info::{Pushdowns, ScanTaskLike};
1414
use daft_core::prelude::{AsArrow, Int64Array, SchemaRef, Utf8Array};
1515
use daft_csv::{CsvConvertOptions, CsvParseOptions, CsvReadOptions};
16+
use daft_dsl::{AggExpr, Expr};
1617
use daft_io::IOStatsRef;
1718
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
1819
use daft_micropartition::MicroPartition;
@@ -470,36 +471,50 @@ async fn stream_scan_task(
470471
chunk_size: chunk_size_from_config,
471472
..
472473
}) => {
473-
let parquet_chunk_size = chunk_size_from_config.or(chunk_size);
474-
let inference_options =
475-
ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit));
476-
477-
let delete_rows = delete_map.as_ref().and_then(|m| m.get(url).cloned());
478-
let row_groups = if let Some(ChunkSpec::Parquet(row_groups)) = source.get_chunk_spec() {
479-
Some(row_groups.clone())
474+
if let Some(aggregation) = &scan_task.pushdowns.aggregation
475+
&& let Expr::Agg(AggExpr::Count(_, _)) = aggregation.as_ref()
476+
{
477+
daft_parquet::read::stream_parquet_count_pushdown(
478+
url,
479+
io_client,
480+
Some(io_stats),
481+
field_id_mapping.clone(),
482+
aggregation,
483+
)
484+
.await?
480485
} else {
481-
None
482-
};
483-
let metadata = scan_task
484-
.sources
485-
.first()
486-
.and_then(|s| s.get_parquet_metadata().cloned());
487-
daft_parquet::read::stream_parquet(
488-
url,
489-
file_column_names.as_deref(),
490-
scan_task.pushdowns.limit,
491-
row_groups,
492-
scan_task.pushdowns.filters.clone(),
493-
io_client,
494-
Some(io_stats),
495-
&inference_options,
496-
field_id_mapping.clone(),
497-
metadata,
498-
maintain_order,
499-
delete_rows,
500-
parquet_chunk_size,
501-
)
502-
.await?
486+
let parquet_chunk_size = chunk_size_from_config.or(chunk_size);
487+
let inference_options =
488+
ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit));
489+
490+
let delete_rows = delete_map.as_ref().and_then(|m| m.get(url).cloned());
491+
let row_groups =
492+
if let Some(ChunkSpec::Parquet(row_groups)) = source.get_chunk_spec() {
493+
Some(row_groups.clone())
494+
} else {
495+
None
496+
};
497+
let metadata = scan_task
498+
.sources
499+
.first()
500+
.and_then(|s| s.get_parquet_metadata().cloned());
501+
daft_parquet::read::stream_parquet(
502+
url,
503+
file_column_names.as_deref(),
504+
scan_task.pushdowns.limit,
505+
row_groups,
506+
scan_task.pushdowns.filters.clone(),
507+
io_client,
508+
Some(io_stats),
509+
&inference_options,
510+
field_id_mapping.clone(),
511+
metadata,
512+
maintain_order,
513+
delete_rows,
514+
parquet_chunk_size,
515+
)
516+
.await?
517+
}
503518
}
504519
FileFormatConfig::Csv(cfg) => {
505520
let schema_of_file = scan_task.schema.clone();

src/daft-logical-plan/src/optimization/optimizer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ mod tests {
837837
)))))
838838
.with_columns(Some(Arc::new(vec!["a".to_string()]))),
839839
)
840+
.aggregate(vec![unresolved_col("a").sum()], vec![])?
840841
.build();
841842

842843
let scan_materializer_and_stats_enricher = get_scan_materializer_and_stats_enricher();

src/daft-logical-plan/src/optimization/rules/push_down_aggregation.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22

3-
use common_error::DaftResult;
3+
use common_error::{DaftError, DaftResult};
44
use common_treenode::{Transformed, TreeNode};
55
use daft_core::{count_mode::CountMode, prelude::Schema};
66
use daft_dsl::{AggExpr, Expr, ExprRef};
@@ -65,7 +65,16 @@ impl OptimizerRule for PushDownAggregation {
6565
SourceInfo::Physical(new_external_info).into(),
6666
))
6767
.into();
68-
Ok(Transformed::yes(new_source))
68+
// Scan operators may produce partial counts over multiple scan tasks (e.g., distributed parquet reads), so we still need to sum them.
69+
let new_aggregate = Aggregate::try_new(
70+
new_source,
71+
vec![Arc::new(Expr::Agg(AggExpr::Sum(count_expr(
72+
&aggregations[0],
73+
)?)))],
74+
groupby.clone(),
75+
)?
76+
.into();
77+
Ok(Transformed::yes(new_aggregate))
6978
} else {
7079
Ok(Transformed::no(node.clone()))
7180
}
@@ -93,6 +102,15 @@ fn is_count_expr(expr: &ExprRef) -> Option<&CountMode> {
93102
}
94103
}
95104

105+
fn count_expr(expr: &ExprRef) -> DaftResult<ExprRef> {
106+
match expr.as_ref() {
107+
Expr::Agg(AggExpr::Count(expr, _)) => Ok(expr.clone()),
108+
_ => Err(DaftError::InternalError(
109+
"Tried to get count expression from non-count expression".to_string(),
110+
)),
111+
}
112+
}
113+
96114
// Check if the count mode is supported for pushdown
97115
// Currently only CountMode::All is fully supported
98116
fn is_count_mode_supported(count_mode: &CountMode) -> bool {
@@ -150,6 +168,7 @@ mod tests {
150168
CountMode::All,
151169
))))),
152170
)
171+
.aggregate(vec![unresolved_col("a").sum()], vec![])?
153172
.build();
154173

155174
assert_optimized_plan_eq(plan, expected)?;

src/daft-micropartition/src/micropartition.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use common_runtime::get_io_runtime;
1515
use common_scan_info::Pushdowns;
1616
use daft_core::prelude::*;
1717
use daft_csv::{CsvConvertOptions, CsvParseOptions, CsvReadOptions};
18-
use daft_dsl::ExprRef;
18+
use daft_dsl::{AggExpr, Expr, ExprRef};
1919
use daft_io::{IOClient, IOConfig, IOStatsContext, IOStatsRef};
2020
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
2121
use daft_parquet::{
@@ -514,6 +514,11 @@ impl MicroPartition {
514514
parquet_metadata,
515515
chunk_size,
516516
scan_task.generated_fields.clone(),
517+
scan_task
518+
.pushdowns
519+
.aggregation
520+
.as_ref()
521+
.map(|agg| agg.as_ref()),
517522
)
518523
.context(DaftCoreComputeSnafu)
519524
}
@@ -1106,6 +1111,7 @@ pub fn read_parquet_into_micropartition<T: AsRef<str>>(
11061111
parquet_metadata: Option<Vec<Arc<FileMetaData>>>,
11071112
chunk_size: Option<usize>,
11081113
generated_fields: Option<SchemaRef>,
1114+
aggregation_pushdown: Option<&Expr>,
11091115
) -> DaftResult<MicroPartition> {
11101116
if let Some(so) = start_offset
11111117
&& so > 0
@@ -1187,6 +1193,28 @@ pub fn read_parquet_into_micropartition<T: AsRef<str>>(
11871193
(metadata, schemas)
11881194
};
11891195

1196+
// Handle count pushdown aggregation optimization.
1197+
if let Some(Expr::Agg(AggExpr::Count(_, _))) = aggregation_pushdown {
1198+
let count: usize = metadata.iter().map(|m| m.num_rows).sum();
1199+
let count_field = daft_core::datatypes::Field::new(
1200+
aggregation_pushdown.unwrap().name(),
1201+
daft_core::datatypes::DataType::UInt64,
1202+
);
1203+
let count_array =
1204+
UInt64Array::from_iter(count_field.clone(), std::iter::once(Some(count as u64)));
1205+
let count_batch = daft_recordbatch::RecordBatch::new_with_size(
1206+
Schema::new(vec![count_field]),
1207+
vec![count_array.into_series()],
1208+
1,
1209+
)
1210+
.context(DaftCoreComputeSnafu)?;
1211+
return Ok(MicroPartition::new_loaded(
1212+
count_batch.schema.clone(),
1213+
Arc::new(vec![count_batch]),
1214+
None,
1215+
));
1216+
}
1217+
11901218
let any_stats_avail = metadata
11911219
.iter()
11921220
.flat_map(|m| m.row_groups.values())

src/daft-micropartition/src/python.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ impl PyMicroPartition {
758758
None,
759759
None,
760760
None,
761+
None,
761762
)
762763
})?;
763764
Ok(mp.into())
@@ -819,6 +820,7 @@ impl PyMicroPartition {
819820
None,
820821
chunk_size,
821822
None,
823+
None,
822824
)
823825
})?;
824826
Ok(mp.into())

src/daft-parquet/src/read.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,35 @@ pub async fn read_parquet_metadata_bulk(
10211021
all_metadatas.into_iter().collect::<DaftResult<Vec<_>>>()
10221022
}
10231023

1024+
/// Optimized for count pushdowns: we can get the count from metadata without reading all data.
1025+
pub async fn stream_parquet_count_pushdown(
1026+
url: &str,
1027+
io_client: Arc<IOClient>,
1028+
io_stats: Option<IOStatsRef>,
1029+
field_id_mapping: Option<Arc<BTreeMap<i32, Field>>>,
1030+
aggregation: &ExprRef,
1031+
) -> DaftResult<BoxStream<'static, DaftResult<RecordBatch>>> {
1032+
let parquet_metadata =
1033+
read_parquet_metadata(url, io_client, io_stats, field_id_mapping.clone()).await?;
1034+
1035+
// Currently only CountMode::All is supported for count pushdown.
1036+
let count = parquet_metadata.num_rows;
1037+
let count_field = daft_core::datatypes::Field::new(
1038+
aggregation.name(),
1039+
daft_core::datatypes::DataType::UInt64,
1040+
);
1041+
let count_array =
1042+
UInt64Array::from_iter(count_field.clone(), std::iter::once(Some(count as u64)));
1043+
let count_batch = daft_recordbatch::RecordBatch::new_with_size(
1044+
Schema::new(vec![count_field]),
1045+
vec![count_array.into_series()],
1046+
1,
1047+
)?;
1048+
Ok(Box::pin(futures::stream::once(
1049+
async move { Ok(count_batch) },
1050+
)))
1051+
}
1052+
10241053
pub fn read_parquet_statistics(
10251054
uris: &Series,
10261055
io_client: Arc<IOClient>,

src/daft-scan/src/glob.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ impl ScanOperator for GlobScanOperator {
353353
}
354354

355355
fn supports_count_pushdown(&self) -> bool {
356-
false
356+
self.file_format_config.file_format() == FileFormat::Parquet
357357
}
358358

359359
fn multiline_display(&self) -> Vec<String> {

src/daft-scan/src/lib.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,15 @@ impl ScanTask {
646646
#[must_use]
647647
pub fn materialized_schema(&self) -> SchemaRef {
648648
match (&self.generated_fields, &self.pushdowns.columns) {
649-
(None, None) => self.schema.clone(),
649+
(None, None) => {
650+
if let Some(aggregation) = &self.pushdowns.aggregation {
651+
Arc::new(Schema::new(vec![aggregation
652+
.to_field(&self.schema)
653+
.expect("Casting to aggregation field should not fail")]))
654+
} else {
655+
self.schema.clone()
656+
}
657+
}
650658
_ => {
651659
let schema_with_generated_fields =
652660
if let Some(generated_fields) = &self.generated_fields {
@@ -657,10 +665,16 @@ impl ScanTask {
657665
};
658666

659667
let mut fields = schema_with_generated_fields.fields().to_vec();
660-
661-
// Filter the schema based on the pushdown column filters.
662-
if let Some(columns) = &self.pushdowns.columns {
663-
fields.retain(|field| columns.contains(&field.name));
668+
if let Some(aggregation) = &self.pushdowns.aggregation {
669+
// If we have a pushdown aggregation, the only field in the schema is the aggregation.
670+
fields = vec![aggregation
671+
.to_field(&schema_with_generated_fields)
672+
.expect("Casting to aggregation field should not fail")];
673+
} else {
674+
// Filter the schema based on the pushdown column filters.
675+
if let Some(columns) = &self.pushdowns.columns {
676+
fields.retain(|field| columns.contains(&field.name));
677+
}
664678
}
665679

666680
Arc::new(Schema::new(fields))

0 commit comments

Comments
 (0)