tokio_io_utility/
write_bytes.rs1use std::{io, mem, pin::Pin, vec::IntoIter};
2
3use bytes::Bytes;
4use tokio::io::{AsyncWrite, AsyncWriteExt};
5
6use crate::{init_maybeuninit_io_slices_mut, ReusableIoSlices};
7
8#[cfg_attr(docsrs, doc(cfg(feature = "bytes")))]
10pub async fn write_all_bytes(
11 writer: Pin<&mut (dyn AsyncWrite + Send)>,
12 buffer: &mut Vec<Bytes>,
13 reusable_io_slices: &mut ReusableIoSlices,
14) -> io::Result<()> {
15 if buffer.is_empty() {
24 return Ok(());
25 }
26
27 let mut iter = mem::take(buffer).into_iter();
29
30 let res = write_all_bytes_inner(writer, &mut iter, reusable_io_slices).await;
31
32 *buffer = Vec::from_iter(iter);
34
35 res
36}
37
38async fn write_all_bytes_inner(
41 mut writer: Pin<&mut (dyn AsyncWrite + Send)>,
42 iter: &mut IntoIter<Bytes>,
43 reusable_io_slices: &mut ReusableIoSlices,
44) -> io::Result<()> {
45 'outer: loop {
48 let uninit_io_slices = reusable_io_slices.get_mut();
49
50 let io_slices = init_maybeuninit_io_slices_mut(
54 uninit_io_slices,
55 iter.as_slice().iter().map(|bytes| io::IoSlice::new(bytes)),
58 );
59
60 debug_assert!(!io_slices.is_empty());
61
62 let mut n = writer.write_vectored(io_slices).await?;
63
64 if n == 0 {
65 return Err(io::Error::from(io::ErrorKind::WriteZero));
68 }
69
70 while n >= iter.as_slice()[0].len() {
72 n -= iter.as_slice()[0].len();
73
74 iter.next().unwrap();
77
78 if iter.as_slice().is_empty() {
79 debug_assert_eq!(n, 0);
80 break 'outer;
81 }
82 }
83
84 if n != 0 {
85 let first = &mut iter.as_mut_slice()[0];
87 *first = first.slice(n..);
89 }
90 }
91
92 Ok(())
93}
94
95#[cfg(test)]
96mod tests {
97 use std::{
98 iter,
99 num::NonZeroUsize,
100 task::{Context, Poll},
101 };
102
103 use bytes::BytesMut;
104 use tokio::io::AsyncWrite;
105
106 use super::*;
107 use crate::IoSliceExt;
108
109 struct WriterRateLimit(usize, Vec<u8>);
111
112 impl AsyncWrite for WriterRateLimit {
113 fn poll_write(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &[u8],
117 ) -> Poll<io::Result<usize>> {
118 let n = buf.len().min(self.0);
119 let buf = &buf[..n];
120
121 Pin::new(&mut self.1).poll_write(cx, buf)
122 }
123 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124 Pin::new(&mut self.1).poll_flush(cx)
125 }
126 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
127 Pin::new(&mut self.1).poll_shutdown(cx)
128 }
129
130 fn poll_write_vectored(
131 mut self: Pin<&mut Self>,
132 _cx: &mut Context<'_>,
133 bufs: &[io::IoSlice<'_>],
134 ) -> Poll<io::Result<usize>> {
135 let mut cnt = 0;
136 let n = self.0;
137 bufs.iter()
138 .copied()
139 .filter_map(|io_slice| {
140 (n > cnt).then(|| {
141 let n = io_slice.len().min(n - cnt);
142 cnt += n;
143 &io_slice.into_inner()[..n]
144 })
145 })
146 .for_each(|slice| {
147 self.1.extend(slice);
148 });
149
150 Poll::Ready(Ok(cnt))
151 }
152
153 fn is_write_vectored(&self) -> bool {
154 true
155 }
156 }
157
158 #[test]
159 fn test() {
160 tokio::runtime::Builder::new_current_thread()
161 .enable_all()
162 .build()
163 .unwrap()
164 .block_on(async {
165 let bytes: BytesMut = (0..255).collect();
166 let bytes = bytes.freeze();
167
168 let iter = iter::once(bytes).cycle().take(20);
169
170 let expected_bytes: Vec<u8> = iter.clone().flatten().collect();
171
172 let mut reusable_io_slices = ReusableIoSlices::new(NonZeroUsize::new(3).unwrap());
173
174 let writer = WriterRateLimit(255 / 2, Vec::new());
177 tokio::pin!(writer);
178
179 write_all_bytes(
180 writer.as_mut(),
181 &mut iter.clone().collect(),
182 &mut reusable_io_slices,
183 )
184 .await
185 .unwrap();
186
187 assert_eq!(writer.1, expected_bytes);
188
189 writer.0 = 255;
192 writer.1.clear();
193
194 write_all_bytes(
195 writer.as_mut(),
196 &mut iter.clone().collect(),
197 &mut reusable_io_slices,
198 )
199 .await
200 .unwrap();
201
202 assert_eq!(writer.1, expected_bytes);
203
204 writer.0 = 255 + 255 / 2;
207 writer.1.clear();
208
209 write_all_bytes(
210 writer.as_mut(),
211 &mut iter.clone().collect(),
212 &mut reusable_io_slices,
213 )
214 .await
215 .unwrap();
216
217 assert_eq!(writer.1, expected_bytes);
218
219 writer.0 = 255 + 5;
222 writer.1.clear();
223
224 write_all_bytes(
225 writer.as_mut(),
226 &mut iter.clone().collect(),
227 &mut reusable_io_slices,
228 )
229 .await
230 .unwrap();
231
232 assert_eq!(writer.1, expected_bytes);
233 });
234 }
235}