compio_io/write/
ext.rs

1use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf};
2
3use crate::{AsyncWrite, AsyncWriteAt, IoResult};
4
5/// Shared code for write a scalar value into the underlying writer.
6macro_rules! write_scalar {
7    ($t:ty, $be:ident, $le:ident) => {
8        ::paste::paste! {
9            #[doc = concat!("Write a big endian `", stringify!($t), "` into the underlying writer.")]
10            async fn [< write_ $t >](&mut self, num: $t) -> IoResult<()> {
11                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
12
13                const LEN: usize = ::std::mem::size_of::<$t>();
14                let BufResult(res, _) = self
15                    .write_all(ArrayVec::<u8, LEN>::from(num.$be()))
16                    .await;
17                res
18            }
19
20            #[doc = concat!("Write a little endian `", stringify!($t), "` into the underlying writer.")]
21            async fn [< write_ $t _le >](&mut self, num: $t) -> IoResult<()> {
22                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
23
24                const LEN: usize = ::std::mem::size_of::<$t>();
25                let BufResult(res, _) = self
26                    .write_all(ArrayVec::<u8, LEN>::from(num.$le()))
27                    .await;
28                res
29            }
30        }
31    };
32}
33
34/// Shared code for loop writing until all contents are written.
35macro_rules! loop_write_all {
36    ($buf:ident, $len:expr, $needle:ident,loop $expr_expr:expr) => {
37        let len = $len;
38        let mut $needle = 0;
39
40        while $needle < len {
41            match $expr_expr.await.into_inner() {
42                BufResult(Ok(0), buf) => {
43                    return BufResult(
44                        Err(::std::io::Error::new(
45                            ::std::io::ErrorKind::WriteZero,
46                            "failed to write whole buffer",
47                        )),
48                        buf,
49                    );
50                }
51                BufResult(Ok(n), buf) => {
52                    $needle += n;
53                    $buf = buf;
54                }
55                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
56                    $buf = buf;
57                }
58                BufResult(Err(e), buf) => return BufResult(Err(e), buf),
59            }
60        }
61
62        return BufResult(Ok(()), $buf);
63    };
64}
65
66macro_rules! loop_write_vectored {
67    ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
68        use ::compio_buf::OwnedIterator;
69
70        let mut $iter = match $buf.owned_iter() {
71            Ok(buf) => buf,
72            Err(buf) => return BufResult(Ok(()), buf),
73        };
74        let mut $tracker: $tracker_ty = 0;
75
76        loop {
77            let len = $iter.buf_len();
78            if len == 0 {
79                continue;
80            }
81
82            match $read_expr.await {
83                BufResult(Ok(()), ret) => {
84                    $iter = ret;
85                    $tracker += len as $tracker_ty;
86                }
87                BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
88            };
89
90            match $iter.next() {
91                Ok(next) => $iter = next,
92                Err(buf) => return BufResult(Ok(()), buf),
93            }
94        }
95    }};
96    (
97        $buf:ident,
98        $tracker:ident :
99        $tracker_ty:ty,
100        $res:ident,
101        $iter:ident,loop
102        $read_expr:expr,break
103        $judge_expr:expr
104    ) => {{
105        use ::compio_buf::OwnedIterator;
106
107        let mut $iter = match $buf.owned_iter() {
108            Ok(buf) => buf,
109            Err(buf) => return BufResult(Ok(0), buf),
110        };
111        let mut $tracker: $tracker_ty = 0;
112
113        loop {
114            if $iter.buf_len() == 0 {
115                continue;
116            }
117
118            match $read_expr.await {
119                BufResult(Ok($res), ret) => {
120                    $iter = ret;
121                    $tracker += $res as $tracker_ty;
122                    if let Some(res) = $judge_expr {
123                        return BufResult(res, $iter.into_inner());
124                    }
125                }
126                BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
127            };
128
129            match $iter.next() {
130                Ok(next) => $iter = next,
131                Err(buf) => return BufResult(Ok($tracker as usize), buf),
132            }
133        }
134    }};
135}
136
137/// Implemented as an extension trait, adding utility methods to all
138/// [`AsyncWrite`] types. Callers will tend to import this trait instead of
139/// [`AsyncWrite`].
140pub trait AsyncWriteExt: AsyncWrite {
141    /// Creates a "by reference" adaptor for this instance of [`AsyncWrite`].
142    ///
143    /// The returned adapter also implements [`AsyncWrite`] and will simply
144    /// borrow this current writer.
145    fn by_ref(&mut self) -> &mut Self
146    where
147        Self: Sized,
148    {
149        self
150    }
151
152    /// Write the entire contents of a buffer into this writer.
153    async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
154        loop_write_all!(
155            buf,
156            buf.buf_len(),
157            needle,
158            loop self.write(buf.slice(needle..))
159        );
160    }
161
162    /// Write the entire contents of a buffer into this writer. Like
163    /// [`AsyncWrite::write_vectored`], except that it tries to write the entire
164    /// contents of the buffer into this writer.
165    async fn write_vectored_all<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<(), T> {
166        loop_write_vectored!(buf, _total: usize, iter, loop self.write_all(iter))
167    }
168
169    write_scalar!(u8, to_be_bytes, to_le_bytes);
170    write_scalar!(u16, to_be_bytes, to_le_bytes);
171    write_scalar!(u32, to_be_bytes, to_le_bytes);
172    write_scalar!(u64, to_be_bytes, to_le_bytes);
173    write_scalar!(u128, to_be_bytes, to_le_bytes);
174    write_scalar!(i8, to_be_bytes, to_le_bytes);
175    write_scalar!(i16, to_be_bytes, to_le_bytes);
176    write_scalar!(i32, to_be_bytes, to_le_bytes);
177    write_scalar!(i64, to_be_bytes, to_le_bytes);
178    write_scalar!(i128, to_be_bytes, to_le_bytes);
179    write_scalar!(f32, to_be_bytes, to_le_bytes);
180    write_scalar!(f64, to_be_bytes, to_le_bytes);
181}
182
183impl<A: AsyncWrite + ?Sized> AsyncWriteExt for A {}
184
185/// Implemented as an extension trait, adding utility methods to all
186/// [`AsyncWriteAt`] types. Callers will tend to import this trait instead of
187/// [`AsyncWriteAt`].
188pub trait AsyncWriteAtExt: AsyncWriteAt {
189    /// Like [`AsyncWriteAt::write_at`], except that it tries to write the
190    /// entire contents of the buffer into this writer.
191    async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<(), T> {
192        loop_write_all!(
193            buf,
194            buf.buf_len(),
195            needle,
196            loop self.write_at(buf.slice(needle..), pos + needle as u64)
197        );
198    }
199
200    /// Like [`AsyncWriteAt::write_vectored_at`], expect that it tries to write
201    /// the entire contents of the buffer into this writer.
202    async fn write_vectored_all_at<T: IoVectoredBuf>(
203        &mut self,
204        buf: T,
205        pos: u64,
206    ) -> BufResult<(), T> {
207        loop_write_vectored!(buf, total: u64, iter, loop self.write_all_at(iter, pos + total))
208    }
209}
210
211impl<A: AsyncWriteAt + ?Sized> AsyncWriteAtExt for A {}