broker_tokio/io/
split.rs

1//! Split a single value implementing `AsyncRead + AsyncWrite` into separate
2//! `AsyncRead` and `AsyncWrite` handles.
3//!
4//! To restore this read/write object from its `split::ReadHalf` and
5//! `split::WriteHalf` use `unsplit`.
6
7use crate::io::{AsyncRead, AsyncWrite};
8
9use bytes::{Buf, BufMut};
10use std::cell::UnsafeCell;
11use std::fmt;
12use std::io;
13use std::pin::Pin;
14use std::sync::atomic::AtomicBool;
15use std::sync::atomic::Ordering::{Acquire, Release};
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19cfg_io_util! {
20    /// The readable half of a value returned from [`split`](split()).
21    pub struct ReadHalf<T> {
22        inner: Arc<Inner<T>>,
23    }
24
25    /// The writable half of a value returned from [`split`](split()).
26    pub struct WriteHalf<T> {
27        inner: Arc<Inner<T>>,
28    }
29
30    /// Split a single value implementing `AsyncRead + AsyncWrite` into separate
31    /// `AsyncRead` and `AsyncWrite` handles.
32    ///
33    /// To restore this read/write object from its `ReadHalf` and
34    /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()).
35    pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
36    where
37        T: AsyncRead + AsyncWrite,
38    {
39        let inner = Arc::new(Inner {
40            locked: AtomicBool::new(false),
41            stream: UnsafeCell::new(stream),
42        });
43
44        let rd = ReadHalf {
45            inner: inner.clone(),
46        };
47
48        let wr = WriteHalf { inner };
49
50        (rd, wr)
51    }
52}
53
54struct Inner<T> {
55    locked: AtomicBool,
56    stream: UnsafeCell<T>,
57}
58
59struct Guard<'a, T> {
60    inner: &'a Inner<T>,
61}
62
63impl<T> ReadHalf<T> {
64    /// Check if this `ReadHalf` and some `WriteHalf` were split from the same
65    /// stream.
66    pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
67        other.is_pair_of(&self)
68    }
69
70    /// Reunite with a previously split `WriteHalf`.
71    ///
72    /// # Panics
73    ///
74    /// If this `ReadHalf` and the given `WriteHalf` do not originate from the
75    /// same `split` operation this method will panic.
76    /// This can be checked ahead of time by comparing the stream ID
77    /// of the two halves.
78    pub fn unsplit(self, wr: WriteHalf<T>) -> T {
79        if self.is_pair_of(&wr) {
80            drop(wr);
81
82            let inner = Arc::try_unwrap(self.inner)
83                .ok()
84                .expect("Arc::try_unwrap failed");
85
86            inner.stream.into_inner()
87        } else {
88            panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
89        }
90    }
91}
92
93impl<T> WriteHalf<T> {
94    /// Check if this `WriteHalf` and some `ReadHalf` were split from the same
95    /// stream.
96    pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
97        Arc::ptr_eq(&self.inner, &other.inner)
98    }
99}
100
101impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
102    fn poll_read(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        buf: &mut [u8],
106    ) -> Poll<io::Result<usize>> {
107        let mut inner = ready!(self.inner.poll_lock(cx));
108        inner.stream_pin().poll_read(cx, buf)
109    }
110
111    fn poll_read_buf<B: BufMut>(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut B,
115    ) -> Poll<io::Result<usize>> {
116        let mut inner = ready!(self.inner.poll_lock(cx));
117        inner.stream_pin().poll_read_buf(cx, buf)
118    }
119}
120
121impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
122    fn poll_write(
123        self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        buf: &[u8],
126    ) -> Poll<Result<usize, io::Error>> {
127        let mut inner = ready!(self.inner.poll_lock(cx));
128        inner.stream_pin().poll_write(cx, buf)
129    }
130
131    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
132        let mut inner = ready!(self.inner.poll_lock(cx));
133        inner.stream_pin().poll_flush(cx)
134    }
135
136    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
137        let mut inner = ready!(self.inner.poll_lock(cx));
138        inner.stream_pin().poll_shutdown(cx)
139    }
140
141    fn poll_write_buf<B: Buf>(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut B,
145    ) -> Poll<Result<usize, io::Error>> {
146        let mut inner = ready!(self.inner.poll_lock(cx));
147        inner.stream_pin().poll_write_buf(cx, buf)
148    }
149}
150
151impl<T> Inner<T> {
152    fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> {
153        if !self.locked.compare_and_swap(false, true, Acquire) {
154            Poll::Ready(Guard { inner: self })
155        } else {
156            // Spin... but investigate a better strategy
157
158            std::thread::yield_now();
159            cx.waker().wake_by_ref();
160
161            Poll::Pending
162        }
163    }
164}
165
166impl<T> Guard<'_, T> {
167    fn stream_pin(&mut self) -> Pin<&mut T> {
168        // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
169        // exclusion.
170        unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) }
171    }
172}
173
174impl<T> Drop for Guard<'_, T> {
175    fn drop(&mut self) {
176        self.inner.locked.store(false, Release);
177    }
178}
179
180unsafe impl<T: Send> Send for ReadHalf<T> {}
181unsafe impl<T: Send> Send for WriteHalf<T> {}
182unsafe impl<T: Sync> Sync for ReadHalf<T> {}
183unsafe impl<T: Sync> Sync for WriteHalf<T> {}
184
185impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> {
186    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
187        fmt.debug_struct("split::ReadHalf").finish()
188    }
189}
190
191impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> {
192    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
193        fmt.debug_struct("split::WriteHalf").finish()
194    }
195}