compio_io/read/
ext.rs

1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3
4use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, t_alloc};
5
6use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
7
8/// Shared code for read a scalar value from the underlying reader.
9macro_rules! read_scalar {
10    ($t:ty, $be:ident, $le:ident) => {
11        ::paste::paste! {
12            #[doc = concat!("Read a big endian `", stringify!($t), "` from the underlying reader.")]
13            async fn [< read_ $t >](&mut self) -> IoResult<$t> {
14                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
15
16                const LEN: usize = ::std::mem::size_of::<$t>();
17                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
18                res?;
19                // Safety: We just checked that the buffer is the correct size
20                Ok($t::$be(unsafe { buf.into_inner_unchecked() }))
21            }
22
23            #[doc = concat!("Read a little endian `", stringify!($t), "` from the underlying reader.")]
24            async fn [< read_ $t _le >](&mut self) -> IoResult<$t> {
25                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
26
27                const LEN: usize = ::std::mem::size_of::<$t>();
28                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
29                res?;
30                // Safety: We just checked that the buffer is the correct size
31                Ok($t::$le(unsafe { buf.into_inner_unchecked() }))
32            }
33        }
34    };
35}
36
37/// Shared code for loop reading until reaching a certain length.
38macro_rules! loop_read_exact {
39    ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
40        let mut $tracker = 0;
41        let len = $len;
42
43        while $tracker < len {
44            match $read_expr.await.into_inner() {
45                BufResult(Ok(0), buf) => {
46                    return BufResult(
47                        Err(::std::io::Error::new(
48                            ::std::io::ErrorKind::UnexpectedEof,
49                            "failed to fill whole buffer",
50                        )),
51                        buf,
52                    );
53                }
54                BufResult(Ok(n), buf) => {
55                    $tracker += n;
56                    $buf = buf;
57                }
58                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
59                    $buf = buf;
60                }
61                BufResult(Err(e), buf) => return BufResult(Err(e), buf),
62            }
63        }
64        return BufResult(Ok(()), $buf)
65    };
66}
67
68macro_rules! loop_read_vectored {
69    ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
70        use ::compio_buf::OwnedIterator;
71
72        let mut $iter = match $buf.owned_iter() {
73            Ok(buf) => buf,
74            Err(buf) => return BufResult(Ok(()), buf),
75        };
76        let mut $tracker: $tracker_ty = 0;
77
78        loop {
79            let len = $iter.buf_capacity();
80            if len == 0 {
81                continue;
82            }
83
84            match $read_expr.await {
85                BufResult(Ok(()), ret) => {
86                    $iter = ret;
87                    $tracker += len as $tracker_ty;
88                }
89                BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
90            };
91
92            match $iter.next() {
93                Ok(next) => $iter = next,
94                Err(buf) => return BufResult(Ok(()), buf),
95            }
96        }
97    }};
98    (
99        $buf:ident,
100        $len:ident,
101        $tracker:ident :
102        $tracker_ty:ty,
103        $res:ident,
104        $iter:ident,loop
105        $read_expr:expr,break
106        $judge_expr:expr
107    ) => {{
108        use ::compio_buf::OwnedIterator;
109
110        let mut $iter = match $buf.owned_iter() {
111            Ok(buf) => buf,
112            Err(buf) => return BufResult(Ok(0), buf),
113        };
114        let mut $tracker: $tracker_ty = 0;
115
116        loop {
117            let $len = $iter.buf_capacity();
118            if $len == 0 {
119                continue;
120            }
121
122            match $read_expr.await {
123                BufResult(Ok($res), ret) => {
124                    $iter = ret;
125                    $tracker += $res as $tracker_ty;
126                    if let Some(res) = $judge_expr {
127                        return BufResult(res, $iter.into_inner());
128                    }
129                }
130                BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
131            };
132
133            match $iter.next() {
134                Ok(next) => $iter = next,
135                Err(buf) => return BufResult(Ok($tracker as usize), buf),
136            }
137        }
138    }};
139}
140
141macro_rules! loop_read_to_end {
142    ($buf:ident, $tracker:ident : $tracker_ty:ty,loop $read_expr:expr) => {{
143        let mut $tracker: $tracker_ty = 0;
144        loop {
145            if $buf.len() == $buf.capacity() {
146                $buf.reserve(32);
147            }
148            match $read_expr.await.into_inner() {
149                BufResult(Ok(0), buf) => {
150                    $buf = buf;
151                    break;
152                }
153                BufResult(Ok(read), buf) => {
154                    $tracker += read as $tracker_ty;
155                    $buf = buf;
156                }
157                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
158                    $buf = buf
159                }
160                res => return res,
161            }
162        }
163        BufResult(Ok($tracker as usize), $buf)
164    }};
165}
166
167/// Implemented as an extension trait, adding utility methods to all
168/// [`AsyncRead`] types. Callers will tend to import this trait instead of
169/// [`AsyncRead`].
170pub trait AsyncReadExt: AsyncRead {
171    /// Creates a "by reference" adaptor for this instance of [`AsyncRead`].
172    ///
173    /// The returned adapter also implements [`AsyncRead`] and will simply
174    /// borrow this current reader.
175    fn by_ref(&mut self) -> &mut Self
176    where
177        Self: Sized,
178    {
179        self
180    }
181
182    /// Read the exact number of bytes required to fill the buf.
183    async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
184        loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
185    }
186
187    /// Read all bytes until underlying reader reaches `EOF`.
188    async fn read_to_end<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
189        &mut self,
190        mut buf: t_alloc!(Vec, u8, A),
191    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
192        loop_read_to_end!(buf, total: usize, loop self.read(buf.slice(total..)))
193    }
194
195    /// Read the exact number of bytes required to fill the vectored buf.
196    async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
197        loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
198    }
199
200    /// Creates an adaptor which reads at most `limit` bytes from it.
201    ///
202    /// This function returns a new instance of `AsyncRead` which will read
203    /// at most `limit` bytes, after which it will always return EOF
204    /// (`Ok(0)`). Any read errors will not count towards the number of
205    /// bytes read and future calls to [`read()`] may succeed.
206    ///
207    /// [`read()`]: AsyncRead::read
208    fn take(self, limit: u64) -> Take<Self>
209    where
210        Self: Sized,
211    {
212        Take::new(self, limit)
213    }
214
215    read_scalar!(u8, from_be_bytes, from_le_bytes);
216    read_scalar!(u16, from_be_bytes, from_le_bytes);
217    read_scalar!(u32, from_be_bytes, from_le_bytes);
218    read_scalar!(u64, from_be_bytes, from_le_bytes);
219    read_scalar!(u128, from_be_bytes, from_le_bytes);
220    read_scalar!(i8, from_be_bytes, from_le_bytes);
221    read_scalar!(i16, from_be_bytes, from_le_bytes);
222    read_scalar!(i32, from_be_bytes, from_le_bytes);
223    read_scalar!(i64, from_be_bytes, from_le_bytes);
224    read_scalar!(i128, from_be_bytes, from_le_bytes);
225    read_scalar!(f32, from_be_bytes, from_le_bytes);
226    read_scalar!(f64, from_be_bytes, from_le_bytes);
227}
228
229impl<A: AsyncRead + ?Sized> AsyncReadExt for A {}
230
231/// Implemented as an extension trait, adding utility methods to all
232/// [`AsyncReadAt`] types. Callers will tend to import this trait instead of
233/// [`AsyncReadAt`].
234pub trait AsyncReadAtExt: AsyncReadAt {
235    /// Read the exact number of bytes required to fill `buffer`.
236    ///
237    /// This function reads as many bytes as necessary to completely fill the
238    /// uninitialized space of specified `buffer`.
239    ///
240    /// # Errors
241    ///
242    /// If this function encounters an "end of file" before completely filling
243    /// the buffer, it returns an error of the kind
244    /// [`ErrorKind::UnexpectedEof`]. The contents of `buffer` are unspecified
245    /// in this case.
246    ///
247    /// If any other read error is encountered then this function immediately
248    /// returns. The contents of `buffer` are unspecified in this case.
249    ///
250    /// If this function returns an error, it is unspecified how many bytes it
251    /// has read, but it will never read more than would be necessary to
252    /// completely fill the buffer.
253    ///
254    /// [`ErrorKind::UnexpectedEof`]: std::io::ErrorKind::UnexpectedEof
255    async fn read_exact_at<T: IoBufMut>(&self, mut buf: T, pos: u64) -> BufResult<(), T> {
256        loop_read_exact!(
257            buf,
258            buf.buf_capacity(),
259            read,
260            loop self.read_at(buf.slice(read..), pos + read as u64)
261        );
262    }
263
264    /// Read all bytes until EOF in this source, placing them into `buffer`.
265    ///
266    /// All bytes read from this source will be appended to the specified buffer
267    /// `buffer`. This function will continuously call [`read_at()`] to append
268    /// more data to `buffer` until [`read_at()`] returns [`Ok(0)`].
269    ///
270    /// If successful, this function will return the total number of bytes read.
271    ///
272    /// [`read_at()`]: AsyncReadAt::read_at
273    async fn read_to_end_at<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
274        &self,
275        mut buffer: t_alloc!(Vec, u8, A),
276        pos: u64,
277    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
278        loop_read_to_end!(buffer, total: u64, loop self.read_at(buffer.slice(total as usize..), pos + total))
279    }
280
281    /// Like [`AsyncReadExt::read_vectored_exact`], expect that it reads at a
282    /// specified position.
283    async fn read_vectored_exact_at<T: IoVectoredBufMut>(
284        &self,
285        buf: T,
286        pos: u64,
287    ) -> BufResult<(), T> {
288        loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
289    }
290}
291
292impl<A: AsyncReadAt + ?Sized> AsyncReadAtExt for A {}