Skip to content

Commit 4325e50

Browse files
committed
Fix - pass axis
1 parent 0967ac6 commit 4325e50

File tree

1 file changed

+10
-3
lines changed
  • awswrangler/distributed/ray/modin

1 file changed

+10
-3
lines changed

awswrangler/distributed/ray/modin/_core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,16 @@ def _validate_partition(df: pd.DataFrame, n_columns: int) -> bool:
3131
return len(df.columns) == n_columns
3232

3333
# Unwrap partitions as they are currently stored (axis=None)
34+
# Partitions are a 2D array because the data frame is split along both row and column axis
3435
partitions: List[List[ray.types.ObjectRef[pd.DataFrame]]] = unwrap_partitions(df, axis=None)
3536
return all(
36-
ray.get([_validate_partition.remote(partition, len(df.columns)) for row in partitions for partition in row])
37+
ray.get(
38+
[
39+
_validate_partition.remote(partition, len(df.columns))
40+
for partitions_row in partitions
41+
for partition in partitions_row
42+
]
43+
)
3744
)
3845

3946

@@ -74,8 +81,8 @@ def wrapper(
7481
"The dataframe will be automatically repartitioned along row axis to ensure "
7582
"each partition can be processed independently."
7683
)
77-
df = from_partitions(unwrap_partitions(df, axis=0), axis=axis, row_lengths=row_lengths)
78-
elif axis is not None:
84+
axis = 0
85+
if axis is not None:
7986
df = from_partitions(unwrap_partitions(df, axis=axis), axis=axis, row_lengths=row_lengths)
8087
return function(df, *args, **kwargs)
8188

0 commit comments

Comments
 (0)