tokio_util/
either.rs

1//! Module defining an Either type.
2use std::{
3    future::Future,
4    io::SeekFrom,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
9
10/// Combines two different futures, streams, or sinks having the same associated types into a single type.
11///
12/// This type implements common asynchronous traits such as [`Future`] and those in Tokio.
13///
14/// [`Future`]: std::future::Future
15///
16/// # Example
17///
18/// The following code will not work:
19///
20/// ```compile_fail
21/// # fn some_condition() -> bool { true }
22/// # async fn some_async_function() -> u32 { 10 }
23/// # async fn other_async_function() -> u32 { 20 }
24/// #[tokio::main]
25/// async fn main() {
26///     let result = if some_condition() {
27///         some_async_function()
28///     } else {
29///         other_async_function() // <- Will print: "`if` and `else` have incompatible types"
30///     };
31///
32///     println!("Result is {}", result.await);
33/// }
34/// ```
35///
36// This is because although the output types for both futures is the same, the exact future
37// types are different, but the compiler must be able to choose a single type for the
38// `result` variable.
39///
40/// When the output type is the same, we can wrap each future in `Either` to avoid the
41/// issue:
42///
43/// ```
44/// use tokio_util::either::Either;
45/// # fn some_condition() -> bool { true }
46/// # async fn some_async_function() -> u32 { 10 }
47/// # async fn other_async_function() -> u32 { 20 }
48///
49/// #[tokio::main]
50/// async fn main() {
51///     let result = if some_condition() {
52///         Either::Left(some_async_function())
53///     } else {
54///         Either::Right(other_async_function())
55///     };
56///
57///     let value = result.await;
58///     println!("Result is {}", value);
59///     # assert_eq!(value, 10);
60/// }
61/// ```
62#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense.
63#[derive(Debug, Clone)]
64pub enum Either<L, R> {
65    Left(L),
66    Right(R),
67}
68
69/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation.
70/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either
71/// enum variant held in `self`.
72macro_rules! delegate_call {
73    ($self:ident.$method:ident($($args:ident),+)) => {
74        unsafe {
75            match $self.get_unchecked_mut() {
76                Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
77                Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
78            }
79        }
80    }
81}
82
83impl<L, R, O> Future for Either<L, R>
84where
85    L: Future<Output = O>,
86    R: Future<Output = O>,
87{
88    type Output = O;
89
90    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91        delegate_call!(self.poll(cx))
92    }
93}
94
95impl<L, R> AsyncRead for Either<L, R>
96where
97    L: AsyncRead,
98    R: AsyncRead,
99{
100    fn poll_read(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103        buf: &mut ReadBuf<'_>,
104    ) -> Poll<Result<()>> {
105        delegate_call!(self.poll_read(cx, buf))
106    }
107}
108
109impl<L, R> AsyncBufRead for Either<L, R>
110where
111    L: AsyncBufRead,
112    R: AsyncBufRead,
113{
114    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
115        delegate_call!(self.poll_fill_buf(cx))
116    }
117
118    fn consume(self: Pin<&mut Self>, amt: usize) {
119        delegate_call!(self.consume(amt));
120    }
121}
122
123impl<L, R> AsyncSeek for Either<L, R>
124where
125    L: AsyncSeek,
126    R: AsyncSeek,
127{
128    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
129        delegate_call!(self.start_seek(position))
130    }
131
132    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
133        delegate_call!(self.poll_complete(cx))
134    }
135}
136
137impl<L, R> AsyncWrite for Either<L, R>
138where
139    L: AsyncWrite,
140    R: AsyncWrite,
141{
142    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
143        delegate_call!(self.poll_write(cx, buf))
144    }
145
146    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
147        delegate_call!(self.poll_flush(cx))
148    }
149
150    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
151        delegate_call!(self.poll_shutdown(cx))
152    }
153
154    fn poll_write_vectored(
155        self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        bufs: &[std::io::IoSlice<'_>],
158    ) -> Poll<std::result::Result<usize, std::io::Error>> {
159        delegate_call!(self.poll_write_vectored(cx, bufs))
160    }
161
162    fn is_write_vectored(&self) -> bool {
163        match self {
164            Self::Left(l) => l.is_write_vectored(),
165            Self::Right(r) => r.is_write_vectored(),
166        }
167    }
168}
169
170impl<L, R> futures_core::stream::Stream for Either<L, R>
171where
172    L: futures_core::stream::Stream,
173    R: futures_core::stream::Stream<Item = L::Item>,
174{
175    type Item = L::Item;
176
177    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        delegate_call!(self.poll_next(cx))
179    }
180}
181
182impl<L, R, Item, Error> futures_sink::Sink<Item> for Either<L, R>
183where
184    L: futures_sink::Sink<Item, Error = Error>,
185    R: futures_sink::Sink<Item, Error = Error>,
186{
187    type Error = Error;
188
189    fn poll_ready(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192    ) -> Poll<std::result::Result<(), Self::Error>> {
193        delegate_call!(self.poll_ready(cx))
194    }
195
196    fn start_send(self: Pin<&mut Self>, item: Item) -> std::result::Result<(), Self::Error> {
197        delegate_call!(self.start_send(item))
198    }
199
200    fn poll_flush(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203    ) -> Poll<std::result::Result<(), Self::Error>> {
204        delegate_call!(self.poll_flush(cx))
205    }
206
207    fn poll_close(
208        self: Pin<&mut Self>,
209        cx: &mut Context<'_>,
210    ) -> Poll<std::result::Result<(), Self::Error>> {
211        delegate_call!(self.poll_close(cx))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use tokio::io::{repeat, AsyncReadExt, Repeat};
219    use tokio_stream::{once, Once, StreamExt};
220
221    #[tokio::test]
222    async fn either_is_stream() {
223        let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
224
225        assert_eq!(Some(1u32), either.next().await);
226    }
227
228    #[tokio::test]
229    async fn either_is_async_read() {
230        let mut buffer = [0; 3];
231        let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
232
233        either.read_exact(&mut buffer).await.unwrap();
234        assert_eq!(buffer, [0b101, 0b101, 0b101]);
235    }
236}