1use crate::io::{AsyncRead, AsyncWrite};
8
9use bytes::{Buf, BufMut};
10use std::cell::UnsafeCell;
11use std::fmt;
12use std::io;
13use std::pin::Pin;
14use std::sync::atomic::AtomicBool;
15use std::sync::atomic::Ordering::{Acquire, Release};
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19cfg_io_util! {
20 pub struct ReadHalf<T> {
22 inner: Arc<Inner<T>>,
23 }
24
25 pub struct WriteHalf<T> {
27 inner: Arc<Inner<T>>,
28 }
29
30 pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
36 where
37 T: AsyncRead + AsyncWrite,
38 {
39 let inner = Arc::new(Inner {
40 locked: AtomicBool::new(false),
41 stream: UnsafeCell::new(stream),
42 });
43
44 let rd = ReadHalf {
45 inner: inner.clone(),
46 };
47
48 let wr = WriteHalf { inner };
49
50 (rd, wr)
51 }
52}
53
54struct Inner<T> {
55 locked: AtomicBool,
56 stream: UnsafeCell<T>,
57}
58
59struct Guard<'a, T> {
60 inner: &'a Inner<T>,
61}
62
63impl<T> ReadHalf<T> {
64 pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
67 other.is_pair_of(&self)
68 }
69
70 pub fn unsplit(self, wr: WriteHalf<T>) -> T {
79 if self.is_pair_of(&wr) {
80 drop(wr);
81
82 let inner = Arc::try_unwrap(self.inner)
83 .ok()
84 .expect("Arc::try_unwrap failed");
85
86 inner.stream.into_inner()
87 } else {
88 panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
89 }
90 }
91}
92
93impl<T> WriteHalf<T> {
94 pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
97 Arc::ptr_eq(&self.inner, &other.inner)
98 }
99}
100
101impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
102 fn poll_read(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &mut [u8],
106 ) -> Poll<io::Result<usize>> {
107 let mut inner = ready!(self.inner.poll_lock(cx));
108 inner.stream_pin().poll_read(cx, buf)
109 }
110
111 fn poll_read_buf<B: BufMut>(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &mut B,
115 ) -> Poll<io::Result<usize>> {
116 let mut inner = ready!(self.inner.poll_lock(cx));
117 inner.stream_pin().poll_read_buf(cx, buf)
118 }
119}
120
121impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
122 fn poll_write(
123 self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 buf: &[u8],
126 ) -> Poll<Result<usize, io::Error>> {
127 let mut inner = ready!(self.inner.poll_lock(cx));
128 inner.stream_pin().poll_write(cx, buf)
129 }
130
131 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
132 let mut inner = ready!(self.inner.poll_lock(cx));
133 inner.stream_pin().poll_flush(cx)
134 }
135
136 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
137 let mut inner = ready!(self.inner.poll_lock(cx));
138 inner.stream_pin().poll_shutdown(cx)
139 }
140
141 fn poll_write_buf<B: Buf>(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut B,
145 ) -> Poll<Result<usize, io::Error>> {
146 let mut inner = ready!(self.inner.poll_lock(cx));
147 inner.stream_pin().poll_write_buf(cx, buf)
148 }
149}
150
151impl<T> Inner<T> {
152 fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> {
153 if !self.locked.compare_and_swap(false, true, Acquire) {
154 Poll::Ready(Guard { inner: self })
155 } else {
156 std::thread::yield_now();
159 cx.waker().wake_by_ref();
160
161 Poll::Pending
162 }
163 }
164}
165
166impl<T> Guard<'_, T> {
167 fn stream_pin(&mut self) -> Pin<&mut T> {
168 unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) }
171 }
172}
173
174impl<T> Drop for Guard<'_, T> {
175 fn drop(&mut self) {
176 self.inner.locked.store(false, Release);
177 }
178}
179
180unsafe impl<T: Send> Send for ReadHalf<T> {}
181unsafe impl<T: Send> Send for WriteHalf<T> {}
182unsafe impl<T: Sync> Sync for ReadHalf<T> {}
183unsafe impl<T: Sync> Sync for WriteHalf<T> {}
184
185impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> {
186 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
187 fmt.debug_struct("split::ReadHalf").finish()
188 }
189}
190
191impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> {
192 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
193 fmt.debug_struct("split::WriteHalf").finish()
194 }
195}