|
1 | 1 | """Distributed ParquetDatasource Module."""
|
2 | 2 |
|
3 | 3 | import logging
|
4 |
| -from typing import Any, Callable, Dict, List, Optional |
| 4 | +from typing import Any, Callable, Dict, Iterator, List, Optional, Union |
5 | 5 |
|
6 | 6 | # fs required to implicitly trigger S3 subsystem initialization
|
7 | 7 | import pyarrow.fs # noqa: F401 pylint: disable=unused-import
|
| 8 | +import ray |
| 9 | +from ray.data._internal.output_buffer import BlockOutputBuffer |
8 | 10 | from ray.data._internal.remote_fn import cached_remote_fn
|
9 | 11 | from ray.data.block import Block, BlockAccessor, BlockMetadata
|
| 12 | +from ray.data.context import DatasetContext |
10 | 13 | from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider, Reader
|
11 | 14 | from ray.data.datasource.datasource import WriteResult
|
12 | 15 | from ray.data.datasource.file_based_datasource import (
|
13 | 16 | _resolve_paths_and_filesystem,
|
14 | 17 | _S3FileSystemWrapper,
|
15 | 18 | _wrap_s3_serialization_workaround,
|
16 | 19 | )
|
17 |
| -from ray.data.datasource.parquet_datasource import _ParquetDatasourceReader |
| 20 | +from ray.data.datasource.parquet_datasource import ( |
| 21 | + PARQUET_READER_ROW_BATCH_SIZE, |
| 22 | + _deserialize_pieces_with_retry, |
| 23 | + _ParquetDatasourceReader, |
| 24 | + _SerializedPiece, |
| 25 | +) |
18 | 26 | from ray.types import ObjectRef
|
19 | 27 |
|
| 28 | +from awswrangler._arrow import _add_table_partitions |
| 29 | + |
20 | 30 | _logger: logging.Logger = logging.getLogger(__name__)
|
21 | 31 |
|
22 | 32 |
|
| 33 | +def _read_pieces( |
| 34 | + block_udf: Optional[Callable[[Block[Any]], Block[Any]]], |
| 35 | + reader_args: Any, |
| 36 | + columns: Optional[List[str]], |
| 37 | + schema: Optional[Union[type, "pyarrow.lib.Schema"]], |
| 38 | + serialized_pieces: List[_SerializedPiece], |
| 39 | +) -> Iterator["pyarrow.Table"]: |
| 40 | + # This import is necessary to load the tensor extension type. |
| 41 | + from ray.data.extensions.tensor_extension import ( # type: ignore # noqa: F401, E501 # pylint: disable=import-outside-toplevel, unused-import |
| 42 | + ArrowTensorType, |
| 43 | + ) |
| 44 | + |
| 45 | + # Deserialize after loading the filesystem class. |
| 46 | + pieces: List["pyarrow._dataset.ParquetFileFragment"] = _deserialize_pieces_with_retry(serialized_pieces) |
| 47 | + |
| 48 | + # Ensure that we're reading at least one dataset fragment. |
| 49 | + assert len(pieces) > 0 |
| 50 | + |
| 51 | + import pyarrow as pa # pylint: disable=import-outside-toplevel |
| 52 | + |
| 53 | + ctx = DatasetContext.get_current() |
| 54 | + output_buffer = BlockOutputBuffer( |
| 55 | + block_udf=block_udf, |
| 56 | + target_max_block_size=ctx.target_max_block_size, |
| 57 | + ) |
| 58 | + |
| 59 | + _logger.debug("Reading %s parquet pieces", len(pieces)) |
| 60 | + use_threads = reader_args.pop("use_threads", False) |
| 61 | + path_root = reader_args.pop("path_root", None) |
| 62 | + for piece in pieces: |
| 63 | + batches = piece.to_batches( |
| 64 | + use_threads=use_threads, |
| 65 | + columns=columns, |
| 66 | + schema=schema, |
| 67 | + batch_size=PARQUET_READER_ROW_BATCH_SIZE, |
| 68 | + **reader_args, |
| 69 | + ) |
| 70 | + for batch in batches: |
| 71 | + # Table creation is wrapped inside _add_table_partitions |
| 72 | + # to add columns with partition values when dataset=True |
| 73 | + # and cast them to categorical |
| 74 | + table = _add_table_partitions( |
| 75 | + table=pa.Table.from_batches([batch], schema=schema), |
| 76 | + path=f"s3://{piece.path}", |
| 77 | + path_root=path_root, |
| 78 | + ) |
| 79 | + # If the table is empty, drop it. |
| 80 | + if table.num_rows > 0: |
| 81 | + output_buffer.add_block(table) |
| 82 | + if output_buffer.has_next(): |
| 83 | + yield output_buffer.next() |
| 84 | + output_buffer.finalize() |
| 85 | + if output_buffer.has_next(): |
| 86 | + yield output_buffer.next() |
| 87 | + |
| 88 | + |
| 89 | +# Patch _read_pieces function |
| 90 | +ray.data.datasource.parquet_datasource._read_pieces = _read_pieces # pylint: disable=protected-access |
| 91 | + |
| 92 | + |
23 | 93 | class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider):
|
24 | 94 | """Block write path provider.
|
25 | 95 |
|
|
0 commit comments