partial_io/
async_write.rs

1// Copyright (c) The partial-io Contributors
2// SPDX-License-Identifier: MIT
3
4//! This module contains an `AsyncWrite` wrapper that breaks writes up
5//! according to a provided iterator.
6//!
7//! This is separate from `PartialWrite` because on `WouldBlock` errors, it
8//! causes `futures` to try writing or flushing again.
9
10use crate::{futures_util::FuturesOps, PartialOp};
11use futures::{io, prelude::*};
12use pin_project::pin_project;
13use std::{
14    fmt,
15    pin::Pin,
16    task::{Context, Poll},
17};
18
19/// A wrapper that breaks inner `AsyncWrite` instances up according to the
20/// provided iterator.
21///
22/// Available with the `futures03` feature for `futures` traits, and with the `tokio1` feature for
23/// `tokio` traits.
24///
25/// # Examples
26///
27/// This example uses `tokio`.
28///
29/// ```rust
30/// # #[cfg(feature = "tokio1")]
31/// use partial_io::{PartialAsyncWrite, PartialOp};
32/// # #[cfg(feature = "tokio1")]
33/// use std::io::{self, Cursor};
34/// # #[cfg(feature = "tokio1")]
35/// use tokio::io::AsyncWriteExt;
36///
37/// # #[cfg(feature = "tokio1")]
38/// #[tokio::main]
39/// async fn main() -> io::Result<()> {
40///     let writer = Cursor::new(Vec::new());
41///     // Sequential calls to `poll_write()` and the other `poll_` methods simulate the following behavior:
42///     let iter = vec![
43///         PartialOp::Err(io::ErrorKind::WouldBlock),   // A not-ready state.
44///         PartialOp::Limited(2),                       // Only allow 2 bytes to be written.
45///         PartialOp::Err(io::ErrorKind::InvalidData),  // Error from the underlying stream.
46///         PartialOp::Unlimited,                        // Allow as many bytes to be written as possible.
47///     ];
48///     let mut partial_writer = PartialAsyncWrite::new(writer, iter);
49///     let in_data = vec![1, 2, 3, 4];
50///
51///     // This causes poll_write to be called twice, yielding after the first call (WouldBlock).
52///     assert_eq!(partial_writer.write(&in_data).await?, 2);
53///     let cursor_ref = partial_writer.get_ref();
54///     let out = cursor_ref.get_ref();
55///     assert_eq!(&out[..], &[1, 2]);
56///
57///     // This next call returns an error.
58///     assert_eq!(
59///         partial_writer.write(&in_data[2..]).await.unwrap_err().kind(),
60///         io::ErrorKind::InvalidData,
61///     );
62///
63///     // And this one causes the last two bytes to be written.
64///     assert_eq!(partial_writer.write(&in_data[2..]).await?, 2);
65///     let cursor_ref = partial_writer.get_ref();
66///     let out = cursor_ref.get_ref();
67///     assert_eq!(&out[..], &[1, 2, 3, 4]);
68///
69///     Ok(())
70/// }
71///
72/// # #[cfg(not(feature = "tokio1"))]
73/// # fn main() {
74/// #     assert!(true, "dummy test");
75/// # }
76/// ```
77#[pin_project]
78pub struct PartialAsyncWrite<W> {
79    #[pin]
80    inner: W,
81    ops: FuturesOps,
82}
83
84impl<W> PartialAsyncWrite<W> {
85    /// Creates a new `PartialAsyncWrite` wrapper over the writer with the specified `PartialOp`s.
86    pub fn new<I>(inner: W, iter: I) -> Self
87    where
88        I: IntoIterator<Item = PartialOp> + 'static,
89        I::IntoIter: Send,
90    {
91        PartialAsyncWrite {
92            inner,
93            ops: FuturesOps::new(iter),
94        }
95    }
96
97    /// Sets the `PartialOp`s for this writer.
98    pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
99    where
100        I: IntoIterator<Item = PartialOp> + 'static,
101        I::IntoIter: Send,
102    {
103        self.ops.replace(iter);
104        self
105    }
106
107    /// Sets the `PartialOp`s for this writer in a pinned context.
108    pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
109    where
110        I: IntoIterator<Item = PartialOp> + 'static,
111        I::IntoIter: Send,
112    {
113        let mut this = self;
114        this.as_mut().project().ops.replace(iter);
115        this
116    }
117
118    /// Returns a shared reference to the underlying writer.
119    pub fn get_ref(&self) -> &W {
120        &self.inner
121    }
122
123    /// Returns a mutable reference to the underlying writer.
124    pub fn get_mut(&mut self) -> &mut W {
125        &mut self.inner
126    }
127
128    /// Returns a pinned mutable reference to the underlying writer.
129    pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
130        self.project().inner
131    }
132
133    /// Consumes this wrapper, returning the underlying writer.
134    pub fn into_inner(self) -> W {
135        self.inner
136    }
137}
138
139// ---
140// Futures impls
141// ---
142
143impl<W> AsyncWrite for PartialAsyncWrite<W>
144where
145    W: AsyncWrite,
146{
147    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
148        let this = self.project();
149        let inner = this.inner;
150
151        this.ops.poll_impl(
152            cx,
153            |cx, len| match len {
154                Some(len) => inner.poll_write(cx, &buf[..len]),
155                None => inner.poll_write(cx, buf),
156            },
157            buf.len(),
158            "error during poll_write, generated by partial-io",
159        )
160    }
161
162    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
163        let this = self.project();
164        let inner = this.inner;
165
166        this.ops.poll_impl_no_limit(
167            cx,
168            |cx| inner.poll_flush(cx),
169            "error during poll_flush, generated by partial-io",
170        )
171    }
172
173    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
174        let this = self.project();
175        let inner = this.inner;
176
177        this.ops.poll_impl_no_limit(
178            cx,
179            |cx| inner.poll_close(cx),
180            "error during poll_close, generated by partial-io",
181        )
182    }
183}
184
185/// This is a forwarding impl to support duplex structs.
186impl<W> AsyncRead for PartialAsyncWrite<W>
187where
188    W: AsyncRead,
189{
190    #[inline]
191    fn poll_read(
192        self: Pin<&mut Self>,
193        cx: &mut Context,
194        buf: &mut [u8],
195    ) -> Poll<io::Result<usize>> {
196        self.project().inner.poll_read(cx, buf)
197    }
198
199    #[inline]
200    fn poll_read_vectored(
201        self: Pin<&mut Self>,
202        cx: &mut Context,
203        bufs: &mut [io::IoSliceMut],
204    ) -> Poll<io::Result<usize>> {
205        self.project().inner.poll_read_vectored(cx, bufs)
206    }
207}
208
209/// This is a forwarding impl to support duplex structs.
210impl<W> AsyncBufRead for PartialAsyncWrite<W>
211where
212    W: AsyncBufRead,
213{
214    #[inline]
215    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
216        self.project().inner.poll_fill_buf(cx)
217    }
218
219    #[inline]
220    fn consume(self: Pin<&mut Self>, amt: usize) {
221        self.project().inner.consume(amt)
222    }
223}
224
225/// This is a forwarding impl to support duplex structs.
226impl<W> AsyncSeek for PartialAsyncWrite<W>
227where
228    W: AsyncSeek,
229{
230    #[inline]
231    fn poll_seek(
232        self: Pin<&mut Self>,
233        cx: &mut Context,
234        pos: io::SeekFrom,
235    ) -> Poll<io::Result<u64>> {
236        self.project().inner.poll_seek(cx, pos)
237    }
238}
239
240// ---
241// Tokio impls
242// ---
243
244#[cfg(feature = "tokio1")]
245mod tokio_impl {
246    use super::PartialAsyncWrite;
247    use std::{
248        io::{self, SeekFrom},
249        pin::Pin,
250        task::{Context, Poll},
251    };
252    use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
253
254    impl<W> AsyncWrite for PartialAsyncWrite<W>
255    where
256        W: AsyncWrite,
257    {
258        fn poll_write(
259            self: Pin<&mut Self>,
260            cx: &mut Context,
261            buf: &[u8],
262        ) -> Poll<io::Result<usize>> {
263            let this = self.project();
264            let inner = this.inner;
265
266            this.ops.poll_impl(
267                cx,
268                |cx, len| match len {
269                    Some(len) => inner.poll_write(cx, &buf[..len]),
270                    None => inner.poll_write(cx, buf),
271                },
272                buf.len(),
273                "error during poll_write, generated by partial-io",
274            )
275        }
276
277        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
278            let this = self.project();
279            let inner = this.inner;
280
281            this.ops.poll_impl_no_limit(
282                cx,
283                |cx| inner.poll_flush(cx),
284                "error during poll_flush, generated by partial-io",
285            )
286        }
287
288        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
289            let this = self.project();
290            let inner = this.inner;
291
292            this.ops.poll_impl_no_limit(
293                cx,
294                |cx| inner.poll_shutdown(cx),
295                "error during poll_shutdown, generated by partial-io",
296            )
297        }
298    }
299
300    /// This is a forwarding impl to support duplex structs.
301    impl<W> AsyncRead for PartialAsyncWrite<W>
302    where
303        W: AsyncRead,
304    {
305        #[inline]
306        fn poll_read(
307            self: Pin<&mut Self>,
308            cx: &mut Context,
309            buf: &mut ReadBuf<'_>,
310        ) -> Poll<io::Result<()>> {
311            self.project().inner.poll_read(cx, buf)
312        }
313    }
314
315    /// This is a forwarding impl to support duplex structs.
316    impl<W> AsyncBufRead for PartialAsyncWrite<W>
317    where
318        W: AsyncBufRead,
319    {
320        #[inline]
321        fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
322            self.project().inner.poll_fill_buf(cx)
323        }
324
325        #[inline]
326        fn consume(self: Pin<&mut Self>, amt: usize) {
327            self.project().inner.consume(amt)
328        }
329    }
330
331    /// This is a forwarding impl to support duplex structs.
332    impl<W> AsyncSeek for PartialAsyncWrite<W>
333    where
334        W: AsyncSeek,
335    {
336        #[inline]
337        fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
338            self.project().inner.start_seek(position)
339        }
340
341        #[inline]
342        fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
343            self.project().inner.poll_complete(cx)
344        }
345    }
346}
347
348impl<W> fmt::Debug for PartialAsyncWrite<W>
349where
350    W: fmt::Debug,
351{
352    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353        f.debug_struct("PartialAsyncWrite")
354            .field("inner", &self.inner)
355            .finish()
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    use std::fs::File;
364
365    use crate::tests::assert_send;
366
367    #[test]
368    fn test_sendable() {
369        assert_send::<PartialAsyncWrite<File>>();
370    }
371}