Skip to content

Commit 496168a

Browse files
authored
perf(flotilla): Use Worker Affinity with Pre-Shuffle Merge (#5112)
1 parent 58e9cb2 commit 496168a

File tree

10 files changed

+53
-10
lines changed

10 files changed

+53
-10
lines changed

src/daft-distributed/src/pipeline_node/into_batches.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ impl IntoBatchesNode {
114114
StatsState::NotMaterialized,
115115
)
116116
},
117+
None,
117118
)?;
118119
if result_tx.send(task).await.is_err() {
119120
break;
@@ -136,6 +137,7 @@ impl IntoBatchesNode {
136137
StatsState::NotMaterialized,
137138
)
138139
},
140+
None,
139141
)?;
140142
let _ = result_tx.send(task).await;
141143
}

src/daft-distributed/src/pipeline_node/into_partitions.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl IntoPartitionsNode {
9999
// Remainder: 10 % 3 = 1 (one task gets an extra input)
100100
let num_partitions_with_extra_input = tasks.len() % self.num_partitions;
101101

102-
let mut output_futures = OrderedJoinSet::new();
102+
let mut tasks_per_partition = Vec::new();
103103

104104
let mut task_iter = tasks.into_iter();
105105
for partition_idx in 0..self.num_partitions {
@@ -115,8 +115,13 @@ impl IntoPartitionsNode {
115115
.take(chunk_size)
116116
.map(|task| task.submit(scheduler_handle))
117117
.collect::<DaftResult<Vec<_>>>()?;
118+
tasks_per_partition.push(submitted_tasks);
119+
}
120+
121+
let mut output_futures = OrderedJoinSet::new();
122+
for tasks in tasks_per_partition {
118123
output_futures.spawn(async move {
119-
let materialized_output = futures::future::try_join_all(submitted_tasks)
124+
let materialized_output = futures::future::try_join_all(tasks)
120125
.await?
121126
.into_iter()
122127
.flatten()
@@ -136,6 +141,7 @@ impl IntoPartitionsNode {
136141
move |input| {
137142
LocalPhysicalPlan::into_partitions(input, 1, StatsState::NotMaterialized)
138143
},
144+
None,
139145
)?;
140146
if result_tx.send(task).await.is_err() {
141147
break;
@@ -167,7 +173,7 @@ impl IntoPartitionsNode {
167173
// Remainder: 10 % 3 = 1 (one partition will split into 4 outputs)
168174
let num_partitions_with_extra_output = self.num_partitions % tasks.len();
169175

170-
let mut output_futures = OrderedJoinSet::new();
176+
let mut submitted_tasks = Vec::new();
171177

172178
for (input_partition_idx, task) in tasks.into_iter().enumerate() {
173179
let mut num_outputs = base_splits_per_partition;
@@ -187,7 +193,12 @@ impl IntoPartitionsNode {
187193
},
188194
);
189195
let submitted_task = into_partitions_task.submit(scheduler_handle)?;
190-
output_futures.spawn(submitted_task);
196+
submitted_tasks.push(submitted_task);
197+
}
198+
199+
let mut output_futures = OrderedJoinSet::new();
200+
for task in submitted_tasks {
201+
output_futures.spawn(task);
191202
}
192203

193204
// Collect all the outputs and emit a new task for each output.
@@ -200,6 +211,7 @@ impl IntoPartitionsNode {
200211
TaskContext::from((&self.context, task_id_counter.next())),
201212
vec![output],
202213
&(self_arc as Arc<dyn DistributedPipelineNode>),
214+
None,
203215
)?;
204216
if result_tx.send(task).await.is_err() {
205217
break;

src/daft-distributed/src/pipeline_node/limit.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ impl LimitNode {
106106
input
107107
}
108108
},
109+
None,
109110
)?
110111
}
111112
Ordering::Greater => {
@@ -122,6 +123,7 @@ impl LimitNode {
122123
StatsState::NotMaterialized,
123124
)
124125
},
126+
None,
125127
)?;
126128
*remaining_take = 0;
127129
task

src/daft-distributed/src/pipeline_node/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ fn make_new_task_from_materialized_outputs<F>(
317317
materialized_outputs: Vec<MaterializedOutput>,
318318
node: &Arc<dyn DistributedPipelineNode>,
319319
plan_builder: F,
320+
scheduling_strategy: Option<SchedulingStrategy>,
320321
) -> DaftResult<SubmittableTask<SwordfishTask>>
321322
where
322323
F: FnOnce(LocalPhysicalPlanRef) -> LocalPhysicalPlanRef + Send + Sync + 'static,
@@ -338,7 +339,7 @@ where
338339
plan,
339340
node.config().execution_config.clone(),
340341
psets,
341-
SchedulingStrategy::Spread,
342+
scheduling_strategy.unwrap_or(SchedulingStrategy::Spread),
342343
node.context().to_hashmap(),
343344
);
344345
Ok(SubmittableTask::new(task))
@@ -348,8 +349,15 @@ fn make_in_memory_task_from_materialized_outputs(
348349
task_context: TaskContext,
349350
materialized_outputs: Vec<MaterializedOutput>,
350351
node: &Arc<dyn DistributedPipelineNode>,
352+
scheduling_strategy: Option<SchedulingStrategy>,
351353
) -> DaftResult<SubmittableTask<SwordfishTask>> {
352-
make_new_task_from_materialized_outputs(task_context, materialized_outputs, node, |input| input)
354+
make_new_task_from_materialized_outputs(
355+
task_context,
356+
materialized_outputs,
357+
node,
358+
|input| input,
359+
scheduling_strategy,
360+
)
353361
}
354362

355363
fn append_plan_to_existing_task<F>(

src/daft-distributed/src/pipeline_node/shuffles/gather.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ impl GatherNode {
8383
TaskContext::from((&self_clone.context, task_id_counter.next())),
8484
materialized,
8585
&(self_clone as Arc<dyn DistributedPipelineNode>),
86+
None,
8687
)?;
8788

8889
let _ = result_tx.send(task).await;

src/daft-distributed/src/pipeline_node/shuffles/pre_shuffle_merge.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
},
1313
scheduling::{
1414
scheduler::{SchedulerHandle, SubmittableTask},
15-
task::{SwordfishTask, TaskContext},
15+
task::{SchedulingStrategy, SwordfishTask, TaskContext},
1616
worker::WorkerId,
1717
},
1818
stage::{StageConfig, StageExecutionContext, TaskIDCounter},
@@ -172,6 +172,10 @@ impl PreShuffleMergeNode {
172172
TaskContext::from((self.context(), task_id_counter.next())),
173173
materialized_outputs,
174174
&(self_clone as Arc<dyn DistributedPipelineNode>),
175+
Some(SchedulingStrategy::WorkerAffinity {
176+
worker_id,
177+
soft: false,
178+
}),
175179
)?;
176180

177181
// Send the task directly to result_tx
@@ -183,13 +187,17 @@ impl PreShuffleMergeNode {
183187
}
184188

185189
// Handle any remaining buckets that haven't reached the threshold
186-
for (_, materialized_outputs) in worker_buckets {
190+
for (worker_id, materialized_outputs) in worker_buckets {
187191
if !materialized_outputs.is_empty() {
188192
let self_clone = self.clone();
189193
let task = make_in_memory_task_from_materialized_outputs(
190194
TaskContext::from((self.context(), task_id_counter.next())),
191195
materialized_outputs,
192196
&(self_clone as Arc<dyn DistributedPipelineNode>),
197+
Some(SchedulingStrategy::WorkerAffinity {
198+
worker_id,
199+
soft: false,
200+
}),
193201
)?;
194202

195203
if result_tx.send(task).await.is_err() {

src/daft-distributed/src/pipeline_node/shuffles/repartition.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ impl RepartitionNode {
9999
TaskContext::from((&self_clone.context, task_id_counter.next())),
100100
partition_group,
101101
&(self_clone as Arc<dyn DistributedPipelineNode>),
102+
None,
102103
)?;
103104

104105
let _ = result_tx.send(task).await;

src/daft-distributed/src/pipeline_node/sink.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ impl SinkNode {
133133
StatsState::NotMaterialized,
134134
)
135135
},
136+
None,
136137
)?;
137138
let _ = sender.send(task).await;
138139
Ok(())

src/daft-distributed/src/pipeline_node/sort.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ impl SortNode {
211211
StatsState::NotMaterialized,
212212
)
213213
},
214+
None,
214215
)?;
215216
let _ = result_tx.send(task).await;
216217
return Ok(());
@@ -242,6 +243,7 @@ impl SortNode {
242243
StatsState::NotMaterialized,
243244
)
244245
},
246+
None,
245247
)?;
246248
let submitted_task = task.submit(&scheduler_handle)?;
247249
Ok(submitted_task)
@@ -281,6 +283,7 @@ impl SortNode {
281283
StatsState::NotMaterialized,
282284
)
283285
},
286+
None,
284287
)?;
285288
let submitted_task = task.submit(&scheduler_handle)?;
286289
Ok(submitted_task)
@@ -312,6 +315,7 @@ impl SortNode {
312315
StatsState::NotMaterialized,
313316
)
314317
},
318+
None,
315319
)?;
316320
let _ = result_tx.send(task).await;
317321
}

src/daft-shuffles/src/client/flight_client.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ impl ShuffleFlightClient {
4040
let channel = endpoint.connect().await.map_err(|e| {
4141
DaftError::External(format!("Failed to connect to endpoint: {}", e).into())
4242
})?;
43-
self.inner =
44-
ClientState::Initialized(std::mem::take(address), FlightClient::new(channel));
43+
let client = FlightClient::new(channel);
44+
let inner = client.into_inner().max_decoding_message_size(usize::MAX);
45+
self.inner = ClientState::Initialized(
46+
std::mem::take(address),
47+
FlightClient::new_from_inner(inner),
48+
);
4549
}
4650
match &mut self.inner {
4751
ClientState::Uninitialized(_) => unreachable!("Client should be initialized"),

0 commit comments

Comments
 (0)