tokio_util/sync/
cancellation_token.rs

1//! An asynchronously awaitable `CancellationToken`.
2//! The token allows to signal a cancellation request to one or more tasks.
3pub(crate) mod guard;
4mod tree_node;
5
6use crate::loom::sync::Arc;
7use crate::util::MaybeDangling;
8use core::future::Future;
9use core::pin::Pin;
10use core::task::{Context, Poll};
11
12use guard::DropGuard;
13use pin_project_lite::pin_project;
14
15/// A token which can be used to signal a cancellation request to one or more
16/// tasks.
17///
18/// Tasks can call [`CancellationToken::cancelled()`] in order to
19/// obtain a Future which will be resolved when cancellation is requested.
20///
21/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
22///
23/// # Examples
24///
25/// ```no_run
26/// use tokio::select;
27/// use tokio_util::sync::CancellationToken;
28///
29/// #[tokio::main]
30/// async fn main() {
31///     let token = CancellationToken::new();
32///     let cloned_token = token.clone();
33///
34///     let join_handle = tokio::spawn(async move {
35///         // Wait for either cancellation or a very long time
36///         select! {
37///             _ = cloned_token.cancelled() => {
38///                 // The token was cancelled
39///                 5
40///             }
41///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
42///                 99
43///             }
44///         }
45///     });
46///
47///     tokio::spawn(async move {
48///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
49///         token.cancel();
50///     });
51///
52///     assert_eq!(5, join_handle.await.unwrap());
53/// }
54/// ```
55pub struct CancellationToken {
56    inner: Arc<tree_node::TreeNode>,
57}
58
59impl std::panic::UnwindSafe for CancellationToken {}
60impl std::panic::RefUnwindSafe for CancellationToken {}
61
62pin_project! {
63    /// A Future that is resolved once the corresponding [`CancellationToken`]
64    /// is cancelled.
65    #[must_use = "futures do nothing unless polled"]
66    pub struct WaitForCancellationFuture<'a> {
67        cancellation_token: &'a CancellationToken,
68        #[pin]
69        future: tokio::sync::futures::Notified<'a>,
70    }
71}
72
73pin_project! {
74    /// A Future that is resolved once the corresponding [`CancellationToken`]
75    /// is cancelled.
76    ///
77    /// This is the counterpart to [`WaitForCancellationFuture`] that takes
78    /// [`CancellationToken`] by value instead of using a reference.
79    #[must_use = "futures do nothing unless polled"]
80    pub struct WaitForCancellationFutureOwned {
81        // This field internally has a reference to the cancellation token, but camouflages
82        // the relationship with `'static`. To avoid Undefined Behavior, we must ensure
83        // that the reference is only used while the cancellation token is still alive. To
84        // do that, we ensure that the future is the first field, so that it is dropped
85        // before the cancellation token.
86        //
87        // We use `MaybeDanglingFuture` here because without it, the compiler could assert
88        // the reference inside `future` to be valid even after the destructor of that
89        // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
90        // as an argument to a function, the reference can be asserted to be valid for the
91        // rest of that function.) To avoid that, we use `MaybeDangling` which tells the
92        // compiler that the reference stored inside it might not be valid.
93        //
94        // See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
95        // for more info.
96        #[pin]
97        future: MaybeDangling<tokio::sync::futures::Notified<'static>>,
98        cancellation_token: CancellationToken,
99    }
100}
101
102// ===== impl CancellationToken =====
103
104impl core::fmt::Debug for CancellationToken {
105    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
106        f.debug_struct("CancellationToken")
107            .field("is_cancelled", &self.is_cancelled())
108            .finish()
109    }
110}
111
112impl Clone for CancellationToken {
113    /// Creates a clone of the `CancellationToken` which will get cancelled
114    /// whenever the current token gets cancelled, and vice versa.
115    fn clone(&self) -> Self {
116        tree_node::increase_handle_refcount(&self.inner);
117        CancellationToken {
118            inner: self.inner.clone(),
119        }
120    }
121}
122
123impl Drop for CancellationToken {
124    fn drop(&mut self) {
125        tree_node::decrease_handle_refcount(&self.inner);
126    }
127}
128
129impl Default for CancellationToken {
130    fn default() -> CancellationToken {
131        CancellationToken::new()
132    }
133}
134
135impl CancellationToken {
136    /// Creates a new `CancellationToken` in the non-cancelled state.
137    pub fn new() -> CancellationToken {
138        CancellationToken {
139            inner: Arc::new(tree_node::TreeNode::new()),
140        }
141    }
142
143    /// Creates a `CancellationToken` which will get cancelled whenever the
144    /// current token gets cancelled. Unlike a cloned `CancellationToken`,
145    /// cancelling a child token does not cancel the parent token.
146    ///
147    /// If the current token is already cancelled, the child token will get
148    /// returned in cancelled state.
149    ///
150    /// # Examples
151    ///
152    /// ```no_run
153    /// use tokio::select;
154    /// use tokio_util::sync::CancellationToken;
155    ///
156    /// #[tokio::main]
157    /// async fn main() {
158    ///     let token = CancellationToken::new();
159    ///     let child_token = token.child_token();
160    ///
161    ///     let join_handle = tokio::spawn(async move {
162    ///         // Wait for either cancellation or a very long time
163    ///         select! {
164    ///             _ = child_token.cancelled() => {
165    ///                 // The token was cancelled
166    ///                 5
167    ///             }
168    ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
169    ///                 99
170    ///             }
171    ///         }
172    ///     });
173    ///
174    ///     tokio::spawn(async move {
175    ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
176    ///         token.cancel();
177    ///     });
178    ///
179    ///     assert_eq!(5, join_handle.await.unwrap());
180    /// }
181    /// ```
182    pub fn child_token(&self) -> CancellationToken {
183        CancellationToken {
184            inner: tree_node::child_node(&self.inner),
185        }
186    }
187
188    /// Cancel the [`CancellationToken`] and all child tokens which had been
189    /// derived from it.
190    ///
191    /// This will wake up all tasks which are waiting for cancellation.
192    ///
193    /// Be aware that cancellation is not an atomic operation. It is possible
194    /// for another thread running in parallel with a call to `cancel` to first
195    /// receive `true` from `is_cancelled` on one child node, and then receive
196    /// `false` from `is_cancelled` on another child node. However, once the
197    /// call to `cancel` returns, all child nodes have been fully cancelled.
198    pub fn cancel(&self) {
199        tree_node::cancel(&self.inner);
200    }
201
202    /// Returns `true` if the `CancellationToken` is cancelled.
203    pub fn is_cancelled(&self) -> bool {
204        tree_node::is_cancelled(&self.inner)
205    }
206
207    /// Returns a `Future` that gets fulfilled when cancellation is requested.
208    ///
209    /// The future will complete immediately if the token is already cancelled
210    /// when this method is called.
211    ///
212    /// # Cancel safety
213    ///
214    /// This method is cancel safe.
215    pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
216        WaitForCancellationFuture {
217            cancellation_token: self,
218            future: self.inner.notified(),
219        }
220    }
221
222    /// Returns a `Future` that gets fulfilled when cancellation is requested.
223    ///
224    /// The future will complete immediately if the token is already cancelled
225    /// when this method is called.
226    ///
227    /// The function takes self by value and returns a future that owns the
228    /// token.
229    ///
230    /// # Cancel safety
231    ///
232    /// This method is cancel safe.
233    pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
234        WaitForCancellationFutureOwned::new(self)
235    }
236
237    /// Creates a `DropGuard` for this token.
238    ///
239    /// Returned guard will cancel this token (and all its children) on drop
240    /// unless disarmed.
241    pub fn drop_guard(self) -> DropGuard {
242        DropGuard { inner: Some(self) }
243    }
244
245    /// Runs a future to completion and returns its result wrapped inside of an `Option`
246    /// unless the `CancellationToken` is cancelled. In that case the function returns
247    /// `None` and the future gets dropped.
248    ///
249    /// # Cancel safety
250    ///
251    /// This method is only cancel safe if `fut` is cancel safe.
252    pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
253    where
254        F: Future,
255    {
256        pin_project! {
257            /// A Future that is resolved once the corresponding [`CancellationToken`]
258            /// is cancelled or a given Future gets resolved. It is biased towards the
259            /// Future completion.
260            #[must_use = "futures do nothing unless polled"]
261            struct RunUntilCancelledFuture<'a, F: Future> {
262                #[pin]
263                cancellation: WaitForCancellationFuture<'a>,
264                #[pin]
265                future: F,
266            }
267        }
268
269        impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
270            type Output = Option<F::Output>;
271
272            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273                let this = self.project();
274                if let Poll::Ready(res) = this.future.poll(cx) {
275                    Poll::Ready(Some(res))
276                } else if this.cancellation.poll(cx).is_ready() {
277                    Poll::Ready(None)
278                } else {
279                    Poll::Pending
280                }
281            }
282        }
283
284        RunUntilCancelledFuture {
285            cancellation: self.cancelled(),
286            future: fut,
287        }
288        .await
289    }
290}
291
292// ===== impl WaitForCancellationFuture =====
293
294impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
295    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
296        f.debug_struct("WaitForCancellationFuture").finish()
297    }
298}
299
300impl<'a> Future for WaitForCancellationFuture<'a> {
301    type Output = ();
302
303    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
304        let mut this = self.project();
305        loop {
306            if this.cancellation_token.is_cancelled() {
307                return Poll::Ready(());
308            }
309
310            // No wakeups can be lost here because there is always a call to
311            // `is_cancelled` between the creation of the future and the call to
312            // `poll`, and the code that sets the cancelled flag does so before
313            // waking the `Notified`.
314            if this.future.as_mut().poll(cx).is_pending() {
315                return Poll::Pending;
316            }
317
318            this.future.set(this.cancellation_token.inner.notified());
319        }
320    }
321}
322
323// ===== impl WaitForCancellationFutureOwned =====
324
325impl core::fmt::Debug for WaitForCancellationFutureOwned {
326    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
327        f.debug_struct("WaitForCancellationFutureOwned").finish()
328    }
329}
330
331impl WaitForCancellationFutureOwned {
332    fn new(cancellation_token: CancellationToken) -> Self {
333        WaitForCancellationFutureOwned {
334            // cancellation_token holds a heap allocation and is guaranteed to have a
335            // stable deref, thus it would be ok to move the cancellation_token while
336            // the future holds a reference to it.
337            //
338            // # Safety
339            //
340            // cancellation_token is dropped after future due to the field ordering.
341            future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
342            cancellation_token,
343        }
344    }
345
346    /// # Safety
347    /// The returned future must be destroyed before the cancellation token is
348    /// destroyed.
349    unsafe fn new_future(
350        cancellation_token: &CancellationToken,
351    ) -> tokio::sync::futures::Notified<'static> {
352        let inner_ptr = Arc::as_ptr(&cancellation_token.inner);
353        // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains
354        // valid until the strong count of the Arc drops to zero, and the caller
355        // guarantees that they will drop the future before that happens.
356        (*inner_ptr).notified()
357    }
358}
359
360impl Future for WaitForCancellationFutureOwned {
361    type Output = ();
362
363    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
364        let mut this = self.project();
365
366        loop {
367            if this.cancellation_token.is_cancelled() {
368                return Poll::Ready(());
369            }
370
371            // No wakeups can be lost here because there is always a call to
372            // `is_cancelled` between the creation of the future and the call to
373            // `poll`, and the code that sets the cancelled flag does so before
374            // waking the `Notified`.
375            if this.future.as_mut().poll(cx).is_pending() {
376                return Poll::Pending;
377            }
378
379            // # Safety
380            //
381            // cancellation_token is dropped after future due to the field ordering.
382            this.future.set(MaybeDangling::new(unsafe {
383                Self::new_future(this.cancellation_token)
384            }));
385        }
386    }
387}