futures_concurrency/future/try_join/
tuple.rsuse super::TryJoin as TryJoinTrait;
use crate::utils::{PollArray, WakerArray};
use core::fmt::{self, Debug};
use core::future::{Future, IntoFuture};
use core::marker::PhantomData;
use core::mem::ManuallyDrop;
use core::mem::MaybeUninit;
use core::ops::DerefMut;
use core::pin::Pin;
use core::task::{Context, Poll};
use pin_project::{pin_project, pinned_drop};
macro_rules! unsafe_poll {
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
if $fut_idx == $iteration {
if let Poll::Ready(value) = unsafe {
$futures.$fut_name.as_mut()
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut $cx)
} {
*$this.completed += 1;
match value {
Ok(value) => {
$this.outputs.$fut_idx.write(value);
$this.state[$fut_idx].set_ready();
unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
}
Err(err) => {
*$this.consumed = true;
$this.state[$fut_idx].set_none();
unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
return Poll::Ready(Err(err));
}
}
}
}
unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
};
(@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};
($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
};
}
macro_rules! drop_initialized_values {
(@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => {
if $states[$state_idx].is_ready() {
unsafe { $output.assume_init_drop() };
$states[$state_idx].set_none();
}
drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
};
(@drop | $states:expr, $($rem_idx:tt,)*) => {};
($($outs:ident,)+ | $states:expr) => {
drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
};
}
macro_rules! drop_pending_futures {
(@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
if $states[$fut_idx].is_pending() {
let futures = unsafe { $futures.as_mut().get_unchecked_mut() };
unsafe { ManuallyDrop::drop(&mut futures.$fut_name) };
}
drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*);
};
(@inner $states:ident, $futures:ident, | $($rest:tt)*) => {};
($states:ident, $futures:ident, $($F:ident,)+) => {
drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
};
}
macro_rules! impl_try_join_tuple {
($mod_name:ident $StructName:ident) => {
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName {}
impl fmt::Debug for $StructName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TryJoin").finish()
}
}
impl Future for $StructName {
type Output = Result<(), core::convert::Infallible>;
fn poll(
self: Pin<&mut Self>, _cx: &mut Context<'_>
) -> Poll<Self::Output> {
Poll::Ready(Ok(()))
}
}
impl TryJoinTrait for () {
type Output = ();
type Error = core::convert::Infallible;
type Future = $StructName;
fn try_join(self) -> Self::Future {
$StructName {}
}
}
};
($mod_name:ident $StructName:ident $(($F:ident $T:ident))+) => {
mod $mod_name {
use core::mem::ManuallyDrop;
#[pin_project::pin_project]
pub(super) struct Futures<$($F,)+> {$(
#[pin]
pub(super) $F: ManuallyDrop<$F>,
)+}
#[repr(u8)]
pub(super) enum Indexes { $($F,)+ }
pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
}
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName<$($F, $T,)+ Err> {
#[pin]
futures: $mod_name::Futures<$($F,)+>,
outputs: ($(MaybeUninit<$T>,)+),
state: PollArray<{$mod_name::LEN}>,
wakers: WakerArray<{$mod_name::LEN}>,
completed: usize,
consumed: bool,
_phantom: PhantomData<Err>,
}
impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err>
where
$( $F: Future + Debug, )+
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TryJoin")
$(.field(&self.futures.$F))+
.finish()
}
}
#[allow(unused_mut)]
#[allow(unused_parens)]
#[allow(unused_variables)]
impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err>
where $(
$F: Future<Output = Result<$T, Err>>,
)+ {
type Output = Result<($($T,)+), Err>;
fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
const LEN: usize = $mod_name::LEN;
let mut this = self.project();
assert!(!*this.consumed, "Futures must not be polled after completing");
let mut futures = this.futures.project();
let mut readiness = this.wakers.readiness();
readiness.set_waker(cx.waker());
for index in 0..LEN {
if !readiness.any_ready() {
return Poll::Pending;
}
if !readiness.clear_ready(index) || this.state[index].is_ready() {
continue;
}
#[allow(clippy::drop_non_drop)]
drop(readiness);
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
unsafe_poll!(index, this, futures, cx, LEN, $($F,)+);
if *this.completed == LEN {
let out = {
let mut out = ($(MaybeUninit::<$T>::uninit(),)+);
core::mem::swap(&mut out, this.outputs);
let ($($F,)+) = out;
unsafe { ($($F.assume_init(),)+) }
};
this.state.set_all_none();
*this.consumed = true;
return Poll::Ready(Ok(out));
}
readiness = this.wakers.readiness();
}
Poll::Pending
}
}
#[pinned_drop]
impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
let ($(ref mut $F,)+) = this.outputs;
let states = this.state;
let mut futures = this.futures;
drop_initialized_values!($($F,)+ | states);
drop_pending_futures!(states, futures, $($F,)+);
}
}
#[allow(unused_parens)]
impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+)
where $(
$F: IntoFuture<Output = Result<$T, Err>>,
)+ {
type Output = ($($T,)+);
type Error = Err;
type Future = $StructName<$($F::IntoFuture, $T,)+ Err>;
fn try_join(self) -> Self::Future {
let ($($F,)+): ($($F,)+) = self;
$StructName {
futures: $mod_name::Futures {$(
$F: ManuallyDrop::new($F.into_future()),
)+},
state: PollArray::new_pending(),
outputs: ($(MaybeUninit::<$T>::uninit(),)+),
wakers: WakerArray::new(),
completed: 0,
consumed: false,
_phantom: PhantomData,
}
}
}
};
}
impl_try_join_tuple! { try_join0 TryJoin0 }
impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) }
impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) }
impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) }
impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) }
impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) }
impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) }
impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) }
impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) }
impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) }
impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) }
impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) }
impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) }
#[cfg(test)]
mod test {
use super::*;
use core::convert::Infallible;
use core::future;
#[test]
fn all_ok() {
futures_lite::future::block_on(async {
let a = async { Ok::<_, Infallible>("aaaa") };
let b = async { Ok::<_, Infallible>(1) };
let c = async { Ok::<_, Infallible>('z') };
let result = (a, b, c).try_join().await;
assert_eq!(result, Ok(("aaaa", 1, 'z')));
})
}
#[test]
fn one_err() {
futures_lite::future::block_on(async {
let res: Result<(_, char), ()> = (future::ready(Ok("hello")), future::ready(Err(())))
.try_join()
.await;
assert_eq!(res, Err(()));
})
}
#[test]
fn issue_135_resume_after_completion() {
use futures_lite::future::yield_now;
futures_lite::future::block_on(async {
let ok = async { Ok::<_, ()>(()) };
let err = async {
yield_now().await;
Ok::<_, ()>(())
};
let res = (ok, err).try_join().await;
assert_eq!(res.unwrap(), ((), ()));
});
}
#[test]
#[cfg(feature = "std")]
fn does_not_leak_memory() {
use core::cell::RefCell;
use futures_lite::future::pending;
thread_local! {
static NOT_LEAKING: RefCell<bool> = const { RefCell::new(false) };
};
struct FlipFlagAtDrop;
impl Drop for FlipFlagAtDrop {
fn drop(&mut self) {
NOT_LEAKING.with(|v| {
*v.borrow_mut() = true;
});
}
}
futures_lite::future::block_on(async {
let string = future::ready(Result::Ok("memory leak".to_owned()));
let flip = future::ready(Result::Ok(FlipFlagAtDrop));
let leak = (string, flip, pending::<Result<u8, ()>>()).try_join();
_ = futures_lite::future::poll_once(leak).await;
});
NOT_LEAKING.with(|flag| {
assert!(*flag.borrow());
})
}
}