tokio_util/task/
join_map.rs

1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a  set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on   their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45///     let mut map = JoinMap::new();
46///
47///     for i in 0..10 {
48///         // Spawn a task on the `JoinMap` with `i` as its key.
49///         map.spawn(i, async move { /* ... */ });
50///     }
51///
52///     let mut seen = [false; 10];
53///
54///     // When a task completes, `join_next` returns the task's key along
55///     // with its output.
56///     while let Some((key, res)) = map.join_next().await {
57///         seen[key] = true;
58///         assert!(res.is_ok(), "task {} completed successfully!", key);
59///     }
60///
61///     for i in 0..10 {
62///         assert!(seen[i]);
63///     }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74///     let mut map = JoinMap::new();
75///
76///     map.spawn("hello world", async move { /* ... */ });
77///     map.spawn("goodbye world", async move { /* ... */});
78///
79///     // Look up the "goodbye world" task in the map and abort it.
80///     let aborted = map.abort("goodbye world");
81///
82///     // `JoinMap::abort` returns `true` if a task existed for the
83///     // provided key.
84///     assert!(aborted);
85///
86///     while let Some((key, res)) = map.join_next().await {
87///         if key == "goodbye world" {
88///             // The aborted task should complete with a cancelled `JoinError`.
89///             assert!(res.unwrap_err().is_cancelled());
90///         } else {
91///             // Other tasks should complete normally.
92///             assert!(res.is_ok());
93///         }
94///     }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105    /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106    /// indexed by their keys and task IDs.
107    ///
108    /// The [`Key`] type contains both the task's `K`-typed key provided when
109    /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110    /// hash collisions when looking up tasks based on their pre-computed hash
111    /// (as stored in the `hashes_by_task` map).
112    tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114    /// A map from task IDs to the hash of the key associated with that task.
115    ///
116    /// This map is used to perform reverse lookups of tasks in the
117    /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118    /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119    /// of that task's key, and then remove it from the `tasks_by_key` map using
120    /// the raw hash code, resolving collisions by comparing task IDs.
121    hashes_by_task: HashMap<Id, u64, S>,
122
123    /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124    /// `JoinMap`.
125    tasks: JoinSet<V>,
126}
127
128/// A [`JoinMap`] key.
129///
130/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131/// a task ID, so that hash collisions between `K`-typed keys can be resolved
132/// using either `K`'s `Eq` impl *or* by checking the task IDs.
133///
134/// This allows looking up a task using either an actual key (such as when the
135/// user queries the map with a key), *or* using a task ID and a hash (such as
136/// when removing completed tasks from the map).
137#[derive(Debug)]
138struct Key<K> {
139    key: K,
140    id: Id,
141}
142
143impl<K, V> JoinMap<K, V> {
144    /// Creates a new empty `JoinMap`.
145    ///
146    /// The `JoinMap` is initially created with a capacity of 0, so it will not
147    /// allocate until a task is first spawned on it.
148    ///
149    /// # Examples
150    ///
151    /// ```
152    /// use tokio_util::task::JoinMap;
153    /// let map: JoinMap<&str, i32> = JoinMap::new();
154    /// ```
155    #[inline]
156    #[must_use]
157    pub fn new() -> Self {
158        Self::with_hasher(RandomState::new())
159    }
160
161    /// Creates an empty `JoinMap` with the specified capacity.
162    ///
163    /// The `JoinMap` will be able to hold at least `capacity` tasks without
164    /// reallocating.
165    ///
166    /// # Examples
167    ///
168    /// ```
169    /// use tokio_util::task::JoinMap;
170    /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
171    /// ```
172    #[inline]
173    #[must_use]
174    pub fn with_capacity(capacity: usize) -> Self {
175        JoinMap::with_capacity_and_hasher(capacity, Default::default())
176    }
177}
178
179impl<K, V, S: Clone> JoinMap<K, V, S> {
180    /// Creates an empty `JoinMap` which will use the given hash builder to hash
181    /// keys.
182    ///
183    /// The created map has the default initial capacity.
184    ///
185    /// Warning: `hash_builder` is normally randomly generated, and
186    /// is designed to allow `JoinMap` to be resistant to attacks that
187    /// cause many collisions and very poor performance. Setting it
188    /// manually using this function can expose a DoS attack vector.
189    ///
190    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
191    /// the `JoinMap` to be useful, see its documentation for details.
192    #[inline]
193    #[must_use]
194    pub fn with_hasher(hash_builder: S) -> Self {
195        Self::with_capacity_and_hasher(0, hash_builder)
196    }
197
198    /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
199    /// to hash the keys.
200    ///
201    /// The `JoinMap` will be able to hold at least `capacity` elements without
202    /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
203    ///
204    /// Warning: `hash_builder` is normally randomly generated, and
205    /// is designed to allow HashMaps to be resistant to attacks that
206    /// cause many collisions and very poor performance. Setting it
207    /// manually using this function can expose a DoS attack vector.
208    ///
209    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
210    /// the `JoinMap`to be useful, see its documentation for details.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # #[tokio::main]
216    /// # async fn main() {
217    /// use tokio_util::task::JoinMap;
218    /// use std::collections::hash_map::RandomState;
219    ///
220    /// let s = RandomState::new();
221    /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
222    /// map.spawn(1, async move { "hello world!" });
223    /// # }
224    /// ```
225    #[inline]
226    #[must_use]
227    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228        Self {
229            tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
230            hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231            tasks: JoinSet::new(),
232        }
233    }
234
235    /// Returns the number of tasks currently in the `JoinMap`.
236    pub fn len(&self) -> usize {
237        let len = self.tasks_by_key.len();
238        debug_assert_eq!(len, self.hashes_by_task.len());
239        len
240    }
241
242    /// Returns whether the `JoinMap` is empty.
243    pub fn is_empty(&self) -> bool {
244        let empty = self.tasks_by_key.is_empty();
245        debug_assert_eq!(empty, self.hashes_by_task.is_empty());
246        empty
247    }
248
249    /// Returns the number of tasks the map can hold without reallocating.
250    ///
251    /// This number is a lower bound; the `JoinMap` might be able to hold
252    /// more, but is guaranteed to be able to hold at least this many.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use tokio_util::task::JoinMap;
258    ///
259    /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
260    /// assert!(map.capacity() >= 100);
261    /// ```
262    #[inline]
263    pub fn capacity(&self) -> usize {
264        let capacity = self.tasks_by_key.capacity();
265        debug_assert_eq!(capacity, self.hashes_by_task.capacity());
266        capacity
267    }
268}
269
270impl<K, V, S> JoinMap<K, V, S>
271where
272    K: Hash + Eq,
273    V: 'static,
274    S: BuildHasher,
275{
276    /// Spawn the provided task and store it in this `JoinMap` with the provided
277    /// key.
278    ///
279    /// If a task previously existed in the `JoinMap` for this key, that task
280    /// will be cancelled and replaced with the new one. The previous task will
281    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
282    /// *not* return a cancelled [`JoinError`] for that task.
283    ///
284    /// # Panics
285    ///
286    /// This method panics if called outside of a Tokio runtime.
287    ///
288    /// [`join_next`]: Self::join_next
289    #[track_caller]
290    pub fn spawn<F>(&mut self, key: K, task: F)
291    where
292        F: Future<Output = V>,
293        F: Send + 'static,
294        V: Send,
295    {
296        let task = self.tasks.spawn(task);
297        self.insert(key, task)
298    }
299
300    /// Spawn the provided task on the provided runtime and store it in this
301    /// `JoinMap` with the provided key.
302    ///
303    /// If a task previously existed in the `JoinMap` for this key, that task
304    /// will be cancelled and replaced with the new one. The previous task will
305    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
306    /// *not* return a cancelled [`JoinError`] for that task.
307    ///
308    /// [`join_next`]: Self::join_next
309    #[track_caller]
310    pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
311    where
312        F: Future<Output = V>,
313        F: Send + 'static,
314        V: Send,
315    {
316        let task = self.tasks.spawn_on(task, handle);
317        self.insert(key, task);
318    }
319
320    /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
321    /// key.
322    ///
323    /// If a task previously existed in the `JoinMap` for this key, that task
324    /// will be cancelled and replaced with the new one. The previous task will
325    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
326    /// *not* return a cancelled [`JoinError`] for that task.
327    ///
328    /// Note that blocking tasks cannot be cancelled after execution starts.
329    /// Replaced blocking tasks will still run to completion if the task has begun
330    /// to execute when it is replaced. A blocking task which is replaced before
331    /// it has been scheduled on a blocking worker thread will be cancelled.
332    ///
333    /// # Panics
334    ///
335    /// This method panics if called outside of a Tokio runtime.
336    ///
337    /// [`join_next`]: Self::join_next
338    #[track_caller]
339    pub fn spawn_blocking<F>(&mut self, key: K, f: F)
340    where
341        F: FnOnce() -> V,
342        F: Send + 'static,
343        V: Send,
344    {
345        let task = self.tasks.spawn_blocking(f);
346        self.insert(key, task)
347    }
348
349    /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
350    /// `JoinMap` with the provided key.
351    ///
352    /// If a task previously existed in the `JoinMap` for this key, that task
353    /// will be cancelled and replaced with the new one. The previous task will
354    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
355    /// *not* return a cancelled [`JoinError`] for that task.
356    ///
357    /// Note that blocking tasks cannot be cancelled after execution starts.
358    /// Replaced blocking tasks will still run to completion if the task has begun
359    /// to execute when it is replaced. A blocking task which is replaced before
360    /// it has been scheduled on a blocking worker thread will be cancelled.
361    ///
362    /// [`join_next`]: Self::join_next
363    #[track_caller]
364    pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
365    where
366        F: FnOnce() -> V,
367        F: Send + 'static,
368        V: Send,
369    {
370        let task = self.tasks.spawn_blocking_on(f, handle);
371        self.insert(key, task);
372    }
373
374    /// Spawn the provided task on the current [`LocalSet`] and store it in this
375    /// `JoinMap` with the provided key.
376    ///
377    /// If a task previously existed in the `JoinMap` for this key, that task
378    /// will be cancelled and replaced with the new one. The previous task will
379    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
380    /// *not* return a cancelled [`JoinError`] for that task.
381    ///
382    /// # Panics
383    ///
384    /// This method panics if it is called outside of a `LocalSet`.
385    ///
386    /// [`LocalSet`]: tokio::task::LocalSet
387    /// [`join_next`]: Self::join_next
388    #[track_caller]
389    pub fn spawn_local<F>(&mut self, key: K, task: F)
390    where
391        F: Future<Output = V>,
392        F: 'static,
393    {
394        let task = self.tasks.spawn_local(task);
395        self.insert(key, task);
396    }
397
398    /// Spawn the provided task on the provided [`LocalSet`] and store it in
399    /// this `JoinMap` with the provided key.
400    ///
401    /// If a task previously existed in the `JoinMap` for this key, that task
402    /// will be cancelled and replaced with the new one. The previous task will
403    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
404    /// *not* return a cancelled [`JoinError`] for that task.
405    ///
406    /// [`LocalSet`]: tokio::task::LocalSet
407    /// [`join_next`]: Self::join_next
408    #[track_caller]
409    pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
410    where
411        F: Future<Output = V>,
412        F: 'static,
413    {
414        let task = self.tasks.spawn_local_on(task, local_set);
415        self.insert(key, task)
416    }
417
418    fn insert(&mut self, key: K, abort: AbortHandle) {
419        let hash = self.hash(&key);
420        let id = abort.id();
421        let map_key = Key { id, key };
422
423        // Insert the new key into the map of tasks by keys.
424        let entry = self
425            .tasks_by_key
426            .raw_entry_mut()
427            .from_hash(hash, |k| k.key == map_key.key);
428        match entry {
429            RawEntryMut::Occupied(mut occ) => {
430                // There was a previous task spawned with the same key! Cancel
431                // that task, and remove its ID from the map of hashes by task IDs.
432                let Key { id: prev_id, .. } = occ.insert_key(map_key);
433                occ.insert(abort).abort();
434                let _prev_hash = self.hashes_by_task.remove(&prev_id);
435                debug_assert_eq!(Some(hash), _prev_hash);
436            }
437            RawEntryMut::Vacant(vac) => {
438                vac.insert(map_key, abort);
439            }
440        };
441
442        // Associate the key's hash with this task's ID, for looking up tasks by ID.
443        let _prev = self.hashes_by_task.insert(id, hash);
444        debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445    }
446
447    /// Waits until one of the tasks in the map completes and returns its
448    /// output, along with the key corresponding to that task.
449    ///
450    /// Returns `None` if the map is empty.
451    ///
452    /// # Cancel Safety
453    ///
454    /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
455    /// statement and some other branch completes first, it is guaranteed that no tasks were
456    /// removed from this `JoinMap`.
457    ///
458    /// # Returns
459    ///
460    /// This function returns:
461    ///
462    ///  * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
463    ///    completed. The `value` is the return value of that ask, and `key` is
464    ///    the key associated with the task.
465    ///  * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
466    ///    panicked or been aborted. `key` is the key associated  with the task
467    ///    that panicked or was aborted.
468    ///  * `None` if the `JoinMap` is empty.
469    ///
470    /// [`tokio::select!`]: tokio::select
471    pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
472        let (res, id) = match self.tasks.join_next_with_id().await {
473            Some(Ok((id, output))) => (Ok(output), id),
474            Some(Err(e)) => {
475                let id = e.id();
476                (Err(e), id)
477            }
478            None => return None,
479        };
480        let key = self.remove_by_id(id)?;
481        Some((key, res))
482    }
483
484    /// Aborts all tasks and waits for them to finish shutting down.
485    ///
486    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
487    /// a loop until it returns `None`.
488    ///
489    /// This method ignores any panics in the tasks shutting down. When this call returns, the
490    /// `JoinMap` will be empty.
491    ///
492    /// [`abort_all`]: fn@Self::abort_all
493    /// [`join_next`]: fn@Self::join_next
494    pub async fn shutdown(&mut self) {
495        self.abort_all();
496        while self.join_next().await.is_some() {}
497    }
498
499    /// Abort the task corresponding to the provided `key`.
500    ///
501    /// If this `JoinMap` contains a task corresponding to `key`, this method
502    /// will abort that task and return `true`. Otherwise, if no task exists for
503    /// `key`, this method returns `false`.
504    ///
505    /// # Examples
506    ///
507    /// Aborting a task by key:
508    ///
509    /// ```
510    /// use tokio_util::task::JoinMap;
511    ///
512    /// # #[tokio::main]
513    /// # async fn main() {
514    /// let mut map = JoinMap::new();
515    ///
516    /// map.spawn("hello world", async move { /* ... */ });
517    /// map.spawn("goodbye world", async move { /* ... */});
518    ///
519    /// // Look up the "goodbye world" task in the map and abort it.
520    /// map.abort("goodbye world");
521    ///
522    /// while let Some((key, res)) = map.join_next().await {
523    ///     if key == "goodbye world" {
524    ///         // The aborted task should complete with a cancelled `JoinError`.
525    ///         assert!(res.unwrap_err().is_cancelled());
526    ///     } else {
527    ///         // Other tasks should complete normally.
528    ///         assert!(res.is_ok());
529    ///     }
530    /// }
531    /// # }
532    /// ```
533    ///
534    /// `abort` returns `true` if a task was aborted:
535    /// ```
536    /// use tokio_util::task::JoinMap;
537    ///
538    /// # #[tokio::main]
539    /// # async fn main() {
540    /// let mut map = JoinMap::new();
541    ///
542    /// map.spawn("hello world", async move { /* ... */ });
543    /// map.spawn("goodbye world", async move { /* ... */});
544    ///
545    /// // A task for the key "goodbye world" should exist in the map:
546    /// assert!(map.abort("goodbye world"));
547    ///
548    /// // Aborting a key that does not exist will return `false`:
549    /// assert!(!map.abort("goodbye universe"));
550    /// # }
551    /// ```
552    pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
553    where
554        Q: Hash + Eq,
555        K: Borrow<Q>,
556    {
557        match self.get_by_key(key) {
558            Some((_, handle)) => {
559                handle.abort();
560                true
561            }
562            None => false,
563        }
564    }
565
566    /// Aborts all tasks with keys matching `predicate`.
567    ///
568    /// `predicate` is a function called with a reference to each key in the
569    /// map. If it returns `true` for a given key, the corresponding task will
570    /// be cancelled.
571    ///
572    /// # Examples
573    /// ```
574    /// use tokio_util::task::JoinMap;
575    ///
576    /// # // use the current thread rt so that spawned tasks don't
577    /// # // complete in the background before they can be aborted.
578    /// # #[tokio::main(flavor = "current_thread")]
579    /// # async fn main() {
580    /// let mut map = JoinMap::new();
581    ///
582    /// map.spawn("hello world", async move {
583    ///     // ...
584    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
585    /// });
586    /// map.spawn("goodbye world", async move {
587    ///     // ...
588    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
589    /// });
590    /// map.spawn("hello san francisco", async move {
591    ///     // ...
592    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
593    /// });
594    /// map.spawn("goodbye universe", async move {
595    ///     // ...
596    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
597    /// });
598    ///
599    /// // Abort all tasks whose keys begin with "goodbye"
600    /// map.abort_matching(|key| key.starts_with("goodbye"));
601    ///
602    /// let mut seen = 0;
603    /// while let Some((key, res)) = map.join_next().await {
604    ///     seen += 1;
605    ///     if key.starts_with("goodbye") {
606    ///         // The aborted task should complete with a cancelled `JoinError`.
607    ///         assert!(res.unwrap_err().is_cancelled());
608    ///     } else {
609    ///         // Other tasks should complete normally.
610    ///         assert!(key.starts_with("hello"));
611    ///         assert!(res.is_ok());
612    ///     }
613    /// }
614    ///
615    /// // All spawned tasks should have completed.
616    /// assert_eq!(seen, 4);
617    /// # }
618    /// ```
619    pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
620        // Note: this method iterates over the tasks and keys *without* removing
621        // any entries, so that the keys from aborted tasks can still be
622        // returned when calling `join_next` in the future.
623        for (Key { ref key, .. }, task) in &self.tasks_by_key {
624            if predicate(key) {
625                task.abort();
626            }
627        }
628    }
629
630    /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
631    ///
632    /// If a task has completed, but its output hasn't yet been consumed by a
633    /// call to [`join_next`], this method will still return its key.
634    ///
635    /// [`join_next`]: fn@Self::join_next
636    pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
637        JoinMapKeys {
638            iter: self.tasks_by_key.keys(),
639            _value: PhantomData,
640        }
641    }
642
643    /// Returns `true` if this `JoinMap` contains a task for the provided key.
644    ///
645    /// If the task has completed, but its output hasn't yet been consumed by a
646    /// call to [`join_next`], this method will still return `true`.
647    ///
648    /// [`join_next`]: fn@Self::join_next
649    pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
650    where
651        Q: Hash + Eq,
652        K: Borrow<Q>,
653    {
654        self.get_by_key(key).is_some()
655    }
656
657    /// Returns `true` if this `JoinMap` contains a task with the provided
658    /// [task ID].
659    ///
660    /// If the task has completed, but its output hasn't yet been consumed by a
661    /// call to [`join_next`], this method will still return `true`.
662    ///
663    /// [`join_next`]: fn@Self::join_next
664    /// [task ID]: tokio::task::Id
665    pub fn contains_task(&self, task: &Id) -> bool {
666        self.get_by_id(task).is_some()
667    }
668
669    /// Reserves capacity for at least `additional` more tasks to be spawned
670    /// on this `JoinMap` without reallocating for the map of task keys. The
671    /// collection may reserve more space to avoid frequent reallocations.
672    ///
673    /// Note that spawning a task will still cause an allocation for the task
674    /// itself.
675    ///
676    /// # Panics
677    ///
678    /// Panics if the new allocation size overflows [`usize`].
679    ///
680    /// # Examples
681    ///
682    /// ```
683    /// use tokio_util::task::JoinMap;
684    ///
685    /// let mut map: JoinMap<&str, i32> = JoinMap::new();
686    /// map.reserve(10);
687    /// ```
688    #[inline]
689    pub fn reserve(&mut self, additional: usize) {
690        self.tasks_by_key.reserve(additional);
691        self.hashes_by_task.reserve(additional);
692    }
693
694    /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
695    /// down as much as possible while maintaining the internal rules
696    /// and possibly leaving some space in accordance with the resize policy.
697    ///
698    /// # Examples
699    ///
700    /// ```
701    /// # #[tokio::main]
702    /// # async fn main() {
703    /// use tokio_util::task::JoinMap;
704    ///
705    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
706    /// map.spawn(1, async move { 2 });
707    /// map.spawn(3, async move { 4 });
708    /// assert!(map.capacity() >= 100);
709    /// map.shrink_to_fit();
710    /// assert!(map.capacity() >= 2);
711    /// # }
712    /// ```
713    #[inline]
714    pub fn shrink_to_fit(&mut self) {
715        self.hashes_by_task.shrink_to_fit();
716        self.tasks_by_key.shrink_to_fit();
717    }
718
719    /// Shrinks the capacity of the map with a lower limit. It will drop
720    /// down no lower than the supplied limit while maintaining the internal rules
721    /// and possibly leaving some space in accordance with the resize policy.
722    ///
723    /// If the current capacity is less than the lower limit, this is a no-op.
724    ///
725    /// # Examples
726    ///
727    /// ```
728    /// # #[tokio::main]
729    /// # async fn main() {
730    /// use tokio_util::task::JoinMap;
731    ///
732    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
733    /// map.spawn(1, async move { 2 });
734    /// map.spawn(3, async move { 4 });
735    /// assert!(map.capacity() >= 100);
736    /// map.shrink_to(10);
737    /// assert!(map.capacity() >= 10);
738    /// map.shrink_to(0);
739    /// assert!(map.capacity() >= 2);
740    /// # }
741    /// ```
742    #[inline]
743    pub fn shrink_to(&mut self, min_capacity: usize) {
744        self.hashes_by_task.shrink_to(min_capacity);
745        self.tasks_by_key.shrink_to(min_capacity)
746    }
747
748    /// Look up a task in the map by its key, returning the key and abort handle.
749    fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
750    where
751        Q: Hash + Eq,
752        K: Borrow<Q>,
753    {
754        let hash = self.hash(key);
755        self.tasks_by_key
756            .raw_entry()
757            .from_hash(hash, |k| k.key.borrow() == key)
758    }
759
760    /// Look up a task in the map by its task ID, returning the key and abort handle.
761    fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
762        let hash = self.hashes_by_task.get(id)?;
763        self.tasks_by_key
764            .raw_entry()
765            .from_hash(*hash, |k| &k.id == id)
766    }
767
768    /// Remove a task from the map by ID, returning the key for that task.
769    fn remove_by_id(&mut self, id: Id) -> Option<K> {
770        // Get the hash for the given ID.
771        let hash = self.hashes_by_task.remove(&id)?;
772
773        // Remove the entry for that hash.
774        let entry = self
775            .tasks_by_key
776            .raw_entry_mut()
777            .from_hash(hash, |k| k.id == id);
778        let (Key { id: _key_id, key }, handle) = match entry {
779            RawEntryMut::Occupied(entry) => entry.remove_entry(),
780            _ => return None,
781        };
782        debug_assert_eq!(_key_id, id);
783        debug_assert_eq!(id, handle.id());
784        self.hashes_by_task.remove(&id);
785        Some(key)
786    }
787
788    /// Returns the hash for a given key.
789    #[inline]
790    fn hash<Q: ?Sized>(&self, key: &Q) -> u64
791    where
792        Q: Hash,
793    {
794        let mut hasher = self.tasks_by_key.hasher().build_hasher();
795        key.hash(&mut hasher);
796        hasher.finish()
797    }
798}
799
800impl<K, V, S> JoinMap<K, V, S>
801where
802    V: 'static,
803{
804    /// Aborts all tasks on this `JoinMap`.
805    ///
806    /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
807    /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
808    pub fn abort_all(&mut self) {
809        self.tasks.abort_all()
810    }
811
812    /// Removes all tasks from this `JoinMap` without aborting them.
813    ///
814    /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
815    /// is dropped. They may still be aborted by key.
816    pub fn detach_all(&mut self) {
817        self.tasks.detach_all();
818        self.tasks_by_key.clear();
819        self.hashes_by_task.clear();
820    }
821}
822
823// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
824// Debug`, since no value is ever actually stored in the map.
825impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
826    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
827        // format the task keys and abort handles a little nicer by just
828        // printing the key and task ID pairs, without format the `Key` struct
829        // itself or the `AbortHandle`, which would just format the task's ID
830        // again.
831        struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
832        impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
833            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
834                f.debug_map()
835                    .entries(self.0.keys().map(|Key { key, id }| (key, id)))
836                    .finish()
837            }
838        }
839
840        f.debug_struct("JoinMap")
841            // The `tasks_by_key` map is the only one that contains information
842            // that's really worth formatting for the user, since it contains
843            // the tasks' keys and IDs. The other fields are basically
844            // implementation details.
845            .field("tasks", &KeySet(&self.tasks_by_key))
846            .finish()
847    }
848}
849
850impl<K, V> Default for JoinMap<K, V> {
851    fn default() -> Self {
852        Self::new()
853    }
854}
855
856// === impl Key ===
857
858impl<K: Hash> Hash for Key<K> {
859    // Don't include the task ID in the hash.
860    #[inline]
861    fn hash<H: Hasher>(&self, hasher: &mut H) {
862        self.key.hash(hasher);
863    }
864}
865
866// Because we override `Hash` for this type, we must also override the
867// `PartialEq` impl, so that all instances with the same hash are equal.
868impl<K: PartialEq> PartialEq for Key<K> {
869    #[inline]
870    fn eq(&self, other: &Self) -> bool {
871        self.key == other.key
872    }
873}
874
875impl<K: Eq> Eq for Key<K> {}
876
877/// An iterator over the keys of a [`JoinMap`].
878#[derive(Debug, Clone)]
879pub struct JoinMapKeys<'a, K, V> {
880    iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
881    /// To make it easier to change `JoinMap` in the future, keep V as a generic
882    /// parameter.
883    _value: PhantomData<&'a V>,
884}
885
886impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
887    type Item = &'a K;
888
889    fn next(&mut self) -> Option<&'a K> {
890        self.iter.next().map(|key| &key.key)
891    }
892
893    fn size_hint(&self) -> (usize, Option<usize>) {
894        self.iter.size_hint()
895    }
896}
897
898impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
899    fn len(&self) -> usize {
900        self.iter.len()
901    }
902}
903
904impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}