1use std::{
3 future::Future,
4 io::SeekFrom,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
9
10#[allow(missing_docs)] #[derive(Debug, Clone)]
64pub enum Either<L, R> {
65 Left(L),
66 Right(R),
67}
68
69macro_rules! delegate_call {
73 ($self:ident.$method:ident($($args:ident),+)) => {
74 unsafe {
75 match $self.get_unchecked_mut() {
76 Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
77 Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
78 }
79 }
80 }
81}
82
83impl<L, R, O> Future for Either<L, R>
84where
85 L: Future<Output = O>,
86 R: Future<Output = O>,
87{
88 type Output = O;
89
90 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91 delegate_call!(self.poll(cx))
92 }
93}
94
95impl<L, R> AsyncRead for Either<L, R>
96where
97 L: AsyncRead,
98 R: AsyncRead,
99{
100 fn poll_read(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 buf: &mut ReadBuf<'_>,
104 ) -> Poll<Result<()>> {
105 delegate_call!(self.poll_read(cx, buf))
106 }
107}
108
109impl<L, R> AsyncBufRead for Either<L, R>
110where
111 L: AsyncBufRead,
112 R: AsyncBufRead,
113{
114 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
115 delegate_call!(self.poll_fill_buf(cx))
116 }
117
118 fn consume(self: Pin<&mut Self>, amt: usize) {
119 delegate_call!(self.consume(amt));
120 }
121}
122
123impl<L, R> AsyncSeek for Either<L, R>
124where
125 L: AsyncSeek,
126 R: AsyncSeek,
127{
128 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
129 delegate_call!(self.start_seek(position))
130 }
131
132 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
133 delegate_call!(self.poll_complete(cx))
134 }
135}
136
137impl<L, R> AsyncWrite for Either<L, R>
138where
139 L: AsyncWrite,
140 R: AsyncWrite,
141{
142 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
143 delegate_call!(self.poll_write(cx, buf))
144 }
145
146 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
147 delegate_call!(self.poll_flush(cx))
148 }
149
150 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
151 delegate_call!(self.poll_shutdown(cx))
152 }
153
154 fn poll_write_vectored(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 bufs: &[std::io::IoSlice<'_>],
158 ) -> Poll<std::result::Result<usize, std::io::Error>> {
159 delegate_call!(self.poll_write_vectored(cx, bufs))
160 }
161
162 fn is_write_vectored(&self) -> bool {
163 match self {
164 Self::Left(l) => l.is_write_vectored(),
165 Self::Right(r) => r.is_write_vectored(),
166 }
167 }
168}
169
170impl<L, R> futures_core::stream::Stream for Either<L, R>
171where
172 L: futures_core::stream::Stream,
173 R: futures_core::stream::Stream<Item = L::Item>,
174{
175 type Item = L::Item;
176
177 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 delegate_call!(self.poll_next(cx))
179 }
180}
181
182impl<L, R, Item, Error> futures_sink::Sink<Item> for Either<L, R>
183where
184 L: futures_sink::Sink<Item, Error = Error>,
185 R: futures_sink::Sink<Item, Error = Error>,
186{
187 type Error = Error;
188
189 fn poll_ready(
190 self: Pin<&mut Self>,
191 cx: &mut Context<'_>,
192 ) -> Poll<std::result::Result<(), Self::Error>> {
193 delegate_call!(self.poll_ready(cx))
194 }
195
196 fn start_send(self: Pin<&mut Self>, item: Item) -> std::result::Result<(), Self::Error> {
197 delegate_call!(self.start_send(item))
198 }
199
200 fn poll_flush(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 ) -> Poll<std::result::Result<(), Self::Error>> {
204 delegate_call!(self.poll_flush(cx))
205 }
206
207 fn poll_close(
208 self: Pin<&mut Self>,
209 cx: &mut Context<'_>,
210 ) -> Poll<std::result::Result<(), Self::Error>> {
211 delegate_call!(self.poll_close(cx))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use tokio::io::{repeat, AsyncReadExt, Repeat};
219 use tokio_stream::{once, Once, StreamExt};
220
221 #[tokio::test]
222 async fn either_is_stream() {
223 let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
224
225 assert_eq!(Some(1u32), either.next().await);
226 }
227
228 #[tokio::test]
229 async fn either_is_async_read() {
230 let mut buffer = [0; 3];
231 let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
232
233 either.read_exact(&mut buffer).await.unwrap();
234 assert_eq!(buffer, [0b101, 0b101, 0b101]);
235 }
236}