diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 254d2a74e..9f71fd021 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -32,7 +32,7 @@ ) from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa from awswrangler._config import config # noqa -from awswrangler._distributed import initialize_ray +from awswrangler.distributed import initialize_ray if config.distributed: initialize_ray() diff --git a/awswrangler/_threading.py b/awswrangler/_threading.py new file mode 100644 index 000000000..40a433337 --- /dev/null +++ b/awswrangler/_threading.py @@ -0,0 +1,40 @@ +"""Threading Module (PRIVATE).""" + +import concurrent.futures +import itertools +import logging +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union + +import boto3 + +from awswrangler import _utils +from awswrangler._config import config + +if TYPE_CHECKING or config.distributed: + from awswrangler.distributed._pool import _RayPoolExecutor + +_logger: logging.Logger = logging.getLogger(__name__) + + +def _get_executor(use_threads: Union[bool, int]) -> Union["_ThreadPoolExecutor", "_RayPoolExecutor"]: + executor = _RayPoolExecutor if config.distributed else _ThreadPoolExecutor + return executor(use_threads) # type: ignore + + +class _ThreadPoolExecutor: + def __init__(self, use_threads: Union[bool, int]): + super().__init__() + self._exec: Optional[concurrent.futures.ThreadPoolExecutor] = None + self._cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + if self._cpus > 1: + self._exec = concurrent.futures.ThreadPoolExecutor(max_workers=self._cpus) # pylint: disable=R1732 + + def map(self, func: Callable[..., List[str]], boto3_session: boto3.Session, *iterables: Any) -> List[Any]: + _logger.debug("Map: %s", func) + if self._exec is not None: + # Deserialize boto3 session into pickable object + boto3_primitives = _utils.boto3_to_primitives(boto3_session=boto3_session) + args = (itertools.repeat(boto3_primitives), *iterables) + return list(self._exec.map(func, *args)) + # Single-threaded + return list(map(func, *(itertools.repeat(boto3_session), *iterables))) # type: ignore diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index deb979dd6..9aeea443d 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -8,16 +8,22 @@ import random import time from concurrent.futures import FIRST_COMPLETED, Future, wait -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast import boto3 import botocore.config import numpy as np import pandas as pd +import pyarrow as pa from awswrangler import _config, exceptions from awswrangler.__metadata__ import __version__ -from awswrangler._config import apply_configs +from awswrangler._config import apply_configs, config + +if TYPE_CHECKING or config.distributed: + import ray + + from awswrangler.distributed._utils import _arrow_refs_to_df _logger: logging.Logger = logging.getLogger(__name__) @@ -401,3 +407,18 @@ def check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Di f"Schema change detected: Data type change on column {c} " f"(Old type: {catalog_cols[c]} / New type {t})." ) + + +def pylist_to_arrow(mapping: List[Dict[str, Any]]) -> pa.Table: + names = list(mapping[0].keys()) if mapping else [] + arrays = [] + for n in names: + v = [row[n] if n in row else None for row in mapping] + arrays.append(v) + return pa.Table.from_arrays(arrays, names) + + +def table_refs_to_df(tables: Union[List[pa.Table], List["ray.ObjectRef"]], kwargs: Dict[str, Any]) -> pd.DataFrame: # type: ignore # noqa: E501 + if isinstance(tables[0], pa.Table): + return ensure_df_is_mutable(pa.concat_tables(tables).to_pandas(**kwargs)) + return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore diff --git a/awswrangler/distributed/__init__.py b/awswrangler/distributed/__init__.py index debacce0f..bbb649cc5 100644 --- a/awswrangler/distributed/__init__.py +++ b/awswrangler/distributed/__init__.py @@ -1 +1,8 @@ """Distributed Module.""" + +from awswrangler.distributed._distributed import initialize_ray, ray_remote # noqa + +__all__ = [ + "ray_remote", + "initialize_ray", +] diff --git a/awswrangler/_distributed.py b/awswrangler/distributed/_distributed.py similarity index 100% rename from awswrangler/_distributed.py rename to awswrangler/distributed/_distributed.py diff --git a/awswrangler/distributed/_pool.py b/awswrangler/distributed/_pool.py new file mode 100644 index 000000000..5348f1972 --- /dev/null +++ b/awswrangler/distributed/_pool.py @@ -0,0 +1,23 @@ +"""Threading Module (PRIVATE).""" + +import itertools +import logging +from typing import Any, Callable, List, Optional, Union + +import boto3 +from ray.util.multiprocessing import Pool + +_logger: logging.Logger = logging.getLogger(__name__) + + +class _RayPoolExecutor: + def __init__(self, processes: Optional[Union[bool, int]] = None): + self._exec: Pool = Pool(processes=None if isinstance(processes, bool) else processes) + + def map(self, func: Callable[..., List[str]], _: boto3.Session, *args: Any) -> List[Any]: + futures = [] + _logger.debug("Ray map: %s", func) + # Discard boto3.Session object & call the fn asynchronously + for arg in zip(itertools.repeat(None), *args): + futures.append(self._exec.apply_async(func, arg)) + return [f.get() for f in futures] diff --git a/awswrangler/distributed/_utils.py b/awswrangler/distributed/_utils.py index d0a956b35..8b70b7e04 100644 --- a/awswrangler/distributed/_utils.py +++ b/awswrangler/distributed/_utils.py @@ -3,9 +3,9 @@ from typing import Any, Callable, Dict, List, Optional import modin.pandas as pd +import pyarrow as pa import ray from modin.distributed.dataframe.pandas.partitions import from_partitions -from pyarrow import Table from ray.data.impl.arrow_block import ArrowBlockAccessor from ray.data.impl.remote_fn import cached_remote_fn @@ -14,7 +14,7 @@ def _block_to_df( block: Any, kwargs: Dict[str, Any], dtype: Optional[Dict[str, str]] = None, -) -> Table: +) -> pa.Table: block = ArrowBlockAccessor.for_block(block) df = block._table.to_pandas(**kwargs) # pylint: disable=protected-access return df.astype(dtype=dtype) if dtype else df diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 00b741feb..c5b6bf674 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -1,30 +1,27 @@ """Amazon Lake Formation Module gathering all read functions.""" -import concurrent.futures import itertools import logging from typing import Any, Dict, List, Optional, Tuple, Union import boto3 import pandas as pd -from pyarrow import NativeFile, RecordBatchStreamReader, Table, concat_tables +from pyarrow import NativeFile, RecordBatchStreamReader, Table from awswrangler import _data_types, _utils, catalog -from awswrangler._config import apply_configs, config -from awswrangler._distributed import ray_remote +from awswrangler._config import apply_configs +from awswrangler._threading import _get_executor from awswrangler.catalog._utils import _catalog_id, _transaction_id +from awswrangler.distributed import ray_remote from awswrangler.lakeformation._utils import commit_transaction, start_transaction, wait_query -if config.distributed: - from awswrangler.distributed._utils import _arrow_refs_to_df - _logger: logging.Logger = logging.getLogger(__name__) @ray_remote def _get_work_unit_results( + boto3_session: Optional[boto3.Session], query_id: str, token_work_unit: Tuple[str, int], - boto3_session: Optional[boto3.Session] = None, ) -> Table: _logger.debug("Query id: %s Token work unit: %s", query_id, token_work_unit) client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) @@ -66,8 +63,7 @@ def _resolve_sql_query( ) next_token = response.get("NextToken", None) scan_kwargs["NextToken"] = next_token - - tables: List[Table] = [] + executor = _get_executor(use_threads=use_threads) kwargs = { "use_threads": use_threads, "split_blocks": True, @@ -80,39 +76,13 @@ def _resolve_sql_query( "safe": safe, "types_mapper": _data_types.pyarrow2pandas_extension if map_types else None, } - if config.distributed: - return _arrow_refs_to_df( - list( - _get_work_unit_results( - query_id=query_id, - token_work_unit=token_work_unit, - ) - for token_work_unit in token_work_units - ), - kwargs=kwargs, - ) - if use_threads is False: - tables = list( - _get_work_unit_results( - query_id=query_id, - token_work_unit=token_work_unit, - boto3_session=boto3_session, - ) - for token_work_unit in token_work_units - ) - else: - cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) - with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: - tables = list( - executor.map( - _get_work_unit_results, - itertools.repeat(query_id), - token_work_units, - itertools.repeat(boto3_session), - ) - ) - table = concat_tables(tables) - return _utils.ensure_df_is_mutable(df=table.to_pandas(**kwargs)) + tables = executor.map( + _get_work_unit_results, + boto3_session, + itertools.repeat(query_id), + token_work_units, + ) + return _utils.table_refs_to_df(tables=tables, kwargs=kwargs) @apply_configs diff --git a/awswrangler/s3/_select.py b/awswrangler/s3/_select.py index 62d9a8f58..78d2b0601 100644 --- a/awswrangler/s3/_select.py +++ b/awswrangler/s3/_select.py @@ -1,18 +1,25 @@ """Amazon S3 Select Module (PRIVATE).""" -import concurrent.futures import itertools import json import logging import pprint -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd +import pyarrow as pa from awswrangler import _utils, exceptions +from awswrangler._config import config +from awswrangler._threading import _get_executor +from awswrangler._utils import pylist_to_arrow, table_refs_to_df +from awswrangler.distributed import ray_remote from awswrangler.s3._describe import size_objects +if TYPE_CHECKING or config.distributed: + import ray + _logger: logging.Logger = logging.getLogger(__name__) _RANGE_CHUNK_SIZE: int = int(1024 * 1024) @@ -23,11 +30,12 @@ def _gen_scan_range(obj_size: int) -> Iterator[Tuple[int, int]]: yield (i, i + min(_RANGE_CHUNK_SIZE, obj_size - i)) +@ray_remote def _select_object_content( - args: Dict[str, Any], boto3_session: Optional[boto3.Session], + args: Dict[str, Any], scan_range: Optional[Tuple[int, int]] = None, -) -> List[Dict[str, Any]]: +) -> Union[pa.Table, "ray.ObjectRef"]: # type: ignore client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) if scan_range: @@ -44,12 +52,15 @@ def _select_object_content( # Record end can either be a partial record or a return char partial_record = records.pop() payload_records.extend([json.loads(record) for record in records]) - return payload_records + return pylist_to_arrow(payload_records) def _paginate_stream( - args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session] -) -> pd.DataFrame: + args: Dict[str, Any], + path: str, + executor: Any, + boto3_session: Optional[boto3.Session], +) -> List[pa.Table]: obj_size: int = size_objects( # type: ignore path=[path], use_threads=False, @@ -58,28 +69,12 @@ def _paginate_stream( if obj_size is None: raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") scan_ranges = _gen_scan_range(obj_size=obj_size) - - if use_threads is False: - stream_records = list( - _select_object_content( - args=args, - boto3_session=boto3_session, - scan_range=scan_range, - ) - for scan_range in scan_ranges - ) - else: - cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) - with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: - stream_records = list( - executor.map( - _select_object_content, - itertools.repeat(args), - itertools.repeat(boto3_session), - scan_ranges, - ) - ) - return pd.DataFrame([item for sublist in stream_records for item in sublist]) # Flatten list of lists + return executor.map( # type: ignore + _select_object_content, + boto3_session, + itertools.repeat(args), + scan_ranges, + ) def select_query( @@ -195,7 +190,7 @@ def select_query( if s3_additional_kwargs: args.update(s3_additional_kwargs) _logger.debug("args:\n%s", pprint.pformat(args)) - + executor = _get_executor(use_threads=use_threads) if any( [ compression, @@ -205,5 +200,20 @@ def select_query( ): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) # and JSON objects (in LINES mode only) _logger.debug("Scan ranges are not supported given provided input.") - return pd.DataFrame(_select_object_content(args=args, boto3_session=boto3_session)) - return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session) + tables = [_select_object_content(boto3_session=boto3_session, args=args)] + else: + tables = _paginate_stream(args=args, path=path, executor=executor, boto3_session=boto3_session) + kwargs = { + "use_threads": use_threads, + "split_blocks": True, + "self_destruct": True, + "integer_object_nulls": False, + "date_as_object": True, + "ignore_metadata": True, + "strings_to_categorical": False, + # TODO: Additional pyarrow args to consider + # "categories": categories, + # "safe": safe, + # "types_mapper": _data_types.pyarrow2pandas_extension if map_types else None, + } + return table_refs_to_df(tables, kwargs=kwargs)