embassy_futures/
join.rs

1//! Wait for multiple futures to complete.
2
3use core::future::Future;
4use core::mem::MaybeUninit;
5use core::pin::Pin;
6use core::task::{Context, Poll};
7use core::{fmt, mem};
8
9#[derive(Debug)]
10enum MaybeDone<Fut: Future> {
11    /// A not-yet-completed future
12    Future(/* #[pin] */ Fut),
13    /// The output of the completed future
14    Done(Fut::Output),
15    /// The empty variant after the result of a [`MaybeDone`] has been
16    /// taken using the [`take_output`](MaybeDone::take_output) method.
17    Gone,
18}
19
20impl<Fut: Future> MaybeDone<Fut> {
21    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool {
22        let this = unsafe { self.get_unchecked_mut() };
23        match this {
24            Self::Future(fut) => match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
25                Poll::Ready(res) => {
26                    *this = Self::Done(res);
27                    true
28                }
29                Poll::Pending => false,
30            },
31            _ => true,
32        }
33    }
34
35    fn take_output(&mut self) -> Fut::Output {
36        match &*self {
37            Self::Done(_) => {}
38            Self::Future(_) | Self::Gone => panic!("take_output when MaybeDone is not done."),
39        }
40        match mem::replace(self, Self::Gone) {
41            MaybeDone::Done(output) => output,
42            _ => unreachable!(),
43        }
44    }
45}
46
47impl<Fut: Future + Unpin> Unpin for MaybeDone<Fut> {}
48
49macro_rules! generate {
50    ($(
51        $(#[$doc:meta])*
52        ($Join:ident, <$($Fut:ident),*>),
53    )*) => ($(
54        $(#[$doc])*
55        #[must_use = "futures do nothing unless you `.await` or poll them"]
56        #[allow(non_snake_case)]
57        pub struct $Join<$($Fut: Future),*> {
58            $(
59                $Fut: MaybeDone<$Fut>,
60            )*
61        }
62
63        impl<$($Fut),*> fmt::Debug for $Join<$($Fut),*>
64        where
65            $(
66                $Fut: Future + fmt::Debug,
67                $Fut::Output: fmt::Debug,
68            )*
69        {
70            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71                f.debug_struct(stringify!($Join))
72                    $(.field(stringify!($Fut), &self.$Fut))*
73                    .finish()
74            }
75        }
76
77        impl<$($Fut: Future),*> $Join<$($Fut),*> {
78            #[allow(non_snake_case)]
79            fn new($($Fut: $Fut),*) -> Self {
80                Self {
81                    $($Fut: MaybeDone::Future($Fut)),*
82                }
83            }
84        }
85
86        impl<$($Fut: Future),*> Future for $Join<$($Fut),*> {
87            type Output = ($($Fut::Output),*);
88
89            fn poll(
90                self: Pin<&mut Self>, cx: &mut Context<'_>
91            ) -> Poll<Self::Output> {
92                let this = unsafe { self.get_unchecked_mut() };
93                let mut all_done = true;
94                $(
95                    all_done &= unsafe { Pin::new_unchecked(&mut this.$Fut) }.poll(cx);
96                )*
97
98                if all_done {
99                    Poll::Ready(($(this.$Fut.take_output()), *))
100                } else {
101                    Poll::Pending
102                }
103            }
104        }
105    )*)
106}
107
108generate! {
109    /// Future for the [`join`](join()) function.
110    (Join, <Fut1, Fut2>),
111
112    /// Future for the [`join3`] function.
113    (Join3, <Fut1, Fut2, Fut3>),
114
115    /// Future for the [`join4`] function.
116    (Join4, <Fut1, Fut2, Fut3, Fut4>),
117
118    /// Future for the [`join5`] function.
119    (Join5, <Fut1, Fut2, Fut3, Fut4, Fut5>),
120}
121
122/// Joins the result of two futures, waiting for them both to complete.
123///
124/// This function will return a new future which awaits both futures to
125/// complete. The returned future will finish with a tuple of both results.
126///
127/// Note that this function consumes the passed futures and returns a
128/// wrapped version of it.
129///
130/// # Examples
131///
132/// ```
133/// # embassy_futures::block_on(async {
134///
135/// let a = async { 1 };
136/// let b = async { 2 };
137/// let pair = embassy_futures::join::join(a, b).await;
138///
139/// assert_eq!(pair, (1, 2));
140/// # });
141/// ```
142pub fn join<Fut1, Fut2>(future1: Fut1, future2: Fut2) -> Join<Fut1, Fut2>
143where
144    Fut1: Future,
145    Fut2: Future,
146{
147    Join::new(future1, future2)
148}
149
150/// Joins the result of three futures, waiting for them all to complete.
151///
152/// This function will return a new future which awaits all futures to
153/// complete. The returned future will finish with a tuple of all results.
154///
155/// Note that this function consumes the passed futures and returns a
156/// wrapped version of it.
157///
158/// # Examples
159///
160/// ```
161/// # embassy_futures::block_on(async {
162///
163/// let a = async { 1 };
164/// let b = async { 2 };
165/// let c = async { 3 };
166/// let res = embassy_futures::join::join3(a, b, c).await;
167///
168/// assert_eq!(res, (1, 2, 3));
169/// # });
170/// ```
171pub fn join3<Fut1, Fut2, Fut3>(future1: Fut1, future2: Fut2, future3: Fut3) -> Join3<Fut1, Fut2, Fut3>
172where
173    Fut1: Future,
174    Fut2: Future,
175    Fut3: Future,
176{
177    Join3::new(future1, future2, future3)
178}
179
180/// Joins the result of four futures, waiting for them all to complete.
181///
182/// This function will return a new future which awaits all futures to
183/// complete. The returned future will finish with a tuple of all results.
184///
185/// Note that this function consumes the passed futures and returns a
186/// wrapped version of it.
187///
188/// # Examples
189///
190/// ```
191/// # embassy_futures::block_on(async {
192///
193/// let a = async { 1 };
194/// let b = async { 2 };
195/// let c = async { 3 };
196/// let d = async { 4 };
197/// let res = embassy_futures::join::join4(a, b, c, d).await;
198///
199/// assert_eq!(res, (1, 2, 3, 4));
200/// # });
201/// ```
202pub fn join4<Fut1, Fut2, Fut3, Fut4>(
203    future1: Fut1,
204    future2: Fut2,
205    future3: Fut3,
206    future4: Fut4,
207) -> Join4<Fut1, Fut2, Fut3, Fut4>
208where
209    Fut1: Future,
210    Fut2: Future,
211    Fut3: Future,
212    Fut4: Future,
213{
214    Join4::new(future1, future2, future3, future4)
215}
216
217/// Joins the result of five futures, waiting for them all to complete.
218///
219/// This function will return a new future which awaits all futures to
220/// complete. The returned future will finish with a tuple of all results.
221///
222/// Note that this function consumes the passed futures and returns a
223/// wrapped version of it.
224///
225/// # Examples
226///
227/// ```
228/// # embassy_futures::block_on(async {
229///
230/// let a = async { 1 };
231/// let b = async { 2 };
232/// let c = async { 3 };
233/// let d = async { 4 };
234/// let e = async { 5 };
235/// let res = embassy_futures::join::join5(a, b, c, d, e).await;
236///
237/// assert_eq!(res, (1, 2, 3, 4, 5));
238/// # });
239/// ```
240pub fn join5<Fut1, Fut2, Fut3, Fut4, Fut5>(
241    future1: Fut1,
242    future2: Fut2,
243    future3: Fut3,
244    future4: Fut4,
245    future5: Fut5,
246) -> Join5<Fut1, Fut2, Fut3, Fut4, Fut5>
247where
248    Fut1: Future,
249    Fut2: Future,
250    Fut3: Future,
251    Fut4: Future,
252    Fut5: Future,
253{
254    Join5::new(future1, future2, future3, future4, future5)
255}
256
257// =====================================================
258
259/// Future for the [`join_array`] function.
260#[must_use = "futures do nothing unless you `.await` or poll them"]
261pub struct JoinArray<Fut: Future, const N: usize> {
262    futures: [MaybeDone<Fut>; N],
263}
264
265impl<Fut: Future, const N: usize> fmt::Debug for JoinArray<Fut, N>
266where
267    Fut: Future + fmt::Debug,
268    Fut::Output: fmt::Debug,
269{
270    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271        f.debug_struct("JoinArray").field("futures", &self.futures).finish()
272    }
273}
274
275impl<Fut: Future, const N: usize> Future for JoinArray<Fut, N> {
276    type Output = [Fut::Output; N];
277    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278        let this = unsafe { self.get_unchecked_mut() };
279        let mut all_done = true;
280        for f in this.futures.iter_mut() {
281            all_done &= unsafe { Pin::new_unchecked(f) }.poll(cx);
282        }
283
284        if all_done {
285            let mut array: [MaybeUninit<Fut::Output>; N] = unsafe { MaybeUninit::uninit().assume_init() };
286            for i in 0..N {
287                array[i].write(this.futures[i].take_output());
288            }
289            Poll::Ready(unsafe { (&array as *const _ as *const [Fut::Output; N]).read() })
290        } else {
291            Poll::Pending
292        }
293    }
294}
295
296/// Joins the result of an array of futures, waiting for them all to complete.
297///
298/// This function will return a new future which awaits all futures to
299/// complete. The returned future will finish with a tuple of all results.
300///
301/// Note that this function consumes the passed futures and returns a
302/// wrapped version of it.
303///
304/// # Examples
305///
306/// ```
307/// # embassy_futures::block_on(async {
308///
309/// async fn foo(n: u32) -> u32 { n }
310/// let a = foo(1);
311/// let b = foo(2);
312/// let c = foo(3);
313/// let res = embassy_futures::join::join_array([a, b, c]).await;
314///
315/// assert_eq!(res, [1, 2, 3]);
316/// # });
317/// ```
318pub fn join_array<Fut: Future, const N: usize>(futures: [Fut; N]) -> JoinArray<Fut, N> {
319    JoinArray {
320        futures: futures.map(MaybeDone::Future),
321    }
322}