futures_concurrency/future/race_ok/array/
mod.rsuse super::RaceOk as RaceOkTrait;
use crate::utils::array_assume_init;
use crate::utils::iter_pin_mut;
use core::array;
use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem::{self, MaybeUninit};
use core::pin::Pin;
use core::task::{Context, Poll};
use pin_project::pin_project;
mod error;
pub use error::AggregateError;
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project]
pub struct RaceOk<Fut, T, E, const N: usize>
where
Fut: Future<Output = Result<T, E>>,
{
#[pin]
futures: [Fut; N],
errors: [MaybeUninit<E>; N],
completed: usize,
}
impl<Fut, T, E, const N: usize> fmt::Debug for RaceOk<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>> + fmt::Debug,
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.futures.iter()).finish()
}
}
impl<Fut, T, E, const N: usize> Future for RaceOk<Fut, T, E, N>
where
Fut: Future<Output = Result<T, E>>,
{
type Output = Result<T, AggregateError<E, N>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let futures = iter_pin_mut(this.futures);
for (fut, out) in futures.zip(this.errors.iter_mut()) {
if let Poll::Ready(output) = fut.poll(cx) {
match output {
Ok(ok) => return Poll::Ready(Ok(ok)),
Err(err) => {
*out = MaybeUninit::new(err);
*this.completed += 1;
}
}
}
}
let all_completed = *this.completed == N;
if all_completed {
let mut errors = array::from_fn(|_| MaybeUninit::uninit());
mem::swap(&mut errors, this.errors);
let result = unsafe { array_assume_init(errors) };
Poll::Ready(Err(AggregateError::new(result)))
} else {
Poll::Pending
}
}
}
impl<Fut, T, E, const N: usize> RaceOkTrait for [Fut; N]
where
Fut: IntoFuture<Output = Result<T, E>>,
{
type Output = T;
type Error = AggregateError<E, N>;
type Future = RaceOk<Fut::IntoFuture, T, E, N>;
fn race_ok(self) -> Self::Future {
RaceOk {
futures: self.map(|fut| fut.into_future()),
errors: array::from_fn(|_| MaybeUninit::uninit()),
completed: 0,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use core::future;
#[test]
fn all_ok() {
futures_lite::future::block_on(async {
let res: Result<&str, AggregateError<(), 2>> =
[future::ready(Ok("hello")), future::ready(Ok("world"))]
.race_ok()
.await;
assert!(res.is_ok());
})
}
#[test]
fn one_err() {
futures_lite::future::block_on(async {
let res: Result<&str, AggregateError<_, 2>> =
[future::ready(Ok("hello")), future::ready(Err("oh no"))]
.race_ok()
.await;
assert_eq!(res.unwrap(), "hello");
});
}
#[test]
fn all_err() {
futures_lite::future::block_on(async {
let res: Result<&str, AggregateError<_, 2>> =
[future::ready(Err("oops")), future::ready(Err("oh no"))]
.race_ok()
.await;
let errs = res.unwrap_err();
assert_eq!(errs[0], "oops");
assert_eq!(errs[1], "oh no");
});
}
}