broker_tokio/io/util/
copy.rs

1use crate::io::{AsyncRead, AsyncWrite};
2
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8cfg_io_util! {
9    /// A future that asynchronously copies the entire contents of a reader into a
10    /// writer.
11    ///
12    /// This struct is generally created by calling [`copy`][copy]. Please
13    /// see the documentation of `copy()` for more details.
14    ///
15    /// [copy]: copy()
16    #[derive(Debug)]
17    #[must_use = "futures do nothing unless you `.await` or poll them"]
18    pub struct Copy<'a, R: ?Sized, W: ?Sized> {
19        reader: &'a mut R,
20        read_done: bool,
21        writer: &'a mut W,
22        pos: usize,
23        cap: usize,
24        amt: u64,
25        buf: Box<[u8]>,
26    }
27
28    /// Asynchronously copies the entire contents of a reader into a writer.
29    ///
30    /// This function returns a future that will continuously read data from
31    /// `reader` and then write it into `writer` in a streaming fashion until
32    /// `reader` returns EOF.
33    ///
34    /// On success, the total number of bytes that were copied from `reader` to
35    /// `writer` is returned.
36    ///
37    /// This is an asynchronous version of [`std::io::copy`][std].
38    ///
39    /// [std]: std::io::copy
40    ///
41    /// # Errors
42    ///
43    /// The returned future will finish with an error will return an error
44    /// immediately if any call to `poll_read` or `poll_write` returns an error.
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use tokio::io;
50    ///
51    /// # async fn dox() -> std::io::Result<()> {
52    /// let mut reader: &[u8] = b"hello";
53    /// let mut writer: Vec<u8> = vec![];
54    ///
55    /// io::copy(&mut reader, &mut writer).await?;
56    ///
57    /// assert_eq!(&b"hello"[..], &writer[..]);
58    /// # Ok(())
59    /// # }
60    /// ```
61    pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W>
62    where
63        R: AsyncRead + Unpin + ?Sized,
64        W: AsyncWrite + Unpin + ?Sized,
65    {
66        Copy {
67            reader,
68            read_done: false,
69            writer,
70            amt: 0,
71            pos: 0,
72            cap: 0,
73            buf: Box::new([0; 2048]),
74        }
75    }
76}
77
78impl<R, W> Future for Copy<'_, R, W>
79where
80    R: AsyncRead + Unpin + ?Sized,
81    W: AsyncWrite + Unpin + ?Sized,
82{
83    type Output = io::Result<u64>;
84
85    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
86        loop {
87            // If our buffer is empty, then we need to read some data to
88            // continue.
89            if self.pos == self.cap && !self.read_done {
90                let me = &mut *self;
91                let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?;
92                if n == 0 {
93                    self.read_done = true;
94                } else {
95                    self.pos = 0;
96                    self.cap = n;
97                }
98            }
99
100            // If our buffer has some data, let's write it out!
101            while self.pos < self.cap {
102                let me = &mut *self;
103                let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?;
104                if i == 0 {
105                    return Poll::Ready(Err(io::Error::new(
106                        io::ErrorKind::WriteZero,
107                        "write zero byte into writer",
108                    )));
109                } else {
110                    self.pos += i;
111                    self.amt += i as u64;
112                }
113            }
114
115            // If we've written all the data and we've seen EOF, flush out the
116            // data and finish the transfer.
117            if self.pos == self.cap && self.read_done {
118                let me = &mut *self;
119                ready!(Pin::new(&mut *me.writer).poll_flush(cx))?;
120                return Poll::Ready(Ok(self.amt));
121            }
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn assert_unpin() {
132        use std::marker::PhantomPinned;
133        crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>();
134    }
135}