tokio_util/task/join_map.rs
1use hashbrown::hash_map::RawEntryMut;
2use hashbrown::HashMap;
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33/// documentation on unstable features][unstable] for details on how to enable
34/// Tokio's unstable features.
35///
36/// # Examples
37///
38/// Spawn multiple tasks and wait for them:
39///
40/// ```
41/// use tokio_util::task::JoinMap;
42///
43/// #[tokio::main]
44/// async fn main() {
45/// let mut map = JoinMap::new();
46///
47/// for i in 0..10 {
48/// // Spawn a task on the `JoinMap` with `i` as its key.
49/// map.spawn(i, async move { /* ... */ });
50/// }
51///
52/// let mut seen = [false; 10];
53///
54/// // When a task completes, `join_next` returns the task's key along
55/// // with its output.
56/// while let Some((key, res)) = map.join_next().await {
57/// seen[key] = true;
58/// assert!(res.is_ok(), "task {} completed successfully!", key);
59/// }
60///
61/// for i in 0..10 {
62/// assert!(seen[i]);
63/// }
64/// }
65/// ```
66///
67/// Cancel tasks based on their keys:
68///
69/// ```
70/// use tokio_util::task::JoinMap;
71///
72/// #[tokio::main]
73/// async fn main() {
74/// let mut map = JoinMap::new();
75///
76/// map.spawn("hello world", async move { /* ... */ });
77/// map.spawn("goodbye world", async move { /* ... */});
78///
79/// // Look up the "goodbye world" task in the map and abort it.
80/// let aborted = map.abort("goodbye world");
81///
82/// // `JoinMap::abort` returns `true` if a task existed for the
83/// // provided key.
84/// assert!(aborted);
85///
86/// while let Some((key, res)) = map.join_next().await {
87/// if key == "goodbye world" {
88/// // The aborted task should complete with a cancelled `JoinError`.
89/// assert!(res.unwrap_err().is_cancelled());
90/// } else {
91/// // Other tasks should complete normally.
92/// assert!(res.is_ok());
93/// }
94/// }
95/// }
96/// ```
97///
98/// [`JoinSet`]: tokio::task::JoinSet
99/// [unstable]: tokio#unstable-features
100/// [abort]: fn@Self::abort
101/// [abort_matching]: fn@Self::abort_matching
102/// [contains]: fn@Self::contains_key
103#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104pub struct JoinMap<K, V, S = RandomState> {
105 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106 /// indexed by their keys and task IDs.
107 ///
108 /// The [`Key`] type contains both the task's `K`-typed key provided when
109 /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110 /// hash collisions when looking up tasks based on their pre-computed hash
111 /// (as stored in the `hashes_by_task` map).
112 tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113
114 /// A map from task IDs to the hash of the key associated with that task.
115 ///
116 /// This map is used to perform reverse lookups of tasks in the
117 /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118 /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119 /// of that task's key, and then remove it from the `tasks_by_key` map using
120 /// the raw hash code, resolving collisions by comparing task IDs.
121 hashes_by_task: HashMap<Id, u64, S>,
122
123 /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124 /// `JoinMap`.
125 tasks: JoinSet<V>,
126}
127
128/// A [`JoinMap`] key.
129///
130/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131/// a task ID, so that hash collisions between `K`-typed keys can be resolved
132/// using either `K`'s `Eq` impl *or* by checking the task IDs.
133///
134/// This allows looking up a task using either an actual key (such as when the
135/// user queries the map with a key), *or* using a task ID and a hash (such as
136/// when removing completed tasks from the map).
137#[derive(Debug)]
138struct Key<K> {
139 key: K,
140 id: Id,
141}
142
143impl<K, V> JoinMap<K, V> {
144 /// Creates a new empty `JoinMap`.
145 ///
146 /// The `JoinMap` is initially created with a capacity of 0, so it will not
147 /// allocate until a task is first spawned on it.
148 ///
149 /// # Examples
150 ///
151 /// ```
152 /// use tokio_util::task::JoinMap;
153 /// let map: JoinMap<&str, i32> = JoinMap::new();
154 /// ```
155 #[inline]
156 #[must_use]
157 pub fn new() -> Self {
158 Self::with_hasher(RandomState::new())
159 }
160
161 /// Creates an empty `JoinMap` with the specified capacity.
162 ///
163 /// The `JoinMap` will be able to hold at least `capacity` tasks without
164 /// reallocating.
165 ///
166 /// # Examples
167 ///
168 /// ```
169 /// use tokio_util::task::JoinMap;
170 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
171 /// ```
172 #[inline]
173 #[must_use]
174 pub fn with_capacity(capacity: usize) -> Self {
175 JoinMap::with_capacity_and_hasher(capacity, Default::default())
176 }
177}
178
179impl<K, V, S: Clone> JoinMap<K, V, S> {
180 /// Creates an empty `JoinMap` which will use the given hash builder to hash
181 /// keys.
182 ///
183 /// The created map has the default initial capacity.
184 ///
185 /// Warning: `hash_builder` is normally randomly generated, and
186 /// is designed to allow `JoinMap` to be resistant to attacks that
187 /// cause many collisions and very poor performance. Setting it
188 /// manually using this function can expose a DoS attack vector.
189 ///
190 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
191 /// the `JoinMap` to be useful, see its documentation for details.
192 #[inline]
193 #[must_use]
194 pub fn with_hasher(hash_builder: S) -> Self {
195 Self::with_capacity_and_hasher(0, hash_builder)
196 }
197
198 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
199 /// to hash the keys.
200 ///
201 /// The `JoinMap` will be able to hold at least `capacity` elements without
202 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
203 ///
204 /// Warning: `hash_builder` is normally randomly generated, and
205 /// is designed to allow HashMaps to be resistant to attacks that
206 /// cause many collisions and very poor performance. Setting it
207 /// manually using this function can expose a DoS attack vector.
208 ///
209 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
210 /// the `JoinMap`to be useful, see its documentation for details.
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// # #[tokio::main]
216 /// # async fn main() {
217 /// use tokio_util::task::JoinMap;
218 /// use std::collections::hash_map::RandomState;
219 ///
220 /// let s = RandomState::new();
221 /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
222 /// map.spawn(1, async move { "hello world!" });
223 /// # }
224 /// ```
225 #[inline]
226 #[must_use]
227 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228 Self {
229 tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
230 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231 tasks: JoinSet::new(),
232 }
233 }
234
235 /// Returns the number of tasks currently in the `JoinMap`.
236 pub fn len(&self) -> usize {
237 let len = self.tasks_by_key.len();
238 debug_assert_eq!(len, self.hashes_by_task.len());
239 len
240 }
241
242 /// Returns whether the `JoinMap` is empty.
243 pub fn is_empty(&self) -> bool {
244 let empty = self.tasks_by_key.is_empty();
245 debug_assert_eq!(empty, self.hashes_by_task.is_empty());
246 empty
247 }
248
249 /// Returns the number of tasks the map can hold without reallocating.
250 ///
251 /// This number is a lower bound; the `JoinMap` might be able to hold
252 /// more, but is guaranteed to be able to hold at least this many.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// use tokio_util::task::JoinMap;
258 ///
259 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
260 /// assert!(map.capacity() >= 100);
261 /// ```
262 #[inline]
263 pub fn capacity(&self) -> usize {
264 let capacity = self.tasks_by_key.capacity();
265 debug_assert_eq!(capacity, self.hashes_by_task.capacity());
266 capacity
267 }
268}
269
270impl<K, V, S> JoinMap<K, V, S>
271where
272 K: Hash + Eq,
273 V: 'static,
274 S: BuildHasher,
275{
276 /// Spawn the provided task and store it in this `JoinMap` with the provided
277 /// key.
278 ///
279 /// If a task previously existed in the `JoinMap` for this key, that task
280 /// will be cancelled and replaced with the new one. The previous task will
281 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
282 /// *not* return a cancelled [`JoinError`] for that task.
283 ///
284 /// # Panics
285 ///
286 /// This method panics if called outside of a Tokio runtime.
287 ///
288 /// [`join_next`]: Self::join_next
289 #[track_caller]
290 pub fn spawn<F>(&mut self, key: K, task: F)
291 where
292 F: Future<Output = V>,
293 F: Send + 'static,
294 V: Send,
295 {
296 let task = self.tasks.spawn(task);
297 self.insert(key, task)
298 }
299
300 /// Spawn the provided task on the provided runtime and store it in this
301 /// `JoinMap` with the provided key.
302 ///
303 /// If a task previously existed in the `JoinMap` for this key, that task
304 /// will be cancelled and replaced with the new one. The previous task will
305 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
306 /// *not* return a cancelled [`JoinError`] for that task.
307 ///
308 /// [`join_next`]: Self::join_next
309 #[track_caller]
310 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
311 where
312 F: Future<Output = V>,
313 F: Send + 'static,
314 V: Send,
315 {
316 let task = self.tasks.spawn_on(task, handle);
317 self.insert(key, task);
318 }
319
320 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
321 /// key.
322 ///
323 /// If a task previously existed in the `JoinMap` for this key, that task
324 /// will be cancelled and replaced with the new one. The previous task will
325 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
326 /// *not* return a cancelled [`JoinError`] for that task.
327 ///
328 /// Note that blocking tasks cannot be cancelled after execution starts.
329 /// Replaced blocking tasks will still run to completion if the task has begun
330 /// to execute when it is replaced. A blocking task which is replaced before
331 /// it has been scheduled on a blocking worker thread will be cancelled.
332 ///
333 /// # Panics
334 ///
335 /// This method panics if called outside of a Tokio runtime.
336 ///
337 /// [`join_next`]: Self::join_next
338 #[track_caller]
339 pub fn spawn_blocking<F>(&mut self, key: K, f: F)
340 where
341 F: FnOnce() -> V,
342 F: Send + 'static,
343 V: Send,
344 {
345 let task = self.tasks.spawn_blocking(f);
346 self.insert(key, task)
347 }
348
349 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
350 /// `JoinMap` with the provided key.
351 ///
352 /// If a task previously existed in the `JoinMap` for this key, that task
353 /// will be cancelled and replaced with the new one. The previous task will
354 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
355 /// *not* return a cancelled [`JoinError`] for that task.
356 ///
357 /// Note that blocking tasks cannot be cancelled after execution starts.
358 /// Replaced blocking tasks will still run to completion if the task has begun
359 /// to execute when it is replaced. A blocking task which is replaced before
360 /// it has been scheduled on a blocking worker thread will be cancelled.
361 ///
362 /// [`join_next`]: Self::join_next
363 #[track_caller]
364 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
365 where
366 F: FnOnce() -> V,
367 F: Send + 'static,
368 V: Send,
369 {
370 let task = self.tasks.spawn_blocking_on(f, handle);
371 self.insert(key, task);
372 }
373
374 /// Spawn the provided task on the current [`LocalSet`] and store it in this
375 /// `JoinMap` with the provided key.
376 ///
377 /// If a task previously existed in the `JoinMap` for this key, that task
378 /// will be cancelled and replaced with the new one. The previous task will
379 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
380 /// *not* return a cancelled [`JoinError`] for that task.
381 ///
382 /// # Panics
383 ///
384 /// This method panics if it is called outside of a `LocalSet`.
385 ///
386 /// [`LocalSet`]: tokio::task::LocalSet
387 /// [`join_next`]: Self::join_next
388 #[track_caller]
389 pub fn spawn_local<F>(&mut self, key: K, task: F)
390 where
391 F: Future<Output = V>,
392 F: 'static,
393 {
394 let task = self.tasks.spawn_local(task);
395 self.insert(key, task);
396 }
397
398 /// Spawn the provided task on the provided [`LocalSet`] and store it in
399 /// this `JoinMap` with the provided key.
400 ///
401 /// If a task previously existed in the `JoinMap` for this key, that task
402 /// will be cancelled and replaced with the new one. The previous task will
403 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
404 /// *not* return a cancelled [`JoinError`] for that task.
405 ///
406 /// [`LocalSet`]: tokio::task::LocalSet
407 /// [`join_next`]: Self::join_next
408 #[track_caller]
409 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
410 where
411 F: Future<Output = V>,
412 F: 'static,
413 {
414 let task = self.tasks.spawn_local_on(task, local_set);
415 self.insert(key, task)
416 }
417
418 fn insert(&mut self, key: K, abort: AbortHandle) {
419 let hash = self.hash(&key);
420 let id = abort.id();
421 let map_key = Key { id, key };
422
423 // Insert the new key into the map of tasks by keys.
424 let entry = self
425 .tasks_by_key
426 .raw_entry_mut()
427 .from_hash(hash, |k| k.key == map_key.key);
428 match entry {
429 RawEntryMut::Occupied(mut occ) => {
430 // There was a previous task spawned with the same key! Cancel
431 // that task, and remove its ID from the map of hashes by task IDs.
432 let Key { id: prev_id, .. } = occ.insert_key(map_key);
433 occ.insert(abort).abort();
434 let _prev_hash = self.hashes_by_task.remove(&prev_id);
435 debug_assert_eq!(Some(hash), _prev_hash);
436 }
437 RawEntryMut::Vacant(vac) => {
438 vac.insert(map_key, abort);
439 }
440 };
441
442 // Associate the key's hash with this task's ID, for looking up tasks by ID.
443 let _prev = self.hashes_by_task.insert(id, hash);
444 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445 }
446
447 /// Waits until one of the tasks in the map completes and returns its
448 /// output, along with the key corresponding to that task.
449 ///
450 /// Returns `None` if the map is empty.
451 ///
452 /// # Cancel Safety
453 ///
454 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
455 /// statement and some other branch completes first, it is guaranteed that no tasks were
456 /// removed from this `JoinMap`.
457 ///
458 /// # Returns
459 ///
460 /// This function returns:
461 ///
462 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
463 /// completed. The `value` is the return value of that ask, and `key` is
464 /// the key associated with the task.
465 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
466 /// panicked or been aborted. `key` is the key associated with the task
467 /// that panicked or was aborted.
468 /// * `None` if the `JoinMap` is empty.
469 ///
470 /// [`tokio::select!`]: tokio::select
471 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
472 let (res, id) = match self.tasks.join_next_with_id().await {
473 Some(Ok((id, output))) => (Ok(output), id),
474 Some(Err(e)) => {
475 let id = e.id();
476 (Err(e), id)
477 }
478 None => return None,
479 };
480 let key = self.remove_by_id(id)?;
481 Some((key, res))
482 }
483
484 /// Aborts all tasks and waits for them to finish shutting down.
485 ///
486 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
487 /// a loop until it returns `None`.
488 ///
489 /// This method ignores any panics in the tasks shutting down. When this call returns, the
490 /// `JoinMap` will be empty.
491 ///
492 /// [`abort_all`]: fn@Self::abort_all
493 /// [`join_next`]: fn@Self::join_next
494 pub async fn shutdown(&mut self) {
495 self.abort_all();
496 while self.join_next().await.is_some() {}
497 }
498
499 /// Abort the task corresponding to the provided `key`.
500 ///
501 /// If this `JoinMap` contains a task corresponding to `key`, this method
502 /// will abort that task and return `true`. Otherwise, if no task exists for
503 /// `key`, this method returns `false`.
504 ///
505 /// # Examples
506 ///
507 /// Aborting a task by key:
508 ///
509 /// ```
510 /// use tokio_util::task::JoinMap;
511 ///
512 /// # #[tokio::main]
513 /// # async fn main() {
514 /// let mut map = JoinMap::new();
515 ///
516 /// map.spawn("hello world", async move { /* ... */ });
517 /// map.spawn("goodbye world", async move { /* ... */});
518 ///
519 /// // Look up the "goodbye world" task in the map and abort it.
520 /// map.abort("goodbye world");
521 ///
522 /// while let Some((key, res)) = map.join_next().await {
523 /// if key == "goodbye world" {
524 /// // The aborted task should complete with a cancelled `JoinError`.
525 /// assert!(res.unwrap_err().is_cancelled());
526 /// } else {
527 /// // Other tasks should complete normally.
528 /// assert!(res.is_ok());
529 /// }
530 /// }
531 /// # }
532 /// ```
533 ///
534 /// `abort` returns `true` if a task was aborted:
535 /// ```
536 /// use tokio_util::task::JoinMap;
537 ///
538 /// # #[tokio::main]
539 /// # async fn main() {
540 /// let mut map = JoinMap::new();
541 ///
542 /// map.spawn("hello world", async move { /* ... */ });
543 /// map.spawn("goodbye world", async move { /* ... */});
544 ///
545 /// // A task for the key "goodbye world" should exist in the map:
546 /// assert!(map.abort("goodbye world"));
547 ///
548 /// // Aborting a key that does not exist will return `false`:
549 /// assert!(!map.abort("goodbye universe"));
550 /// # }
551 /// ```
552 pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
553 where
554 Q: Hash + Eq,
555 K: Borrow<Q>,
556 {
557 match self.get_by_key(key) {
558 Some((_, handle)) => {
559 handle.abort();
560 true
561 }
562 None => false,
563 }
564 }
565
566 /// Aborts all tasks with keys matching `predicate`.
567 ///
568 /// `predicate` is a function called with a reference to each key in the
569 /// map. If it returns `true` for a given key, the corresponding task will
570 /// be cancelled.
571 ///
572 /// # Examples
573 /// ```
574 /// use tokio_util::task::JoinMap;
575 ///
576 /// # // use the current thread rt so that spawned tasks don't
577 /// # // complete in the background before they can be aborted.
578 /// # #[tokio::main(flavor = "current_thread")]
579 /// # async fn main() {
580 /// let mut map = JoinMap::new();
581 ///
582 /// map.spawn("hello world", async move {
583 /// // ...
584 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
585 /// });
586 /// map.spawn("goodbye world", async move {
587 /// // ...
588 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
589 /// });
590 /// map.spawn("hello san francisco", async move {
591 /// // ...
592 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
593 /// });
594 /// map.spawn("goodbye universe", async move {
595 /// // ...
596 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
597 /// });
598 ///
599 /// // Abort all tasks whose keys begin with "goodbye"
600 /// map.abort_matching(|key| key.starts_with("goodbye"));
601 ///
602 /// let mut seen = 0;
603 /// while let Some((key, res)) = map.join_next().await {
604 /// seen += 1;
605 /// if key.starts_with("goodbye") {
606 /// // The aborted task should complete with a cancelled `JoinError`.
607 /// assert!(res.unwrap_err().is_cancelled());
608 /// } else {
609 /// // Other tasks should complete normally.
610 /// assert!(key.starts_with("hello"));
611 /// assert!(res.is_ok());
612 /// }
613 /// }
614 ///
615 /// // All spawned tasks should have completed.
616 /// assert_eq!(seen, 4);
617 /// # }
618 /// ```
619 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
620 // Note: this method iterates over the tasks and keys *without* removing
621 // any entries, so that the keys from aborted tasks can still be
622 // returned when calling `join_next` in the future.
623 for (Key { ref key, .. }, task) in &self.tasks_by_key {
624 if predicate(key) {
625 task.abort();
626 }
627 }
628 }
629
630 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
631 ///
632 /// If a task has completed, but its output hasn't yet been consumed by a
633 /// call to [`join_next`], this method will still return its key.
634 ///
635 /// [`join_next`]: fn@Self::join_next
636 pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
637 JoinMapKeys {
638 iter: self.tasks_by_key.keys(),
639 _value: PhantomData,
640 }
641 }
642
643 /// Returns `true` if this `JoinMap` contains a task for the provided key.
644 ///
645 /// If the task has completed, but its output hasn't yet been consumed by a
646 /// call to [`join_next`], this method will still return `true`.
647 ///
648 /// [`join_next`]: fn@Self::join_next
649 pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
650 where
651 Q: Hash + Eq,
652 K: Borrow<Q>,
653 {
654 self.get_by_key(key).is_some()
655 }
656
657 /// Returns `true` if this `JoinMap` contains a task with the provided
658 /// [task ID].
659 ///
660 /// If the task has completed, but its output hasn't yet been consumed by a
661 /// call to [`join_next`], this method will still return `true`.
662 ///
663 /// [`join_next`]: fn@Self::join_next
664 /// [task ID]: tokio::task::Id
665 pub fn contains_task(&self, task: &Id) -> bool {
666 self.get_by_id(task).is_some()
667 }
668
669 /// Reserves capacity for at least `additional` more tasks to be spawned
670 /// on this `JoinMap` without reallocating for the map of task keys. The
671 /// collection may reserve more space to avoid frequent reallocations.
672 ///
673 /// Note that spawning a task will still cause an allocation for the task
674 /// itself.
675 ///
676 /// # Panics
677 ///
678 /// Panics if the new allocation size overflows [`usize`].
679 ///
680 /// # Examples
681 ///
682 /// ```
683 /// use tokio_util::task::JoinMap;
684 ///
685 /// let mut map: JoinMap<&str, i32> = JoinMap::new();
686 /// map.reserve(10);
687 /// ```
688 #[inline]
689 pub fn reserve(&mut self, additional: usize) {
690 self.tasks_by_key.reserve(additional);
691 self.hashes_by_task.reserve(additional);
692 }
693
694 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
695 /// down as much as possible while maintaining the internal rules
696 /// and possibly leaving some space in accordance with the resize policy.
697 ///
698 /// # Examples
699 ///
700 /// ```
701 /// # #[tokio::main]
702 /// # async fn main() {
703 /// use tokio_util::task::JoinMap;
704 ///
705 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
706 /// map.spawn(1, async move { 2 });
707 /// map.spawn(3, async move { 4 });
708 /// assert!(map.capacity() >= 100);
709 /// map.shrink_to_fit();
710 /// assert!(map.capacity() >= 2);
711 /// # }
712 /// ```
713 #[inline]
714 pub fn shrink_to_fit(&mut self) {
715 self.hashes_by_task.shrink_to_fit();
716 self.tasks_by_key.shrink_to_fit();
717 }
718
719 /// Shrinks the capacity of the map with a lower limit. It will drop
720 /// down no lower than the supplied limit while maintaining the internal rules
721 /// and possibly leaving some space in accordance with the resize policy.
722 ///
723 /// If the current capacity is less than the lower limit, this is a no-op.
724 ///
725 /// # Examples
726 ///
727 /// ```
728 /// # #[tokio::main]
729 /// # async fn main() {
730 /// use tokio_util::task::JoinMap;
731 ///
732 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
733 /// map.spawn(1, async move { 2 });
734 /// map.spawn(3, async move { 4 });
735 /// assert!(map.capacity() >= 100);
736 /// map.shrink_to(10);
737 /// assert!(map.capacity() >= 10);
738 /// map.shrink_to(0);
739 /// assert!(map.capacity() >= 2);
740 /// # }
741 /// ```
742 #[inline]
743 pub fn shrink_to(&mut self, min_capacity: usize) {
744 self.hashes_by_task.shrink_to(min_capacity);
745 self.tasks_by_key.shrink_to(min_capacity)
746 }
747
748 /// Look up a task in the map by its key, returning the key and abort handle.
749 fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
750 where
751 Q: Hash + Eq,
752 K: Borrow<Q>,
753 {
754 let hash = self.hash(key);
755 self.tasks_by_key
756 .raw_entry()
757 .from_hash(hash, |k| k.key.borrow() == key)
758 }
759
760 /// Look up a task in the map by its task ID, returning the key and abort handle.
761 fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
762 let hash = self.hashes_by_task.get(id)?;
763 self.tasks_by_key
764 .raw_entry()
765 .from_hash(*hash, |k| &k.id == id)
766 }
767
768 /// Remove a task from the map by ID, returning the key for that task.
769 fn remove_by_id(&mut self, id: Id) -> Option<K> {
770 // Get the hash for the given ID.
771 let hash = self.hashes_by_task.remove(&id)?;
772
773 // Remove the entry for that hash.
774 let entry = self
775 .tasks_by_key
776 .raw_entry_mut()
777 .from_hash(hash, |k| k.id == id);
778 let (Key { id: _key_id, key }, handle) = match entry {
779 RawEntryMut::Occupied(entry) => entry.remove_entry(),
780 _ => return None,
781 };
782 debug_assert_eq!(_key_id, id);
783 debug_assert_eq!(id, handle.id());
784 self.hashes_by_task.remove(&id);
785 Some(key)
786 }
787
788 /// Returns the hash for a given key.
789 #[inline]
790 fn hash<Q: ?Sized>(&self, key: &Q) -> u64
791 where
792 Q: Hash,
793 {
794 let mut hasher = self.tasks_by_key.hasher().build_hasher();
795 key.hash(&mut hasher);
796 hasher.finish()
797 }
798}
799
800impl<K, V, S> JoinMap<K, V, S>
801where
802 V: 'static,
803{
804 /// Aborts all tasks on this `JoinMap`.
805 ///
806 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
807 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
808 pub fn abort_all(&mut self) {
809 self.tasks.abort_all()
810 }
811
812 /// Removes all tasks from this `JoinMap` without aborting them.
813 ///
814 /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
815 /// is dropped. They may still be aborted by key.
816 pub fn detach_all(&mut self) {
817 self.tasks.detach_all();
818 self.tasks_by_key.clear();
819 self.hashes_by_task.clear();
820 }
821}
822
823// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
824// Debug`, since no value is ever actually stored in the map.
825impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
826 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
827 // format the task keys and abort handles a little nicer by just
828 // printing the key and task ID pairs, without format the `Key` struct
829 // itself or the `AbortHandle`, which would just format the task's ID
830 // again.
831 struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
832 impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
833 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
834 f.debug_map()
835 .entries(self.0.keys().map(|Key { key, id }| (key, id)))
836 .finish()
837 }
838 }
839
840 f.debug_struct("JoinMap")
841 // The `tasks_by_key` map is the only one that contains information
842 // that's really worth formatting for the user, since it contains
843 // the tasks' keys and IDs. The other fields are basically
844 // implementation details.
845 .field("tasks", &KeySet(&self.tasks_by_key))
846 .finish()
847 }
848}
849
850impl<K, V> Default for JoinMap<K, V> {
851 fn default() -> Self {
852 Self::new()
853 }
854}
855
856// === impl Key ===
857
858impl<K: Hash> Hash for Key<K> {
859 // Don't include the task ID in the hash.
860 #[inline]
861 fn hash<H: Hasher>(&self, hasher: &mut H) {
862 self.key.hash(hasher);
863 }
864}
865
866// Because we override `Hash` for this type, we must also override the
867// `PartialEq` impl, so that all instances with the same hash are equal.
868impl<K: PartialEq> PartialEq for Key<K> {
869 #[inline]
870 fn eq(&self, other: &Self) -> bool {
871 self.key == other.key
872 }
873}
874
875impl<K: Eq> Eq for Key<K> {}
876
877/// An iterator over the keys of a [`JoinMap`].
878#[derive(Debug, Clone)]
879pub struct JoinMapKeys<'a, K, V> {
880 iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
881 /// To make it easier to change `JoinMap` in the future, keep V as a generic
882 /// parameter.
883 _value: PhantomData<&'a V>,
884}
885
886impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
887 type Item = &'a K;
888
889 fn next(&mut self) -> Option<&'a K> {
890 self.iter.next().map(|key| &key.key)
891 }
892
893 fn size_hint(&self) -> (usize, Option<usize>) {
894 self.iter.size_hint()
895 }
896}
897
898impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
899 fn len(&self) -> usize {
900 self.iter.len()
901 }
902}
903
904impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}