tokio_util/
compat.rs

1//! Compatibility between the `tokio::io` and `futures-io` versions of the
2//! `AsyncRead` and `AsyncWrite` traits.
3use pin_project_lite::pin_project;
4use std::io;
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8pin_project! {
9    /// A compatibility layer that allows conversion between the
10    /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
11    #[derive(Copy, Clone, Debug)]
12    pub struct Compat<T> {
13        #[pin]
14        inner: T,
15        seek_pos: Option<io::SeekFrom>,
16    }
17}
18
19/// Extension trait that allows converting a type implementing
20/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
21pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
22    /// Wraps `self` with a compatibility layer that implements
23    /// `tokio_io::AsyncRead`.
24    fn compat(self) -> Compat<Self>
25    where
26        Self: Sized,
27    {
28        Compat::new(self)
29    }
30}
31
32impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
33
34/// Extension trait that allows converting a type implementing
35/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
36pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
37    /// Wraps `self` with a compatibility layer that implements
38    /// `tokio::io::AsyncWrite`.
39    fn compat_write(self) -> Compat<Self>
40    where
41        Self: Sized,
42    {
43        Compat::new(self)
44    }
45}
46
47impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
48
49/// Extension trait that allows converting a type implementing
50/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
51pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
52    /// Wraps `self` with a compatibility layer that implements
53    /// `futures_io::AsyncRead`.
54    fn compat(self) -> Compat<Self>
55    where
56        Self: Sized,
57    {
58        Compat::new(self)
59    }
60}
61
62impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
63
64/// Extension trait that allows converting a type implementing
65/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
66pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
67    /// Wraps `self` with a compatibility layer that implements
68    /// `futures_io::AsyncWrite`.
69    fn compat_write(self) -> Compat<Self>
70    where
71        Self: Sized,
72    {
73        Compat::new(self)
74    }
75}
76
77impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
78
79// === impl Compat ===
80
81impl<T> Compat<T> {
82    fn new(inner: T) -> Self {
83        Self {
84            inner,
85            seek_pos: None,
86        }
87    }
88
89    /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
90    /// contained within.
91    pub fn get_ref(&self) -> &T {
92        &self.inner
93    }
94
95    /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
96    /// contained within.
97    pub fn get_mut(&mut self) -> &mut T {
98        &mut self.inner
99    }
100
101    /// Returns the wrapped item.
102    pub fn into_inner(self) -> T {
103        self.inner
104    }
105}
106
107impl<T> tokio::io::AsyncRead for Compat<T>
108where
109    T: futures_io::AsyncRead,
110{
111    fn poll_read(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut tokio::io::ReadBuf<'_>,
115    ) -> Poll<io::Result<()>> {
116        // We can't trust the inner type to not peak at the bytes,
117        // so we must defensively initialize the buffer.
118        let slice = buf.initialize_unfilled();
119        let n = ready!(futures_io::AsyncRead::poll_read(
120            self.project().inner,
121            cx,
122            slice
123        ))?;
124        buf.advance(n);
125        Poll::Ready(Ok(()))
126    }
127}
128
129impl<T> futures_io::AsyncRead for Compat<T>
130where
131    T: tokio::io::AsyncRead,
132{
133    fn poll_read(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136        slice: &mut [u8],
137    ) -> Poll<io::Result<usize>> {
138        let mut buf = tokio::io::ReadBuf::new(slice);
139        ready!(tokio::io::AsyncRead::poll_read(
140            self.project().inner,
141            cx,
142            &mut buf
143        ))?;
144        Poll::Ready(Ok(buf.filled().len()))
145    }
146}
147
148impl<T> tokio::io::AsyncBufRead for Compat<T>
149where
150    T: futures_io::AsyncBufRead,
151{
152    fn poll_fill_buf<'a>(
153        self: Pin<&'a mut Self>,
154        cx: &mut Context<'_>,
155    ) -> Poll<io::Result<&'a [u8]>> {
156        futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
157    }
158
159    fn consume(self: Pin<&mut Self>, amt: usize) {
160        futures_io::AsyncBufRead::consume(self.project().inner, amt)
161    }
162}
163
164impl<T> futures_io::AsyncBufRead for Compat<T>
165where
166    T: tokio::io::AsyncBufRead,
167{
168    fn poll_fill_buf<'a>(
169        self: Pin<&'a mut Self>,
170        cx: &mut Context<'_>,
171    ) -> Poll<io::Result<&'a [u8]>> {
172        tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
173    }
174
175    fn consume(self: Pin<&mut Self>, amt: usize) {
176        tokio::io::AsyncBufRead::consume(self.project().inner, amt)
177    }
178}
179
180impl<T> tokio::io::AsyncWrite for Compat<T>
181where
182    T: futures_io::AsyncWrite,
183{
184    fn poll_write(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &[u8],
188    ) -> Poll<io::Result<usize>> {
189        futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
190    }
191
192    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193        futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
194    }
195
196    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197        futures_io::AsyncWrite::poll_close(self.project().inner, cx)
198    }
199}
200
201impl<T> futures_io::AsyncWrite for Compat<T>
202where
203    T: tokio::io::AsyncWrite,
204{
205    fn poll_write(
206        self: Pin<&mut Self>,
207        cx: &mut Context<'_>,
208        buf: &[u8],
209    ) -> Poll<io::Result<usize>> {
210        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
211    }
212
213    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
214        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
215    }
216
217    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
219    }
220}
221
222impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
223    fn poll_seek(
224        mut self: Pin<&mut Self>,
225        cx: &mut Context<'_>,
226        pos: io::SeekFrom,
227    ) -> Poll<io::Result<u64>> {
228        if self.seek_pos != Some(pos) {
229            // Ensure previous seeks have finished before starting a new one
230            ready!(self.as_mut().project().inner.poll_complete(cx))?;
231            self.as_mut().project().inner.start_seek(pos)?;
232            *self.as_mut().project().seek_pos = Some(pos);
233        }
234        let res = ready!(self.as_mut().project().inner.poll_complete(cx));
235        *self.as_mut().project().seek_pos = None;
236        Poll::Ready(res)
237    }
238}
239
240impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
241    fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
242        *self.as_mut().project().seek_pos = Some(pos);
243        Ok(())
244    }
245
246    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
247        let pos = match self.seek_pos {
248            None => {
249                // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
250                // We don't have to guarantee that the value returned by
251                // poll_complete called without start_seek is correct,
252                // so we'll return 0.
253                return Poll::Ready(Ok(0));
254            }
255            Some(pos) => pos,
256        };
257        let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
258        *self.as_mut().project().seek_pos = None;
259        Poll::Ready(res)
260    }
261}
262
263#[cfg(unix)]
264impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
265    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
266        self.inner.as_raw_fd()
267    }
268}
269
270#[cfg(windows)]
271impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
272    fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
273        self.inner.as_raw_handle()
274    }
275}