compio_io/write/
mod.rs

1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3use std::io::Cursor;
4
5use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf, OwnedIterator, buf_try, t_alloc};
6
7use crate::IoResult;
8
9mod buf;
10#[macro_use]
11mod ext;
12
13pub use buf::*;
14pub use ext::*;
15
16/// # AsyncWrite
17///
18/// Async write with a ownership of a buffer
19pub trait AsyncWrite {
20    /// Write some bytes from the buffer into this source and return a
21    /// [`BufResult`], consisting of the buffer and a [`usize`] indicating how
22    /// many bytes were written.
23    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T>;
24
25    /// Like `write`, except that it write bytes from a buffer implements
26    /// [`IoVectoredBuf`] into the source.
27    ///
28    /// The default implementation will try to write from the buffers in order
29    /// as if they're concatenated. It will stop whenever the writer returns
30    /// an error, `Ok(0)`, or a length less than the length of the buf passed
31    /// in, meaning it's possible that not all contents are written. If
32    /// guaranteed full write is desired, it is recommended to use
33    /// [`AsyncWriteExt::write_vectored_all`] instead.
34    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
35        loop_write_vectored!(buf, iter, self.write(iter))
36    }
37
38    /// Attempts to flush the object, ensuring that any buffered data reach
39    /// their destination.
40    async fn flush(&mut self) -> IoResult<()>;
41
42    /// Initiates or attempts to shut down this writer, returning success when
43    /// the I/O connection has completely shut down.
44    async fn shutdown(&mut self) -> IoResult<()>;
45}
46
47impl<A: AsyncWrite + ?Sized> AsyncWrite for &mut A {
48    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
49        (**self).write(buf).await
50    }
51
52    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
53        (**self).write_vectored(buf).await
54    }
55
56    async fn flush(&mut self) -> IoResult<()> {
57        (**self).flush().await
58    }
59
60    async fn shutdown(&mut self) -> IoResult<()> {
61        (**self).shutdown().await
62    }
63}
64
65impl<W: AsyncWrite + ?Sized, #[cfg(feature = "allocator_api")] A: Allocator> AsyncWrite
66    for t_alloc!(Box, W, A)
67{
68    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
69        (**self).write(buf).await
70    }
71
72    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
73        (**self).write_vectored(buf).await
74    }
75
76    async fn flush(&mut self) -> IoResult<()> {
77        (**self).flush().await
78    }
79
80    async fn shutdown(&mut self) -> IoResult<()> {
81        (**self).shutdown().await
82    }
83}
84
85/// Write is implemented for `Vec<u8>` by appending to the vector. The vector
86/// will grow as needed.
87impl<#[cfg(feature = "allocator_api")] A: Allocator> AsyncWrite for t_alloc!(Vec, u8, A) {
88    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
89        self.extend_from_slice(buf.as_slice());
90        BufResult(Ok(buf.buf_len()), buf)
91    }
92
93    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
94        let len = buf.iter_buf().map(|b| b.buf_len()).sum();
95        self.reserve(len - self.len());
96        for buf in buf.iter_buf() {
97            self.extend_from_slice(buf.as_slice());
98        }
99        BufResult(Ok(len), buf)
100    }
101
102    async fn flush(&mut self) -> IoResult<()> {
103        Ok(())
104    }
105
106    async fn shutdown(&mut self) -> IoResult<()> {
107        Ok(())
108    }
109}
110
111/// # AsyncWriteAt
112///
113/// Async write with a ownership of a buffer and a position
114pub trait AsyncWriteAt {
115    /// Like [`AsyncWrite::write`], except that it writes at a specified
116    /// position.
117    async fn write_at<T: IoBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T>;
118
119    /// Like [`AsyncWrite::write_vectored`], except that it writes at a
120    /// specified position.
121    async fn write_vectored_at<T: IoVectoredBuf>(
122        &mut self,
123        buf: T,
124        pos: u64,
125    ) -> BufResult<usize, T> {
126        loop_write_vectored!(buf, iter, self.write_at(iter, pos))
127    }
128}
129
130impl<A: AsyncWriteAt + ?Sized> AsyncWriteAt for &mut A {
131    async fn write_at<T: IoBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T> {
132        (**self).write_at(buf, pos).await
133    }
134
135    async fn write_vectored_at<T: IoVectoredBuf>(
136        &mut self,
137        buf: T,
138        pos: u64,
139    ) -> BufResult<usize, T> {
140        (**self).write_vectored_at(buf, pos).await
141    }
142}
143
144impl<W: AsyncWriteAt + ?Sized, #[cfg(feature = "allocator_api")] A: Allocator> AsyncWriteAt
145    for t_alloc!(Box, W, A)
146{
147    async fn write_at<T: IoBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T> {
148        (**self).write_at(buf, pos).await
149    }
150
151    async fn write_vectored_at<T: IoVectoredBuf>(
152        &mut self,
153        buf: T,
154        pos: u64,
155    ) -> BufResult<usize, T> {
156        (**self).write_vectored_at(buf, pos).await
157    }
158}
159
160impl AsyncWrite for &mut [u8] {
161    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
162        let slice = buf.as_slice();
163        BufResult(std::io::Write::write(self, slice), buf)
164    }
165
166    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
167        let mut iter = match buf.owned_iter() {
168            Ok(buf) => buf,
169            Err(buf) => return BufResult(Ok(0), buf),
170        };
171        let mut total = 0;
172        loop {
173            let n = match std::io::Write::write(self, iter.as_slice()) {
174                Ok(n) => n,
175                // TODO: unlikely
176                Err(e) => return BufResult(Err(e), iter.into_inner()),
177            };
178            total += n;
179            if self.is_empty() {
180                return BufResult(Ok(total), iter.into_inner());
181            }
182            match iter.next() {
183                Ok(next) => iter = next,
184                Err(buf) => return BufResult(Ok(total), buf),
185            }
186        }
187    }
188
189    async fn flush(&mut self) -> IoResult<()> {
190        Ok(())
191    }
192
193    async fn shutdown(&mut self) -> IoResult<()> {
194        Ok(())
195    }
196}
197
198macro_rules! impl_write_at {
199    ($($(const $len:ident =>)? $ty:ty),*) => {
200        $(
201            impl<$(const $len: usize)?> AsyncWriteAt for $ty {
202                async fn write_at<T: IoBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T> {
203                    let pos = (pos as usize).min(self.len());
204                    let slice = buf.as_slice();
205                    let n = slice.len().min(self.len() - pos);
206                    self[pos..pos + n].copy_from_slice(&slice[..n]);
207                    BufResult(Ok(n), buf)
208                }
209
210                async fn write_vectored_at<T: IoVectoredBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T> {
211                    let mut iter = match buf.owned_iter() {
212                        Ok(buf) => buf,
213                        Err(buf) => return BufResult(Ok(0), buf),
214                    };
215                    let mut total = 0;
216                    loop {
217                        let n;
218                        (n, iter) = match self.write_at(iter, pos + total as u64).await {
219                            BufResult(Ok(n), iter) => (n, iter),
220                            // TODO: unlikely
221                            BufResult(Err(e), iter) => return BufResult(Err(e), iter.into_inner()),
222                        };
223                        total += n;
224                        if self.is_empty() {
225                            return BufResult(Ok(total), iter.into_inner());
226                        }
227                        match iter.next() {
228                            Ok(next) => iter = next,
229                            Err(buf) => return BufResult(Ok(total), buf),
230                        }
231                    }
232                }
233            }
234        )*
235    }
236}
237
238impl_write_at!([u8], const LEN => [u8; LEN]);
239
240/// This implementation aligns the behavior of files. If `pos` is larger than
241/// the vector length, the vectored will be extended, and the extended area will
242/// be filled with 0.
243impl<#[cfg(feature = "allocator_api")] A: Allocator> AsyncWriteAt for t_alloc!(Vec, u8, A) {
244    async fn write_at<T: IoBuf>(&mut self, buf: T, pos: u64) -> BufResult<usize, T> {
245        let pos = pos as usize;
246        let slice = buf.as_slice();
247        if pos <= self.len() {
248            let n = slice.len().min(self.len() - pos);
249            if n < slice.len() {
250                self.reserve(slice.len() - n);
251                self[pos..pos + n].copy_from_slice(&slice[..n]);
252                self.extend_from_slice(&slice[n..]);
253            } else {
254                self[pos..pos + n].copy_from_slice(slice);
255            }
256        } else {
257            self.reserve(pos - self.len() + slice.len());
258            self.resize(pos, 0);
259            self.extend_from_slice(slice);
260        }
261        BufResult(Ok(slice.len()), buf)
262    }
263
264    async fn write_vectored_at<T: IoVectoredBuf>(
265        &mut self,
266        buf: T,
267        pos: u64,
268    ) -> BufResult<usize, T> {
269        let mut pos = pos as usize;
270        let len = buf.iter_buf().map(|b| b.buf_len()).sum();
271        if pos <= self.len() {
272            self.reserve(len - (self.len() - pos));
273        } else {
274            self.reserve(pos - self.len() + len);
275            self.resize(pos, 0);
276        }
277        for buf in buf.iter_buf() {
278            let slice = buf.as_slice();
279            if pos <= self.len() {
280                let n = slice.len().min(self.len() - pos);
281                if n < slice.len() {
282                    self[pos..pos + n].copy_from_slice(&slice[..n]);
283                    self.extend_from_slice(&slice[n..]);
284                } else {
285                    self[pos..pos + n].copy_from_slice(slice);
286                }
287            } else {
288                self.extend_from_slice(slice);
289            }
290            pos += slice.len();
291        }
292        BufResult(Ok(len), buf)
293    }
294}
295
296impl<A: AsyncWriteAt> AsyncWrite for Cursor<A> {
297    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
298        let pos = self.position();
299        let (n, buf) = buf_try!(self.get_mut().write_at(buf, pos).await);
300        self.set_position(pos + n as u64);
301        BufResult(Ok(n), buf)
302    }
303
304    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
305        let pos = self.position();
306        let (n, buf) = buf_try!(self.get_mut().write_vectored_at(buf, pos).await);
307        self.set_position(pos + n as u64);
308        BufResult(Ok(n), buf)
309    }
310
311    async fn flush(&mut self) -> IoResult<()> {
312        Ok(())
313    }
314
315    async fn shutdown(&mut self) -> IoResult<()> {
316        Ok(())
317    }
318}