async_std/io/read/
take.rs

1use std::cmp;
2use std::pin::Pin;
3
4use pin_project_lite::pin_project;
5
6use crate::io::{self, BufRead, Read};
7use crate::task::{Context, Poll};
8
9pin_project! {
10    /// Reader adaptor which limits the bytes read from an underlying reader.
11    ///
12    /// This struct is generally created by calling [`take`] on a reader.
13    /// Please see the documentation of [`take`] for more details.
14    ///
15    /// [`take`]: trait.Read.html#method.take
16    #[derive(Debug)]
17    pub struct Take<T> {
18        #[pin]
19        pub(crate) inner: T,
20        pub(crate) limit: u64,
21    }
22}
23
24impl<T> Take<T> {
25    /// Returns the number of bytes that can be read before this instance will
26    /// return EOF.
27    ///
28    /// # Note
29    ///
30    /// This instance may reach `EOF` after reading fewer bytes than indicated by
31    /// this method if the underlying [`Read`] instance reaches EOF.
32    ///
33    /// [`Read`]: trait.Read.html
34    ///
35    /// # Examples
36    ///
37    /// ```no_run
38    /// # fn main() -> async_std::io::Result<()> { async_std::task::block_on(async {
39    /// #
40    /// use async_std::prelude::*;
41    /// use async_std::fs::File;
42    ///
43    /// let f = File::open("foo.txt").await?;
44    ///
45    /// // read at most five bytes
46    /// let handle = f.take(5);
47    ///
48    /// println!("limit: {}", handle.limit());
49    /// #
50    /// #     Ok(()) }) }
51    /// ```
52    pub fn limit(&self) -> u64 {
53        self.limit
54    }
55
56    /// Sets the number of bytes that can be read before this instance will
57    /// return EOF. This is the same as constructing a new `Take` instance, so
58    /// the amount of bytes read and the previous limit value don't matter when
59    /// calling this method.
60    ///
61    /// # Examples
62    ///
63    /// ```no_run
64    /// # fn main() -> async_std::io::Result<()> { async_std::task::block_on(async {
65    /// #
66    /// use async_std::prelude::*;
67    /// use async_std::fs::File;
68    ///
69    /// let f = File::open("foo.txt").await?;
70    ///
71    /// // read at most five bytes
72    /// let mut handle = f.take(5);
73    /// handle.set_limit(10);
74    ///
75    /// assert_eq!(handle.limit(), 10);
76    /// #
77    /// # Ok(()) }) }
78    /// ```
79    pub fn set_limit(&mut self, limit: u64) {
80        self.limit = limit;
81    }
82
83    /// Consumes the `Take`, returning the wrapped reader.
84    ///
85    /// # Examples
86    ///
87    /// ```no_run
88    /// # fn main() -> async_std::io::Result<()> { async_std::task::block_on(async {
89    /// #
90    /// use async_std::prelude::*;
91    /// use async_std::fs::File;
92    ///
93    /// let file = File::open("foo.txt").await?;
94    ///
95    /// let mut buffer = [0; 5];
96    /// let mut handle = file.take(5);
97    /// handle.read(&mut buffer).await?;
98    ///
99    /// let file = handle.into_inner();
100    /// #
101    /// # Ok(()) }) }
102    /// ```
103    pub fn into_inner(self) -> T {
104        self.inner
105    }
106
107    /// Gets a reference to the underlying reader.
108    ///
109    /// # Examples
110    ///
111    /// ```no_run
112    /// # fn main() -> async_std::io::Result<()> { async_std::task::block_on(async {
113    /// #
114    /// use async_std::prelude::*;
115    /// use async_std::fs::File;
116    ///
117    /// let file = File::open("foo.txt").await?;
118    ///
119    /// let mut buffer = [0; 5];
120    /// let mut handle = file.take(5);
121    /// handle.read(&mut buffer).await?;
122    ///
123    /// let file = handle.get_ref();
124    /// #
125    /// # Ok(()) }) }
126    /// ```
127    pub fn get_ref(&self) -> &T {
128        &self.inner
129    }
130
131    /// Gets a mutable reference to the underlying reader.
132    ///
133    /// Care should be taken to avoid modifying the internal I/O state of the
134    /// underlying reader as doing so may corrupt the internal limit of this
135    /// `Take`.
136    ///
137    /// # Examples
138    ///
139    /// ```no_run
140    /// # fn main() -> async_std::io::Result<()> { async_std::task::block_on(async {
141    /// #
142    /// use async_std::prelude::*;
143    /// use async_std::fs::File;
144    ///
145    /// let file = File::open("foo.txt").await?;
146    ///
147    /// let mut buffer = [0; 5];
148    /// let mut handle = file.take(5);
149    /// handle.read(&mut buffer).await?;
150    ///
151    /// let file = handle.get_mut();
152    /// #
153    /// # Ok(()) }) }
154    /// ```
155    pub fn get_mut(&mut self) -> &mut T {
156        &mut self.inner
157    }
158}
159
160impl<T: Read> Read for Take<T> {
161    /// Attempt to read from the `AsyncRead` into `buf`.
162    fn poll_read(
163        self: Pin<&mut Self>,
164        cx: &mut Context<'_>,
165        buf: &mut [u8],
166    ) -> Poll<io::Result<usize>> {
167        let this = self.project();
168        take_read_internal(this.inner, cx, buf, this.limit)
169    }
170}
171
172pub fn take_read_internal<R: Read + ?Sized>(
173    mut rd: Pin<&mut R>,
174    cx: &mut Context<'_>,
175    buf: &mut [u8],
176    limit: &mut u64,
177) -> Poll<io::Result<usize>> {
178    // Don't call into inner reader at all at EOF because it may still block
179    if *limit == 0 {
180        return Poll::Ready(Ok(0));
181    }
182
183    let max = cmp::min(buf.len() as u64, *limit) as usize;
184
185    match futures_core::ready!(rd.as_mut().poll_read(cx, &mut buf[..max])) {
186        Ok(n) => {
187            *limit -= n as u64;
188            Poll::Ready(Ok(n))
189        }
190        Err(e) => Poll::Ready(Err(e)),
191    }
192}
193
194impl<T: BufRead> BufRead for Take<T> {
195    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
196        let this = self.project();
197
198        if *this.limit == 0 {
199            return Poll::Ready(Ok(&[]));
200        }
201
202        match futures_core::ready!(this.inner.poll_fill_buf(cx)) {
203            Ok(buf) => {
204                let cap = cmp::min(buf.len() as u64, *this.limit) as usize;
205                Poll::Ready(Ok(&buf[..cap]))
206            }
207            Err(e) => Poll::Ready(Err(e)),
208        }
209    }
210
211    fn consume(self: Pin<&mut Self>, amt: usize) {
212        let this = self.project();
213        // Don't let callers reset the limit by passing an overlarge value
214        let amt = cmp::min(amt as u64, *this.limit) as usize;
215        *this.limit -= amt as u64;
216
217        this.inner.consume(amt);
218    }
219}
220
221#[cfg(all(test, not(target_os = "unknown")))]
222mod tests {
223    use crate::io;
224    use crate::prelude::*;
225    use crate::task;
226
227    #[test]
228    fn test_take_basics() -> std::io::Result<()> {
229        let source: io::Cursor<Vec<u8>> = io::Cursor::new(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
230
231        task::block_on(async move {
232            let mut buffer = [0u8; 5];
233
234            // read at most five bytes
235            let mut handle = source.take(5);
236
237            handle.read(&mut buffer).await?;
238            assert_eq!(buffer, [0, 1, 2, 3, 4]);
239
240            // check that the we are actually at the end
241            assert_eq!(handle.read(&mut buffer).await.unwrap(), 0);
242
243            Ok(())
244        })
245    }
246}