1use std::{
5 collections::VecDeque,
6 sync::{Arc, Mutex},
7 task::Waker,
8};
9
10use futures::{stream::BoxStream, Stream, StreamExt};
11use tokio::sync::Semaphore;
12use tokio_util::sync::PollSemaphore;
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15enum Side {
16 Left,
17 Right,
18}
19
20#[derive(Clone, Copy, Debug, PartialEq)]
22pub enum Capacity {
23 Bounded(u32),
24 Unbounded,
25}
26
27struct InnerState<'a, T> {
28 inner: Option<BoxStream<'a, T>>,
29 buffer: VecDeque<T>,
30 polling: Option<Side>,
31 waker: Option<Waker>,
32 exhausted: bool,
33 left_buffered: u32,
34 right_buffered: u32,
35 available_buffer: Option<PollSemaphore>,
36}
37
38pub struct SharedStream<'a, T: Clone> {
40 state: Arc<Mutex<InnerState<'a, T>>>,
41 side: Side,
42}
43
44impl<'a, T: Clone> SharedStream<'a, T> {
45 pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
46 let available_buffer = match capacity {
47 Capacity::Unbounded => None,
48 Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
49 capacity as usize,
50 )))),
51 };
52 let state = InnerState {
53 inner: Some(inner),
54 buffer: VecDeque::new(),
55 polling: None,
56 waker: None,
57 exhausted: false,
58 left_buffered: 0,
59 right_buffered: 0,
60 available_buffer,
61 };
62
63 let state = Arc::new(Mutex::new(state));
64
65 let left = Self {
66 state: state.clone(),
67 side: Side::Left,
68 };
69 let right = Self {
70 state,
71 side: Side::Right,
72 };
73 (left, right)
74 }
75}
76
77impl<T: Clone> Stream for SharedStream<'_, T> {
78 type Item = T;
79
80 fn poll_next(
81 self: std::pin::Pin<&mut Self>,
82 cx: &mut std::task::Context<'_>,
83 ) -> std::task::Poll<Option<Self::Item>> {
84 let mut inner_state = self.state.lock().unwrap();
85 let can_take_buffered = match self.side {
86 Side::Left => inner_state.left_buffered > 0,
87 Side::Right => inner_state.right_buffered > 0,
88 };
89 if can_take_buffered {
90 let item = inner_state.buffer.pop_front();
92 match self.side {
93 Side::Left => {
94 inner_state.left_buffered -= 1;
95 }
96 Side::Right => {
97 inner_state.right_buffered -= 1;
98 }
99 }
100 if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
101 available_buffer.add_permits(1);
102 }
103 std::task::Poll::Ready(item)
104 } else {
105 if inner_state.exhausted {
106 return std::task::Poll::Ready(None);
107 }
108 let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
110 match available_buffer.poll_acquire(cx) {
111 std::task::Poll::Ready(permit) => Some(permit.unwrap()),
114 std::task::Poll::Pending => {
115 return std::task::Poll::Pending;
116 }
117 }
118 } else {
119 None
120 };
121 if let Some(polling_side) = inner_state.polling.as_ref() {
122 if *polling_side != self.side {
123 inner_state.waker = Some(cx.waker().clone());
131 return std::task::Poll::Pending;
132 }
133 }
134 inner_state.polling = Some(self.side);
135 let mut to_poll = inner_state
137 .inner
138 .take()
139 .expect("Other half of shared stream panic'd while polling inner stream");
140 drop(inner_state);
141 let res = to_poll.poll_next_unpin(cx);
142 let mut inner_state = self.state.lock().unwrap();
143
144 let mut should_wake = true;
145 match &res {
146 std::task::Poll::Ready(None) => {
147 inner_state.exhausted = true;
148 inner_state.polling = None;
149 }
150 std::task::Poll::Ready(Some(item)) => {
151 if let Some(permit) = permit {
153 permit.forget();
154 }
155 inner_state.polling = None;
156 match self.side {
158 Side::Left => {
159 inner_state.right_buffered += 1;
160 }
161 Side::Right => {
162 inner_state.left_buffered += 1;
163 }
164 };
165 inner_state.buffer.push_back(item.clone());
166 }
167 std::task::Poll::Pending => {
168 should_wake = false;
169 }
170 };
171
172 inner_state.inner = Some(to_poll);
173
174 let to_wake = if should_wake {
176 inner_state.waker.take()
177 } else {
178 None
181 };
182 drop(inner_state);
183 if let Some(waker) = to_wake {
184 waker.wake();
185 }
186 res
187 }
188 }
189}
190
191pub trait SharedStreamExt<'a>: Stream + Send
192where
193 Self::Item: Clone,
194{
195 fn share(
208 self,
209 capacity: Capacity,
210 ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
211}
212
213impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
214 fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
215 SharedStream::new(self, capacity)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221
222 use futures::{FutureExt, StreamExt};
223 use tokio_stream::wrappers::ReceiverStream;
224
225 use crate::utils::futures::{Capacity, SharedStreamExt};
226
227 fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
228 let noop_waker = futures::task::noop_waker();
229 let mut context = std::task::Context::from_waker(&noop_waker);
230 fut.poll_unpin(&mut context).is_pending()
231 }
232
233 #[tokio::test]
234 async fn test_shared_stream() {
235 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
236 let inner_stream = ReceiverStream::new(rx);
237
238 for i in 0..3 {
240 tx.send(i).await.unwrap();
241 }
242
243 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
244
245 assert_eq!(left.next().await.unwrap(), 0);
247 assert_eq!(left.next().await.unwrap(), 1);
248
249 let mut left_fut = left.next();
251
252 assert!(is_pending(&mut left_fut));
253
254 assert_eq!(right.next().await.unwrap(), 0);
256 assert_eq!(left_fut.await.unwrap(), 2);
257
258 assert_eq!(right.next().await.unwrap(), 1);
260 assert_eq!(right.next().await.unwrap(), 2);
261
262 let mut right_fut = right.next();
264 let mut left_fut = left.next();
265 assert!(is_pending(&mut right_fut));
266 assert!(is_pending(&mut left_fut));
267
268 tx.send(3).await.unwrap();
270
271 assert_eq!(right_fut.await.unwrap(), 3);
273 assert_eq!(left_fut.await.unwrap(), 3);
274
275 drop(tx);
276
277 assert_eq!(left.next().await, None);
279 assert_eq!(right.next().await, None);
280
281 assert_eq!(left.next().await, None);
283 assert_eq!(right.next().await, None);
284 }
285
286 #[tokio::test]
287 async fn test_unbounded_shared_stream() {
288 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
289 let inner_stream = ReceiverStream::new(rx);
290
291 for i in 0..10 {
293 tx.send(i).await.unwrap();
294 }
295 drop(tx);
296
297 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
298
299 for i in 0..10 {
301 assert_eq!(left.next().await.unwrap(), i);
302 }
303 assert_eq!(left.next().await, None);
304
305 for i in 0..10 {
307 assert_eq!(right.next().await.unwrap(), i);
308 }
309 assert_eq!(right.next().await, None);
310 }
311
312 #[tokio::test(flavor = "multi_thread")]
313 async fn stress_shared_stream() {
314 for _ in 0..100 {
315 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
316 let inner_stream = ReceiverStream::new(rx);
317 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
318
319 let left_handle = tokio::spawn(async move {
320 let mut counter = 0;
321 while let Some(item) = left.next().await {
322 assert_eq!(item, counter);
323 counter += 1;
324 }
325 });
326
327 let right_handle = tokio::spawn(async move {
328 let mut counter = 0;
329 while let Some(item) = right.next().await {
330 assert_eq!(item, counter);
331 counter += 1;
332 }
333 });
334
335 for i in 0..1000 {
336 tx.send(i).await.unwrap();
337 }
338 drop(tx);
339 left_handle.await.unwrap();
340 right_handle.await.unwrap();
341 }
342 }
343}