Skip to content

Commit d62a475

Browse files
committed
Backward-compatible partition handling
1 parent f69869c commit d62a475

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

awswrangler/distributed/_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def initialize_ray(
132132
cpu_count : Optional[int]
133133
Number of CPUs to assign to each raylet, by default None
134134
gpu_count : Optional[int]
135-
Number of GPUs to assign to each raylet, by default 0
135+
Number of GPUs to assign to each raylet, by default None
136136
"""
137137
if not ray.is_initialized():
138138
# Detect an existing cluster

awswrangler/distributed/datasources/parquet_datasource.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,95 @@
11
"""Distributed ParquetDatasource Module."""
22

33
import logging
4-
from typing import Any, Callable, Dict, List, Optional
4+
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
55

66
# fs required to implicitly trigger S3 subsystem initialization
77
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
8+
import ray
9+
from ray.data._internal.output_buffer import BlockOutputBuffer
810
from ray.data._internal.remote_fn import cached_remote_fn
911
from ray.data.block import Block, BlockAccessor, BlockMetadata
12+
from ray.data.context import DatasetContext
1013
from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider, Reader
1114
from ray.data.datasource.datasource import WriteResult
1215
from ray.data.datasource.file_based_datasource import (
1316
_resolve_paths_and_filesystem,
1417
_S3FileSystemWrapper,
1518
_wrap_s3_serialization_workaround,
1619
)
17-
from ray.data.datasource.parquet_datasource import _ParquetDatasourceReader
20+
from ray.data.datasource.parquet_datasource import (
21+
PARQUET_READER_ROW_BATCH_SIZE,
22+
_deserialize_pieces_with_retry,
23+
_ParquetDatasourceReader,
24+
_SerializedPiece,
25+
)
1826
from ray.types import ObjectRef
1927

28+
from awswrangler._arrow import _add_table_partitions
29+
2030
_logger: logging.Logger = logging.getLogger(__name__)
2131

2232

33+
def _read_pieces(
34+
block_udf: Optional[Callable[[Block[Any]], Block[Any]]],
35+
reader_args: Any,
36+
columns: Optional[List[str]],
37+
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
38+
serialized_pieces: List[_SerializedPiece],
39+
) -> Iterator["pyarrow.Table"]:
40+
# This import is necessary to load the tensor extension type.
41+
from ray.data.extensions.tensor_extension import ( # type: ignore # noqa: F401, E501 # pylint: disable=import-outside-toplevel, unused-import
42+
ArrowTensorType,
43+
)
44+
45+
# Deserialize after loading the filesystem class.
46+
pieces: List["pyarrow._dataset.ParquetFileFragment"] = _deserialize_pieces_with_retry(serialized_pieces)
47+
48+
# Ensure that we're reading at least one dataset fragment.
49+
assert len(pieces) > 0
50+
51+
import pyarrow as pa # pylint: disable=import-outside-toplevel
52+
53+
ctx = DatasetContext.get_current()
54+
output_buffer = BlockOutputBuffer(
55+
block_udf=block_udf,
56+
target_max_block_size=ctx.target_max_block_size,
57+
)
58+
59+
_logger.debug("Reading %s parquet pieces", len(pieces))
60+
use_threads = reader_args.pop("use_threads", False)
61+
path_root = reader_args.pop("path_root", None)
62+
for piece in pieces:
63+
batches = piece.to_batches(
64+
use_threads=use_threads,
65+
columns=columns,
66+
schema=schema,
67+
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
68+
**reader_args,
69+
)
70+
for batch in batches:
71+
# Table creation is wrapped inside _add_table_partitions
72+
# to add columns with partition values when dataset=True
73+
# and cast them to categorical
74+
table = _add_table_partitions(
75+
table=pa.Table.from_batches([batch], schema=schema),
76+
path=f"s3://{piece.path}",
77+
path_root=path_root,
78+
)
79+
# If the table is empty, drop it.
80+
if table.num_rows > 0:
81+
output_buffer.add_block(table)
82+
if output_buffer.has_next():
83+
yield output_buffer.next()
84+
output_buffer.finalize()
85+
if output_buffer.has_next():
86+
yield output_buffer.next()
87+
88+
89+
# Patch _read_pieces function
90+
ray.data.datasource.parquet_datasource._read_pieces = _read_pieces # pylint: disable=protected-access
91+
92+
2393
class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider):
2494
"""Block write path provider.
2595

awswrangler/s3/_read_parquet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def _read_parquet(
350350
schema=schema,
351351
columns=columns,
352352
dataset_kwargs=dataset_kwargs,
353+
path_root=path_root,
353354
)
354355
return _to_modin(dataset=dataset, to_pandas_kwargs=arrow_kwargs)
355356

@@ -475,7 +476,7 @@ def read_parquet(
475476
If integer is provided, specified number is used.
476477
parallelism : int, optional
477478
The requested parallelism of the read. Only used when `distributed` add-on is installed.
478-
Parallelism may be limited by the number of files of the dataset. 200 by default.
479+
Parallelism may be limited by the number of files of the dataset. -1 (autodetect) by default.
479480
boto3_session : boto3.Session(), optional
480481
Boto3 Session. The default boto3 session is used if None is received.
481482
s3_additional_kwargs : Optional[Dict[str, Any]]

0 commit comments

Comments
 (0)