Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 366 additions & 0 deletions tokio-util/src/task/join_deque.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
use super::AbortOnDropHandle;
use std::{
collections::VecDeque,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
runtime::Handle,
task::{AbortHandle, Id, JoinError, JoinHandle, LocalSet},
};

/// A collection of tasks spawned on a Tokio runtime.
///
/// A `JoinDeque` can be used to await the completion of the tasks in FIFO
/// order. That is, if tasks are spawned in the order A, B, C, then
/// awaiting the next completed task will always return A first, then B,
/// then C, regardless of the order in which the tasks actually complete.
///
/// All of the tasks must have the same return type `T`.
///
/// When the `JoinDeque` is dropped, all tasks in the `JoinDeque` are
/// immediately aborted.
#[derive(Debug)]
pub struct JoinDeque<T>(VecDeque<AbortOnDropHandle<T>>);

impl<T> JoinDeque<T> {
/// Create a new empty `JoinDeque`.
pub const fn new() -> Self {
Self(VecDeque::new())
}

/// Creates an empty `JoinDeque` with space for at least `capacity` tasks.
pub fn with_capacity(capacity: usize) -> Self {
Self(VecDeque::with_capacity(capacity))
}

/// Returns the number of tasks currently in the `JoinDeque`.
///
/// This includes both tasks that are currently running and tasks that have
/// completed but not yet been removed from the queue because outputting of
/// them waits for FIFO order.
pub fn len(&self) -> usize {
self.0.len()
}

/// Returns whether the `JoinDeque` is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

/// Spawn the provided task on the `JoinDeque`, returning an [`AbortHandle`]
/// that can be used to remotely cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// `JoinDeque`.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.insert(tokio::spawn(task))
}

/// Spawn the provided task on the provided runtime and store it in this
/// `JoinDeque` returning an [`AbortHandle`] that can be used to remotely
/// cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// `JoinDeque`.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.insert(handle.spawn(task))
}

/// Spawn the provided task on the current [`LocalSet`] and store it in this
/// `JoinDeque`, returning an [`AbortHandle`] that can be used to remotely
/// cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// `JoinDeque`.
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet`.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + 'static,
T: 'static,
{
self.insert(tokio::task::spawn_local(task))
}

/// Spawn the provided task on the provided [`LocalSet`] and store it in
/// this `JoinDeque`, returning an [`AbortHandle`] that can be used to
/// remotely cancel the task.
///
/// Unlike the [`spawn_local`] method, this method may be used to spawn local
/// tasks on a `LocalSet` that is _not_ currently running. The provided
/// future will start running whenever the `LocalSet` is next started.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`AbortHandle`]: tokio::task::AbortHandle
/// [`spawn_local`]: Self::spawn_local
#[track_caller]
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T> + 'static,
T: 'static,
{
self.insert(local_set.spawn_local(task))
}
Comment on lines +114 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since LocalSet might be deprecated in the future, we can exclude this method for now, see #6741.


/// Spawn the blocking code on the blocking threadpool and store
/// it in this `JoinDeque`, returning an [`AbortHandle`] that can be
/// used to remotely cancel the task.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.insert(tokio::task::spawn_blocking(f))
}

/// Spawn the blocking code on the blocking threadpool of the
/// provided runtime and store it in this `JoinDeque`, returning an
/// [`AbortHandle`] that can be used to remotely cancel the task.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.insert(handle.spawn_blocking(f))
}

fn insert(&mut self, jh: JoinHandle<T>) -> AbortHandle {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even this is an internal interface, I still prefer the push_back or similar name.

let join_handle = AbortOnDropHandle::new(jh);
let abort_handle = join_handle.abort_handle();
self.0.push_back(join_handle);
abort_handle
}

/// Waits until the next task in FIFO order completes and returns its output.
///
/// Returns `None` if the queue is empty.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this `JoinDeque`.
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
std::future::poll_fn(|cx| self.poll_join_next(cx)).await
}

/// Waits until the next task in FIFO order completes and returns its output,
/// along with the [task ID] of the completed task.
///
/// Returns `None` if the queue is empty.
///
/// When this method returns an error, then the id of the task that failed can be accessed
/// using the [`JoinError::id`] method.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this `JoinDeque`.
///
/// [task ID]: tokio::task::Id
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
}

/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
/// a loop until it returns `None`.
///
/// This method ignores any panics in the tasks shutting down. When this call returns, the
/// `JoinDeque` will be empty.
///
/// [`abort_all`]: fn@Self::abort_all
/// [`join_next`]: fn@Self::join_next
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_next().await.is_some() {}
}

/// Awaits the completion of all tasks in this `JoinDeque`, returning a vector of their results.
///
/// The results will be stored in the order they were spawned, not the order they completed.
/// This is a convenience method that is equivalent to calling [`join_next`] in
/// a loop. If any tasks on the `JoinDeque` fail with an [`JoinError`], then this call
/// to `join_all` will panic and all remaining tasks on the `JoinDeque` are
/// cancelled. To handle errors in any other way, manually call [`join_next`]
/// in a loop.
///
/// [`join_next`]: fn@Self::join_next
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
pub async fn join_all(mut self) -> Vec<T> {
let mut output = Vec::with_capacity(self.len());

while let Some(res) = self.join_next().await {
match res {
Ok(t) => output.push(t),
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
Err(err) => panic!("{err}"),
}
}
output
}

/// Aborts all tasks on this `JoinDeque`.
///
/// This does not remove the tasks from the `JoinDeque`. To wait for the tasks to complete
/// cancellation, you should call `join_next` in a loop until the `JoinDeque` is empty.
pub fn abort_all(&mut self) {
self.0.iter().for_each(|jh| jh.abort());
}

/// Removes all tasks from this `JoinDeque` without aborting them.
///
/// The tasks removed by this call will continue to run in the background even if the `JoinDeque`
/// is dropped.
pub fn detach_all(&mut self) {
self.0.drain(..).for_each(|jh| drop(jh.detach()));
}

/// Polls for the next task in `JoinDeque` to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the `JoinDeque` completes. Note that on multiple calls to
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the `JoinDeque` is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Some(Ok(value)))` if the next task in this `JoinDeque` has completed.
/// The `value` is the return value that task.
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinDeque` has panicked or been
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
/// * `Poll::Ready(None)` if the `JoinDeque` is empty.
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
let jh = match self.0.front_mut() {
None => return Poll::Ready(None),
Some(jh) => jh,
};
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
drop(self.0.pop_front().unwrap().detach());
Poll::Ready(Some(res))
} else {
// A JoinHandle generally won't emit a wakeup without being ready unless
// the coop limit has been reached. We yield to the executor in this
// case.
cx.waker().wake_by_ref();
Poll::Pending
}
}

/// Polls for the next task in `JoinDeque` to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the `JoinDeque` completes. Note that on multiple calls to
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the `JoinDeque` is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this `JoinDeque` has completed.
/// The `value` is the return value that task, and `id` is its [task ID].
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinDeque` has panicked or been
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
/// * `Poll::Ready(None)` if the `JoinDeque` is empty.
///
/// [task ID]: tokio::task::Id
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
let jh = match self.0.front_mut() {
None => return Poll::Ready(None),
Some(jh) => jh,
};
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
let jh = self.0.pop_front().unwrap().detach();
let id = jh.id();
drop(jh);
// If the task succeeded, add the task ID to the output. Otherwise, the
// `JoinError` will already have the task's ID.
Poll::Ready(Some(res.map(|output| (id, output))))
} else {
// A JoinHandle generally won't emit a wakeup without being ready unless
// the coop limit has been reached. We yield to the executor in this
// case.
cx.waker().wake_by_ref();
Poll::Pending
}
}
}

impl<T> Default for JoinDeque<T> {
fn default() -> Self {
Self::new()
}
}

/// Collect an iterator of futures into a [`JoinDeque`].
///
/// This is equivalent to calling [`JoinDeque::spawn`] on each element of the iterator.
impl<T, F> std::iter::FromIterator<F> for JoinDeque<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
let mut set = Self::new();
iter.into_iter().for_each(|task| {
set.spawn(task);
});
set
}
}
3 changes: 3 additions & 0 deletions tokio-util/src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ cfg_rt! {

mod abort_on_drop;
pub use abort_on_drop::AbortOnDropHandle;

mod join_deque;
pub use join_deque::JoinDeque;
}

#[cfg(feature = "join-map")]
Expand Down
Loading