tokio_util/io/
inspect.rs

1use pin_project_lite::pin_project;
2use std::io::{IoSlice, Result};
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8pin_project! {
9    /// An adapter that lets you inspect the data that's being read.
10    ///
11    /// This is useful for things like hashing data as it's read in.
12    pub struct InspectReader<R, F> {
13        #[pin]
14        reader: R,
15        f: F,
16    }
17}
18
19impl<R, F> InspectReader<R, F> {
20    /// Create a new `InspectReader`, wrapping `reader` and calling `f` for the
21    /// new data supplied by each read call.
22    ///
23    /// The closure will only be called with an empty slice if the inner reader
24    /// returns without reading data into the buffer. This happens at EOF, or if
25    /// `poll_read` is called with a zero-size buffer.
26    pub fn new(reader: R, f: F) -> InspectReader<R, F>
27    where
28        R: AsyncRead,
29        F: FnMut(&[u8]),
30    {
31        InspectReader { reader, f }
32    }
33
34    /// Consumes the `InspectReader`, returning the wrapped reader
35    pub fn into_inner(self) -> R {
36        self.reader
37    }
38}
39
40impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
41    fn poll_read(
42        self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44        buf: &mut ReadBuf<'_>,
45    ) -> Poll<Result<()>> {
46        let me = self.project();
47        let filled_length = buf.filled().len();
48        ready!(me.reader.poll_read(cx, buf))?;
49        (me.f)(&buf.filled()[filled_length..]);
50        Poll::Ready(Ok(()))
51    }
52}
53
54impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
55    fn poll_write(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &[u8],
59    ) -> Poll<std::result::Result<usize, std::io::Error>> {
60        self.project().reader.poll_write(cx, buf)
61    }
62
63    fn poll_flush(
64        self: Pin<&mut Self>,
65        cx: &mut Context<'_>,
66    ) -> Poll<std::result::Result<(), std::io::Error>> {
67        self.project().reader.poll_flush(cx)
68    }
69
70    fn poll_shutdown(
71        self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73    ) -> Poll<std::result::Result<(), std::io::Error>> {
74        self.project().reader.poll_shutdown(cx)
75    }
76
77    fn poll_write_vectored(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80        bufs: &[IoSlice<'_>],
81    ) -> Poll<Result<usize>> {
82        self.project().reader.poll_write_vectored(cx, bufs)
83    }
84
85    fn is_write_vectored(&self) -> bool {
86        self.reader.is_write_vectored()
87    }
88}
89
90pin_project! {
91    /// An adapter that lets you inspect the data that's being written.
92    ///
93    /// This is useful for things like hashing data as it's written out.
94    pub struct InspectWriter<W, F> {
95        #[pin]
96        writer: W,
97        f: F,
98    }
99}
100
101impl<W, F> InspectWriter<W, F> {
102    /// Create a new `InspectWriter`, wrapping `write` and calling `f` for the
103    /// data successfully written by each write call.
104    ///
105    /// The closure `f` will never be called with an empty slice. A vectored
106    /// write can result in multiple calls to `f` - at most one call to `f` per
107    /// buffer supplied to `poll_write_vectored`.
108    pub fn new(writer: W, f: F) -> InspectWriter<W, F>
109    where
110        W: AsyncWrite,
111        F: FnMut(&[u8]),
112    {
113        InspectWriter { writer, f }
114    }
115
116    /// Consumes the `InspectWriter`, returning the wrapped writer
117    pub fn into_inner(self) -> W {
118        self.writer
119    }
120}
121
122impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
123    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
124        let me = self.project();
125        let res = me.writer.poll_write(cx, buf);
126        if let Poll::Ready(Ok(count)) = res {
127            if count != 0 {
128                (me.f)(&buf[..count]);
129            }
130        }
131        res
132    }
133
134    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
135        let me = self.project();
136        me.writer.poll_flush(cx)
137    }
138
139    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
140        let me = self.project();
141        me.writer.poll_shutdown(cx)
142    }
143
144    fn poll_write_vectored(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        bufs: &[IoSlice<'_>],
148    ) -> Poll<Result<usize>> {
149        let me = self.project();
150        let res = me.writer.poll_write_vectored(cx, bufs);
151        if let Poll::Ready(Ok(mut count)) = res {
152            for buf in bufs {
153                if count == 0 {
154                    break;
155                }
156                let size = count.min(buf.len());
157                if size != 0 {
158                    (me.f)(&buf[..size]);
159                    count -= size;
160                }
161            }
162        }
163        res
164    }
165
166    fn is_write_vectored(&self) -> bool {
167        self.writer.is_write_vectored()
168    }
169}
170
171impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
172    fn poll_read(
173        self: Pin<&mut Self>,
174        cx: &mut Context<'_>,
175        buf: &mut ReadBuf<'_>,
176    ) -> Poll<std::io::Result<()>> {
177        self.project().writer.poll_read(cx, buf)
178    }
179}