Skip to content

Commit 0b511b1

Browse files
committed
Read metadata to get count_rows
1 parent 282dbc5 commit 0b511b1

File tree

3 files changed

+160
-153
lines changed

3 files changed

+160
-153
lines changed

daft/io/iceberg/iceberg_scan.py

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,15 @@
3939

4040

4141
def _iceberg_count_result_function(total_count: int, field_name: str) -> Iterator[PyRecordBatch]:
42-
"""Construct Iceberg count query result.
43-
44-
This function creates a single-row result containing the count value,
45-
which is used by the count pushdown optimization.
46-
"""
42+
"""Construct Iceberg count query result."""
4743
try:
48-
# Create Arrow schema and array for the count result
4944
arrow_schema = pa.schema([pa.field(field_name, pa.uint64())])
5045
arrow_array = pa.array([total_count], type=pa.uint64())
5146
arrow_batch = pa.RecordBatch.from_arrays([arrow_array], [field_name])
5247

53-
# Convert to Daft RecordBatch
54-
result_batch = RecordBatch.from_arrow_record_batches([arrow_batch], arrow_schema)._recordbatch
55-
5648
logger.debug("Generated Iceberg count result: %s=%d", field_name, total_count)
5749

58-
# Yield the result batch (generator pattern)
59-
yield result_batch
60-
50+
yield RecordBatch.from_arrow_record_batches([arrow_batch], arrow_schema)._recordbatch
6151
except Exception as e:
6252
logger.error("Failed to construct Iceberg count result: %s", e)
6353
raise
@@ -266,26 +256,25 @@ def _create_regular_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTas
266256
return iter(scan_tasks)
267257

268258
def _create_count_scan_task(self, pushdowns: PyPushdowns, field_name: str) -> Iterator[ScanTask]:
269-
"""Create count pushdown scan task using Iceberg metadata.
270-
271-
This method leverages Iceberg's manifest files to calculate the total row count
272-
without reading the actual data files, providing significant performance improvements.
273-
"""
259+
"""Create count pushdown scan task using Iceberg metadata."""
274260
try:
275-
# Calculate total count from Iceberg metadata
276-
total_count = self._calculate_total_rows_from_metadata()
261+
from pyiceberg.table.snapshots import TOTAL_RECORDS
277262

278-
# Create result schema with the count field
263+
if self._snapshot_id is None:
264+
snapshot = self._table.current_snapshot()
265+
else:
266+
snapshot = self._table.snapshot_by_id(self._snapshot_id)
267+
268+
total_count = int(snapshot.summary.get(TOTAL_RECORDS, 0))
279269
result_schema = Schema.from_pyarrow_schema(pa.schema([pa.field(field_name, pa.uint64())]))
280270

281-
# Create Python factory function scan task
282271
scan_task = ScanTask.python_factory_func_scan_task(
283272
module=_iceberg_count_result_function.__module__,
284273
func_name=_iceberg_count_result_function.__name__,
285274
func_args=(total_count, field_name),
286275
schema=result_schema._schema,
287-
num_rows=1, # Count result is always a single row
288-
size_bytes=8, # uint64 size
276+
num_rows=1,
277+
size_bytes=8,
289278
pushdowns=pushdowns,
290279
stats=None,
291280
)
@@ -294,54 +283,9 @@ def _create_count_scan_task(self, pushdowns: PyPushdowns, field_name: str) -> It
294283
yield scan_task
295284

296285
except Exception as e:
297-
logger.error("Failed to create Iceberg count pushdown task: %s", e)
298-
# Fallback to regular scan if count pushdown fails
299-
logger.warning("Falling back to regular scan due to count pushdown failure")
286+
logger.error("Failed to create Iceberg count pushdown task: %s, now falling back to regular scan", e)
300287
yield from self._create_regular_scan_tasks(pushdowns)
301288

302-
def _calculate_total_rows_from_metadata(self) -> int:
303-
"""Calculate total row count from Iceberg manifest metadata.
304-
305-
This method reads the manifest files to aggregate record_count information
306-
from all data files without accessing the actual data.
307-
"""
308-
try:
309-
# Get scan plan from Iceberg table
310-
iceberg_tasks = self._table.scan(
311-
limit=None, # No limit for count calculation
312-
snapshot_id=self._snapshot_id,
313-
).plan_files()
314-
315-
total_rows = 0
316-
total_deleted = 0
317-
318-
# Aggregate row counts from all data files
319-
for task in iceberg_tasks:
320-
data_file = task.file
321-
total_rows += data_file.record_count
322-
323-
# Handle delete files (for Iceberg MOR - Merge-on-Read)
324-
for delete_file in task.delete_files:
325-
# For now, we'll use a simple estimation for delete files
326-
# In a production implementation, this could be more sophisticated
327-
total_deleted += delete_file.record_count
328-
329-
# Calculate final count (ensure non-negative)
330-
final_count = max(0, total_rows - total_deleted)
331-
332-
logger.info(
333-
"Calculated Iceberg count from metadata: total_rows=%d, deleted_rows=%d, final_count=%d",
334-
total_rows,
335-
total_deleted,
336-
final_count,
337-
)
338-
339-
return final_count
340-
341-
except Exception as e:
342-
logger.error("Failed to calculate total rows from Iceberg metadata: %s", e)
343-
raise
344-
345289
def can_absorb_filter(self) -> bool:
346290
return False
347291

@@ -352,16 +296,7 @@ def can_absorb_select(self) -> bool:
352296
return True
353297

354298
def supports_count_pushdown(self) -> bool:
355-
"""Returns whether this scan operator supports count pushdown.
356-
357-
Iceberg supports count pushdown by leveraging metadata stored in manifest files.
358-
Each data file's record_count is available without reading the actual data.
359-
"""
360299
return True
361300

362301
def supported_count_modes(self) -> list[CountMode]:
363-
"""Returns the count modes supported by this scan operator.
364-
365-
Currently only supports COUNT(*) which corresponds to CountMode.All.
366-
"""
367302
return [CountMode.All]

tests/integration/iceberg/docker-compose/provision.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,143 @@
427427
""")
428428

429429
spark.sql("INSERT INTO default.test_snapshotting VALUES (4, 1)")
430+
431+
432+
###
433+
# MOR (Merge-on-Read) Complex Scenario Test Table
434+
# Used to test the accuracy of Count push down function in complex delete file scenarios
435+
###
436+
437+
spark.sql(
438+
"""
439+
CREATE OR REPLACE TABLE default.test_overlapping_deletes (
440+
id integer,
441+
name string,
442+
value double,
443+
category string
444+
)
445+
USING iceberg
446+
TBLPROPERTIES (
447+
'write.delete.mode'='merge-on-read',
448+
'write.update.mode'='merge-on-read',
449+
'write.merge.mode'='merge-on-read',
450+
'format-version'='2'
451+
);
452+
"""
453+
)
454+
455+
spark.sql(
456+
"""
457+
INSERT INTO default.test_overlapping_deletes
458+
VALUES
459+
(1, 'Alice', 100.0, 'A'),
460+
(2, 'Bob', 200.0, 'B'),
461+
(3, 'Charlie', 300.0, 'A'),
462+
(4, 'David', 400.0, 'B'),
463+
(5, 'Eve', 500.0, 'A'),
464+
(6, 'Frank', 600.0, 'B'),
465+
(7, 'Grace', 700.0, 'A'),
466+
(8, 'Henry', 800.0, 'B'),
467+
(9, 'Ivy', 900.0, 'A'),
468+
(10, 'Jack', 1000.0, 'B'),
469+
(11, 'Kate', 1100.0, 'A'),
470+
(12, 'Leo', 1200.0, 'B'),
471+
(13, 'Mary', 1300.0, 'A'),
472+
(14, 'Nick', 1400.0, 'B'),
473+
(15, 'Olivia', 1500.0, 'A');
474+
"""
475+
)
476+
477+
spark.sql(
478+
"""
479+
DELETE FROM default.test_overlapping_deletes WHERE id <= 5
480+
"""
481+
)
482+
483+
spark.sql(
484+
"""
485+
DELETE FROM default.test_overlapping_deletes WHERE id <= 3
486+
"""
487+
)
488+
489+
spark.sql(
490+
"""
491+
DELETE FROM default.test_overlapping_deletes WHERE id >= 4 AND id <= 8
492+
"""
493+
)
494+
495+
# Mixed Delete Type Test Table - Testing the Mixed Processing of Position Delete and Equality Delete
496+
497+
spark.sql(
498+
"""
499+
CREATE OR REPLACE TABLE default.test_mixed_delete_types (
500+
id integer,
501+
name string,
502+
age integer,
503+
department string,
504+
salary double,
505+
active boolean
506+
)
507+
USING iceberg
508+
TBLPROPERTIES (
509+
'write.delete.mode'='merge-on-read',
510+
'write.update.mode'='merge-on-read',
511+
'write.merge.mode'='merge-on-read',
512+
'format-version'='2'
513+
);
514+
"""
515+
)
516+
517+
spark.sql(
518+
"""
519+
INSERT INTO default.test_mixed_delete_types
520+
VALUES
521+
(1, 'Alice', 25, 'Engineering', 75000.0, true),
522+
(2, 'Bob', 30, 'Marketing', 65000.0, true),
523+
(3, 'Charlie', 35, 'Engineering', 85000.0, true),
524+
(4, 'David', 28, 'Sales', 60000.0, false),
525+
(5, 'Eve', 32, 'Engineering', 90000.0, true),
526+
(6, 'Frank', 45, 'Marketing', 70000.0, true),
527+
(7, 'Grace', 29, 'Engineering', 80000.0, true),
528+
(8, 'Henry', 38, 'Sales', 55000.0, false),
529+
(9, 'Ivy', 26, 'Engineering', 78000.0, true),
530+
(10, 'Jack', 33, 'Marketing', 68000.0, true),
531+
(11, 'Kate', 31, 'Engineering', 82000.0, true),
532+
(12, 'Leo', 27, 'Sales', 58000.0, true),
533+
(13, 'Mary', 34, 'Engineering', 88000.0, true),
534+
(14, 'Nick', 29, 'Marketing', 66000.0, false),
535+
(15, 'Olivia', 36, 'Engineering', 92000.0, true),
536+
(16, 'Paul', 40, 'Sales', 62000.0, true),
537+
(17, 'Quinn', 28, 'Engineering', 76000.0, true),
538+
(18, 'Rachel', 32, 'Marketing', 69000.0, true),
539+
(19, 'Steve', 37, 'Engineering', 87000.0, true),
540+
(20, 'Tina', 30, 'Sales', 61000.0, false);
541+
"""
542+
)
543+
544+
spark.sql(
545+
"""
546+
DELETE FROM default.test_mixed_delete_types WHERE id IN (2, 5, 8, 11, 14)
547+
"""
548+
)
549+
550+
spark.sql(
551+
"""
552+
DELETE FROM default.test_mixed_delete_types WHERE department = 'Sales' AND active = false
553+
"""
554+
)
555+
556+
spark.sql(
557+
"""
558+
DELETE FROM default.test_mixed_delete_types WHERE age < 30 AND salary < 70000
559+
"""
560+
)
561+
562+
spark.sql(
563+
"""
564+
INSERT INTO default.test_mixed_delete_types
565+
VALUES
566+
(2, 'Lily', 60, 'Sales', 2000.0, true),
567+
(21, 'Lucy', 28, 'Engineering', 76000.0, true);
568+
"""
569+
)

tests/integration/iceberg/test_iceberg_reads.py

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def test_daft_iceberg_table_open(local_iceberg_tables, local_iceberg_catalog):
4545
"test_add_new_column",
4646
"test_new_column_with_no_data",
4747
"test_table_rename",
48+
"test_overlapping_deletes",
49+
"test_mixed_delete_types",
4850
# Partition evolution currently not supported, see issue: https://github.com/Eventual-Inc/Daft/issues/2249
4951
# "test_evolve_partitioning",
5052
]
@@ -69,6 +71,8 @@ def test_daft_iceberg_table_open(local_iceberg_tables, local_iceberg_catalog):
6971
"test_add_new_column": ["idx"],
7072
"test_new_column_with_no_data": [],
7173
"test_table_rename": [],
74+
"test_overlapping_deletes": [],
75+
"test_mixed_delete_types": [],
7276
}
7377

7478

@@ -240,62 +244,12 @@ def test_daft_iceberg_table_mor_predicate_collect_correct(table_name, local_iceb
240244
class TestIcebergCountPushdown:
241245
"""Test suite for Iceberg Count pushdown optimization."""
242246

243-
@pytest.mark.integration()
244-
def test_count_pushdown_basic(self, local_iceberg_catalog, capsys):
245-
"""Test basic count(*) pushdown functionality."""
246-
catalog_name, pyiceberg_catalog = local_iceberg_catalog
247-
tab = pyiceberg_catalog.load_table("default.test_all_types")
248-
249-
# Test Daft count with pushdown
250-
df = daft.read_table(f"{catalog_name}.default.test_all_types").count()
251-
_ = capsys.readouterr()
252-
df.explain(True)
253-
actual = capsys.readouterr()
254-
assert "daft.io.iceberg.iceberg_scan:_iceberg_count_result_function" in actual.out
255-
256-
daft_count = df.collect().to_pydict()["count"][0]
257-
258-
# Compare with PyIceberg count
259-
iceberg_count = len(tab.scan().to_arrow())
260-
261-
assert daft_count == iceberg_count
262-
263-
@pytest.mark.integration()
264-
def test_count_pushdown_empty_table(self, local_iceberg_catalog, capsys):
265-
"""Test count pushdown on empty table."""
266-
catalog_name, pyiceberg_catalog = local_iceberg_catalog
267-
268-
# Use a table that might be empty or create logic to test empty scenario
269-
try:
270-
tab = pyiceberg_catalog.load_table("default.test_new_column_with_no_data")
271-
df = daft.read_table(f"{catalog_name}.default.test_new_column_with_no_data").count()
272-
273-
_ = capsys.readouterr()
274-
df.explain(True)
275-
actual = capsys.readouterr()
276-
assert "daft.io.iceberg.iceberg_scan:_iceberg_count_result_function" in actual.out
277-
278-
daft_count = df.collect().to_pydict()["count"][0]
279-
280-
# Compare with PyIceberg count
281-
iceberg_count = len(tab.scan().to_arrow())
282-
283-
assert daft_count == iceberg_count
284-
except Exception:
285-
# If table doesn't exist or has issues, skip this test
286-
pytest.skip("Empty table test requires specific table setup")
287-
288247
@pytest.mark.integration()
289248
@pytest.mark.parametrize(
290249
"table_name",
291-
[
292-
"test_partitioned_by_identity",
293-
"test_partitioned_by_bucket",
294-
"test_partitioned_by_days",
295-
"test_partitioned_by_years",
296-
],
250+
WORKING_SHOW_COLLECT,
297251
)
298-
def test_count_pushdown_partitioned_tables(self, table_name, local_iceberg_catalog, capsys):
252+
def test_count_pushdown_basic(self, table_name, local_iceberg_catalog, capsys):
299253
"""Test count pushdown on partitioned tables."""
300254
catalog_name, pyiceberg_catalog = local_iceberg_catalog
301255
tab = pyiceberg_catalog.load_table(f"default.{table_name}")
@@ -345,7 +299,7 @@ def test_count_pushdown_with_column_selection(self, local_iceberg_catalog, capsy
345299

346300
# Test count with column selection (should still use pushdown)
347301
df = daft.read_table(f"{catalog_name}.default.test_all_types")
348-
df = df.select("id") if "id" in df.column_names else df.select(df.column_names[0]).count()
302+
df = df.select("id").count() if "id" in df.column_names else df.select(df.column_names[0]).count()
349303

350304
_ = capsys.readouterr()
351305
df.explain(True)
@@ -407,25 +361,3 @@ def test_count_pushdown_snapshot_consistency(self, local_iceberg_catalog, capsys
407361
except Exception:
408362
# If snapshotting table doesn't exist, skip this test
409363
pytest.skip("Snapshot test requires test_snapshotting table")
410-
411-
@pytest.mark.integration()
412-
@pytest.mark.parametrize("table_name", ["test_positional_mor_deletes", "test_positional_mor_double_deletes"])
413-
def test_count_pushdown_with_deletes(self, table_name, local_iceberg_catalog, capsys):
414-
"""Test count pushdown on tables with MOR (Merge-On-Read) deletes."""
415-
catalog_name, pyiceberg_catalog = local_iceberg_catalog
416-
tab = pyiceberg_catalog.load_table(f"default.{table_name}")
417-
418-
# Test Daft count on table with deletes
419-
df = daft.read_table(f"{catalog_name}.default.{table_name}").count()
420-
421-
_ = capsys.readouterr()
422-
df.explain(True)
423-
actual = capsys.readouterr()
424-
assert "daft.io.iceberg.iceberg_scan:_iceberg_count_result_function" in actual.out
425-
426-
daft_count = df.collect().to_pydict()["count"][0]
427-
428-
# Compare with PyIceberg count (should account for deletes)
429-
iceberg_count = len(tab.scan().to_arrow())
430-
431-
assert daft_count == iceberg_count

0 commit comments

Comments
 (0)