compio_io/read/
ext.rs

1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3
4use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, t_alloc};
5
6use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
7
8/// Shared code for read a scalar value from the underlying reader.
9macro_rules! read_scalar {
10    ($t:ty, $be:ident, $le:ident) => {
11        ::paste::paste! {
12            #[doc = concat!("Read a big endian `", stringify!($t), "` from the underlying reader.")]
13            async fn [< read_ $t >](&mut self) -> IoResult<$t> {
14                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
15
16                const LEN: usize = ::std::mem::size_of::<$t>();
17                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
18                res?;
19                // Safety: We just checked that the buffer is the correct size
20                Ok($t::$be(unsafe { buf.into_inner_unchecked() }))
21            }
22
23            #[doc = concat!("Read a little endian `", stringify!($t), "` from the underlying reader.")]
24            async fn [< read_ $t _le >](&mut self) -> IoResult<$t> {
25                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
26
27                const LEN: usize = ::std::mem::size_of::<$t>();
28                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
29                res?;
30                // Safety: We just checked that the buffer is the correct size
31                Ok($t::$le(unsafe { buf.into_inner_unchecked() }))
32            }
33        }
34    };
35}
36
37/// Shared code for loop reading until reaching a certain length.
38macro_rules! loop_read_exact {
39    ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
40        let mut $tracker = 0;
41        let len = $len;
42
43        while $tracker < len {
44            match $read_expr.await.into_inner() {
45                BufResult(Ok(0), buf) => {
46                    return BufResult(
47                        Err(::std::io::Error::new(
48                            ::std::io::ErrorKind::UnexpectedEof,
49                            "failed to fill whole buffer",
50                        )),
51                        buf,
52                    );
53                }
54                BufResult(Ok(n), buf) => {
55                    $tracker += n;
56                    $buf = buf;
57                }
58                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
59                    $buf = buf;
60                }
61                BufResult(Err(e), buf) => return BufResult(Err(e), buf),
62            }
63        }
64        return BufResult(Ok(()), $buf)
65    };
66}
67
68macro_rules! loop_read_vectored {
69    ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
70        use ::compio_buf::OwnedIterator;
71
72        let mut $iter = match $buf.owned_iter() {
73            Ok(buf) => buf,
74            Err(buf) => return BufResult(Ok(()), buf),
75        };
76        let mut $tracker: $tracker_ty = 0;
77
78        loop {
79            let len = $iter.buf_capacity();
80            if len > 0 {
81                match $read_expr.await {
82                    BufResult(Ok(()), ret) => {
83                        $iter = ret;
84                        $tracker += len as $tracker_ty;
85                    }
86                    BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
87                };
88            }
89
90            match $iter.next() {
91                Ok(next) => $iter = next,
92                Err(buf) => return BufResult(Ok(()), buf),
93            }
94        }
95    }};
96    ($buf:ident, $iter:ident, $read_expr:expr) => {{
97        use ::compio_buf::OwnedIterator;
98
99        let mut $iter = match $buf.owned_iter() {
100            Ok(buf) => buf,
101            Err(buf) => return BufResult(Ok(0), buf),
102        };
103
104        loop {
105            let len = $iter.buf_capacity();
106            if len > 0 {
107                return $read_expr.await.into_inner();
108            }
109
110            match $iter.next() {
111                Ok(next) => $iter = next,
112                Err(buf) => return BufResult(Ok(0), buf),
113            }
114        }
115    }};
116}
117
118macro_rules! loop_read_to_end {
119    ($buf:ident, $tracker:ident : $tracker_ty:ty,loop $read_expr:expr) => {{
120        let mut $tracker: $tracker_ty = 0;
121        loop {
122            if $buf.len() == $buf.capacity() {
123                $buf.reserve(32);
124            }
125            match $read_expr.await.into_inner() {
126                BufResult(Ok(0), buf) => {
127                    $buf = buf;
128                    break;
129                }
130                BufResult(Ok(read), buf) => {
131                    $tracker += read as $tracker_ty;
132                    $buf = buf;
133                }
134                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
135                    $buf = buf
136                }
137                res => return res,
138            }
139        }
140        BufResult(Ok($tracker as usize), $buf)
141    }};
142}
143
144/// Implemented as an extension trait, adding utility methods to all
145/// [`AsyncRead`] types. Callers will tend to import this trait instead of
146/// [`AsyncRead`].
147pub trait AsyncReadExt: AsyncRead {
148    /// Creates a "by reference" adaptor for this instance of [`AsyncRead`].
149    ///
150    /// The returned adapter also implements [`AsyncRead`] and will simply
151    /// borrow this current reader.
152    fn by_ref(&mut self) -> &mut Self
153    where
154        Self: Sized,
155    {
156        self
157    }
158
159    /// Read the exact number of bytes required to fill the buf.
160    async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
161        loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
162    }
163
164    /// Read all bytes until underlying reader reaches `EOF`.
165    async fn read_to_end<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
166        &mut self,
167        mut buf: t_alloc!(Vec, u8, A),
168    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
169        loop_read_to_end!(buf, total: usize, loop self.read(buf.slice(total..)))
170    }
171
172    /// Read the exact number of bytes required to fill the vectored buf.
173    async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
174        loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
175    }
176
177    /// Creates an adaptor which reads at most `limit` bytes from it.
178    ///
179    /// This function returns a new instance of `AsyncRead` which will read
180    /// at most `limit` bytes, after which it will always return EOF
181    /// (`Ok(0)`). Any read errors will not count towards the number of
182    /// bytes read and future calls to [`read()`] may succeed.
183    ///
184    /// [`read()`]: AsyncRead::read
185    fn take(self, limit: u64) -> Take<Self>
186    where
187        Self: Sized,
188    {
189        Take::new(self, limit)
190    }
191
192    read_scalar!(u8, from_be_bytes, from_le_bytes);
193    read_scalar!(u16, from_be_bytes, from_le_bytes);
194    read_scalar!(u32, from_be_bytes, from_le_bytes);
195    read_scalar!(u64, from_be_bytes, from_le_bytes);
196    read_scalar!(u128, from_be_bytes, from_le_bytes);
197    read_scalar!(i8, from_be_bytes, from_le_bytes);
198    read_scalar!(i16, from_be_bytes, from_le_bytes);
199    read_scalar!(i32, from_be_bytes, from_le_bytes);
200    read_scalar!(i64, from_be_bytes, from_le_bytes);
201    read_scalar!(i128, from_be_bytes, from_le_bytes);
202    read_scalar!(f32, from_be_bytes, from_le_bytes);
203    read_scalar!(f64, from_be_bytes, from_le_bytes);
204}
205
206impl<A: AsyncRead + ?Sized> AsyncReadExt for A {}
207
208/// Implemented as an extension trait, adding utility methods to all
209/// [`AsyncReadAt`] types. Callers will tend to import this trait instead of
210/// [`AsyncReadAt`].
211pub trait AsyncReadAtExt: AsyncReadAt {
212    /// Read the exact number of bytes required to fill `buffer`.
213    ///
214    /// This function reads as many bytes as necessary to completely fill the
215    /// uninitialized space of specified `buffer`.
216    ///
217    /// # Errors
218    ///
219    /// If this function encounters an "end of file" before completely filling
220    /// the buffer, it returns an error of the kind
221    /// [`ErrorKind::UnexpectedEof`]. The contents of `buffer` are unspecified
222    /// in this case.
223    ///
224    /// If any other read error is encountered then this function immediately
225    /// returns. The contents of `buffer` are unspecified in this case.
226    ///
227    /// If this function returns an error, it is unspecified how many bytes it
228    /// has read, but it will never read more than would be necessary to
229    /// completely fill the buffer.
230    ///
231    /// [`ErrorKind::UnexpectedEof`]: std::io::ErrorKind::UnexpectedEof
232    async fn read_exact_at<T: IoBufMut>(&self, mut buf: T, pos: u64) -> BufResult<(), T> {
233        loop_read_exact!(
234            buf,
235            buf.buf_capacity(),
236            read,
237            loop self.read_at(buf.slice(read..), pos + read as u64)
238        );
239    }
240
241    /// Read all bytes until EOF in this source, placing them into `buffer`.
242    ///
243    /// All bytes read from this source will be appended to the specified buffer
244    /// `buffer`. This function will continuously call [`read_at()`] to append
245    /// more data to `buffer` until [`read_at()`] returns [`Ok(0)`].
246    ///
247    /// If successful, this function will return the total number of bytes read.
248    ///
249    /// [`read_at()`]: AsyncReadAt::read_at
250    async fn read_to_end_at<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
251        &self,
252        mut buffer: t_alloc!(Vec, u8, A),
253        pos: u64,
254    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
255        loop_read_to_end!(buffer, total: u64, loop self.read_at(buffer.slice(total as usize..), pos + total))
256    }
257
258    /// Like [`AsyncReadExt::read_vectored_exact`], expect that it reads at a
259    /// specified position.
260    async fn read_vectored_exact_at<T: IoVectoredBufMut>(
261        &self,
262        buf: T,
263        pos: u64,
264    ) -> BufResult<(), T> {
265        loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
266    }
267}
268
269impl<A: AsyncReadAt + ?Sized> AsyncReadAtExt for A {}