tokio_util/task/
join_map.rs

1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a  set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on   their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45///     let mut map = JoinMap::new();
46///
47///     for i in 0..10 {
48///         // Spawn a task on the `JoinMap` with `i` as its key.
49///         map.spawn(i, async move { /* ... */ });
50///     }
51///
52///     let mut seen = [false; 10];
53///
54///     // When a task completes, `join_next` returns the task's key along
55///     // with its output.
56///     while let Some((key, res)) = map.join_next().await {
57///         seen[key] = true;
58///         assert!(res.is_ok(), "task {} completed successfully!", key);
59///     }
60///
61///     for i in 0..10 {
62///         assert!(seen[i]);
63///     }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74///     let mut map = JoinMap::new();
75///
76///     map.spawn("hello world", async move { /* ... */ });
77///     map.spawn("goodbye world", async move { /* ... */});
78///
79///     // Look up the "goodbye world" task in the map and abort it.
80///     let aborted = map.abort("goodbye world");
81///
82///     // `JoinMap::abort` returns `true` if a task existed for the
83///     // provided key.
84///     assert!(aborted);
85///
86///     while let Some((key, res)) = map.join_next().await {
87///         if key == "goodbye world" {
88///             // The aborted task should complete with a cancelled `JoinError`.
89///             assert!(res.unwrap_err().is_cancelled());
90///         } else {
91///             // Other tasks should complete normally.
92///             assert!(res.is_ok());
93///         }
94///     }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105    /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106    /// indexed by their keys and task IDs.
107    ///
108    /// The [`Key`] type contains both the task's `K`-typed key provided when
109    /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110    /// hash collisions when looking up tasks based on their pre-computed hash
111    /// (as stored in the `hashes_by_task` map).
112    tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114    /// A map from task IDs to the hash of the key associated with that task.
115    ///
116    /// This map is used to perform reverse lookups of tasks in the
117    /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118    /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119    /// of that task's key, and then remove it from the `tasks_by_key` map using
120    /// the raw hash code, resolving collisions by comparing task IDs.
121    hashes_by_task: HashMap<Id, u64, S>,
122
123    /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124    /// `JoinMap`.
125    tasks: JoinSet<V>,
126}
127
128/// A [`JoinMap`] key.
129///
130/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131/// a task ID, so that hash collisions between `K`-typed keys can be resolved
132/// using either `K`'s `Eq` impl *or* by checking the task IDs.
133///
134/// This allows looking up a task using either an actual key (such as when the
135/// user queries the map with a key), *or* using a task ID and a hash (such as
136/// when removing completed tasks from the map).
137#[derive(Debug)]
138struct Key<K> {
139    key: K,
140    id: Id,
141}
142
143impl<K, V> JoinMap<K, V> {
144    /// Creates a new empty `JoinMap`.
145    ///
146    /// The `JoinMap` is initially created with a capacity of 0, so it will not
147    /// allocate until a task is first spawned on it.
148    ///
149    /// # Examples
150    ///
151    /// ```
152    /// use tokio_util::task::JoinMap;
153    /// let map: JoinMap<&str, i32> = JoinMap::new();
154    /// ```
155    #[inline]
156    #[must_use]
157    pub fn new() -> Self {
158        Self::with_hasher(RandomState::new())
159    }
160
161    /// Creates an empty `JoinMap` with the specified capacity.
162    ///
163    /// The `JoinMap` will be able to hold at least `capacity` tasks without
164    /// reallocating.
165    ///
166    /// # Examples
167    ///
168    /// ```
169    /// use tokio_util::task::JoinMap;
170    /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
171    /// ```
172    #[inline]
173    #[must_use]
174    pub fn with_capacity(capacity: usize) -> Self {
175        JoinMap::with_capacity_and_hasher(capacity, Default::default())
176    }
177}
178
179impl<K, V, S: Clone> JoinMap<K, V, S> {
180    /// Creates an empty `JoinMap` which will use the given hash builder to hash
181    /// keys.
182    ///
183    /// The created map has the default initial capacity.
184    ///
185    /// Warning: `hash_builder` is normally randomly generated, and
186    /// is designed to allow `JoinMap` to be resistant to attacks that
187    /// cause many collisions and very poor performance. Setting it
188    /// manually using this function can expose a DoS attack vector.
189    ///
190    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
191    /// the `JoinMap` to be useful, see its documentation for details.
192    #[inline]
193    #[must_use]
194    pub fn with_hasher(hash_builder: S) -> Self {
195        Self::with_capacity_and_hasher(0, hash_builder)
196    }
197
198    /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
199    /// to hash the keys.
200    ///
201    /// The `JoinMap` will be able to hold at least `capacity` elements without
202    /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
203    ///
204    /// Warning: `hash_builder` is normally randomly generated, and
205    /// is designed to allow HashMaps to be resistant to attacks that
206    /// cause many collisions and very poor performance. Setting it
207    /// manually using this function can expose a DoS attack vector.
208    ///
209    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
210    /// the `JoinMap`to be useful, see its documentation for details.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # #[tokio::main]
216    /// # async fn main() {
217    /// use tokio_util::task::JoinMap;
218    /// use std::collections::hash_map::RandomState;
219    ///
220    /// let s = RandomState::new();
221    /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
222    /// map.spawn(1, async move { "hello world!" });
223    /// # }
224    /// ```
225    #[inline]
226    #[must_use]
227    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228        Self {
229            tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
230            hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231            tasks: JoinSet::new(),
232        }
233    }
234
235    /// Returns the number of tasks currently in the `JoinMap`.
236    pub fn len(&self) -> usize {
237        let len = self.tasks_by_key.len();
238        debug_assert_eq!(len, self.hashes_by_task.len());
239        len
240    }
241
242    /// Returns whether the `JoinMap` is empty.
243    pub fn is_empty(&self) -> bool {
244        let empty = self.tasks_by_key.is_empty();
245        debug_assert_eq!(empty, self.hashes_by_task.is_empty());
246        empty
247    }
248
249    /// Returns the number of tasks the map can hold without reallocating.
250    ///
251    /// This number is a lower bound; the `JoinMap` might be able to hold
252    /// more, but is guaranteed to be able to hold at least this many.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use tokio_util::task::JoinMap;
258    ///
259    /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
260    /// assert!(map.capacity() >= 100);
261    /// ```
262    #[inline]
263    pub fn capacity(&self) -> usize {
264        let capacity = self.tasks_by_key.capacity();
265        debug_assert_eq!(capacity, self.hashes_by_task.capacity());
266        capacity
267    }
268}
269
270impl<K, V, S> JoinMap<K, V, S>
271where
272    K: Hash + Eq,
273    V: 'static,
274    S: BuildHasher,
275{
276    /// Spawn the provided task and store it in this `JoinMap` with the provided
277    /// key.
278    ///
279    /// If a task previously existed in the `JoinMap` for this key, that task
280    /// will be cancelled and replaced with the new one. The previous task will
281    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
282    /// *not* return a cancelled [`JoinError`] for that task.
283    ///
284    /// # Panics
285    ///
286    /// This method panics if called outside of a Tokio runtime.
287    ///
288    /// [`join_next`]: Self::join_next
289    #[track_caller]
290    pub fn spawn<F>(&mut self, key: K, task: F)
291    where
292        F: Future<Output = V>,
293        F: Send + 'static,
294        V: Send,
295    {
296        let task = self.tasks.spawn(task);
297        self.insert(key, task)
298    }
299
300    /// Spawn the provided task on the provided runtime and store it in this
301    /// `JoinMap` with the provided key.
302    ///
303    /// If a task previously existed in the `JoinMap` for this key, that task
304    /// will be cancelled and replaced with the new one. The previous task will
305    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
306    /// *not* return a cancelled [`JoinError`] for that task.
307    ///
308    /// [`join_next`]: Self::join_next
309    #[track_caller]
310    pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
311    where
312        F: Future<Output = V>,
313        F: Send + 'static,
314        V: Send,
315    {
316        let task = self.tasks.spawn_on(task, handle);
317        self.insert(key, task);
318    }
319
320    /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
321    /// key.
322    ///
323    /// If a task previously existed in the `JoinMap` for this key, that task
324    /// will be cancelled and replaced with the new one. The previous task will
325    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
326    /// *not* return a cancelled [`JoinError`] for that task.
327    ///
328    /// Note that blocking tasks cannot be cancelled after execution starts.
329    /// Replaced blocking tasks will still run to completion if the task has begun
330    /// to execute when it is replaced. A blocking task which is replaced before
331    /// it has been scheduled on a blocking worker thread will be cancelled.
332    ///
333    /// # Panics
334    ///
335    /// This method panics if called outside of a Tokio runtime.
336    ///
337    /// [`join_next`]: Self::join_next
338    #[track_caller]
339    pub fn spawn_blocking<F>(&mut self, key: K, f: F)
340    where
341        F: FnOnce() -> V,
342        F: Send + 'static,
343        V: Send,
344    {
345        let task = self.tasks.spawn_blocking(f);
346        self.insert(key, task)
347    }
348
349    /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
350    /// `JoinMap` with the provided key.
351    ///
352    /// If a task previously existed in the `JoinMap` for this key, that task
353    /// will be cancelled and replaced with the new one. The previous task will
354    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
355    /// *not* return a cancelled [`JoinError`] for that task.
356    ///
357    /// Note that blocking tasks cannot be cancelled after execution starts.
358    /// Replaced blocking tasks will still run to completion if the task has begun
359    /// to execute when it is replaced. A blocking task which is replaced before
360    /// it has been scheduled on a blocking worker thread will be cancelled.
361    ///
362    /// [`join_next`]: Self::join_next
363    #[track_caller]
364    pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
365    where
366        F: FnOnce() -> V,
367        F: Send + 'static,
368        V: Send,
369    {
370        let task = self.tasks.spawn_blocking_on(f, handle);
371        self.insert(key, task);
372    }
373
374    /// Spawn the provided task on the current [`LocalSet`] and store it in this
375    /// `JoinMap` with the provided key.
376    ///
377    /// If a task previously existed in the `JoinMap` for this key, that task
378    /// will be cancelled and replaced with the new one. The previous task will
379    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
380    /// *not* return a cancelled [`JoinError`] for that task.
381    ///
382    /// # Panics
383    ///
384    /// This method panics if it is called outside of a `LocalSet`.
385    ///
386    /// [`LocalSet`]: tokio::task::LocalSet
387    /// [`join_next`]: Self::join_next
388    #[track_caller]
389    pub fn spawn_local<F>(&mut self, key: K, task: F)
390    where
391        F: Future<Output = V>,
392        F: 'static,
393    {
394        let task = self.tasks.spawn_local(task);
395        self.insert(key, task);
396    }
397
398    /// Spawn the provided task on the provided [`LocalSet`] and store it in
399    /// this `JoinMap` with the provided key.
400    ///
401    /// If a task previously existed in the `JoinMap` for this key, that task
402    /// will be cancelled and replaced with the new one. The previous task will
403    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
404    /// *not* return a cancelled [`JoinError`] for that task.
405    ///
406    /// [`LocalSet`]: tokio::task::LocalSet
407    /// [`join_next`]: Self::join_next
408    #[track_caller]
409    pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
410    where
411        F: Future<Output = V>,
412        F: 'static,
413    {
414        let task = self.tasks.spawn_local_on(task, local_set);
415        self.insert(key, task)
416    }
417
418    fn insert(&mut self, key: K, abort: AbortHandle) {
419        let hash = self.hash(&key);
420        let id = abort.id();
421        let map_key = Key { id, key };
422
423        // Insert the new key into the map of tasks by keys.
424        let entry = self
425            .tasks_by_key
426            .raw_entry_mut()
427            .from_hash(hash, |k| k.key == map_key.key);
428        match entry {
429            RawEntryMut::Occupied(mut occ) => {
430                // There was a previous task spawned with the same key! Cancel
431                // that task, and remove its ID from the map of hashes by task IDs.
432                let Key { id: prev_id, .. } = occ.insert_key(map_key);
433                occ.insert(abort).abort();
434                let _prev_hash = self.hashes_by_task.remove(&prev_id);
435                debug_assert_eq!(Some(hash), _prev_hash);
436            }
437            RawEntryMut::Vacant(vac) => {
438                vac.insert(map_key, abort);
439            }
440        };
441
442        // Associate the key's hash with this task's ID, for looking up tasks by ID.
443        let _prev = self.hashes_by_task.insert(id, hash);
444        debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445    }
446
447    /// Waits until one of the tasks in the map completes and returns its
448    /// output, along with the key corresponding to that task.
449    ///
450    /// Returns `None` if the map is empty.
451    ///
452    /// # Cancel Safety
453    ///
454    /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
455    /// statement and some other branch completes first, it is guaranteed that no tasks were
456    /// removed from this `JoinMap`.
457    ///
458    /// # Returns
459    ///
460    /// This function returns:
461    ///
462    ///  * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
463    ///    completed. The `value` is the return value of that ask, and `key` is
464    ///    the key associated with the task.
465    ///  * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
466    ///    panicked or been aborted. `key` is the key associated  with the task
467    ///    that panicked or was aborted.
468    ///  * `None` if the `JoinMap` is empty.
469    ///
470    /// [`tokio::select!`]: tokio::select
471    pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
472        loop {
473            let (res, id) = match self.tasks.join_next_with_id().await {
474                Some(Ok((id, output))) => (Ok(output), id),
475                Some(Err(e)) => {
476                    let id = e.id();
477                    (Err(e), id)
478                }
479                None => return None,
480            };
481            if let Some(key) = self.remove_by_id(id) {
482                break Some((key, res));
483            }
484        }
485    }
486
487    /// Aborts all tasks and waits for them to finish shutting down.
488    ///
489    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
490    /// a loop until it returns `None`.
491    ///
492    /// This method ignores any panics in the tasks shutting down. When this call returns, the
493    /// `JoinMap` will be empty.
494    ///
495    /// [`abort_all`]: fn@Self::abort_all
496    /// [`join_next`]: fn@Self::join_next
497    pub async fn shutdown(&mut self) {
498        self.abort_all();
499        while self.join_next().await.is_some() {}
500    }
501
502    /// Abort the task corresponding to the provided `key`.
503    ///
504    /// If this `JoinMap` contains a task corresponding to `key`, this method
505    /// will abort that task and return `true`. Otherwise, if no task exists for
506    /// `key`, this method returns `false`.
507    ///
508    /// # Examples
509    ///
510    /// Aborting a task by key:
511    ///
512    /// ```
513    /// use tokio_util::task::JoinMap;
514    ///
515    /// # #[tokio::main]
516    /// # async fn main() {
517    /// let mut map = JoinMap::new();
518    ///
519    /// map.spawn("hello world", async move { /* ... */ });
520    /// map.spawn("goodbye world", async move { /* ... */});
521    ///
522    /// // Look up the "goodbye world" task in the map and abort it.
523    /// map.abort("goodbye world");
524    ///
525    /// while let Some((key, res)) = map.join_next().await {
526    ///     if key == "goodbye world" {
527    ///         // The aborted task should complete with a cancelled `JoinError`.
528    ///         assert!(res.unwrap_err().is_cancelled());
529    ///     } else {
530    ///         // Other tasks should complete normally.
531    ///         assert!(res.is_ok());
532    ///     }
533    /// }
534    /// # }
535    /// ```
536    ///
537    /// `abort` returns `true` if a task was aborted:
538    /// ```
539    /// use tokio_util::task::JoinMap;
540    ///
541    /// # #[tokio::main]
542    /// # async fn main() {
543    /// let mut map = JoinMap::new();
544    ///
545    /// map.spawn("hello world", async move { /* ... */ });
546    /// map.spawn("goodbye world", async move { /* ... */});
547    ///
548    /// // A task for the key "goodbye world" should exist in the map:
549    /// assert!(map.abort("goodbye world"));
550    ///
551    /// // Aborting a key that does not exist will return `false`:
552    /// assert!(!map.abort("goodbye universe"));
553    /// # }
554    /// ```
555    pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
556    where
557        Q: Hash + Eq,
558        K: Borrow<Q>,
559    {
560        match self.get_by_key(key) {
561            Some((_, handle)) => {
562                handle.abort();
563                true
564            }
565            None => false,
566        }
567    }
568
569    /// Aborts all tasks with keys matching `predicate`.
570    ///
571    /// `predicate` is a function called with a reference to each key in the
572    /// map. If it returns `true` for a given key, the corresponding task will
573    /// be cancelled.
574    ///
575    /// # Examples
576    /// ```
577    /// use tokio_util::task::JoinMap;
578    ///
579    /// # // use the current thread rt so that spawned tasks don't
580    /// # // complete in the background before they can be aborted.
581    /// # #[tokio::main(flavor = "current_thread")]
582    /// # async fn main() {
583    /// let mut map = JoinMap::new();
584    ///
585    /// map.spawn("hello world", async move {
586    ///     // ...
587    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
588    /// });
589    /// map.spawn("goodbye world", async move {
590    ///     // ...
591    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
592    /// });
593    /// map.spawn("hello san francisco", async move {
594    ///     // ...
595    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
596    /// });
597    /// map.spawn("goodbye universe", async move {
598    ///     // ...
599    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
600    /// });
601    ///
602    /// // Abort all tasks whose keys begin with "goodbye"
603    /// map.abort_matching(|key| key.starts_with("goodbye"));
604    ///
605    /// let mut seen = 0;
606    /// while let Some((key, res)) = map.join_next().await {
607    ///     seen += 1;
608    ///     if key.starts_with("goodbye") {
609    ///         // The aborted task should complete with a cancelled `JoinError`.
610    ///         assert!(res.unwrap_err().is_cancelled());
611    ///     } else {
612    ///         // Other tasks should complete normally.
613    ///         assert!(key.starts_with("hello"));
614    ///         assert!(res.is_ok());
615    ///     }
616    /// }
617    ///
618    /// // All spawned tasks should have completed.
619    /// assert_eq!(seen, 4);
620    /// # }
621    /// ```
622    pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
623        // Note: this method iterates over the tasks and keys *without* removing
624        // any entries, so that the keys from aborted tasks can still be
625        // returned when calling `join_next` in the future.
626        for (Key { ref key, .. }, task) in &self.tasks_by_key {
627            if predicate(key) {
628                task.abort();
629            }
630        }
631    }
632
633    /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
634    ///
635    /// If a task has completed, but its output hasn't yet been consumed by a
636    /// call to [`join_next`], this method will still return its key.
637    ///
638    /// [`join_next`]: fn@Self::join_next
639    pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
640        JoinMapKeys {
641            iter: self.tasks_by_key.keys(),
642            _value: PhantomData,
643        }
644    }
645
646    /// Returns `true` if this `JoinMap` contains a task for the provided key.
647    ///
648    /// If the task has completed, but its output hasn't yet been consumed by a
649    /// call to [`join_next`], this method will still return `true`.
650    ///
651    /// [`join_next`]: fn@Self::join_next
652    pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
653    where
654        Q: Hash + Eq,
655        K: Borrow<Q>,
656    {
657        self.get_by_key(key).is_some()
658    }
659
660    /// Returns `true` if this `JoinMap` contains a task with the provided
661    /// [task ID].
662    ///
663    /// If the task has completed, but its output hasn't yet been consumed by a
664    /// call to [`join_next`], this method will still return `true`.
665    ///
666    /// [`join_next`]: fn@Self::join_next
667    /// [task ID]: tokio::task::Id
668    pub fn contains_task(&self, task: &Id) -> bool {
669        self.get_by_id(task).is_some()
670    }
671
672    /// Reserves capacity for at least `additional` more tasks to be spawned
673    /// on this `JoinMap` without reallocating for the map of task keys. The
674    /// collection may reserve more space to avoid frequent reallocations.
675    ///
676    /// Note that spawning a task will still cause an allocation for the task
677    /// itself.
678    ///
679    /// # Panics
680    ///
681    /// Panics if the new allocation size overflows [`usize`].
682    ///
683    /// # Examples
684    ///
685    /// ```
686    /// use tokio_util::task::JoinMap;
687    ///
688    /// let mut map: JoinMap<&str, i32> = JoinMap::new();
689    /// map.reserve(10);
690    /// ```
691    #[inline]
692    pub fn reserve(&mut self, additional: usize) {
693        self.tasks_by_key.reserve(additional);
694        self.hashes_by_task.reserve(additional);
695    }
696
697    /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
698    /// down as much as possible while maintaining the internal rules
699    /// and possibly leaving some space in accordance with the resize policy.
700    ///
701    /// # Examples
702    ///
703    /// ```
704    /// # #[tokio::main]
705    /// # async fn main() {
706    /// use tokio_util::task::JoinMap;
707    ///
708    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
709    /// map.spawn(1, async move { 2 });
710    /// map.spawn(3, async move { 4 });
711    /// assert!(map.capacity() >= 100);
712    /// map.shrink_to_fit();
713    /// assert!(map.capacity() >= 2);
714    /// # }
715    /// ```
716    #[inline]
717    pub fn shrink_to_fit(&mut self) {
718        self.hashes_by_task.shrink_to_fit();
719        self.tasks_by_key.shrink_to_fit();
720    }
721
722    /// Shrinks the capacity of the map with a lower limit. It will drop
723    /// down no lower than the supplied limit while maintaining the internal rules
724    /// and possibly leaving some space in accordance with the resize policy.
725    ///
726    /// If the current capacity is less than the lower limit, this is a no-op.
727    ///
728    /// # Examples
729    ///
730    /// ```
731    /// # #[tokio::main]
732    /// # async fn main() {
733    /// use tokio_util::task::JoinMap;
734    ///
735    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
736    /// map.spawn(1, async move { 2 });
737    /// map.spawn(3, async move { 4 });
738    /// assert!(map.capacity() >= 100);
739    /// map.shrink_to(10);
740    /// assert!(map.capacity() >= 10);
741    /// map.shrink_to(0);
742    /// assert!(map.capacity() >= 2);
743    /// # }
744    /// ```
745    #[inline]
746    pub fn shrink_to(&mut self, min_capacity: usize) {
747        self.hashes_by_task.shrink_to(min_capacity);
748        self.tasks_by_key.shrink_to(min_capacity)
749    }
750
751    /// Look up a task in the map by its key, returning the key and abort handle.
752    fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
753    where
754        Q: Hash + Eq,
755        K: Borrow<Q>,
756    {
757        let hash = self.hash(key);
758        self.tasks_by_key
759            .raw_entry()
760            .from_hash(hash, |k| k.key.borrow() == key)
761    }
762
763    /// Look up a task in the map by its task ID, returning the key and abort handle.
764    fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
765        let hash = self.hashes_by_task.get(id)?;
766        self.tasks_by_key
767            .raw_entry()
768            .from_hash(*hash, |k| &k.id == id)
769    }
770
771    /// Remove a task from the map by ID, returning the key for that task.
772    fn remove_by_id(&mut self, id: Id) -> Option<K> {
773        // Get the hash for the given ID.
774        let hash = self.hashes_by_task.remove(&id)?;
775
776        // Remove the entry for that hash.
777        let entry = self
778            .tasks_by_key
779            .raw_entry_mut()
780            .from_hash(hash, |k| k.id == id);
781        let (Key { id: _key_id, key }, handle) = match entry {
782            RawEntryMut::Occupied(entry) => entry.remove_entry(),
783            _ => return None,
784        };
785        debug_assert_eq!(_key_id, id);
786        debug_assert_eq!(id, handle.id());
787        self.hashes_by_task.remove(&id);
788        Some(key)
789    }
790
791    /// Returns the hash for a given key.
792    #[inline]
793    fn hash<Q: ?Sized>(&self, key: &Q) -> u64
794    where
795        Q: Hash,
796    {
797        let mut hasher = self.tasks_by_key.hasher().build_hasher();
798        key.hash(&mut hasher);
799        hasher.finish()
800    }
801}
802
803impl<K, V, S> JoinMap<K, V, S>
804where
805    V: 'static,
806{
807    /// Aborts all tasks on this `JoinMap`.
808    ///
809    /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
810    /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
811    pub fn abort_all(&mut self) {
812        self.tasks.abort_all()
813    }
814
815    /// Removes all tasks from this `JoinMap` without aborting them.
816    ///
817    /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
818    /// is dropped. They may still be aborted by key.
819    pub fn detach_all(&mut self) {
820        self.tasks.detach_all();
821        self.tasks_by_key.clear();
822        self.hashes_by_task.clear();
823    }
824}
825
826// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
827// Debug`, since no value is ever actually stored in the map.
828impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
829    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
830        // format the task keys and abort handles a little nicer by just
831        // printing the key and task ID pairs, without format the `Key` struct
832        // itself or the `AbortHandle`, which would just format the task's ID
833        // again.
834        struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
835        impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
836            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
837                f.debug_map()
838                    .entries(self.0.keys().map(|Key { key, id }| (key, id)))
839                    .finish()
840            }
841        }
842
843        f.debug_struct("JoinMap")
844            // The `tasks_by_key` map is the only one that contains information
845            // that's really worth formatting for the user, since it contains
846            // the tasks' keys and IDs. The other fields are basically
847            // implementation details.
848            .field("tasks", &KeySet(&self.tasks_by_key))
849            .finish()
850    }
851}
852
853impl<K, V> Default for JoinMap<K, V> {
854    fn default() -> Self {
855        Self::new()
856    }
857}
858
859// === impl Key ===
860
861impl<K: Hash> Hash for Key<K> {
862    // Don't include the task ID in the hash.
863    #[inline]
864    fn hash<H: Hasher>(&self, hasher: &mut H) {
865        self.key.hash(hasher);
866    }
867}
868
869// Because we override `Hash` for this type, we must also override the
870// `PartialEq` impl, so that all instances with the same hash are equal.
871impl<K: PartialEq> PartialEq for Key<K> {
872    #[inline]
873    fn eq(&self, other: &Self) -> bool {
874        self.key == other.key
875    }
876}
877
878impl<K: Eq> Eq for Key<K> {}
879
880/// An iterator over the keys of a [`JoinMap`].
881#[derive(Debug, Clone)]
882pub struct JoinMapKeys<'a, K, V> {
883    iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
884    /// To make it easier to change `JoinMap` in the future, keep V as a generic
885    /// parameter.
886    _value: PhantomData<&'a V>,
887}
888
889impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
890    type Item = &'a K;
891
892    fn next(&mut self) -> Option<&'a K> {
893        self.iter.next().map(|key| &key.key)
894    }
895
896    fn size_hint(&self) -> (usize, Option<usize>) {
897        self.iter.size_hint()
898    }
899}
900
901impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
902    fn len(&self) -> usize {
903        self.iter.len()
904    }
905}
906
907impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}