partial_io/
async_read.rs

1// Copyright (c) The partial-io Contributors
2// SPDX-License-Identifier: MIT
3
4//! This module contains an `AsyncRead` wrapper that breaks its inputs up
5//! according to a provided iterator.
6//!
7//! This is separate from `PartialWrite` because on `WouldBlock` errors, it
8//! causes `futures` to try writing or flushing again.
9
10use crate::{futures_util::FuturesOps, PartialOp};
11use futures::prelude::*;
12use pin_project::pin_project;
13use std::{
14    fmt, io,
15    pin::Pin,
16    task::{Context, Poll},
17};
18
19/// A wrapper that breaks inner `AsyncRead` instances up according to the
20/// provided iterator.
21///
22/// Available with the `futures03` feature for `futures` traits, and with the `tokio1` feature for
23/// `tokio` traits.
24///
25/// # Examples
26///
27/// This example uses `tokio`.
28///
29/// ```rust
30/// # #[cfg(feature = "tokio1")]
31/// use partial_io::{PartialAsyncRead, PartialOp};
32/// # #[cfg(feature = "tokio1")]
33/// use std::io::{self, Cursor};
34/// # #[cfg(feature = "tokio1")]
35/// use tokio::io::AsyncReadExt;
36///
37/// # #[cfg(feature = "tokio1")]
38/// #[tokio::main]
39/// async fn main() -> io::Result<()> {
40///     let reader = Cursor::new(vec![1, 2, 3, 4]);
41///     // Sequential calls to `poll_read()` and the other `poll_` methods simulate the following behavior:
42///     let iter = vec![
43///         PartialOp::Err(io::ErrorKind::WouldBlock),   // A not-ready state.
44///         PartialOp::Limited(2),                       // Only allow 2 bytes to be read.
45///         PartialOp::Err(io::ErrorKind::InvalidData),  // Error from the underlying stream.
46///         PartialOp::Unlimited,                        // Allow as many bytes to be read as possible.
47///     ];
48///     let mut partial_reader = PartialAsyncRead::new(reader, iter);
49///     let mut out = vec![0; 256];
50///
51///     // This causes poll_read to be called twice, yielding after the first call (WouldBlock).
52///     assert_eq!(partial_reader.read(&mut out).await?, 2, "first read with Limited(2)");
53///     assert_eq!(&out[..4], &[1, 2, 0, 0]);
54///
55///     // This next call returns an error.
56///     assert_eq!(
57///         partial_reader.read(&mut out[2..]).await.unwrap_err().kind(),
58///         io::ErrorKind::InvalidData,
59///     );
60///
61///     // And this one causes the last two bytes to be written.
62///     assert_eq!(partial_reader.read(&mut out[2..]).await?, 2, "second read with Unlimited");
63///     assert_eq!(&out[..4], &[1, 2, 3, 4]);
64///
65///     Ok(())
66/// }
67///
68/// # #[cfg(not(feature = "tokio1"))]
69/// # fn main() {
70/// #     assert!(true, "dummy test");
71/// # }
72/// ```
73#[pin_project]
74pub struct PartialAsyncRead<R> {
75    #[pin]
76    inner: R,
77    ops: FuturesOps,
78}
79
80impl<R> PartialAsyncRead<R> {
81    /// Creates a new `PartialAsyncRead` wrapper over the reader with the specified `PartialOp`s.
82    pub fn new<I>(inner: R, iter: I) -> Self
83    where
84        I: IntoIterator<Item = PartialOp> + 'static,
85        I::IntoIter: Send,
86    {
87        PartialAsyncRead {
88            inner,
89            ops: FuturesOps::new(iter),
90        }
91    }
92
93    /// Sets the `PartialOp`s for this reader.
94    pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
95    where
96        I: IntoIterator<Item = PartialOp> + 'static,
97        I::IntoIter: Send,
98    {
99        self.ops.replace(iter);
100        self
101    }
102
103    /// Sets the `PartialOp`s for this reader in a pinned context.
104    pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
105    where
106        I: IntoIterator<Item = PartialOp> + 'static,
107        I::IntoIter: Send,
108    {
109        let mut this = self;
110        this.as_mut().project().ops.replace(iter);
111        this
112    }
113
114    /// Returns a shared reference to the underlying reader.
115    pub fn get_ref(&self) -> &R {
116        &self.inner
117    }
118
119    /// Returns a mutable reference to the underlying reader.
120    pub fn get_mut(&mut self) -> &mut R {
121        &mut self.inner
122    }
123
124    /// Returns a pinned mutable reference to the underlying reader.
125    pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
126        self.project().inner
127    }
128
129    /// Consumes this wrapper, returning the underlying reader.
130    pub fn into_inner(self) -> R {
131        self.inner
132    }
133}
134
135// ---
136// Futures impls
137// ---
138
139impl<R> AsyncRead for PartialAsyncRead<R>
140where
141    R: AsyncRead,
142{
143    #[inline]
144    fn poll_read(
145        self: Pin<&mut Self>,
146        cx: &mut Context,
147        buf: &mut [u8],
148    ) -> Poll<io::Result<usize>> {
149        let this = self.project();
150        let inner = this.inner;
151        let len = buf.len();
152
153        this.ops.poll_impl(
154            cx,
155            |cx, len| match len {
156                Some(len) => inner.poll_read(cx, &mut buf[..len]),
157                None => inner.poll_read(cx, buf),
158            },
159            len,
160            "error during poll_read, generated by partial-io",
161        )
162    }
163
164    // TODO: do we need to implement poll_read_vectored? It's a bit tricky to do.
165}
166
167impl<R> AsyncBufRead for PartialAsyncRead<R>
168where
169    R: AsyncBufRead,
170{
171    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
172        let this = self.project();
173        let inner = this.inner;
174
175        this.ops.poll_impl_no_limit(
176            cx,
177            |cx| inner.poll_fill_buf(cx),
178            "error during poll_read, generated by partial-io",
179        )
180    }
181
182    #[inline]
183    fn consume(self: Pin<&mut Self>, amt: usize) {
184        self.project().inner.consume(amt)
185    }
186}
187
188/// This is a forwarding impl to support duplex structs.
189impl<R> AsyncWrite for PartialAsyncRead<R>
190where
191    R: AsyncWrite,
192{
193    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
194        self.project().inner.poll_write(cx, buf)
195    }
196
197    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
198        self.project().inner.poll_flush(cx)
199    }
200
201    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
202        self.project().inner.poll_close(cx)
203    }
204}
205
206/// This is a forwarding impl to support duplex structs.
207impl<R> AsyncSeek for PartialAsyncRead<R>
208where
209    R: AsyncSeek,
210{
211    #[inline]
212    fn poll_seek(
213        self: Pin<&mut Self>,
214        cx: &mut Context,
215        pos: io::SeekFrom,
216    ) -> Poll<io::Result<u64>> {
217        self.project().inner.poll_seek(cx, pos)
218    }
219}
220
221// ---
222// Tokio impls
223// ---
224
225#[cfg(feature = "tokio1")]
226pub(crate) mod tokio_impl {
227    use super::PartialAsyncRead;
228    use std::{
229        io::{self, SeekFrom},
230        pin::Pin,
231        task::{Context, Poll},
232    };
233    use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
234
235    impl<R> AsyncRead for PartialAsyncRead<R>
236    where
237        R: AsyncRead,
238    {
239        fn poll_read(
240            self: Pin<&mut Self>,
241            cx: &mut Context,
242            buf: &mut ReadBuf<'_>,
243        ) -> Poll<io::Result<()>> {
244            let this = self.project();
245            let inner = this.inner;
246            let capacity = buf.capacity();
247
248            this.ops.poll_impl(
249                cx,
250                |cx, len| match len {
251                    Some(len) => {
252                        buf.with_limited(len, |limited_buf| inner.poll_read(cx, limited_buf))
253                    }
254                    None => inner.poll_read(cx, buf),
255                },
256                capacity,
257                "error during poll_read, generated by partial-io",
258            )
259        }
260    }
261
262    /// Extensions to `tokio`'s `ReadBuf`.
263    ///
264    /// Requires the `tokio1` feature to be enabled.
265    pub trait ReadBufExt {
266        /// Convert this `ReadBuf` into a limited one backed by the same storage, then
267        /// call the callback with this limited instance..
268        ///
269        /// Any changes to the `ReadBuf` made by the callback are reflected in the original
270        /// `ReadBuf`.
271        fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
272        where
273            F: FnOnce(&mut ReadBuf<'_>) -> T;
274    }
275
276    impl<'a> ReadBufExt for ReadBuf<'a> {
277        fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
278        where
279            F: FnOnce(&mut ReadBuf<'_>) -> T,
280        {
281            // Use limit to set upper limits on the capacity and both cursors.
282            let capacity_limit = self.capacity().min(limit);
283            let old_initialized_len = self.initialized().len().min(limit);
284            let old_filled_len = self.filled().len().min(limit);
285
286            // SAFETY: We assume that the input buf's initialized length is trustworthy.
287            let mut limited_buf = unsafe {
288                let inner_mut = &mut self.inner_mut()[..capacity_limit];
289                let mut limited_buf = ReadBuf::uninit(inner_mut);
290                // Note: assume_init adds the passed-in value to self.filled, but for a freshly created
291                // uninitialized buffer, self.filled is 0. The value of filled is updated below
292                // with the set_filled() call.
293                limited_buf.assume_init(old_initialized_len);
294                limited_buf
295            };
296            limited_buf.set_filled(old_filled_len);
297
298            // Call the callback.
299            let ret = callback(&mut limited_buf);
300
301            // The callback may have modified the cursors in `limited_buf` -- if so, port them back to
302            // the original.
303            let new_initialized_len = limited_buf.initialized().len();
304            let new_filled_len = limited_buf.filled().len();
305
306            if new_initialized_len > old_initialized_len {
307                // SAFETY: We assume that if new_initialized_len > old_initialized_len, that
308                // the extra bytes were initialized by the callback.
309                unsafe {
310                    // Note: assume_init adds the passed-in value to buf.filled.len().
311                    self.assume_init(new_initialized_len - self.filled().len());
312                }
313            }
314
315            if new_filled_len != old_filled_len {
316                // This can happen if either:
317                // * old_filled_len < limit, and the callback filled some more bytes into buf ->
318                //   reflect that in the original buffer.
319                // * old_filled_len <= limit, and the callback *shortened* the filled bytes -> reflect
320                //   that in the original buffer as well.
321                //
322                // (Note if old_filled_len == limit, then new_filled_len cannot be greater than
323                // old_filled_len since it's at the limit already.)
324                self.set_filled(new_filled_len);
325            }
326
327            ret
328        }
329    }
330
331    impl<R> AsyncBufRead for PartialAsyncRead<R>
332    where
333        R: AsyncBufRead,
334    {
335        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
336            let this = self.project();
337            let inner = this.inner;
338
339            this.ops.poll_impl_no_limit(
340                cx,
341                |cx| inner.poll_fill_buf(cx),
342                "error during poll_fill_buf, generated by partial-io",
343            )
344        }
345
346        fn consume(self: Pin<&mut Self>, amt: usize) {
347            self.project().inner.consume(amt)
348        }
349    }
350
351    /// This is a forwarding impl to support duplex structs.
352    impl<R> AsyncWrite for PartialAsyncRead<R>
353    where
354        R: AsyncWrite,
355    {
356        #[inline]
357        fn poll_write(
358            self: Pin<&mut Self>,
359            cx: &mut Context,
360            buf: &[u8],
361        ) -> Poll<io::Result<usize>> {
362            self.project().inner.poll_write(cx, buf)
363        }
364
365        #[inline]
366        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
367            self.project().inner.poll_flush(cx)
368        }
369
370        #[inline]
371        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
372            self.project().inner.poll_shutdown(cx)
373        }
374    }
375
376    /// This is a forwarding impl to support duplex structs.
377    impl<R> AsyncSeek for PartialAsyncRead<R>
378    where
379        R: AsyncSeek,
380    {
381        #[inline]
382        fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
383            self.project().inner.start_seek(position)
384        }
385
386        #[inline]
387        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
388            self.project().inner.poll_complete(cx)
389        }
390    }
391
392    #[cfg(test)]
393    mod tests {
394        use super::*;
395        use itertools::Itertools;
396        use std::mem::MaybeUninit;
397
398        // with_limited is pretty complex: test that it works properly.
399        #[test]
400        fn test_with_limited() {
401            const CAPACITY: usize = 256;
402
403            let inputs = vec![
404                // Columns are (filled, initialized). The capacity is always 256.
405
406                // Fully filled, fully initialized buffer.
407                (256, 256),
408                // Partly filled, fully initialized buffer.
409                (64, 256),
410                // Unfilled, fully initialized buffer.
411                (0, 256),
412                // Fully filled, partly initialized buffer.
413                (128, 128),
414                // Partly filled, partly initialized buffer.
415                (64, 128),
416                // Unfilled, partly initialized buffer.
417                (0, 128),
418                // Unfilled, uninitialized buffer.
419                (0, 0),
420            ];
421            // Test a series of limits for every possible case.
422            let limits = vec![0, 32, 64, 128, 192, 256, 384];
423
424            for ((filled, initialized), limit) in inputs.into_iter().cartesian_product(limits) {
425                // Create an uninitialized array of `MaybeUninit` for storage. The `assume_init` is
426                // safe because the type we are claiming to have initialized here is a
427                // bunch of `MaybeUninit`s, which do not require initialization.
428                let mut storage: [MaybeUninit<u8>; CAPACITY] =
429                    unsafe { MaybeUninit::uninit().assume_init() };
430                let mut buf = ReadBuf::uninit(&mut storage);
431                buf.initialize_unfilled_to(initialized);
432                buf.set_filled(filled);
433
434                println!("*** limit = {}, original buf = {:?}", limit, buf);
435
436                // ---
437                // Test that making no changes to the limited buffer causes no changes to the
438                // original buffer.
439                // ---
440                buf.with_limited(limit, |limited_buf| {
441                    println!("  * do-nothing: limited buf = {:?}", limited_buf);
442                    assert!(
443                        limited_buf.capacity() <= limit,
444                        "limit is applied to capacity"
445                    );
446                    assert!(
447                        limited_buf.initialized().len() <= limit,
448                        "limit is applied to initialized len"
449                    );
450                    assert!(
451                        limited_buf.filled().len() <= limit,
452                        "limit is applied to filled len"
453                    );
454                });
455
456                assert_eq!(
457                    buf.filled().len(),
458                    filled,
459                    "do-nothing -> filled is the same as before"
460                );
461                assert_eq!(
462                    buf.initialized().len(),
463                    initialized,
464                    "do-nothing -> initialized is the same as before"
465                );
466
467                // ---
468                // Test that set_filled with a smaller value is reflected in the original buffer.
469                // ---
470                let new_filled = buf.with_limited(limit, |limited_buf| {
471                    println!("  * halve-filled: limited buf = {:?}", limited_buf);
472                    let new_filled = limited_buf.filled().len() / 2;
473                    limited_buf.set_filled(new_filled);
474                    println!("  * halve-filled: after = {:?}", limited_buf);
475                    new_filled
476                });
477
478                match new_filled.cmp(&limit) {
479                    std::cmp::Ordering::Less => {
480                        assert_eq!(
481                            buf.filled().len(),
482                            new_filled,
483                            "halve-filled, new filled < limit -> filled is updated"
484                        );
485                    }
486                    std::cmp::Ordering::Equal => {
487                        assert_eq!(limit, 0, "halve-filled, new filled == limit -> limit = 0");
488                        assert_eq!(
489                            buf.filled().len(),
490                            filled,
491                            "halve-filled, new filled == limit -> filled stays the same"
492                        );
493                    }
494                    std::cmp::Ordering::Greater => {
495                        panic!("new_filled {} must be <= limit {}", new_filled, limit);
496                    }
497                }
498
499                assert_eq!(
500                    buf.initialized().len(),
501                    initialized,
502                    "halve-filled -> initialized is same as before"
503                );
504
505                // ---
506                // Test that pushing a single byte is reflected in the original buffer.
507                // ---
508                if filled < limit.min(CAPACITY) {
509                    // Reset the ReadBuf.
510                    let mut storage: [MaybeUninit<u8>; CAPACITY] =
511                        unsafe { MaybeUninit::uninit().assume_init() };
512                    let mut buf = ReadBuf::uninit(&mut storage);
513                    buf.initialize_unfilled_to(initialized);
514                    buf.set_filled(filled);
515
516                    buf.with_limited(limit, |limited_buf| {
517                        println!("  * push-one-byte: limited buf = {:?}", limited_buf);
518                        limited_buf.put_slice(&[42]);
519                        println!("  * push-one-byte: after = {:?}", limited_buf);
520                    });
521
522                    assert_eq!(
523                        buf.filled().len(),
524                        filled + 1,
525                        "push-one-byte, filled incremented by 1"
526                    );
527                    assert_eq!(
528                        buf.filled()[filled],
529                        42,
530                        "push-one-byte, correct byte was pushed"
531                    );
532                    if filled == initialized {
533                        assert_eq!(
534                            buf.initialized().len(),
535                            initialized + 1,
536                            "push-one-byte, filled == initialized -> initialized incremented by 1"
537                        );
538                    } else {
539                        assert_eq!(
540                            buf.initialized().len(),
541                            initialized,
542                            "push-one-byte, filled < initialized -> initialized stays the same"
543                        );
544                    }
545                }
546
547                // ---
548                // Test that initializing unfilled bytes is reflected in the original buffer.
549                // ---
550                if initialized <= limit.min(CAPACITY) {
551                    // Reset the ReadBuf.
552                    let mut storage: [MaybeUninit<u8>; CAPACITY] =
553                        unsafe { MaybeUninit::uninit().assume_init() };
554                    let mut buf = ReadBuf::uninit(&mut storage);
555                    buf.initialize_unfilled_to(initialized);
556                    buf.set_filled(filled);
557
558                    buf.with_limited(limit, |limited_buf| {
559                        println!("  * initialize-unfilled: limited buf = {:?}", limited_buf);
560                        limited_buf.initialize_unfilled();
561                        println!("  * initialize-unfilled: after = {:?}", limited_buf);
562                    });
563
564                    assert_eq!(
565                        buf.filled().len(),
566                        filled,
567                        "initialize-unfilled, filled stays the same"
568                    );
569                    assert_eq!(
570                        buf.initialized().len(),
571                        limit.min(CAPACITY),
572                        "initialize-unfilled, initialized is capped at the limit"
573                    );
574                    // Actually access the bytes and ensure this doesn't crash.
575                    assert_eq!(
576                        buf.initialized(),
577                        vec![0; buf.initialized().len()],
578                        "initialize-unfilled, bytes are correct"
579                    );
580                }
581            }
582        }
583    }
584}
585
586impl<R> fmt::Debug for PartialAsyncRead<R>
587where
588    R: fmt::Debug,
589{
590    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
591        f.debug_struct("PartialAsyncRead")
592            .field("inner", &self.inner)
593            .finish()
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    use std::fs::File;
602
603    use crate::tests::assert_send;
604
605    #[test]
606    fn test_sendable() {
607        assert_send::<PartialAsyncRead<File>>();
608    }
609}