madsim_real_tokio/io/util/
mem.rs

1//! In-process memory IO types.
2
3use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8    pin::Pin,
9    sync::Arc,
10    task::{self, Poll, Waker},
11};
12
13/// A bidirectional pipe to read and write bytes in memory.
14///
15/// A pair of `DuplexStream`s are created together, and they act as a "channel"
16/// that can be used as in-memory IO types. Writing to one of the pairs will
17/// allow that data to be read from the other, and vice versa.
18///
19/// # Closing a `DuplexStream`
20///
21/// If one end of the `DuplexStream` channel is dropped, any pending reads on
22/// the other side will continue to read data until the buffer is drained, then
23/// they will signal EOF by returning 0 bytes. Any writes to the other side,
24/// including pending ones (that are waiting for free space in the buffer) will
25/// return `Err(BrokenPipe)` immediately.
26///
27/// # Example
28///
29/// ```
30/// # async fn ex() -> std::io::Result<()> {
31/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32/// let (mut client, mut server) = tokio::io::duplex(64);
33///
34/// client.write_all(b"ping").await?;
35///
36/// let mut buf = [0u8; 4];
37/// server.read_exact(&mut buf).await?;
38/// assert_eq!(&buf, b"ping");
39///
40/// server.write_all(b"pong").await?;
41///
42/// client.read_exact(&mut buf).await?;
43/// assert_eq!(&buf, b"pong");
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50    read: Arc<Mutex<Pipe>>,
51    write: Arc<Mutex<Pipe>>,
52}
53
54/// A unidirectional IO over a piece of memory.
55///
56/// Data can be written to the pipe, and reading will return that data.
57#[derive(Debug)]
58struct Pipe {
59    /// The buffer storing the bytes written, also read from.
60    ///
61    /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62    /// functionality already. Additionally, it can try to copy data in the
63    /// same buffer if there read index has advanced far enough.
64    buffer: BytesMut,
65    /// Determines if the write side has been closed.
66    is_closed: bool,
67    /// The maximum amount of bytes that can be written before returning
68    /// `Poll::Pending`.
69    max_buf_size: usize,
70    /// If the `read` side has been polled and is pending, this is the waker
71    /// for that parked task.
72    read_waker: Option<Waker>,
73    /// If the `write` side has filled the `max_buf_size` and returned
74    /// `Poll::Pending`, this is the waker for that parked task.
75    write_waker: Option<Waker>,
76}
77
78// ===== impl DuplexStream =====
79
80/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81///
82/// The `max_buf_size` argument is the maximum amount of bytes that can be
83/// written to a side before the write returns `Poll::Pending`.
84#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
85pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86    let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
87    let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
88
89    (
90        DuplexStream {
91            read: one.clone(),
92            write: two.clone(),
93        },
94        DuplexStream {
95            read: two,
96            write: one,
97        },
98    )
99}
100
101impl AsyncRead for DuplexStream {
102    // Previous rustc required this `self` to be `mut`, even though newer
103    // versions recognize it isn't needed to call `lock()`. So for
104    // compatibility, we include the `mut` and `allow` the lint.
105    //
106    // See https://github.com/rust-lang/rust/issues/73592
107    #[allow(unused_mut)]
108    fn poll_read(
109        mut self: Pin<&mut Self>,
110        cx: &mut task::Context<'_>,
111        buf: &mut ReadBuf<'_>,
112    ) -> Poll<std::io::Result<()>> {
113        Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114    }
115}
116
117impl AsyncWrite for DuplexStream {
118    #[allow(unused_mut)]
119    fn poll_write(
120        mut self: Pin<&mut Self>,
121        cx: &mut task::Context<'_>,
122        buf: &[u8],
123    ) -> Poll<std::io::Result<usize>> {
124        Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125    }
126
127    fn poll_write_vectored(
128        self: Pin<&mut Self>,
129        cx: &mut task::Context<'_>,
130        bufs: &[std::io::IoSlice<'_>],
131    ) -> Poll<Result<usize, std::io::Error>> {
132        Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
133    }
134
135    fn is_write_vectored(&self) -> bool {
136        true
137    }
138
139    #[allow(unused_mut)]
140    fn poll_flush(
141        mut self: Pin<&mut Self>,
142        cx: &mut task::Context<'_>,
143    ) -> Poll<std::io::Result<()>> {
144        Pin::new(&mut *self.write.lock()).poll_flush(cx)
145    }
146
147    #[allow(unused_mut)]
148    fn poll_shutdown(
149        mut self: Pin<&mut Self>,
150        cx: &mut task::Context<'_>,
151    ) -> Poll<std::io::Result<()>> {
152        Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
153    }
154}
155
156impl Drop for DuplexStream {
157    fn drop(&mut self) {
158        // notify the other side of the closure
159        self.write.lock().close_write();
160        self.read.lock().close_read();
161    }
162}
163
164// ===== impl Pipe =====
165
166impl Pipe {
167    fn new(max_buf_size: usize) -> Self {
168        Pipe {
169            buffer: BytesMut::new(),
170            is_closed: false,
171            max_buf_size,
172            read_waker: None,
173            write_waker: None,
174        }
175    }
176
177    fn close_write(&mut self) {
178        self.is_closed = true;
179        // needs to notify any readers that no more data will come
180        if let Some(waker) = self.read_waker.take() {
181            waker.wake();
182        }
183    }
184
185    fn close_read(&mut self) {
186        self.is_closed = true;
187        // needs to notify any writers that they have to abort
188        if let Some(waker) = self.write_waker.take() {
189            waker.wake();
190        }
191    }
192
193    fn poll_read_internal(
194        mut self: Pin<&mut Self>,
195        cx: &mut task::Context<'_>,
196        buf: &mut ReadBuf<'_>,
197    ) -> Poll<std::io::Result<()>> {
198        if self.buffer.has_remaining() {
199            let max = self.buffer.remaining().min(buf.remaining());
200            buf.put_slice(&self.buffer[..max]);
201            self.buffer.advance(max);
202            if max > 0 {
203                // The passed `buf` might have been empty, don't wake up if
204                // no bytes have been moved.
205                if let Some(waker) = self.write_waker.take() {
206                    waker.wake();
207                }
208            }
209            Poll::Ready(Ok(()))
210        } else if self.is_closed {
211            Poll::Ready(Ok(()))
212        } else {
213            self.read_waker = Some(cx.waker().clone());
214            Poll::Pending
215        }
216    }
217
218    fn poll_write_internal(
219        mut self: Pin<&mut Self>,
220        cx: &mut task::Context<'_>,
221        buf: &[u8],
222    ) -> Poll<std::io::Result<usize>> {
223        if self.is_closed {
224            return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
225        }
226        let avail = self.max_buf_size - self.buffer.len();
227        if avail == 0 {
228            self.write_waker = Some(cx.waker().clone());
229            return Poll::Pending;
230        }
231
232        let len = buf.len().min(avail);
233        self.buffer.extend_from_slice(&buf[..len]);
234        if let Some(waker) = self.read_waker.take() {
235            waker.wake();
236        }
237        Poll::Ready(Ok(len))
238    }
239
240    fn poll_write_vectored_internal(
241        mut self: Pin<&mut Self>,
242        cx: &mut task::Context<'_>,
243        bufs: &[std::io::IoSlice<'_>],
244    ) -> Poll<Result<usize, std::io::Error>> {
245        if self.is_closed {
246            return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
247        }
248        let avail = self.max_buf_size - self.buffer.len();
249        if avail == 0 {
250            self.write_waker = Some(cx.waker().clone());
251            return Poll::Pending;
252        }
253
254        let mut rem = avail;
255        for buf in bufs {
256            if rem == 0 {
257                break;
258            }
259
260            let len = buf.len().min(rem);
261            self.buffer.extend_from_slice(&buf[..len]);
262            rem -= len;
263        }
264
265        if let Some(waker) = self.read_waker.take() {
266            waker.wake();
267        }
268        Poll::Ready(Ok(avail - rem))
269    }
270}
271
272impl AsyncRead for Pipe {
273    cfg_coop! {
274        fn poll_read(
275            self: Pin<&mut Self>,
276            cx: &mut task::Context<'_>,
277            buf: &mut ReadBuf<'_>,
278        ) -> Poll<std::io::Result<()>> {
279            ready!(crate::trace::trace_leaf(cx));
280            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
281
282            let ret = self.poll_read_internal(cx, buf);
283            if ret.is_ready() {
284                coop.made_progress();
285            }
286            ret
287        }
288    }
289
290    cfg_not_coop! {
291        fn poll_read(
292            self: Pin<&mut Self>,
293            cx: &mut task::Context<'_>,
294            buf: &mut ReadBuf<'_>,
295        ) -> Poll<std::io::Result<()>> {
296            ready!(crate::trace::trace_leaf(cx));
297            self.poll_read_internal(cx, buf)
298        }
299    }
300}
301
302impl AsyncWrite for Pipe {
303    cfg_coop! {
304        fn poll_write(
305            self: Pin<&mut Self>,
306            cx: &mut task::Context<'_>,
307            buf: &[u8],
308        ) -> Poll<std::io::Result<usize>> {
309            ready!(crate::trace::trace_leaf(cx));
310            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
311
312            let ret = self.poll_write_internal(cx, buf);
313            if ret.is_ready() {
314                coop.made_progress();
315            }
316            ret
317        }
318    }
319
320    cfg_not_coop! {
321        fn poll_write(
322            self: Pin<&mut Self>,
323            cx: &mut task::Context<'_>,
324            buf: &[u8],
325        ) -> Poll<std::io::Result<usize>> {
326            ready!(crate::trace::trace_leaf(cx));
327            self.poll_write_internal(cx, buf)
328        }
329    }
330
331    cfg_coop! {
332        fn poll_write_vectored(
333            self: Pin<&mut Self>,
334            cx: &mut task::Context<'_>,
335            bufs: &[std::io::IoSlice<'_>],
336        ) -> Poll<Result<usize, std::io::Error>> {
337            ready!(crate::trace::trace_leaf(cx));
338            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
339
340            let ret = self.poll_write_vectored_internal(cx, bufs);
341            if ret.is_ready() {
342                coop.made_progress();
343            }
344            ret
345        }
346    }
347
348    cfg_not_coop! {
349        fn poll_write_vectored(
350            self: Pin<&mut Self>,
351            cx: &mut task::Context<'_>,
352            bufs: &[std::io::IoSlice<'_>],
353        ) -> Poll<Result<usize, std::io::Error>> {
354            ready!(crate::trace::trace_leaf(cx));
355            self.poll_write_vectored_internal(cx, bufs)
356        }
357    }
358
359    fn is_write_vectored(&self) -> bool {
360        true
361    }
362
363    fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
364        Poll::Ready(Ok(()))
365    }
366
367    fn poll_shutdown(
368        mut self: Pin<&mut Self>,
369        _: &mut task::Context<'_>,
370    ) -> Poll<std::io::Result<()>> {
371        self.close_write();
372        Poll::Ready(Ok(()))
373    }
374}