Skip to content

Commit 0a3e4d5

Browse files
authored
(enhance) configure scheduling options, remove dependencies on internal ray impl (#1734)
* Update ray remote decorator to pass options * Remove dependency on ray._internal methods * Add scheduling parameters * remove pack scheduling; typing/linting fixes * linting/comments * fix write_block remote fn call * remove default scheduling strategy * Remove _open_input_source overload as it's not used by ParquetReader * clarify comments
1 parent 4fd0914 commit 0a3e4d5

File tree

9 files changed

+403
-87
lines changed

9 files changed

+403
-87
lines changed

awswrangler/distributed/ray/_core.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
5151
return wrapper
5252

5353

54-
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
54+
def ray_remote(**options: Any) -> Callable[..., Any]:
5555
"""
56-
Decorate callable to wrap within ray.remote.
56+
Decorate with @ray.remote providing .options().
5757
5858
Parameters
5959
----------
60-
function : Callable[..., Any]
61-
Callable as input to ray.remote.
60+
options : Any
61+
Ray remote options
6262
6363
Returns
6464
-------
6565
Callable[..., Any]
6666
"""
67-
# Access the source function if it exists
68-
function = getattr(function, "_source_func", function)
6967

70-
@wraps(function)
71-
def wrapper(*args: Any, **kwargs: Any) -> Any:
72-
return ray.remote(ray_logger(function)).remote(*args, **kwargs) # type: ignore
68+
def remote_decorator(function: Callable[..., Any]) -> Callable[..., Any]:
69+
"""
70+
Decorate callable to wrap within ray.remote.
7371
74-
return wrapper
72+
Parameters
73+
----------
74+
function : Callable[..., Any]
75+
Callable as input to ray.remote.
76+
77+
Returns
78+
-------
79+
Callable[..., Any]
80+
"""
81+
# Access the source function if it exists
82+
function = getattr(function, "_source_func", function)
83+
84+
@wraps(function)
85+
def wrapper(*args: Any, **kwargs: Any) -> Any:
86+
remote_fn = ray.remote(ray_logger(function))
87+
if options:
88+
remote_fn = remote_fn.options(**options)
89+
return remote_fn.remote(*args, **kwargs) # type: ignore
90+
91+
return wrapper
92+
93+
return remote_decorator
7594

7695

7796
def ray_get(futures: Union["ray.ObjectRef[Any]", List["ray.ObjectRef[Any]"]]) -> Any:

awswrangler/distributed/ray/_core.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class RayLogger:
1313
def get_logger(self, name: Union[str, Any] = None) -> logging.Logger: ...
1414

1515
def ray_logger(function: Callable[..., Any]) -> Callable[..., Any]: ...
16-
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]: ...
16+
def ray_remote(**options: Any) -> Callable[..., Any]: ...
1717
def ray_get(futures: List[Any]) -> Any: ...
1818
def initialize_ray(
1919
address: Optional[str] = None,

awswrangler/distributed/ray/_register.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ def register_ray() -> None:
2727
_select_query,
2828
_select_object_content,
2929
_wait_object_batch,
30+
]:
31+
# Schedule for maximum concurrency
32+
engine.register_func(func, ray_remote(scheduling_strategy="SPREAD")(func))
33+
34+
for func in [
3035
_write_batch,
3136
_write_df,
3237
]:
33-
engine.register_func(func, ray_remote(func))
38+
engine.register_func(func, ray_remote()(func))
3439

3540
if memory_format.get() == MemoryFormatEnum.MODIN:
3641
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed

0 commit comments

Comments
 (0)