madsim_real_tokio/macros/
try_join.rs

1/// Waits on multiple concurrent branches, returning when **all** branches
2/// complete with `Ok(_)` or on the first `Err(_)`.
3///
4/// The `try_join!` macro must be used inside of async functions, closures, and
5/// blocks.
6///
7/// Similar to [`join!`], the `try_join!` macro takes a list of async
8/// expressions and evaluates them concurrently on the same task. Each async
9/// expression evaluates to a future and the futures from each expression are
10/// multiplexed on the current task. The `try_join!` macro returns when **all**
11/// branches return with `Ok` or when the **first** branch returns with `Err`.
12///
13/// [`join!`]: macro@join
14///
15/// # Notes
16///
17/// The supplied futures are stored inline and does not require allocating a
18/// `Vec`.
19///
20/// ### Runtime characteristics
21///
22/// By running all async expressions on the current task, the expressions are
23/// able to run **concurrently** but not in **parallel**. This means all
24/// expressions are run on the same thread and if one branch blocks the thread,
25/// all other expressions will be unable to continue. If parallelism is
26/// required, spawn each async expression using [`tokio::spawn`] and pass the
27/// join handle to `try_join!`.
28///
29/// [`tokio::spawn`]: crate::spawn
30///
31/// # Examples
32///
33/// Basic `try_join` with two branches.
34///
35/// ```
36/// async fn do_stuff_async() -> Result<(), &'static str> {
37///     // async work
38/// # Ok(())
39/// }
40///
41/// async fn more_async_work() -> Result<(), &'static str> {
42///     // more here
43/// # Ok(())
44/// }
45///
46/// #[tokio::main]
47/// async fn main() {
48///     let res = tokio::try_join!(
49///         do_stuff_async(),
50///         more_async_work());
51///
52///     match res {
53///          Ok((first, second)) => {
54///              // do something with the values
55///          }
56///          Err(err) => {
57///             println!("processing failed; error = {}", err);
58///          }
59///     }
60/// }
61/// ```
62///
63/// Using `try_join!` with spawned tasks.
64///
65/// ```
66/// use tokio::task::JoinHandle;
67///
68/// async fn do_stuff_async() -> Result<(), &'static str> {
69///     // async work
70/// # Err("failed")
71/// }
72///
73/// async fn more_async_work() -> Result<(), &'static str> {
74///     // more here
75/// # Ok(())
76/// }
77///
78/// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> {
79///     match handle.await {
80///         Ok(Ok(result)) => Ok(result),
81///         Ok(Err(err)) => Err(err),
82///         Err(err) => Err("handling failed"),
83///     }
84/// }
85///
86/// #[tokio::main]
87/// async fn main() {
88///     let handle1 = tokio::spawn(do_stuff_async());
89///     let handle2 = tokio::spawn(more_async_work());
90///     match tokio::try_join!(flatten(handle1), flatten(handle2)) {
91///         Ok(val) => {
92///             // do something with the values
93///         }
94///         Err(err) => {
95///             println!("Failed with {}.", err);
96///             # assert_eq!(err, "failed");
97///         }
98///     }
99/// }
100/// ```
101#[macro_export]
102#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
103macro_rules! try_join {
104    (@ {
105        // One `_` for each branch in the `try_join!` macro. This is not used once
106        // normalization is complete.
107        ( $($count:tt)* )
108
109        // The expression `0+1+1+ ... +1` equal to the number of branches.
110        ( $($total:tt)* )
111
112        // Normalized try_join! branches
113        $( ( $($skip:tt)* ) $e:expr, )*
114
115    }) => {{
116        use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
117        use $crate::macros::support::Poll::{Ready, Pending};
118
119        // Safety: nothing must be moved out of `futures`. This is to satisfy
120        // the requirement of `Pin::new_unchecked` called below.
121        //
122        // We can't use the `pin!` macro for this because `futures` is a tuple
123        // and the standard library provides no way to pin-project to the fields
124        // of a tuple.
125        let mut futures = ( $( maybe_done($e), )* );
126
127        // This assignment makes sure that the `poll_fn` closure only has a
128        // reference to the futures, instead of taking ownership of them. This
129        // mitigates the issue described in
130        // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
131        let mut futures = &mut futures;
132
133        // Each time the future created by poll_fn is polled, a different future will be polled first
134        // to ensure every future passed to join! gets a chance to make progress even if
135        // one of the futures consumes the whole budget.
136        //
137        // This is number of futures that will be skipped in the first loop
138        // iteration the next time.
139        let mut skip_next_time: u32 = 0;
140
141        poll_fn(move |cx| {
142            const COUNT: u32 = $($total)*;
143
144            let mut is_pending = false;
145
146            let mut to_run = COUNT;
147
148            // The number of futures that will be skipped in the first loop iteration
149            let mut skip = skip_next_time;
150
151            skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
152
153            // This loop runs twice and the first `skip` futures
154            // are not polled in the first iteration.
155            loop {
156            $(
157                if skip == 0 {
158                    if to_run == 0 {
159                        // Every future has been polled
160                        break;
161                    }
162                    to_run -= 1;
163
164                    // Extract the future for this branch from the tuple.
165                    let ( $($skip,)* fut, .. ) = &mut *futures;
166
167                    // Safety: future is stored on the stack above
168                    // and never moved.
169                    let mut fut = unsafe { Pin::new_unchecked(fut) };
170
171                    // Try polling
172                    if fut.as_mut().poll(cx).is_pending() {
173                        is_pending = true;
174                    } else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
175                        return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
176                    }
177                } else {
178                    // Future skipped, one less future to skip in the next iteration
179                    skip -= 1;
180                }
181            )*
182            }
183
184            if is_pending {
185                Pending
186            } else {
187                Ready(Ok(($({
188                    // Extract the future for this branch from the tuple.
189                    let ( $($skip,)* fut, .. ) = &mut futures;
190
191                    // Safety: future is stored on the stack above
192                    // and never moved.
193                    let mut fut = unsafe { Pin::new_unchecked(fut) };
194
195                    fut
196                        .take_output()
197                        .expect("expected completed future")
198                        .ok()
199                        .expect("expected Ok(_)")
200                },)*)))
201            }
202        }).await
203    }};
204
205    // ===== Normalize =====
206
207    (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
208      $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
209    };
210
211    // ===== Entry point =====
212
213    ( $($e:expr),+ $(,)?) => {
214        $crate::try_join!(@{ () (0) } $($e,)*)
215    };
216
217    () => { async { Ok(()) }.await }
218}