tokio_util/task/join_map.rs
1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45/// let mut map = JoinMap::new();
46///
47/// for i in 0..10 {
48/// // Spawn a task on the `JoinMap` with `i` as its key.
49/// map.spawn(i, async move { /* ... */ });
50/// }
51///
52/// let mut seen = [false; 10];
53///
54/// // When a task completes, `join_next` returns the task's key along
55/// // with its output.
56/// while let Some((key, res)) = map.join_next().await {
57/// seen[key] = true;
58/// assert!(res.is_ok(), "task {} completed successfully!", key);
59/// }
60///
61/// for i in 0..10 {
62/// assert!(seen[i]);
63/// }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74/// let mut map = JoinMap::new();
75///
76/// map.spawn("hello world", async move { /* ... */ });
77/// map.spawn("goodbye world", async move { /* ... */});
78///
79/// // Look up the "goodbye world" task in the map and abort it.
80/// let aborted = map.abort("goodbye world");
81///
82/// // `JoinMap::abort` returns `true` if a task existed for the
83/// // provided key.
84/// assert!(aborted);
85///
86/// while let Some((key, res)) = map.join_next().await {
87/// if key == "goodbye world" {
88/// // The aborted task should complete with a cancelled `JoinError`.
89/// assert!(res.unwrap_err().is_cancelled());
90/// } else {
91/// // Other tasks should complete normally.
92/// assert!(res.is_ok());
93/// }
94/// }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106 /// indexed by their keys and task IDs.
107 ///
108 /// The [`Key`] type contains both the task's `K`-typed key provided when
109 /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110 /// hash collisions when looking up tasks based on their pre-computed hash
111 /// (as stored in the `hashes_by_task` map).
112 tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114 /// A map from task IDs to the hash of the key associated with that task.
115 ///
116 /// This map is used to perform reverse lookups of tasks in the
117 /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118 /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119 /// of that task's key, and then remove it from the `tasks_by_key` map using
120 /// the raw hash code, resolving collisions by comparing task IDs.
121 hashes_by_task: HashMap<Id, u64, S>,
122
123 /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124 /// `JoinMap`.
125 tasks: JoinSet<V>,
126}
127
128/// A [`JoinMap`] key.
129///
130/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131/// a task ID, so that hash collisions between `K`-typed keys can be resolved
132/// using either `K`'s `Eq` impl *or* by checking the task IDs.
133///
134/// This allows looking up a task using either an actual key (such as when the
135/// user queries the map with a key), *or* using a task ID and a hash (such as
136/// when removing completed tasks from the map).
137#[derive(Debug)]
138struct Key<K> {
139 key: K,
140 id: Id,
141}
142
143impl<K, V> JoinMap<K, V> {
144 /// Creates a new empty `JoinMap`.
145 ///
146 /// The `JoinMap` is initially created with a capacity of 0, so it will not
147 /// allocate until a task is first spawned on it.
148 ///
149 /// # Examples
150 ///
151 /// ```
152 /// use tokio_util::task::JoinMap;
153 /// let map: JoinMap<&str, i32> = JoinMap::new();
154 /// ```
155 #[inline]
156 #[must_use]
157 pub fn new() -> Self {
158 Self::with_hasher(RandomState::new())
159 }
160
161 /// Creates an empty `JoinMap` with the specified capacity.
162 ///
163 /// The `JoinMap` will be able to hold at least `capacity` tasks without
164 /// reallocating.
165 ///
166 /// # Examples
167 ///
168 /// ```
169 /// use tokio_util::task::JoinMap;
170 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
171 /// ```
172 #[inline]
173 #[must_use]
174 pub fn with_capacity(capacity: usize) -> Self {
175 JoinMap::with_capacity_and_hasher(capacity, Default::default())
176 }
177}
178
179impl<K, V, S: Clone> JoinMap<K, V, S> {
180 /// Creates an empty `JoinMap` which will use the given hash builder to hash
181 /// keys.
182 ///
183 /// The created map has the default initial capacity.
184 ///
185 /// Warning: `hash_builder` is normally randomly generated, and
186 /// is designed to allow `JoinMap` to be resistant to attacks that
187 /// cause many collisions and very poor performance. Setting it
188 /// manually using this function can expose a DoS attack vector.
189 ///
190 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
191 /// the `JoinMap` to be useful, see its documentation for details.
192 #[inline]
193 #[must_use]
194 pub fn with_hasher(hash_builder: S) -> Self {
195 Self::with_capacity_and_hasher(0, hash_builder)
196 }
197
198 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
199 /// to hash the keys.
200 ///
201 /// The `JoinMap` will be able to hold at least `capacity` elements without
202 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
203 ///
204 /// Warning: `hash_builder` is normally randomly generated, and
205 /// is designed to allow HashMaps to be resistant to attacks that
206 /// cause many collisions and very poor performance. Setting it
207 /// manually using this function can expose a DoS attack vector.
208 ///
209 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
210 /// the `JoinMap`to be useful, see its documentation for details.
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// # #[tokio::main]
216 /// # async fn main() {
217 /// use tokio_util::task::JoinMap;
218 /// use std::collections::hash_map::RandomState;
219 ///
220 /// let s = RandomState::new();
221 /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
222 /// map.spawn(1, async move { "hello world!" });
223 /// # }
224 /// ```
225 #[inline]
226 #[must_use]
227 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228 Self {
229 tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
230 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231 tasks: JoinSet::new(),
232 }
233 }
234
235 /// Returns the number of tasks currently in the `JoinMap`.
236 pub fn len(&self) -> usize {
237 let len = self.tasks_by_key.len();
238 debug_assert_eq!(len, self.hashes_by_task.len());
239 len
240 }
241
242 /// Returns whether the `JoinMap` is empty.
243 pub fn is_empty(&self) -> bool {
244 let empty = self.tasks_by_key.is_empty();
245 debug_assert_eq!(empty, self.hashes_by_task.is_empty());
246 empty
247 }
248
249 /// Returns the number of tasks the map can hold without reallocating.
250 ///
251 /// This number is a lower bound; the `JoinMap` might be able to hold
252 /// more, but is guaranteed to be able to hold at least this many.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// use tokio_util::task::JoinMap;
258 ///
259 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
260 /// assert!(map.capacity() >= 100);
261 /// ```
262 #[inline]
263 pub fn capacity(&self) -> usize {
264 let capacity = self.tasks_by_key.capacity();
265 debug_assert_eq!(capacity, self.hashes_by_task.capacity());
266 capacity
267 }
268}
269
270impl<K, V, S> JoinMap<K, V, S>
271where
272 K: Hash + Eq,
273 V: 'static,
274 S: BuildHasher,
275{
276 /// Spawn the provided task and store it in this `JoinMap` with the provided
277 /// key.
278 ///
279 /// If a task previously existed in the `JoinMap` for this key, that task
280 /// will be cancelled and replaced with the new one. The previous task will
281 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
282 /// *not* return a cancelled [`JoinError`] for that task.
283 ///
284 /// # Panics
285 ///
286 /// This method panics if called outside of a Tokio runtime.
287 ///
288 /// [`join_next`]: Self::join_next
289 #[track_caller]
290 pub fn spawn<F>(&mut self, key: K, task: F)
291 where
292 F: Future<Output = V>,
293 F: Send + 'static,
294 V: Send,
295 {
296 let task = self.tasks.spawn(task);
297 self.insert(key, task)
298 }
299
300 /// Spawn the provided task on the provided runtime and store it in this
301 /// `JoinMap` with the provided key.
302 ///
303 /// If a task previously existed in the `JoinMap` for this key, that task
304 /// will be cancelled and replaced with the new one. The previous task will
305 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
306 /// *not* return a cancelled [`JoinError`] for that task.
307 ///
308 /// [`join_next`]: Self::join_next
309 #[track_caller]
310 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
311 where
312 F: Future<Output = V>,
313 F: Send + 'static,
314 V: Send,
315 {
316 let task = self.tasks.spawn_on(task, handle);
317 self.insert(key, task);
318 }
319
320 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
321 /// key.
322 ///
323 /// If a task previously existed in the `JoinMap` for this key, that task
324 /// will be cancelled and replaced with the new one. The previous task will
325 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
326 /// *not* return a cancelled [`JoinError`] for that task.
327 ///
328 /// Note that blocking tasks cannot be cancelled after execution starts.
329 /// Replaced blocking tasks will still run to completion if the task has begun
330 /// to execute when it is replaced. A blocking task which is replaced before
331 /// it has been scheduled on a blocking worker thread will be cancelled.
332 ///
333 /// # Panics
334 ///
335 /// This method panics if called outside of a Tokio runtime.
336 ///
337 /// [`join_next`]: Self::join_next
338 #[track_caller]
339 pub fn spawn_blocking<F>(&mut self, key: K, f: F)
340 where
341 F: FnOnce() -> V,
342 F: Send + 'static,
343 V: Send,
344 {
345 let task = self.tasks.spawn_blocking(f);
346 self.insert(key, task)
347 }
348
349 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
350 /// `JoinMap` with the provided key.
351 ///
352 /// If a task previously existed in the `JoinMap` for this key, that task
353 /// will be cancelled and replaced with the new one. The previous task will
354 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
355 /// *not* return a cancelled [`JoinError`] for that task.
356 ///
357 /// Note that blocking tasks cannot be cancelled after execution starts.
358 /// Replaced blocking tasks will still run to completion if the task has begun
359 /// to execute when it is replaced. A blocking task which is replaced before
360 /// it has been scheduled on a blocking worker thread will be cancelled.
361 ///
362 /// [`join_next`]: Self::join_next
363 #[track_caller]
364 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
365 where
366 F: FnOnce() -> V,
367 F: Send + 'static,
368 V: Send,
369 {
370 let task = self.tasks.spawn_blocking_on(f, handle);
371 self.insert(key, task);
372 }
373
374 /// Spawn the provided task on the current [`LocalSet`] and store it in this
375 /// `JoinMap` with the provided key.
376 ///
377 /// If a task previously existed in the `JoinMap` for this key, that task
378 /// will be cancelled and replaced with the new one. The previous task will
379 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
380 /// *not* return a cancelled [`JoinError`] for that task.
381 ///
382 /// # Panics
383 ///
384 /// This method panics if it is called outside of a `LocalSet`.
385 ///
386 /// [`LocalSet`]: tokio::task::LocalSet
387 /// [`join_next`]: Self::join_next
388 #[track_caller]
389 pub fn spawn_local<F>(&mut self, key: K, task: F)
390 where
391 F: Future<Output = V>,
392 F: 'static,
393 {
394 let task = self.tasks.spawn_local(task);
395 self.insert(key, task);
396 }
397
398 /// Spawn the provided task on the provided [`LocalSet`] and store it in
399 /// this `JoinMap` with the provided key.
400 ///
401 /// If a task previously existed in the `JoinMap` for this key, that task
402 /// will be cancelled and replaced with the new one. The previous task will
403 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
404 /// *not* return a cancelled [`JoinError`] for that task.
405 ///
406 /// [`LocalSet`]: tokio::task::LocalSet
407 /// [`join_next`]: Self::join_next
408 #[track_caller]
409 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
410 where
411 F: Future<Output = V>,
412 F: 'static,
413 {
414 let task = self.tasks.spawn_local_on(task, local_set);
415 self.insert(key, task)
416 }
417
418 fn insert(&mut self, key: K, abort: AbortHandle) {
419 let hash = self.hash(&key);
420 let id = abort.id();
421 let map_key = Key { id, key };
422
423 // Insert the new key into the map of tasks by keys.
424 let entry = self
425 .tasks_by_key
426 .raw_entry_mut()
427 .from_hash(hash, |k| k.key == map_key.key);
428 match entry {
429 RawEntryMut::Occupied(mut occ) => {
430 // There was a previous task spawned with the same key! Cancel
431 // that task, and remove its ID from the map of hashes by task IDs.
432 let Key { id: prev_id, .. } = occ.insert_key(map_key);
433 occ.insert(abort).abort();
434 let _prev_hash = self.hashes_by_task.remove(&prev_id);
435 debug_assert_eq!(Some(hash), _prev_hash);
436 }
437 RawEntryMut::Vacant(vac) => {
438 vac.insert(map_key, abort);
439 }
440 };
441
442 // Associate the key's hash with this task's ID, for looking up tasks by ID.
443 let _prev = self.hashes_by_task.insert(id, hash);
444 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445 }
446
447 /// Waits until one of the tasks in the map completes and returns its
448 /// output, along with the key corresponding to that task.
449 ///
450 /// Returns `None` if the map is empty.
451 ///
452 /// # Cancel Safety
453 ///
454 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
455 /// statement and some other branch completes first, it is guaranteed that no tasks were
456 /// removed from this `JoinMap`.
457 ///
458 /// # Returns
459 ///
460 /// This function returns:
461 ///
462 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
463 /// completed. The `value` is the return value of that ask, and `key` is
464 /// the key associated with the task.
465 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
466 /// panicked or been aborted. `key` is the key associated with the task
467 /// that panicked or was aborted.
468 /// * `None` if the `JoinMap` is empty.
469 ///
470 /// [`tokio::select!`]: tokio::select
471 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
472 loop {
473 let (res, id) = match self.tasks.join_next_with_id().await {
474 Some(Ok((id, output))) => (Ok(output), id),
475 Some(Err(e)) => {
476 let id = e.id();
477 (Err(e), id)
478 }
479 None => return None,
480 };
481 if let Some(key) = self.remove_by_id(id) {
482 break Some((key, res));
483 }
484 }
485 }
486
487 /// Aborts all tasks and waits for them to finish shutting down.
488 ///
489 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
490 /// a loop until it returns `None`.
491 ///
492 /// This method ignores any panics in the tasks shutting down. When this call returns, the
493 /// `JoinMap` will be empty.
494 ///
495 /// [`abort_all`]: fn@Self::abort_all
496 /// [`join_next`]: fn@Self::join_next
497 pub async fn shutdown(&mut self) {
498 self.abort_all();
499 while self.join_next().await.is_some() {}
500 }
501
502 /// Abort the task corresponding to the provided `key`.
503 ///
504 /// If this `JoinMap` contains a task corresponding to `key`, this method
505 /// will abort that task and return `true`. Otherwise, if no task exists for
506 /// `key`, this method returns `false`.
507 ///
508 /// # Examples
509 ///
510 /// Aborting a task by key:
511 ///
512 /// ```
513 /// use tokio_util::task::JoinMap;
514 ///
515 /// # #[tokio::main]
516 /// # async fn main() {
517 /// let mut map = JoinMap::new();
518 ///
519 /// map.spawn("hello world", async move { /* ... */ });
520 /// map.spawn("goodbye world", async move { /* ... */});
521 ///
522 /// // Look up the "goodbye world" task in the map and abort it.
523 /// map.abort("goodbye world");
524 ///
525 /// while let Some((key, res)) = map.join_next().await {
526 /// if key == "goodbye world" {
527 /// // The aborted task should complete with a cancelled `JoinError`.
528 /// assert!(res.unwrap_err().is_cancelled());
529 /// } else {
530 /// // Other tasks should complete normally.
531 /// assert!(res.is_ok());
532 /// }
533 /// }
534 /// # }
535 /// ```
536 ///
537 /// `abort` returns `true` if a task was aborted:
538 /// ```
539 /// use tokio_util::task::JoinMap;
540 ///
541 /// # #[tokio::main]
542 /// # async fn main() {
543 /// let mut map = JoinMap::new();
544 ///
545 /// map.spawn("hello world", async move { /* ... */ });
546 /// map.spawn("goodbye world", async move { /* ... */});
547 ///
548 /// // A task for the key "goodbye world" should exist in the map:
549 /// assert!(map.abort("goodbye world"));
550 ///
551 /// // Aborting a key that does not exist will return `false`:
552 /// assert!(!map.abort("goodbye universe"));
553 /// # }
554 /// ```
555 pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
556 where
557 Q: Hash + Eq,
558 K: Borrow<Q>,
559 {
560 match self.get_by_key(key) {
561 Some((_, handle)) => {
562 handle.abort();
563 true
564 }
565 None => false,
566 }
567 }
568
569 /// Aborts all tasks with keys matching `predicate`.
570 ///
571 /// `predicate` is a function called with a reference to each key in the
572 /// map. If it returns `true` for a given key, the corresponding task will
573 /// be cancelled.
574 ///
575 /// # Examples
576 /// ```
577 /// use tokio_util::task::JoinMap;
578 ///
579 /// # // use the current thread rt so that spawned tasks don't
580 /// # // complete in the background before they can be aborted.
581 /// # #[tokio::main(flavor = "current_thread")]
582 /// # async fn main() {
583 /// let mut map = JoinMap::new();
584 ///
585 /// map.spawn("hello world", async move {
586 /// // ...
587 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
588 /// });
589 /// map.spawn("goodbye world", async move {
590 /// // ...
591 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
592 /// });
593 /// map.spawn("hello san francisco", async move {
594 /// // ...
595 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
596 /// });
597 /// map.spawn("goodbye universe", async move {
598 /// // ...
599 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
600 /// });
601 ///
602 /// // Abort all tasks whose keys begin with "goodbye"
603 /// map.abort_matching(|key| key.starts_with("goodbye"));
604 ///
605 /// let mut seen = 0;
606 /// while let Some((key, res)) = map.join_next().await {
607 /// seen += 1;
608 /// if key.starts_with("goodbye") {
609 /// // The aborted task should complete with a cancelled `JoinError`.
610 /// assert!(res.unwrap_err().is_cancelled());
611 /// } else {
612 /// // Other tasks should complete normally.
613 /// assert!(key.starts_with("hello"));
614 /// assert!(res.is_ok());
615 /// }
616 /// }
617 ///
618 /// // All spawned tasks should have completed.
619 /// assert_eq!(seen, 4);
620 /// # }
621 /// ```
622 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
623 // Note: this method iterates over the tasks and keys *without* removing
624 // any entries, so that the keys from aborted tasks can still be
625 // returned when calling `join_next` in the future.
626 for (Key { ref key, .. }, task) in &self.tasks_by_key {
627 if predicate(key) {
628 task.abort();
629 }
630 }
631 }
632
633 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
634 ///
635 /// If a task has completed, but its output hasn't yet been consumed by a
636 /// call to [`join_next`], this method will still return its key.
637 ///
638 /// [`join_next`]: fn@Self::join_next
639 pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
640 JoinMapKeys {
641 iter: self.tasks_by_key.keys(),
642 _value: PhantomData,
643 }
644 }
645
646 /// Returns `true` if this `JoinMap` contains a task for the provided key.
647 ///
648 /// If the task has completed, but its output hasn't yet been consumed by a
649 /// call to [`join_next`], this method will still return `true`.
650 ///
651 /// [`join_next`]: fn@Self::join_next
652 pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
653 where
654 Q: Hash + Eq,
655 K: Borrow<Q>,
656 {
657 self.get_by_key(key).is_some()
658 }
659
660 /// Returns `true` if this `JoinMap` contains a task with the provided
661 /// [task ID].
662 ///
663 /// If the task has completed, but its output hasn't yet been consumed by a
664 /// call to [`join_next`], this method will still return `true`.
665 ///
666 /// [`join_next`]: fn@Self::join_next
667 /// [task ID]: tokio::task::Id
668 pub fn contains_task(&self, task: &Id) -> bool {
669 self.get_by_id(task).is_some()
670 }
671
672 /// Reserves capacity for at least `additional` more tasks to be spawned
673 /// on this `JoinMap` without reallocating for the map of task keys. The
674 /// collection may reserve more space to avoid frequent reallocations.
675 ///
676 /// Note that spawning a task will still cause an allocation for the task
677 /// itself.
678 ///
679 /// # Panics
680 ///
681 /// Panics if the new allocation size overflows [`usize`].
682 ///
683 /// # Examples
684 ///
685 /// ```
686 /// use tokio_util::task::JoinMap;
687 ///
688 /// let mut map: JoinMap<&str, i32> = JoinMap::new();
689 /// map.reserve(10);
690 /// ```
691 #[inline]
692 pub fn reserve(&mut self, additional: usize) {
693 self.tasks_by_key.reserve(additional);
694 self.hashes_by_task.reserve(additional);
695 }
696
697 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
698 /// down as much as possible while maintaining the internal rules
699 /// and possibly leaving some space in accordance with the resize policy.
700 ///
701 /// # Examples
702 ///
703 /// ```
704 /// # #[tokio::main]
705 /// # async fn main() {
706 /// use tokio_util::task::JoinMap;
707 ///
708 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
709 /// map.spawn(1, async move { 2 });
710 /// map.spawn(3, async move { 4 });
711 /// assert!(map.capacity() >= 100);
712 /// map.shrink_to_fit();
713 /// assert!(map.capacity() >= 2);
714 /// # }
715 /// ```
716 #[inline]
717 pub fn shrink_to_fit(&mut self) {
718 self.hashes_by_task.shrink_to_fit();
719 self.tasks_by_key.shrink_to_fit();
720 }
721
722 /// Shrinks the capacity of the map with a lower limit. It will drop
723 /// down no lower than the supplied limit while maintaining the internal rules
724 /// and possibly leaving some space in accordance with the resize policy.
725 ///
726 /// If the current capacity is less than the lower limit, this is a no-op.
727 ///
728 /// # Examples
729 ///
730 /// ```
731 /// # #[tokio::main]
732 /// # async fn main() {
733 /// use tokio_util::task::JoinMap;
734 ///
735 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
736 /// map.spawn(1, async move { 2 });
737 /// map.spawn(3, async move { 4 });
738 /// assert!(map.capacity() >= 100);
739 /// map.shrink_to(10);
740 /// assert!(map.capacity() >= 10);
741 /// map.shrink_to(0);
742 /// assert!(map.capacity() >= 2);
743 /// # }
744 /// ```
745 #[inline]
746 pub fn shrink_to(&mut self, min_capacity: usize) {
747 self.hashes_by_task.shrink_to(min_capacity);
748 self.tasks_by_key.shrink_to(min_capacity)
749 }
750
751 /// Look up a task in the map by its key, returning the key and abort handle.
752 fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
753 where
754 Q: Hash + Eq,
755 K: Borrow<Q>,
756 {
757 let hash = self.hash(key);
758 self.tasks_by_key
759 .raw_entry()
760 .from_hash(hash, |k| k.key.borrow() == key)
761 }
762
763 /// Look up a task in the map by its task ID, returning the key and abort handle.
764 fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
765 let hash = self.hashes_by_task.get(id)?;
766 self.tasks_by_key
767 .raw_entry()
768 .from_hash(*hash, |k| &k.id == id)
769 }
770
771 /// Remove a task from the map by ID, returning the key for that task.
772 fn remove_by_id(&mut self, id: Id) -> Option<K> {
773 // Get the hash for the given ID.
774 let hash = self.hashes_by_task.remove(&id)?;
775
776 // Remove the entry for that hash.
777 let entry = self
778 .tasks_by_key
779 .raw_entry_mut()
780 .from_hash(hash, |k| k.id == id);
781 let (Key { id: _key_id, key }, handle) = match entry {
782 RawEntryMut::Occupied(entry) => entry.remove_entry(),
783 _ => return None,
784 };
785 debug_assert_eq!(_key_id, id);
786 debug_assert_eq!(id, handle.id());
787 self.hashes_by_task.remove(&id);
788 Some(key)
789 }
790
791 /// Returns the hash for a given key.
792 #[inline]
793 fn hash<Q: ?Sized>(&self, key: &Q) -> u64
794 where
795 Q: Hash,
796 {
797 let mut hasher = self.tasks_by_key.hasher().build_hasher();
798 key.hash(&mut hasher);
799 hasher.finish()
800 }
801}
802
803impl<K, V, S> JoinMap<K, V, S>
804where
805 V: 'static,
806{
807 /// Aborts all tasks on this `JoinMap`.
808 ///
809 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
810 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
811 pub fn abort_all(&mut self) {
812 self.tasks.abort_all()
813 }
814
815 /// Removes all tasks from this `JoinMap` without aborting them.
816 ///
817 /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
818 /// is dropped. They may still be aborted by key.
819 pub fn detach_all(&mut self) {
820 self.tasks.detach_all();
821 self.tasks_by_key.clear();
822 self.hashes_by_task.clear();
823 }
824}
825
826// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
827// Debug`, since no value is ever actually stored in the map.
828impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
829 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
830 // format the task keys and abort handles a little nicer by just
831 // printing the key and task ID pairs, without format the `Key` struct
832 // itself or the `AbortHandle`, which would just format the task's ID
833 // again.
834 struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
835 impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
836 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
837 f.debug_map()
838 .entries(self.0.keys().map(|Key { key, id }| (key, id)))
839 .finish()
840 }
841 }
842
843 f.debug_struct("JoinMap")
844 // The `tasks_by_key` map is the only one that contains information
845 // that's really worth formatting for the user, since it contains
846 // the tasks' keys and IDs. The other fields are basically
847 // implementation details.
848 .field("tasks", &KeySet(&self.tasks_by_key))
849 .finish()
850 }
851}
852
853impl<K, V> Default for JoinMap<K, V> {
854 fn default() -> Self {
855 Self::new()
856 }
857}
858
859// === impl Key ===
860
861impl<K: Hash> Hash for Key<K> {
862 // Don't include the task ID in the hash.
863 #[inline]
864 fn hash<H: Hasher>(&self, hasher: &mut H) {
865 self.key.hash(hasher);
866 }
867}
868
869// Because we override `Hash` for this type, we must also override the
870// `PartialEq` impl, so that all instances with the same hash are equal.
871impl<K: PartialEq> PartialEq for Key<K> {
872 #[inline]
873 fn eq(&self, other: &Self) -> bool {
874 self.key == other.key
875 }
876}
877
878impl<K: Eq> Eq for Key<K> {}
879
880/// An iterator over the keys of a [`JoinMap`].
881#[derive(Debug, Clone)]
882pub struct JoinMapKeys<'a, K, V> {
883 iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
884 /// To make it easier to change `JoinMap` in the future, keep V as a generic
885 /// parameter.
886 _value: PhantomData<&'a V>,
887}
888
889impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
890 type Item = &'a K;
891
892 fn next(&mut self) -> Option<&'a K> {
893 self.iter.next().map(|key| &key.key)
894 }
895
896 fn size_hint(&self) -> (usize, Option<usize>) {
897 self.iter.size_hint()
898 }
899}
900
901impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
902 fn len(&self) -> usize {
903 self.iter.len()
904 }
905}
906
907impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}