tokio_io_utility/
write_bytes.rs

1use std::{io, mem, pin::Pin, vec::IntoIter};
2
3use bytes::Bytes;
4use tokio::io::{AsyncWrite, AsyncWriteExt};
5
6use crate::{init_maybeuninit_io_slices_mut, ReusableIoSlices};
7
8/// * `buffer` - must not contain empty `Bytes`s.
9#[cfg_attr(docsrs, doc(cfg(feature = "bytes")))]
10pub async fn write_all_bytes(
11    writer: Pin<&mut (dyn AsyncWrite + Send)>,
12    buffer: &mut Vec<Bytes>,
13    reusable_io_slices: &mut ReusableIoSlices,
14) -> io::Result<()> {
15    // `buffer` does not contain any empty `Bytes`s, so:
16    //  - We can check for `io::ErrorKind::WriteZero` error easily
17    //  - It won't occupy slots in `reusable_io_slices` so that
18    //    we can group as many non-zero IoSlice in one write.
19    //  - Avoid conserion from/to `VecDeque` unless necessary,
20    //    which might allocate.
21    //  - Simplify the loop in write_all_bytes_inner.
22
23    if buffer.is_empty() {
24        return Ok(());
25    }
26
27    // This is O(1)
28    let mut iter = mem::take(buffer).into_iter();
29
30    let res = write_all_bytes_inner(writer, &mut iter, reusable_io_slices).await;
31
32    // This is O(1) because of the specailization in std
33    *buffer = Vec::from_iter(iter);
34
35    res
36}
37
38/// * `buffer` - contains at least one element and must not contain empty
39///   `Bytes`
40async fn write_all_bytes_inner(
41    mut writer: Pin<&mut (dyn AsyncWrite + Send)>,
42    iter: &mut IntoIter<Bytes>,
43    reusable_io_slices: &mut ReusableIoSlices,
44) -> io::Result<()> {
45    // do-while style loop, because on the first iteration
46    // iter must not be empty
47    'outer: loop {
48        let uninit_io_slices = reusable_io_slices.get_mut();
49
50        // iter must not be empty
51        // io_slices.is_empty() == false since uninit_io_slices also must not
52        // be empty
53        let io_slices = init_maybeuninit_io_slices_mut(
54            uninit_io_slices,
55            // Do not consume the iter yet since write_vectored might
56            // do partial write.
57            iter.as_slice().iter().map(|bytes| io::IoSlice::new(bytes)),
58        );
59
60        debug_assert!(!io_slices.is_empty());
61
62        let mut n = writer.write_vectored(io_slices).await?;
63
64        if n == 0 {
65            // Since io_slices is not empty and it does not contain empty
66            // `Bytes`, it must be WriteZero error.
67            return Err(io::Error::from(io::ErrorKind::WriteZero));
68        }
69
70        // On first iteration, iter cannot be empty
71        while n >= iter.as_slice()[0].len() {
72            n -= iter.as_slice()[0].len();
73
74            // Release `Bytes` so that the memory they occupied
75            // can be reused in `BytesMut`.
76            iter.next().unwrap();
77
78            if iter.as_slice().is_empty() {
79                debug_assert_eq!(n, 0);
80                break 'outer;
81            }
82        }
83
84        if n != 0 {
85            // iter must not be empty
86            let first = &mut iter.as_mut_slice()[0];
87            // n < buffer[start].len()
88            *first = first.slice(n..);
89        }
90    }
91
92    Ok(())
93}
94
95#[cfg(test)]
96mod tests {
97    use std::{
98        iter,
99        num::NonZeroUsize,
100        task::{Context, Poll},
101    };
102
103    use bytes::BytesMut;
104    use tokio::io::AsyncWrite;
105
106    use super::*;
107    use crate::IoSliceExt;
108
109    /// Limit number of bytes that can be sent for each write.
110    struct WriterRateLimit(usize, Vec<u8>);
111
112    impl AsyncWrite for WriterRateLimit {
113        fn poll_write(
114            mut self: Pin<&mut Self>,
115            cx: &mut Context<'_>,
116            buf: &[u8],
117        ) -> Poll<io::Result<usize>> {
118            let n = buf.len().min(self.0);
119            let buf = &buf[..n];
120
121            Pin::new(&mut self.1).poll_write(cx, buf)
122        }
123        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124            Pin::new(&mut self.1).poll_flush(cx)
125        }
126        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
127            Pin::new(&mut self.1).poll_shutdown(cx)
128        }
129
130        fn poll_write_vectored(
131            mut self: Pin<&mut Self>,
132            _cx: &mut Context<'_>,
133            bufs: &[io::IoSlice<'_>],
134        ) -> Poll<io::Result<usize>> {
135            let mut cnt = 0;
136            let n = self.0;
137            bufs.iter()
138                .copied()
139                .filter_map(|io_slice| {
140                    (n > cnt).then(|| {
141                        let n = io_slice.len().min(n - cnt);
142                        cnt += n;
143                        &io_slice.into_inner()[..n]
144                    })
145                })
146                .for_each(|slice| {
147                    self.1.extend(slice);
148                });
149
150            Poll::Ready(Ok(cnt))
151        }
152
153        fn is_write_vectored(&self) -> bool {
154            true
155        }
156    }
157
158    #[test]
159    fn test() {
160        tokio::runtime::Builder::new_current_thread()
161            .enable_all()
162            .build()
163            .unwrap()
164            .block_on(async {
165                let bytes: BytesMut = (0..255).collect();
166                let bytes = bytes.freeze();
167
168                let iter = iter::once(bytes).cycle().take(20);
169
170                let expected_bytes: Vec<u8> = iter.clone().flatten().collect();
171
172                let mut reusable_io_slices = ReusableIoSlices::new(NonZeroUsize::new(3).unwrap());
173
174                // Emulate a pipe where each time only half of the Bytes can be
175                // written.
176                let writer = WriterRateLimit(255 / 2, Vec::new());
177                tokio::pin!(writer);
178
179                write_all_bytes(
180                    writer.as_mut(),
181                    &mut iter.clone().collect(),
182                    &mut reusable_io_slices,
183                )
184                .await
185                .unwrap();
186
187                assert_eq!(writer.1, expected_bytes);
188
189                // Emulate a pipe where each time exactly one Bytes can be
190                // written.
191                writer.0 = 255;
192                writer.1.clear();
193
194                write_all_bytes(
195                    writer.as_mut(),
196                    &mut iter.clone().collect(),
197                    &mut reusable_io_slices,
198                )
199                .await
200                .unwrap();
201
202                assert_eq!(writer.1, expected_bytes);
203
204                // Emulate a pipe where each time one and a half Bytes can be
205                // written.
206                writer.0 = 255 + 255 / 2;
207                writer.1.clear();
208
209                write_all_bytes(
210                    writer.as_mut(),
211                    &mut iter.clone().collect(),
212                    &mut reusable_io_slices,
213                )
214                .await
215                .unwrap();
216
217                assert_eq!(writer.1, expected_bytes);
218
219                // Emulate a pipe where each time one Bytes and a little bit
220                // of the next Bytes can be written.
221                writer.0 = 255 + 5;
222                writer.1.clear();
223
224                write_all_bytes(
225                    writer.as_mut(),
226                    &mut iter.clone().collect(),
227                    &mut reusable_io_slices,
228                )
229                .await
230                .unwrap();
231
232                assert_eq!(writer.1, expected_bytes);
233            });
234    }
235}