Skip to content

Commit 2c34aef

Browse files
committed
util: add JoinDeque structure
1 parent 9f59c69 commit 2c34aef

File tree

2 files changed

+369
-0
lines changed

2 files changed

+369
-0
lines changed

tokio-util/src/task/join_deque.rs

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
use super::AbortOnDropHandle;
2+
use std::{
3+
collections::VecDeque,
4+
future::Future,
5+
task::{Context, Poll},
6+
pin::Pin,
7+
};
8+
use tokio::{
9+
runtime::Handle,
10+
task::{AbortHandle, Id, JoinError, JoinHandle, LocalSet},
11+
};
12+
13+
/// A collection of tasks spawned on a Tokio runtime.
14+
///
15+
/// A `JoinDeque` can be used to await the completion of the tasks in FIFO
16+
/// order. That is, if tasks are spawned in the order A, B, C, then
17+
/// awaiting the next completed task will always return A first, then B,
18+
/// then C, regardless of the order in which the tasks actually complete.
19+
///
20+
/// All of the tasks must have the same return type `T`.
21+
///
22+
/// When the `JoinDeque` is dropped, all tasks in the `JoinDeque` are
23+
/// immediately aborted.
24+
#[derive(Debug)]
25+
pub struct JoinDeque<T>(VecDeque<AbortOnDropHandle<T>>);
26+
27+
impl<T> JoinDeque<T> {
28+
/// Create a new empty `JoinDeque`.
29+
pub const fn new() -> Self {
30+
Self(VecDeque::new())
31+
}
32+
33+
/// Creates an empty `JoinDeque` with space for at least `capacity` tasks.
34+
pub fn with_capacity(capacity: usize) -> Self {
35+
Self(VecDeque::with_capacity(capacity))
36+
}
37+
38+
/// Returns the number of tasks currently in the `JoinDeque`.
39+
///
40+
/// This includes both tasks that are currently running and tasks that have
41+
/// completed but not yet been removed from the queue because outputting of
42+
/// them waits for FIFO order.
43+
pub fn len(&self) -> usize {
44+
self.0.len()
45+
}
46+
47+
/// Returns whether the `JoinDeque` is empty.
48+
pub fn is_empty(&self) -> bool {
49+
self.0.is_empty()
50+
}
51+
52+
/// Spawn the provided task on the `JoinDeque`, returning an [`AbortHandle`]
53+
/// that can be used to remotely cancel the task.
54+
///
55+
/// The provided future will start running in the background immediately
56+
/// when this method is called, even if you don't await anything on this
57+
/// `JoinDeque`.
58+
///
59+
/// # Panics
60+
///
61+
/// This method panics if called outside of a Tokio runtime.
62+
///
63+
/// [`AbortHandle`]: tokio::task::AbortHandle
64+
#[track_caller]
65+
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
66+
where
67+
F: Future<Output = T> + Send + 'static,
68+
T: Send + 'static,
69+
{
70+
self.insert(tokio::spawn(task))
71+
}
72+
73+
/// Spawn the provided task on the provided runtime and store it in this
74+
/// `JoinDeque` returning an [`AbortHandle`] that can be used to remotely
75+
/// cancel the task.
76+
///
77+
/// The provided future will start running in the background immediately
78+
/// when this method is called, even if you don't await anything on this
79+
/// `JoinDeque`.
80+
///
81+
/// [`AbortHandle`]: tokio::task::AbortHandle
82+
#[track_caller]
83+
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
84+
where
85+
F: Future<Output = T> + Send + 'static,
86+
T: Send + 'static,
87+
{
88+
self.insert(handle.spawn(task))
89+
}
90+
91+
/// Spawn the provided task on the current [`LocalSet`] and store it in this
92+
/// `JoinDeque`, returning an [`AbortHandle`] that can be used to remotely
93+
/// cancel the task.
94+
///
95+
/// The provided future will start running in the background immediately
96+
/// when this method is called, even if you don't await anything on this
97+
/// `JoinDeque`.
98+
///
99+
/// # Panics
100+
///
101+
/// This method panics if it is called outside of a `LocalSet`.
102+
///
103+
/// [`LocalSet`]: tokio::task::LocalSet
104+
/// [`AbortHandle`]: tokio::task::AbortHandle
105+
#[track_caller]
106+
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
107+
where
108+
F: Future<Output = T> + 'static,
109+
T: 'static,
110+
{
111+
self.insert(tokio::task::spawn_local(task))
112+
}
113+
114+
/// Spawn the provided task on the provided [`LocalSet`] and store it in
115+
/// this `JoinDeque`, returning an [`AbortHandle`] that can be used to
116+
/// remotely cancel the task.
117+
///
118+
/// Unlike the [`spawn_local`] method, this method may be used to spawn local
119+
/// tasks on a `LocalSet` that is _not_ currently running. The provided
120+
/// future will start running whenever the `LocalSet` is next started.
121+
///
122+
/// [`LocalSet`]: tokio::task::LocalSet
123+
/// [`AbortHandle`]: tokio::task::AbortHandle
124+
/// [`spawn_local`]: Self::spawn_local
125+
#[track_caller]
126+
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
127+
where
128+
F: Future<Output = T> + 'static,
129+
T: 'static,
130+
{
131+
self.insert(local_set.spawn_local(task))
132+
}
133+
134+
/// Spawn the blocking code on the blocking threadpool and store
135+
/// it in this `JoinDeque`, returning an [`AbortHandle`] that can be
136+
/// used to remotely cancel the task.
137+
///
138+
/// # Panics
139+
///
140+
/// This method panics if called outside of a Tokio runtime.
141+
///
142+
/// [`AbortHandle`]: tokio::task::AbortHandle
143+
#[track_caller]
144+
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
145+
where
146+
F: FnOnce() -> T + Send + 'static,
147+
T: Send + 'static,
148+
{
149+
self.insert(tokio::task::spawn_blocking(f))
150+
}
151+
152+
/// Spawn the blocking code on the blocking threadpool of the
153+
/// provided runtime and store it in this `JoinDeque`, returning an
154+
/// [`AbortHandle`] that can be used to remotely cancel the task.
155+
///
156+
/// [`AbortHandle`]: tokio::task::AbortHandle
157+
#[track_caller]
158+
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
159+
where
160+
F: FnOnce() -> T + Send + 'static,
161+
T: Send + 'static,
162+
{
163+
self.insert(handle.spawn_blocking(f))
164+
}
165+
166+
fn insert(&mut self, jh: JoinHandle<T>) -> AbortHandle {
167+
let join_handle = AbortOnDropHandle::new(jh);
168+
let abort_handle = join_handle.abort_handle();
169+
self.0.push_back(join_handle);
170+
abort_handle
171+
}
172+
173+
/// Waits until the next task in FIFO order completes and returns its output.
174+
///
175+
/// Returns `None` if the queue is empty.
176+
///
177+
/// # Cancel Safety
178+
///
179+
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
180+
/// statement and some other branch completes first, it is guaranteed that no tasks were
181+
/// removed from this `JoinDeque`.
182+
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
183+
std::future::poll_fn(|cx| self.poll_join_next(cx)).await
184+
}
185+
186+
/// Waits until the next task in FIFO order completes and returns its output,
187+
/// along with the [task ID] of the completed task.
188+
///
189+
/// Returns `None` if the queue is empty.
190+
///
191+
/// When this method returns an error, then the id of the task that failed can be accessed
192+
/// using the [`JoinError::id`] method.
193+
///
194+
/// # Cancel Safety
195+
///
196+
/// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
197+
/// statement and some other branch completes first, it is guaranteed that no tasks were
198+
/// removed from this `JoinDeque`.
199+
///
200+
/// [task ID]: tokio::task::Id
201+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
202+
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
203+
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
204+
}
205+
206+
/// Aborts all tasks and waits for them to finish shutting down.
207+
///
208+
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
209+
/// a loop until it returns `None`.
210+
///
211+
/// This method ignores any panics in the tasks shutting down. When this call returns, the
212+
/// `JoinDeque` will be empty.
213+
///
214+
/// [`abort_all`]: fn@Self::abort_all
215+
/// [`join_next`]: fn@Self::join_next
216+
pub async fn shutdown(&mut self) {
217+
self.abort_all();
218+
while self.join_next().await.is_some() {}
219+
}
220+
221+
/// Awaits the completion of all tasks in this `JoinDeque`, returning a vector of their results.
222+
///
223+
/// The results will be stored in the order they were spawned, not the order they completed.
224+
/// This is a convenience method that is equivalent to calling [`join_next`] in
225+
/// a loop. If any tasks on the `JoinDeque` fail with an [`JoinError`], then this call
226+
/// to `join_all` will panic and all remaining tasks on the `JoinDeque` are
227+
/// cancelled. To handle errors in any other way, manually call [`join_next`]
228+
/// in a loop.
229+
///
230+
/// [`join_next`]: fn@Self::join_next
231+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
232+
pub async fn join_all(mut self) -> Vec<T> {
233+
let mut output = Vec::with_capacity(self.len());
234+
235+
while let Some(res) = self.join_next().await {
236+
match res {
237+
Ok(t) => output.push(t),
238+
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
239+
Err(err) => panic!("{err}"),
240+
}
241+
}
242+
output
243+
}
244+
245+
/// Aborts all tasks on this `JoinDeque`.
246+
///
247+
/// This does not remove the tasks from the `JoinDeque`. To wait for the tasks to complete
248+
/// cancellation, you should call `join_next` in a loop until the `JoinDeque` is empty.
249+
pub fn abort_all(&mut self) {
250+
self.0.iter().for_each(|jh| jh.abort());
251+
}
252+
253+
/// Removes all tasks from this `JoinDeque` without aborting them.
254+
///
255+
/// The tasks removed by this call will continue to run in the background even if the `JoinDeque`
256+
/// is dropped.
257+
pub fn detach_all(&mut self) {
258+
self.0.drain(..).for_each(|jh| drop(jh.detach()));
259+
}
260+
261+
/// Polls for the next task in `JoinDeque` to complete.
262+
///
263+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
264+
///
265+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
266+
/// to receive a wakeup when a task in the `JoinDeque` completes. Note that on multiple calls to
267+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
268+
/// scheduled to receive a wakeup.
269+
///
270+
/// # Returns
271+
///
272+
/// This function returns:
273+
///
274+
/// * `Poll::Pending` if the `JoinDeque` is not empty but there is no task whose output is
275+
/// available right now.
276+
/// * `Poll::Ready(Some(Ok(value)))` if the next task in this `JoinDeque` has completed.
277+
/// The `value` is the return value that task.
278+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinDeque` has panicked or been
279+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
280+
/// * `Poll::Ready(None)` if the `JoinDeque` is empty.
281+
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
282+
let jh = match self.0.front_mut() {
283+
None => return Poll::Ready(None),
284+
Some(jh) => jh,
285+
};
286+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
287+
drop(self.0.pop_front().unwrap().detach());
288+
Poll::Ready(Some(res))
289+
} else {
290+
// A JoinHandle generally won't emit a wakeup without being ready unless
291+
// the coop limit has been reached. We yield to the executor in this
292+
// case.
293+
cx.waker().wake_by_ref();
294+
Poll::Pending
295+
}
296+
}
297+
298+
/// Polls for the next task in `JoinDeque` to complete.
299+
///
300+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
301+
///
302+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
303+
/// to receive a wakeup when a task in the `JoinDeque` completes. Note that on multiple calls to
304+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
305+
/// scheduled to receive a wakeup.
306+
///
307+
/// # Returns
308+
///
309+
/// This function returns:
310+
///
311+
/// * `Poll::Pending` if the `JoinDeque` is not empty but there is no task whose output is
312+
/// available right now.
313+
/// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this `JoinDeque` has completed.
314+
/// The `value` is the return value that task, and `id` is its [task ID].
315+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinDeque` has panicked or been
316+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
317+
/// * `Poll::Ready(None)` if the `JoinDeque` is empty.
318+
///
319+
/// [task ID]: tokio::task::Id
320+
pub fn poll_join_next_with_id(
321+
&mut self,
322+
cx: &mut Context<'_>,
323+
) -> Poll<Option<Result<(Id, T), JoinError>>> {
324+
let jh = match self.0.front_mut() {
325+
None => return Poll::Ready(None),
326+
Some(jh) => jh,
327+
};
328+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
329+
let jh = self.0.pop_front().unwrap().detach();
330+
let id = jh.id();
331+
drop(jh);
332+
// If the task succeeded, add the task ID to the output. Otherwise, the
333+
// `JoinError` will already have the task's ID.
334+
Poll::Ready(Some(res.map(|output| (id, output))))
335+
} else {
336+
// A JoinHandle generally won't emit a wakeup without being ready unless
337+
// the coop limit has been reached. We yield to the executor in this
338+
// case.
339+
cx.waker().wake_by_ref();
340+
Poll::Pending
341+
}
342+
}
343+
}
344+
345+
impl<T> Default for JoinDeque<T> {
346+
fn default() -> Self {
347+
Self::new()
348+
}
349+
}
350+
351+
/// Collect an iterator of futures into a [`JoinDeque`].
352+
///
353+
/// This is equivalent to calling [`JoinDeque::spawn`] on each element of the iterator.
354+
impl<T, F> std::iter::FromIterator<F> for JoinDeque<T>
355+
where
356+
F: Future<Output = T> + Send + 'static,
357+
T: Send + 'static,
358+
{
359+
fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
360+
let mut set = Self::new();
361+
iter.into_iter().for_each(|task| {
362+
set.spawn(task);
363+
});
364+
set
365+
}
366+
}

tokio-util/src/task/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ cfg_rt! {
1313

1414
mod abort_on_drop;
1515
pub use abort_on_drop::AbortOnDropHandle;
16+
17+
mod join_deque;
18+
pub use join_deque::JoinDeque;
1619
}
1720

1821
#[cfg(feature = "join-map")]

0 commit comments

Comments
 (0)