partial_io/async_read.rs
1// Copyright (c) The partial-io Contributors
2// SPDX-License-Identifier: MIT
3
4//! This module contains an `AsyncRead` wrapper that breaks its inputs up
5//! according to a provided iterator.
6//!
7//! This is separate from `PartialWrite` because on `WouldBlock` errors, it
8//! causes `futures` to try writing or flushing again.
9
10use crate::{futures_util::FuturesOps, PartialOp};
11use futures::prelude::*;
12use pin_project::pin_project;
13use std::{
14 fmt, io,
15 pin::Pin,
16 task::{Context, Poll},
17};
18
19/// A wrapper that breaks inner `AsyncRead` instances up according to the
20/// provided iterator.
21///
22/// Available with the `futures03` feature for `futures` traits, and with the `tokio1` feature for
23/// `tokio` traits.
24///
25/// # Examples
26///
27/// This example uses `tokio`.
28///
29/// ```rust
30/// # #[cfg(feature = "tokio1")]
31/// use partial_io::{PartialAsyncRead, PartialOp};
32/// # #[cfg(feature = "tokio1")]
33/// use std::io::{self, Cursor};
34/// # #[cfg(feature = "tokio1")]
35/// use tokio::io::AsyncReadExt;
36///
37/// # #[cfg(feature = "tokio1")]
38/// #[tokio::main]
39/// async fn main() -> io::Result<()> {
40/// let reader = Cursor::new(vec![1, 2, 3, 4]);
41/// // Sequential calls to `poll_read()` and the other `poll_` methods simulate the following behavior:
42/// let iter = vec![
43/// PartialOp::Err(io::ErrorKind::WouldBlock), // A not-ready state.
44/// PartialOp::Limited(2), // Only allow 2 bytes to be read.
45/// PartialOp::Err(io::ErrorKind::InvalidData), // Error from the underlying stream.
46/// PartialOp::Unlimited, // Allow as many bytes to be read as possible.
47/// ];
48/// let mut partial_reader = PartialAsyncRead::new(reader, iter);
49/// let mut out = vec![0; 256];
50///
51/// // This causes poll_read to be called twice, yielding after the first call (WouldBlock).
52/// assert_eq!(partial_reader.read(&mut out).await?, 2, "first read with Limited(2)");
53/// assert_eq!(&out[..4], &[1, 2, 0, 0]);
54///
55/// // This next call returns an error.
56/// assert_eq!(
57/// partial_reader.read(&mut out[2..]).await.unwrap_err().kind(),
58/// io::ErrorKind::InvalidData,
59/// );
60///
61/// // And this one causes the last two bytes to be written.
62/// assert_eq!(partial_reader.read(&mut out[2..]).await?, 2, "second read with Unlimited");
63/// assert_eq!(&out[..4], &[1, 2, 3, 4]);
64///
65/// Ok(())
66/// }
67///
68/// # #[cfg(not(feature = "tokio1"))]
69/// # fn main() {
70/// # assert!(true, "dummy test");
71/// # }
72/// ```
73#[pin_project]
74pub struct PartialAsyncRead<R> {
75 #[pin]
76 inner: R,
77 ops: FuturesOps,
78}
79
80impl<R> PartialAsyncRead<R> {
81 /// Creates a new `PartialAsyncRead` wrapper over the reader with the specified `PartialOp`s.
82 pub fn new<I>(inner: R, iter: I) -> Self
83 where
84 I: IntoIterator<Item = PartialOp> + 'static,
85 I::IntoIter: Send,
86 {
87 PartialAsyncRead {
88 inner,
89 ops: FuturesOps::new(iter),
90 }
91 }
92
93 /// Sets the `PartialOp`s for this reader.
94 pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
95 where
96 I: IntoIterator<Item = PartialOp> + 'static,
97 I::IntoIter: Send,
98 {
99 self.ops.replace(iter);
100 self
101 }
102
103 /// Sets the `PartialOp`s for this reader in a pinned context.
104 pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
105 where
106 I: IntoIterator<Item = PartialOp> + 'static,
107 I::IntoIter: Send,
108 {
109 let mut this = self;
110 this.as_mut().project().ops.replace(iter);
111 this
112 }
113
114 /// Returns a shared reference to the underlying reader.
115 pub fn get_ref(&self) -> &R {
116 &self.inner
117 }
118
119 /// Returns a mutable reference to the underlying reader.
120 pub fn get_mut(&mut self) -> &mut R {
121 &mut self.inner
122 }
123
124 /// Returns a pinned mutable reference to the underlying reader.
125 pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
126 self.project().inner
127 }
128
129 /// Consumes this wrapper, returning the underlying reader.
130 pub fn into_inner(self) -> R {
131 self.inner
132 }
133}
134
135// ---
136// Futures impls
137// ---
138
139impl<R> AsyncRead for PartialAsyncRead<R>
140where
141 R: AsyncRead,
142{
143 #[inline]
144 fn poll_read(
145 self: Pin<&mut Self>,
146 cx: &mut Context,
147 buf: &mut [u8],
148 ) -> Poll<io::Result<usize>> {
149 let this = self.project();
150 let inner = this.inner;
151 let len = buf.len();
152
153 this.ops.poll_impl(
154 cx,
155 |cx, len| match len {
156 Some(len) => inner.poll_read(cx, &mut buf[..len]),
157 None => inner.poll_read(cx, buf),
158 },
159 len,
160 "error during poll_read, generated by partial-io",
161 )
162 }
163
164 // TODO: do we need to implement poll_read_vectored? It's a bit tricky to do.
165}
166
167impl<R> AsyncBufRead for PartialAsyncRead<R>
168where
169 R: AsyncBufRead,
170{
171 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
172 let this = self.project();
173 let inner = this.inner;
174
175 this.ops.poll_impl_no_limit(
176 cx,
177 |cx| inner.poll_fill_buf(cx),
178 "error during poll_read, generated by partial-io",
179 )
180 }
181
182 #[inline]
183 fn consume(self: Pin<&mut Self>, amt: usize) {
184 self.project().inner.consume(amt)
185 }
186}
187
188/// This is a forwarding impl to support duplex structs.
189impl<R> AsyncWrite for PartialAsyncRead<R>
190where
191 R: AsyncWrite,
192{
193 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
194 self.project().inner.poll_write(cx, buf)
195 }
196
197 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
198 self.project().inner.poll_flush(cx)
199 }
200
201 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
202 self.project().inner.poll_close(cx)
203 }
204}
205
206/// This is a forwarding impl to support duplex structs.
207impl<R> AsyncSeek for PartialAsyncRead<R>
208where
209 R: AsyncSeek,
210{
211 #[inline]
212 fn poll_seek(
213 self: Pin<&mut Self>,
214 cx: &mut Context,
215 pos: io::SeekFrom,
216 ) -> Poll<io::Result<u64>> {
217 self.project().inner.poll_seek(cx, pos)
218 }
219}
220
221// ---
222// Tokio impls
223// ---
224
225#[cfg(feature = "tokio1")]
226pub(crate) mod tokio_impl {
227 use super::PartialAsyncRead;
228 use std::{
229 io::{self, SeekFrom},
230 pin::Pin,
231 task::{Context, Poll},
232 };
233 use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
234
235 impl<R> AsyncRead for PartialAsyncRead<R>
236 where
237 R: AsyncRead,
238 {
239 fn poll_read(
240 self: Pin<&mut Self>,
241 cx: &mut Context,
242 buf: &mut ReadBuf<'_>,
243 ) -> Poll<io::Result<()>> {
244 let this = self.project();
245 let inner = this.inner;
246 let capacity = buf.capacity();
247
248 this.ops.poll_impl(
249 cx,
250 |cx, len| match len {
251 Some(len) => {
252 buf.with_limited(len, |limited_buf| inner.poll_read(cx, limited_buf))
253 }
254 None => inner.poll_read(cx, buf),
255 },
256 capacity,
257 "error during poll_read, generated by partial-io",
258 )
259 }
260 }
261
262 /// Extensions to `tokio`'s `ReadBuf`.
263 ///
264 /// Requires the `tokio1` feature to be enabled.
265 pub trait ReadBufExt {
266 /// Convert this `ReadBuf` into a limited one backed by the same storage, then
267 /// call the callback with this limited instance..
268 ///
269 /// Any changes to the `ReadBuf` made by the callback are reflected in the original
270 /// `ReadBuf`.
271 fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
272 where
273 F: FnOnce(&mut ReadBuf<'_>) -> T;
274 }
275
276 impl<'a> ReadBufExt for ReadBuf<'a> {
277 fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
278 where
279 F: FnOnce(&mut ReadBuf<'_>) -> T,
280 {
281 // Use limit to set upper limits on the capacity and both cursors.
282 let capacity_limit = self.capacity().min(limit);
283 let old_initialized_len = self.initialized().len().min(limit);
284 let old_filled_len = self.filled().len().min(limit);
285
286 // SAFETY: We assume that the input buf's initialized length is trustworthy.
287 let mut limited_buf = unsafe {
288 let inner_mut = &mut self.inner_mut()[..capacity_limit];
289 let mut limited_buf = ReadBuf::uninit(inner_mut);
290 // Note: assume_init adds the passed-in value to self.filled, but for a freshly created
291 // uninitialized buffer, self.filled is 0. The value of filled is updated below
292 // with the set_filled() call.
293 limited_buf.assume_init(old_initialized_len);
294 limited_buf
295 };
296 limited_buf.set_filled(old_filled_len);
297
298 // Call the callback.
299 let ret = callback(&mut limited_buf);
300
301 // The callback may have modified the cursors in `limited_buf` -- if so, port them back to
302 // the original.
303 let new_initialized_len = limited_buf.initialized().len();
304 let new_filled_len = limited_buf.filled().len();
305
306 if new_initialized_len > old_initialized_len {
307 // SAFETY: We assume that if new_initialized_len > old_initialized_len, that
308 // the extra bytes were initialized by the callback.
309 unsafe {
310 // Note: assume_init adds the passed-in value to buf.filled.len().
311 self.assume_init(new_initialized_len - self.filled().len());
312 }
313 }
314
315 if new_filled_len != old_filled_len {
316 // This can happen if either:
317 // * old_filled_len < limit, and the callback filled some more bytes into buf ->
318 // reflect that in the original buffer.
319 // * old_filled_len <= limit, and the callback *shortened* the filled bytes -> reflect
320 // that in the original buffer as well.
321 //
322 // (Note if old_filled_len == limit, then new_filled_len cannot be greater than
323 // old_filled_len since it's at the limit already.)
324 self.set_filled(new_filled_len);
325 }
326
327 ret
328 }
329 }
330
331 impl<R> AsyncBufRead for PartialAsyncRead<R>
332 where
333 R: AsyncBufRead,
334 {
335 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
336 let this = self.project();
337 let inner = this.inner;
338
339 this.ops.poll_impl_no_limit(
340 cx,
341 |cx| inner.poll_fill_buf(cx),
342 "error during poll_fill_buf, generated by partial-io",
343 )
344 }
345
346 fn consume(self: Pin<&mut Self>, amt: usize) {
347 self.project().inner.consume(amt)
348 }
349 }
350
351 /// This is a forwarding impl to support duplex structs.
352 impl<R> AsyncWrite for PartialAsyncRead<R>
353 where
354 R: AsyncWrite,
355 {
356 #[inline]
357 fn poll_write(
358 self: Pin<&mut Self>,
359 cx: &mut Context,
360 buf: &[u8],
361 ) -> Poll<io::Result<usize>> {
362 self.project().inner.poll_write(cx, buf)
363 }
364
365 #[inline]
366 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
367 self.project().inner.poll_flush(cx)
368 }
369
370 #[inline]
371 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
372 self.project().inner.poll_shutdown(cx)
373 }
374 }
375
376 /// This is a forwarding impl to support duplex structs.
377 impl<R> AsyncSeek for PartialAsyncRead<R>
378 where
379 R: AsyncSeek,
380 {
381 #[inline]
382 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
383 self.project().inner.start_seek(position)
384 }
385
386 #[inline]
387 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
388 self.project().inner.poll_complete(cx)
389 }
390 }
391
392 #[cfg(test)]
393 mod tests {
394 use super::*;
395 use itertools::Itertools;
396 use std::mem::MaybeUninit;
397
398 // with_limited is pretty complex: test that it works properly.
399 #[test]
400 fn test_with_limited() {
401 const CAPACITY: usize = 256;
402
403 let inputs = vec![
404 // Columns are (filled, initialized). The capacity is always 256.
405
406 // Fully filled, fully initialized buffer.
407 (256, 256),
408 // Partly filled, fully initialized buffer.
409 (64, 256),
410 // Unfilled, fully initialized buffer.
411 (0, 256),
412 // Fully filled, partly initialized buffer.
413 (128, 128),
414 // Partly filled, partly initialized buffer.
415 (64, 128),
416 // Unfilled, partly initialized buffer.
417 (0, 128),
418 // Unfilled, uninitialized buffer.
419 (0, 0),
420 ];
421 // Test a series of limits for every possible case.
422 let limits = vec![0, 32, 64, 128, 192, 256, 384];
423
424 for ((filled, initialized), limit) in inputs.into_iter().cartesian_product(limits) {
425 // Create an uninitialized array of `MaybeUninit` for storage. The `assume_init` is
426 // safe because the type we are claiming to have initialized here is a
427 // bunch of `MaybeUninit`s, which do not require initialization.
428 let mut storage: [MaybeUninit<u8>; CAPACITY] =
429 unsafe { MaybeUninit::uninit().assume_init() };
430 let mut buf = ReadBuf::uninit(&mut storage);
431 buf.initialize_unfilled_to(initialized);
432 buf.set_filled(filled);
433
434 println!("*** limit = {}, original buf = {:?}", limit, buf);
435
436 // ---
437 // Test that making no changes to the limited buffer causes no changes to the
438 // original buffer.
439 // ---
440 buf.with_limited(limit, |limited_buf| {
441 println!(" * do-nothing: limited buf = {:?}", limited_buf);
442 assert!(
443 limited_buf.capacity() <= limit,
444 "limit is applied to capacity"
445 );
446 assert!(
447 limited_buf.initialized().len() <= limit,
448 "limit is applied to initialized len"
449 );
450 assert!(
451 limited_buf.filled().len() <= limit,
452 "limit is applied to filled len"
453 );
454 });
455
456 assert_eq!(
457 buf.filled().len(),
458 filled,
459 "do-nothing -> filled is the same as before"
460 );
461 assert_eq!(
462 buf.initialized().len(),
463 initialized,
464 "do-nothing -> initialized is the same as before"
465 );
466
467 // ---
468 // Test that set_filled with a smaller value is reflected in the original buffer.
469 // ---
470 let new_filled = buf.with_limited(limit, |limited_buf| {
471 println!(" * halve-filled: limited buf = {:?}", limited_buf);
472 let new_filled = limited_buf.filled().len() / 2;
473 limited_buf.set_filled(new_filled);
474 println!(" * halve-filled: after = {:?}", limited_buf);
475 new_filled
476 });
477
478 match new_filled.cmp(&limit) {
479 std::cmp::Ordering::Less => {
480 assert_eq!(
481 buf.filled().len(),
482 new_filled,
483 "halve-filled, new filled < limit -> filled is updated"
484 );
485 }
486 std::cmp::Ordering::Equal => {
487 assert_eq!(limit, 0, "halve-filled, new filled == limit -> limit = 0");
488 assert_eq!(
489 buf.filled().len(),
490 filled,
491 "halve-filled, new filled == limit -> filled stays the same"
492 );
493 }
494 std::cmp::Ordering::Greater => {
495 panic!("new_filled {} must be <= limit {}", new_filled, limit);
496 }
497 }
498
499 assert_eq!(
500 buf.initialized().len(),
501 initialized,
502 "halve-filled -> initialized is same as before"
503 );
504
505 // ---
506 // Test that pushing a single byte is reflected in the original buffer.
507 // ---
508 if filled < limit.min(CAPACITY) {
509 // Reset the ReadBuf.
510 let mut storage: [MaybeUninit<u8>; CAPACITY] =
511 unsafe { MaybeUninit::uninit().assume_init() };
512 let mut buf = ReadBuf::uninit(&mut storage);
513 buf.initialize_unfilled_to(initialized);
514 buf.set_filled(filled);
515
516 buf.with_limited(limit, |limited_buf| {
517 println!(" * push-one-byte: limited buf = {:?}", limited_buf);
518 limited_buf.put_slice(&[42]);
519 println!(" * push-one-byte: after = {:?}", limited_buf);
520 });
521
522 assert_eq!(
523 buf.filled().len(),
524 filled + 1,
525 "push-one-byte, filled incremented by 1"
526 );
527 assert_eq!(
528 buf.filled()[filled],
529 42,
530 "push-one-byte, correct byte was pushed"
531 );
532 if filled == initialized {
533 assert_eq!(
534 buf.initialized().len(),
535 initialized + 1,
536 "push-one-byte, filled == initialized -> initialized incremented by 1"
537 );
538 } else {
539 assert_eq!(
540 buf.initialized().len(),
541 initialized,
542 "push-one-byte, filled < initialized -> initialized stays the same"
543 );
544 }
545 }
546
547 // ---
548 // Test that initializing unfilled bytes is reflected in the original buffer.
549 // ---
550 if initialized <= limit.min(CAPACITY) {
551 // Reset the ReadBuf.
552 let mut storage: [MaybeUninit<u8>; CAPACITY] =
553 unsafe { MaybeUninit::uninit().assume_init() };
554 let mut buf = ReadBuf::uninit(&mut storage);
555 buf.initialize_unfilled_to(initialized);
556 buf.set_filled(filled);
557
558 buf.with_limited(limit, |limited_buf| {
559 println!(" * initialize-unfilled: limited buf = {:?}", limited_buf);
560 limited_buf.initialize_unfilled();
561 println!(" * initialize-unfilled: after = {:?}", limited_buf);
562 });
563
564 assert_eq!(
565 buf.filled().len(),
566 filled,
567 "initialize-unfilled, filled stays the same"
568 );
569 assert_eq!(
570 buf.initialized().len(),
571 limit.min(CAPACITY),
572 "initialize-unfilled, initialized is capped at the limit"
573 );
574 // Actually access the bytes and ensure this doesn't crash.
575 assert_eq!(
576 buf.initialized(),
577 vec![0; buf.initialized().len()],
578 "initialize-unfilled, bytes are correct"
579 );
580 }
581 }
582 }
583 }
584}
585
586impl<R> fmt::Debug for PartialAsyncRead<R>
587where
588 R: fmt::Debug,
589{
590 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
591 f.debug_struct("PartialAsyncRead")
592 .field("inner", &self.inner)
593 .finish()
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 use std::fs::File;
602
603 use crate::tests::assert_send;
604
605 #[test]
606 fn test_sendable() {
607 assert_send::<PartialAsyncRead<File>>();
608 }
609}