madsim_real_tokio/io/util/
copy.rs

1use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8#[derive(Debug)]
9pub(super) struct CopyBuffer {
10    read_done: bool,
11    need_flush: bool,
12    pos: usize,
13    cap: usize,
14    amt: u64,
15    buf: Box<[u8]>,
16}
17
18impl CopyBuffer {
19    pub(super) fn new() -> Self {
20        Self {
21            read_done: false,
22            need_flush: false,
23            pos: 0,
24            cap: 0,
25            amt: 0,
26            buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(),
27        }
28    }
29
30    fn poll_fill_buf<R>(
31        &mut self,
32        cx: &mut Context<'_>,
33        reader: Pin<&mut R>,
34    ) -> Poll<io::Result<()>>
35    where
36        R: AsyncRead + ?Sized,
37    {
38        let me = &mut *self;
39        let mut buf = ReadBuf::new(&mut me.buf);
40        buf.set_filled(me.cap);
41
42        let res = reader.poll_read(cx, &mut buf);
43        if let Poll::Ready(Ok(())) = res {
44            let filled_len = buf.filled().len();
45            me.read_done = me.cap == filled_len;
46            me.cap = filled_len;
47        }
48        res
49    }
50
51    fn poll_write_buf<R, W>(
52        &mut self,
53        cx: &mut Context<'_>,
54        mut reader: Pin<&mut R>,
55        mut writer: Pin<&mut W>,
56    ) -> Poll<io::Result<usize>>
57    where
58        R: AsyncRead + ?Sized,
59        W: AsyncWrite + ?Sized,
60    {
61        let me = &mut *self;
62        match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63            Poll::Pending => {
64                // Top up the buffer towards full if we can read a bit more
65                // data - this should improve the chances of a large write
66                if !me.read_done && me.cap < me.buf.len() {
67                    ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68                }
69                Poll::Pending
70            }
71            res => res,
72        }
73    }
74
75    pub(super) fn poll_copy<R, W>(
76        &mut self,
77        cx: &mut Context<'_>,
78        mut reader: Pin<&mut R>,
79        mut writer: Pin<&mut W>,
80    ) -> Poll<io::Result<u64>>
81    where
82        R: AsyncRead + ?Sized,
83        W: AsyncWrite + ?Sized,
84    {
85        ready!(crate::trace::trace_leaf(cx));
86        #[cfg(any(
87            feature = "fs",
88            feature = "io-std",
89            feature = "net",
90            feature = "process",
91            feature = "rt",
92            feature = "signal",
93            feature = "sync",
94            feature = "time",
95        ))]
96        // Keep track of task budget
97        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
98        loop {
99            // If our buffer is empty, then we need to read some data to
100            // continue.
101            if self.pos == self.cap && !self.read_done {
102                self.pos = 0;
103                self.cap = 0;
104
105                match self.poll_fill_buf(cx, reader.as_mut()) {
106                    Poll::Ready(Ok(())) => {
107                        #[cfg(any(
108                            feature = "fs",
109                            feature = "io-std",
110                            feature = "net",
111                            feature = "process",
112                            feature = "rt",
113                            feature = "signal",
114                            feature = "sync",
115                            feature = "time",
116                        ))]
117                        coop.made_progress();
118                    }
119                    Poll::Ready(Err(err)) => {
120                        #[cfg(any(
121                            feature = "fs",
122                            feature = "io-std",
123                            feature = "net",
124                            feature = "process",
125                            feature = "rt",
126                            feature = "signal",
127                            feature = "sync",
128                            feature = "time",
129                        ))]
130                        coop.made_progress();
131                        return Poll::Ready(Err(err));
132                    }
133                    Poll::Pending => {
134                        // Try flushing when the reader has no progress to avoid deadlock
135                        // when the reader depends on buffered writer.
136                        if self.need_flush {
137                            ready!(writer.as_mut().poll_flush(cx))?;
138                            #[cfg(any(
139                                feature = "fs",
140                                feature = "io-std",
141                                feature = "net",
142                                feature = "process",
143                                feature = "rt",
144                                feature = "signal",
145                                feature = "sync",
146                                feature = "time",
147                            ))]
148                            coop.made_progress();
149                            self.need_flush = false;
150                        }
151
152                        return Poll::Pending;
153                    }
154                }
155            }
156
157            // If our buffer has some data, let's write it out!
158            while self.pos < self.cap {
159                let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
160                #[cfg(any(
161                    feature = "fs",
162                    feature = "io-std",
163                    feature = "net",
164                    feature = "process",
165                    feature = "rt",
166                    feature = "signal",
167                    feature = "sync",
168                    feature = "time",
169                ))]
170                coop.made_progress();
171                if i == 0 {
172                    return Poll::Ready(Err(io::Error::new(
173                        io::ErrorKind::WriteZero,
174                        "write zero byte into writer",
175                    )));
176                } else {
177                    self.pos += i;
178                    self.amt += i as u64;
179                    self.need_flush = true;
180                }
181            }
182
183            // If pos larger than cap, this loop will never stop.
184            // In particular, user's wrong poll_write implementation returning
185            // incorrect written length may lead to thread blocking.
186            debug_assert!(
187                self.pos <= self.cap,
188                "writer returned length larger than input slice"
189            );
190
191            // If we've written all the data and we've seen EOF, flush out the
192            // data and finish the transfer.
193            if self.pos == self.cap && self.read_done {
194                ready!(writer.as_mut().poll_flush(cx))?;
195                #[cfg(any(
196                    feature = "fs",
197                    feature = "io-std",
198                    feature = "net",
199                    feature = "process",
200                    feature = "rt",
201                    feature = "signal",
202                    feature = "sync",
203                    feature = "time",
204                ))]
205                coop.made_progress();
206                return Poll::Ready(Ok(self.amt));
207            }
208        }
209    }
210}
211
212/// A future that asynchronously copies the entire contents of a reader into a
213/// writer.
214#[derive(Debug)]
215#[must_use = "futures do nothing unless you `.await` or poll them"]
216struct Copy<'a, R: ?Sized, W: ?Sized> {
217    reader: &'a mut R,
218    writer: &'a mut W,
219    buf: CopyBuffer,
220}
221
222cfg_io_util! {
223    /// Asynchronously copies the entire contents of a reader into a writer.
224    ///
225    /// This function returns a future that will continuously read data from
226    /// `reader` and then write it into `writer` in a streaming fashion until
227    /// `reader` returns EOF or fails.
228    ///
229    /// On success, the total number of bytes that were copied from `reader` to
230    /// `writer` is returned.
231    ///
232    /// This is an asynchronous version of [`std::io::copy`][std].
233    ///
234    /// A heap-allocated copy buffer with 8 KB is created to take data from the
235    /// reader to the writer, check [`copy_buf`] if you want an alternative for
236    /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the
237    /// buffer capacity.
238    ///
239    /// [std]: std::io::copy
240    /// [`copy_buf`]: crate::io::copy_buf
241    /// [`AsyncBufRead`]: crate::io::AsyncBufRead
242    /// [`BufReader`]: crate::io::BufReader
243    ///
244    /// # Errors
245    ///
246    /// The returned future will return an error immediately if any call to
247    /// `poll_read` or `poll_write` returns an error.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use tokio::io;
253    ///
254    /// # async fn dox() -> std::io::Result<()> {
255    /// let mut reader: &[u8] = b"hello";
256    /// let mut writer: Vec<u8> = vec![];
257    ///
258    /// io::copy(&mut reader, &mut writer).await?;
259    ///
260    /// assert_eq!(&b"hello"[..], &writer[..]);
261    /// # Ok(())
262    /// # }
263    /// ```
264    pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
265    where
266        R: AsyncRead + Unpin + ?Sized,
267        W: AsyncWrite + Unpin + ?Sized,
268    {
269        Copy {
270            reader,
271            writer,
272            buf: CopyBuffer::new()
273        }.await
274    }
275}
276
277impl<R, W> Future for Copy<'_, R, W>
278where
279    R: AsyncRead + Unpin + ?Sized,
280    W: AsyncWrite + Unpin + ?Sized,
281{
282    type Output = io::Result<u64>;
283
284    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
285        let me = &mut *self;
286
287        me.buf
288            .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
289    }
290}