Skip to content

Commit 10369cb

Browse files
authored
feat: Flotilla into partitions (#4963)
## Changes Made Implements `into_partitions` for flotilla. First collect all input tasks. Then either coalesces inputs or splits them such that exactly `num_partitions` partitions are emitted from this operator. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent fe0d4bf commit 10369cb

File tree

13 files changed

+594
-46
lines changed

13 files changed

+594
-46
lines changed
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
use std::sync::Arc;
2+
3+
use common_display::{tree::TreeDisplay, DisplayLevel};
4+
use common_error::DaftResult;
5+
use daft_local_plan::LocalPhysicalPlan;
6+
use daft_logical_plan::{partitioning::UnknownClusteringConfig, stats::StatsState};
7+
use daft_schema::schema::SchemaRef;
8+
use futures::StreamExt;
9+
10+
use super::{DistributedPipelineNode, SubmittableTaskStream};
11+
use crate::{
12+
pipeline_node::{
13+
append_plan_to_existing_task, make_in_memory_task_from_materialized_outputs,
14+
make_new_task_from_materialized_outputs, NodeID, NodeName, PipelineNodeConfig,
15+
PipelineNodeContext,
16+
},
17+
scheduling::{
18+
scheduler::{SchedulerHandle, SubmittableTask},
19+
task::{SwordfishTask, TaskContext},
20+
},
21+
stage::{StageConfig, StageExecutionContext, TaskIDCounter},
22+
utils::{
23+
channel::{create_channel, Sender},
24+
joinset::OrderedJoinSet,
25+
},
26+
};
27+
28+
#[derive(Clone)]
29+
pub(crate) struct IntoPartitionsNode {
30+
config: PipelineNodeConfig,
31+
context: PipelineNodeContext,
32+
num_partitions: usize,
33+
child: Arc<dyn DistributedPipelineNode>,
34+
}
35+
36+
impl IntoPartitionsNode {
37+
const NODE_NAME: NodeName = "IntoPartitions";
38+
39+
pub fn new(
40+
node_id: NodeID,
41+
logical_node_id: Option<NodeID>,
42+
stage_config: &StageConfig,
43+
num_partitions: usize,
44+
schema: SchemaRef,
45+
child: Arc<dyn DistributedPipelineNode>,
46+
) -> Self {
47+
let context = PipelineNodeContext::new(
48+
stage_config,
49+
node_id,
50+
Self::NODE_NAME,
51+
vec![child.node_id()],
52+
vec![child.name()],
53+
logical_node_id,
54+
);
55+
let config = PipelineNodeConfig::new(
56+
schema,
57+
stage_config.config.clone(),
58+
Arc::new(UnknownClusteringConfig::new(num_partitions).into()),
59+
);
60+
61+
Self {
62+
config,
63+
context,
64+
num_partitions,
65+
child,
66+
}
67+
}
68+
69+
pub fn arced(self) -> Arc<dyn DistributedPipelineNode> {
70+
Arc::new(self)
71+
}
72+
73+
fn multiline_display(&self) -> Vec<String> {
74+
vec![
75+
"IntoPartitions".to_string(),
76+
format!("Num partitions = {}", self.num_partitions),
77+
]
78+
}
79+
80+
async fn coalesce_tasks(
81+
self: Arc<Self>,
82+
tasks: Vec<SubmittableTask<SwordfishTask>>,
83+
scheduler_handle: &SchedulerHandle<SwordfishTask>,
84+
task_id_counter: &TaskIDCounter,
85+
result_tx: Sender<SubmittableTask<SwordfishTask>>,
86+
) -> DaftResult<()> {
87+
assert!(
88+
tasks.len() >= self.num_partitions,
89+
"Cannot coalesce from {} to {} partitions.",
90+
tasks.len(),
91+
self.num_partitions
92+
);
93+
94+
// Coalesce partitions evenly with remainder handling
95+
// Example: 10 inputs, 3 partitions = 4, 3, 3
96+
97+
// Base inputs per partition: 10 / 3 = 3 (all tasks get at least 3 inputs)
98+
let base_inputs_per_partition = tasks.len() / self.num_partitions;
99+
// Remainder: 10 % 3 = 1 (one task gets an extra input)
100+
let num_partitions_with_extra_input = tasks.len() % self.num_partitions;
101+
102+
let mut output_futures = OrderedJoinSet::new();
103+
104+
let mut task_iter = tasks.into_iter();
105+
for partition_idx in 0..self.num_partitions {
106+
let mut chunk_size = base_inputs_per_partition;
107+
// This partition needs an extra input, i.e. partition_idx == 0 and remainder == 1
108+
if partition_idx < num_partitions_with_extra_input {
109+
chunk_size += 1;
110+
}
111+
112+
// Submit all the tasks for this partition
113+
let submitted_tasks = task_iter
114+
.by_ref()
115+
.take(chunk_size)
116+
.map(|task| task.submit(scheduler_handle))
117+
.collect::<DaftResult<Vec<_>>>()?;
118+
output_futures.spawn(async move {
119+
let materialized_output = futures::future::try_join_all(submitted_tasks)
120+
.await?
121+
.into_iter()
122+
.flatten()
123+
.collect::<Vec<_>>();
124+
DaftResult::Ok(materialized_output)
125+
});
126+
}
127+
128+
while let Some(result) = output_futures.join_next().await {
129+
// Collect all the outputs from this task and coalesce them into a single task.
130+
let materialized_outputs = result??;
131+
let self_arc = self.clone();
132+
let task = make_new_task_from_materialized_outputs(
133+
TaskContext::from((&self.context, task_id_counter.next())),
134+
materialized_outputs,
135+
&(self_arc as Arc<dyn DistributedPipelineNode>),
136+
move |input| {
137+
LocalPhysicalPlan::into_partitions(input, 1, StatsState::NotMaterialized)
138+
},
139+
)?;
140+
if result_tx.send(task).await.is_err() {
141+
break;
142+
}
143+
}
144+
145+
Ok(())
146+
}
147+
148+
async fn split_tasks(
149+
self: Arc<Self>,
150+
tasks: Vec<SubmittableTask<SwordfishTask>>,
151+
scheduler_handle: &SchedulerHandle<SwordfishTask>,
152+
task_id_counter: &TaskIDCounter,
153+
result_tx: Sender<SubmittableTask<SwordfishTask>>,
154+
) -> DaftResult<()> {
155+
assert!(
156+
tasks.len() <= self.num_partitions,
157+
"Cannot split from {} to {} partitions.",
158+
tasks.len(),
159+
self.num_partitions
160+
);
161+
162+
// Split partitions evenly with remainder handling
163+
// Example: 3 inputs, 10 partitions = 4, 3, 3
164+
165+
// Base outputs per partition: 10 / 3 = 3 (all partitions will split into at least 3 outputs)
166+
let base_splits_per_partition = self.num_partitions / tasks.len();
167+
// Remainder: 10 % 3 = 1 (one partition will split into 4 outputs)
168+
let num_partitions_with_extra_output = self.num_partitions % tasks.len();
169+
170+
let mut output_futures = OrderedJoinSet::new();
171+
172+
for (input_partition_idx, task) in tasks.into_iter().enumerate() {
173+
let mut num_outputs = base_splits_per_partition;
174+
// This partition will split into one more output, i.e. input_partition_idx == 0 and remainder == 1
175+
if input_partition_idx < num_partitions_with_extra_output {
176+
num_outputs += 1;
177+
}
178+
let into_partitions_task = append_plan_to_existing_task(
179+
task,
180+
&(self.clone() as Arc<dyn DistributedPipelineNode>),
181+
&move |plan| {
182+
LocalPhysicalPlan::into_partitions(
183+
plan,
184+
num_outputs,
185+
StatsState::NotMaterialized,
186+
)
187+
},
188+
);
189+
let submitted_task = into_partitions_task.submit(scheduler_handle)?;
190+
output_futures.spawn(submitted_task);
191+
}
192+
193+
// Collect all the outputs and emit a new task for each output.
194+
while let Some(result) = output_futures.join_next().await {
195+
let materialized_outputs = result??;
196+
if let Some(output) = materialized_outputs {
197+
for output in output.split_into_materialized_outputs() {
198+
let self_arc = self.clone();
199+
let task = make_in_memory_task_from_materialized_outputs(
200+
TaskContext::from((&self.context, task_id_counter.next())),
201+
vec![output],
202+
&(self_arc as Arc<dyn DistributedPipelineNode>),
203+
)?;
204+
if result_tx.send(task).await.is_err() {
205+
break;
206+
}
207+
}
208+
}
209+
}
210+
211+
Ok(())
212+
}
213+
214+
async fn execute_into_partitions(
215+
self: Arc<Self>,
216+
input_stream: SubmittableTaskStream,
217+
task_id_counter: TaskIDCounter,
218+
result_tx: Sender<SubmittableTask<SwordfishTask>>,
219+
scheduler_handle: SchedulerHandle<SwordfishTask>,
220+
) -> DaftResult<()> {
221+
// Collect all input tasks without materializing to count them
222+
let input_tasks: Vec<SubmittableTask<SwordfishTask>> = input_stream.collect().await;
223+
let num_input_tasks = input_tasks.len();
224+
225+
match num_input_tasks.cmp(&self.num_partitions) {
226+
std::cmp::Ordering::Equal => {
227+
// Exact match - pass through as-is
228+
for task in input_tasks {
229+
let _ = result_tx.send(task).await;
230+
}
231+
}
232+
std::cmp::Ordering::Greater => {
233+
// Too many tasks - coalesce
234+
self.coalesce_tasks(input_tasks, &scheduler_handle, &task_id_counter, result_tx)
235+
.await?;
236+
}
237+
std::cmp::Ordering::Less => {
238+
// Too few tasks - split
239+
self.split_tasks(input_tasks, &scheduler_handle, &task_id_counter, result_tx)
240+
.await?;
241+
}
242+
};
243+
Ok(())
244+
}
245+
}
246+
247+
impl TreeDisplay for IntoPartitionsNode {
248+
fn display_as(&self, level: DisplayLevel) -> String {
249+
use std::fmt::Write;
250+
let mut display = String::new();
251+
match level {
252+
DisplayLevel::Compact => {
253+
writeln!(display, "{}", self.context.node_name).unwrap();
254+
}
255+
_ => {
256+
let multiline_display = self.multiline_display().join("\n");
257+
writeln!(display, "{}", multiline_display).unwrap();
258+
}
259+
}
260+
display
261+
}
262+
263+
fn get_children(&self) -> Vec<&dyn TreeDisplay> {
264+
vec![self.child.as_tree_display()]
265+
}
266+
267+
fn get_name(&self) -> String {
268+
self.context.node_name.to_string()
269+
}
270+
}
271+
272+
impl DistributedPipelineNode for IntoPartitionsNode {
273+
fn context(&self) -> &PipelineNodeContext {
274+
&self.context
275+
}
276+
277+
fn config(&self) -> &PipelineNodeConfig {
278+
&self.config
279+
}
280+
281+
fn children(&self) -> Vec<Arc<dyn DistributedPipelineNode>> {
282+
vec![self.child.clone()]
283+
}
284+
285+
fn produce_tasks(
286+
self: Arc<Self>,
287+
stage_context: &mut StageExecutionContext,
288+
) -> SubmittableTaskStream {
289+
let input_stream = self.child.clone().produce_tasks(stage_context);
290+
let (result_tx, result_rx) = create_channel(1);
291+
292+
stage_context.spawn(self.execute_into_partitions(
293+
input_stream,
294+
stage_context.task_id_counter(),
295+
result_tx,
296+
stage_context.scheduler_handle(),
297+
));
298+
299+
SubmittableTaskStream::from(result_rx)
300+
}
301+
302+
fn as_tree_display(&self) -> &dyn TreeDisplay {
303+
self
304+
}
305+
}

src/daft-distributed/src/pipeline_node/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mod filter;
4242
mod gather;
4343
mod in_memory_source;
4444
mod into_batches;
45+
mod into_partitions;
4546
mod join;
4647
mod limit;
4748
pub(crate) mod materialize;
@@ -116,6 +117,7 @@ impl MaterializedOutput {
116117
}
117118
}
118119

120+
#[derive(Clone)]
119121
pub(super) struct PipelineNodeConfig {
120122
pub schema: SchemaRef,
121123
pub execution_config: Arc<DaftExecutionConfig>,
@@ -136,6 +138,7 @@ impl PipelineNodeConfig {
136138
}
137139
}
138140

141+
#[derive(Clone)]
139142
pub(super) struct PipelineNodeContext {
140143
pub plan_id: PlanID,
141144
pub stage_id: StageID,

src/daft-distributed/src/pipeline_node/translate.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ use crate::{
1616
pipeline_node::{
1717
concat::ConcatNode, distinct::DistinctNode, explode::ExplodeNode, filter::FilterNode,
1818
gather::GatherNode, in_memory_source::InMemorySourceNode, into_batches::IntoBatchesNode,
19-
limit::LimitNode, monotonically_increasing_id::MonotonicallyIncreasingIdNode,
20-
project::ProjectNode, repartition::RepartitionNode, sample::SampleNode,
21-
scan_source::ScanSourceNode, sink::SinkNode, sort::SortNode, top_n::TopNNode, udf::UDFNode,
22-
unpivot::UnpivotNode, window::WindowNode, DistributedPipelineNode, NodeID,
19+
into_partitions::IntoPartitionsNode, limit::LimitNode,
20+
monotonically_increasing_id::MonotonicallyIncreasingIdNode, project::ProjectNode,
21+
repartition::RepartitionNode, sample::SampleNode, scan_source::ScanSourceNode,
22+
sink::SinkNode, sort::SortNode, top_n::TopNNode, udf::UDFNode, unpivot::UnpivotNode,
23+
window::WindowNode, DistributedPipelineNode, NodeID,
2324
},
2425
stage::StageConfig,
2526
};
@@ -300,26 +301,38 @@ impl TreeNodeVisitor for LogicalPlanToPipelineNodeTranslator {
300301
self.curr_node.pop().unwrap(), // Child
301302
)
302303
.arced(),
303-
LogicalPlan::Repartition(repartition) => {
304-
match &repartition.repartition_spec {
305-
RepartitionSpec::Hash(repart_spec) => {
306-
assert!(!repart_spec.by.is_empty());
307-
}
308-
RepartitionSpec::Random(_) => {}
309-
RepartitionSpec::IntoPartitions(_) => {
310-
todo!("FLOTILLA_MS3: Support other types of repartition");
311-
}
304+
LogicalPlan::Repartition(repartition) => match &repartition.repartition_spec {
305+
RepartitionSpec::Hash(repart_spec) => {
306+
assert!(!repart_spec.by.is_empty());
307+
RepartitionNode::new(
308+
self.get_next_pipeline_node_id(),
309+
logical_node_id,
310+
&self.stage_config,
311+
repartition.repartition_spec.clone(),
312+
node.schema(),
313+
self.curr_node.pop().unwrap(),
314+
)
315+
.arced()
312316
}
313-
RepartitionNode::new(
317+
RepartitionSpec::Random(_) => RepartitionNode::new(
314318
self.get_next_pipeline_node_id(),
315319
logical_node_id,
316320
&self.stage_config,
317321
repartition.repartition_spec.clone(),
318322
node.schema(),
319323
self.curr_node.pop().unwrap(),
320324
)
321-
.arced()
322-
}
325+
.arced(),
326+
RepartitionSpec::IntoPartitions(into_partitions_spec) => IntoPartitionsNode::new(
327+
self.get_next_pipeline_node_id(),
328+
logical_node_id,
329+
&self.stage_config,
330+
into_partitions_spec.num_partitions,
331+
node.schema(),
332+
self.curr_node.pop().unwrap(),
333+
)
334+
.arced(),
335+
},
323336
LogicalPlan::Aggregate(aggregate) => {
324337
let input_schema = aggregate.input.schema();
325338
let group_by = BoundExpr::bind_all(&aggregate.groupby, &input_schema)?;

0 commit comments

Comments
 (0)