Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions awswrangler/distributed/ray/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,46 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
def ray_remote(**options: Any) -> Callable[..., Any]:
"""
Decorate callable to wrap within ray.remote.
Decorate with @ray.remote providing .options().

Parameters
----------
function : Callable[..., Any]
Callable as input to ray.remote.
options : Any
Ray remote options

Returns
-------
Callable[..., Any]
"""
# Access the source function if it exists
function = getattr(function, "_source_func", function)

@wraps(function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return ray.remote(ray_logger(function)).remote(*args, **kwargs) # type: ignore
def remote_decorator(function: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorate callable to wrap within ray.remote.

return wrapper
Parameters
----------
function : Callable[..., Any]
Callable as input to ray.remote.

Returns
-------
Callable[..., Any]
"""
# Access the source function if it exists
function = getattr(function, "_source_func", function)

@wraps(function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
remote_fn = ray.remote(ray_logger(function))
if options:
remote_fn = remote_fn.options(**options)
return remote_fn.remote(*args, **kwargs) # type: ignore

return wrapper

return remote_decorator


def ray_get(futures: Union["ray.ObjectRef[Any]", List["ray.ObjectRef[Any]"]]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/distributed/ray/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class RayLogger:
def get_logger(self, name: Union[str, Any] = None) -> logging.Logger: ...

def ray_logger(function: Callable[..., Any]) -> Callable[..., Any]: ...
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]: ...
def ray_remote(**options: Any) -> Callable[..., Any]: ...
def ray_get(futures: List[Any]) -> Any: ...
def initialize_ray(
address: Optional[str] = None,
Expand Down
7 changes: 6 additions & 1 deletion awswrangler/distributed/ray/_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ def register_ray() -> None:
_select_query,
_select_object_content,
_wait_object_batch,
]:
# Schedule for maximum concurrency
engine.register_func(func, ray_remote(scheduling_strategy="SPREAD")(func))

for func in [
_write_batch,
_write_df,
]:
engine.register_func(func, ray_remote(func))
engine.register_func(func, ray_remote()(func))

if memory_format.get() == MemoryFormatEnum.MODIN:
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed
Expand Down
Loading