1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
29
30use futures::{prelude::*, ready};
31use std::{
32 io::{self, Read},
33 mem,
34 pin::Pin,
35 task::{Context, Poll},
36};
37
38static_assertions::const_assert!(mem::size_of::<usize>() <= mem::size_of::<u64>());
39
40#[pin_project::pin_project]
43pub struct RwStreamSink<S: TryStream> {
44 #[pin]
45 inner: S,
46 current_item: Option<std::io::Cursor<<S as TryStream>::Ok>>,
47}
48
49impl<S: TryStream> RwStreamSink<S> {
50 pub fn new(inner: S) -> Self {
52 RwStreamSink {
53 inner,
54 current_item: None,
55 }
56 }
57}
58
59impl<S> AsyncRead for RwStreamSink<S>
60where
61 S: TryStream<Error = io::Error>,
62 <S as TryStream>::Ok: AsRef<[u8]>,
63{
64 fn poll_read(
65 self: Pin<&mut Self>,
66 cx: &mut Context,
67 buf: &mut [u8],
68 ) -> Poll<io::Result<usize>> {
69 let mut this = self.project();
70
71 let item_to_copy = loop {
73 if let Some(ref mut i) = this.current_item {
74 if i.position() < i.get_ref().as_ref().len() as u64 {
75 break i;
76 }
77 }
78 *this.current_item = Some(match ready!(this.inner.as_mut().try_poll_next(cx)) {
79 Some(Ok(i)) => std::io::Cursor::new(i),
80 Some(Err(e)) => return Poll::Ready(Err(e)),
81 None => return Poll::Ready(Ok(0)), });
83 };
84
85 Poll::Ready(Ok(item_to_copy.read(buf)?))
87 }
88}
89
90impl<S> AsyncWrite for RwStreamSink<S>
91where
92 S: TryStream + Sink<<S as TryStream>::Ok, Error = io::Error>,
93 <S as TryStream>::Ok: for<'r> From<&'r [u8]>,
94{
95 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
96 let mut this = self.project();
97 ready!(this.inner.as_mut().poll_ready(cx)?);
98 let n = buf.len();
99 if let Err(e) = this.inner.start_send(buf.into()) {
100 return Poll::Ready(Err(e));
101 }
102 Poll::Ready(Ok(n))
103 }
104
105 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
106 let this = self.project();
107 this.inner.poll_flush(cx)
108 }
109
110 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
111 let this = self.project();
112 this.inner.poll_close(cx)
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::RwStreamSink;
119 use async_std::task;
120 use futures::{channel::mpsc, prelude::*, stream};
121 use std::{
122 pin::Pin,
123 task::{Context, Poll},
124 };
125
126 struct Wrapper<St, Si>(St, Si);
128
129 impl<St, Si> Stream for Wrapper<St, Si>
130 where
131 St: Stream + Unpin,
132 Si: Unpin,
133 {
134 type Item = St::Item;
135
136 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
137 self.0.poll_next_unpin(cx)
138 }
139 }
140
141 impl<St, Si, T> Sink<T> for Wrapper<St, Si>
142 where
143 St: Unpin,
144 Si: Sink<T> + Unpin,
145 {
146 type Error = Si::Error;
147
148 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
149 Pin::new(&mut self.1).poll_ready(cx)
150 }
151
152 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
153 Pin::new(&mut self.1).start_send(item)
154 }
155
156 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
157 Pin::new(&mut self.1).poll_flush(cx)
158 }
159
160 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
161 Pin::new(&mut self.1).poll_close(cx)
162 }
163 }
164
165 #[test]
166 fn basic_reading() {
167 let (tx1, _) = mpsc::channel::<Vec<u8>>(10);
168 let (mut tx2, rx2) = mpsc::channel(10);
169
170 let mut wrapper = RwStreamSink::new(Wrapper(rx2.map(Ok), tx1));
171
172 task::block_on(async move {
173 tx2.send(Vec::from("hel")).await.unwrap();
174 tx2.send(Vec::from("lo wor")).await.unwrap();
175 tx2.send(Vec::from("ld")).await.unwrap();
176 tx2.close().await.unwrap();
177
178 let mut data = Vec::new();
179 wrapper.read_to_end(&mut data).await.unwrap();
180 assert_eq!(data, b"hello world");
181 })
182 }
183
184 #[test]
185 fn skip_empty_stream_items() {
186 let data: Vec<&[u8]> = vec![b"", b"foo", b"", b"bar", b"", b"baz", b""];
187 let mut rws = RwStreamSink::new(stream::iter(data).map(Ok));
188 let mut buf = [0; 9];
189 task::block_on(async move {
190 assert_eq!(3, rws.read(&mut buf).await.unwrap());
191 assert_eq!(3, rws.read(&mut buf[3..]).await.unwrap());
192 assert_eq!(3, rws.read(&mut buf[6..]).await.unwrap());
193 assert_eq!(0, rws.read(&mut buf).await.unwrap());
194 assert_eq!(b"foobarbaz", &buf[..])
195 })
196 }
197
198 #[test]
199 fn partial_read() {
200 let data: Vec<&[u8]> = vec![b"hell", b"o world"];
201 let mut rws = RwStreamSink::new(stream::iter(data).map(Ok));
202 let mut buf = [0; 3];
203 task::block_on(async move {
204 assert_eq!(3, rws.read(&mut buf).await.unwrap());
205 assert_eq!(b"hel", &buf[..3]);
206 assert_eq!(0, rws.read(&mut buf[..0]).await.unwrap());
207 assert_eq!(1, rws.read(&mut buf).await.unwrap());
208 assert_eq!(b"l", &buf[..1]);
209 assert_eq!(3, rws.read(&mut buf).await.unwrap());
210 assert_eq!(b"o w", &buf[..3]);
211 assert_eq!(0, rws.read(&mut buf[..0]).await.unwrap());
212 assert_eq!(3, rws.read(&mut buf).await.unwrap());
213 assert_eq!(b"orl", &buf[..3]);
214 assert_eq!(1, rws.read(&mut buf).await.unwrap());
215 assert_eq!(b"d", &buf[..1]);
216 assert_eq!(0, rws.read(&mut buf).await.unwrap());
217 })
218 }
219}