lance_core/utils/
futures.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::VecDeque,
6    sync::{Arc, Mutex},
7    task::Waker,
8};
9
10use futures::{stream::BoxStream, Stream, StreamExt};
11use tokio::sync::Semaphore;
12use tokio_util::sync::PollSemaphore;
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15enum Side {
16    Left,
17    Right,
18}
19
20/// A potentially unbounded capacity
21#[derive(Clone, Copy, Debug, PartialEq)]
22pub enum Capacity {
23    Bounded(u32),
24    Unbounded,
25}
26
27struct InnerState<'a, T> {
28    inner: Option<BoxStream<'a, T>>,
29    buffer: VecDeque<T>,
30    polling: Option<Side>,
31    waker: Option<Waker>,
32    exhausted: bool,
33    left_buffered: u32,
34    right_buffered: u32,
35    available_buffer: Option<PollSemaphore>,
36}
37
38/// The stream returned by [`share`].
39pub struct SharedStream<'a, T: Clone> {
40    state: Arc<Mutex<InnerState<'a, T>>>,
41    side: Side,
42}
43
44impl<'a, T: Clone> SharedStream<'a, T> {
45    pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
46        let available_buffer = match capacity {
47            Capacity::Unbounded => None,
48            Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
49                capacity as usize,
50            )))),
51        };
52        let state = InnerState {
53            inner: Some(inner),
54            buffer: VecDeque::new(),
55            polling: None,
56            waker: None,
57            exhausted: false,
58            left_buffered: 0,
59            right_buffered: 0,
60            available_buffer,
61        };
62
63        let state = Arc::new(Mutex::new(state));
64
65        let left = Self {
66            state: state.clone(),
67            side: Side::Left,
68        };
69        let right = Self {
70            state,
71            side: Side::Right,
72        };
73        (left, right)
74    }
75}
76
77impl<T: Clone> Stream for SharedStream<'_, T> {
78    type Item = T;
79
80    fn poll_next(
81        self: std::pin::Pin<&mut Self>,
82        cx: &mut std::task::Context<'_>,
83    ) -> std::task::Poll<Option<Self::Item>> {
84        let mut inner_state = self.state.lock().unwrap();
85        let can_take_buffered = match self.side {
86            Side::Left => inner_state.left_buffered > 0,
87            Side::Right => inner_state.right_buffered > 0,
88        };
89        if can_take_buffered {
90            // Easy case, there is an item in the buffer.  Grab it, decrement the count, and return it.
91            let item = inner_state.buffer.pop_front();
92            match self.side {
93                Side::Left => {
94                    inner_state.left_buffered -= 1;
95                }
96                Side::Right => {
97                    inner_state.right_buffered -= 1;
98                }
99            }
100            if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
101                available_buffer.add_permits(1);
102            }
103            std::task::Poll::Ready(item)
104        } else {
105            if inner_state.exhausted {
106                return std::task::Poll::Ready(None);
107            }
108            // No buffered items, if we have room in the buffer, then try and poll for one
109            let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
110                match available_buffer.poll_acquire(cx) {
111                    // Can return None if the semaphore is closed but we never close the semaphore
112                    // so its safe to unwrap here
113                    std::task::Poll::Ready(permit) => Some(permit.unwrap()),
114                    std::task::Poll::Pending => {
115                        return std::task::Poll::Pending;
116                    }
117                }
118            } else {
119                None
120            };
121            if let Some(polling_side) = inner_state.polling.as_ref() {
122                if *polling_side != self.side {
123                    // Another task is already polling the inner stream, so we don't need to do anything
124
125                    // Per rust docs:
126                    //   Note that on multiple calls to poll, only the Waker from the Context
127                    //   passed to the most recent call should be scheduled to receive a wakeup.
128                    //
129                    // So it is safe to replace a potentially stale waker here.
130                    inner_state.waker = Some(cx.waker().clone());
131                    return std::task::Poll::Pending;
132                }
133            }
134            inner_state.polling = Some(self.side);
135            // Release the mutex here as polling the inner stream is potentially expensive
136            let mut to_poll = inner_state
137                .inner
138                .take()
139                .expect("Other half of shared stream panic'd while polling inner stream");
140            drop(inner_state);
141            let res = to_poll.poll_next_unpin(cx);
142            let mut inner_state = self.state.lock().unwrap();
143
144            let mut should_wake = true;
145            match &res {
146                std::task::Poll::Ready(None) => {
147                    inner_state.exhausted = true;
148                    inner_state.polling = None;
149                }
150                std::task::Poll::Ready(Some(item)) => {
151                    // We got an item, forget the permit to mark that we can take one fewer items
152                    if let Some(permit) = permit {
153                        permit.forget();
154                    }
155                    inner_state.polling = None;
156                    // Let the other side know an item is available
157                    match self.side {
158                        Side::Left => {
159                            inner_state.right_buffered += 1;
160                        }
161                        Side::Right => {
162                            inner_state.left_buffered += 1;
163                        }
164                    };
165                    inner_state.buffer.push_back(item.clone());
166                }
167                std::task::Poll::Pending => {
168                    should_wake = false;
169                }
170            };
171
172            inner_state.inner = Some(to_poll);
173
174            // If the other side was waiting for us to poll, wake them up, but only after we release the mutex
175            let to_wake = if should_wake {
176                inner_state.waker.take()
177            } else {
178                // If the inner stream is pending then the inner stream will wake us up and we will wake the
179                // other side up then.
180                None
181            };
182            drop(inner_state);
183            if let Some(waker) = to_wake {
184                waker.wake();
185            }
186            res
187        }
188    }
189}
190
191pub trait SharedStreamExt<'a>: Stream + Send
192where
193    Self::Item: Clone,
194{
195    /// Split a stream into two shared streams
196    ///
197    /// Each shared stream will return the full set of items from the underlying stream.
198    /// This works by buffering the items from the underlying stream and then replaying
199    /// them to the other side.
200    ///
201    /// The capacity parameter controls how many items can be buffered at once.  Be careful
202    /// with the capacity parameter as it can lead to deadlock if the two streams are not
203    /// polled evenly.
204    ///
205    /// If the capacity is unbounded then the stream could potentially buffer the entire
206    /// input stream in memory.
207    fn share(
208        self,
209        capacity: Capacity,
210    ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
211}
212
213impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
214    fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
215        SharedStream::new(self, capacity)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221
222    use futures::{FutureExt, StreamExt};
223    use tokio_stream::wrappers::ReceiverStream;
224
225    use crate::utils::futures::{Capacity, SharedStreamExt};
226
227    fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
228        let noop_waker = futures::task::noop_waker();
229        let mut context = std::task::Context::from_waker(&noop_waker);
230        fut.poll_unpin(&mut context).is_pending()
231    }
232
233    #[tokio::test]
234    async fn test_shared_stream() {
235        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
236        let inner_stream = ReceiverStream::new(rx);
237
238        // Feed in a few items
239        for i in 0..3 {
240            tx.send(i).await.unwrap();
241        }
242
243        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
244
245        // We should be able to immediately poll 2 items
246        assert_eq!(left.next().await.unwrap(), 0);
247        assert_eq!(left.next().await.unwrap(), 1);
248
249        // Polling again should block because the right side has fallen behind
250        let mut left_fut = left.next();
251
252        assert!(is_pending(&mut left_fut));
253
254        // Polling the right side should yield the first cached item and unblock the left
255        assert_eq!(right.next().await.unwrap(), 0);
256        assert_eq!(left_fut.await.unwrap(), 2);
257
258        // Drain the rest of the stream from the right
259        assert_eq!(right.next().await.unwrap(), 1);
260        assert_eq!(right.next().await.unwrap(), 2);
261
262        // The channel isn't closed yet so we should get pending on both sides
263        let mut right_fut = right.next();
264        let mut left_fut = left.next();
265        assert!(is_pending(&mut right_fut));
266        assert!(is_pending(&mut left_fut));
267
268        // Send one more item
269        tx.send(3).await.unwrap();
270
271        // Should be received by both
272        assert_eq!(right_fut.await.unwrap(), 3);
273        assert_eq!(left_fut.await.unwrap(), 3);
274
275        drop(tx);
276
277        // Now we should be able to poll the end from either side
278        assert_eq!(left.next().await, None);
279        assert_eq!(right.next().await, None);
280
281        // We should be self-fused
282        assert_eq!(left.next().await, None);
283        assert_eq!(right.next().await, None);
284    }
285
286    #[tokio::test]
287    async fn test_unbounded_shared_stream() {
288        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
289        let inner_stream = ReceiverStream::new(rx);
290
291        // Feed in a few items
292        for i in 0..10 {
293            tx.send(i).await.unwrap();
294        }
295        drop(tx);
296
297        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
298
299        // We should be able to completely drain one side
300        for i in 0..10 {
301            assert_eq!(left.next().await.unwrap(), i);
302        }
303        assert_eq!(left.next().await, None);
304
305        // And still drain the other side from the buffer
306        for i in 0..10 {
307            assert_eq!(right.next().await.unwrap(), i);
308        }
309        assert_eq!(right.next().await, None);
310    }
311
312    #[tokio::test(flavor = "multi_thread")]
313    async fn stress_shared_stream() {
314        for _ in 0..100 {
315            let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
316            let inner_stream = ReceiverStream::new(rx);
317            let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
318
319            let left_handle = tokio::spawn(async move {
320                let mut counter = 0;
321                while let Some(item) = left.next().await {
322                    assert_eq!(item, counter);
323                    counter += 1;
324                }
325            });
326
327            let right_handle = tokio::spawn(async move {
328                let mut counter = 0;
329                while let Some(item) = right.next().await {
330                    assert_eq!(item, counter);
331                    counter += 1;
332                }
333            });
334
335            for i in 0..1000 {
336                tx.send(i).await.unwrap();
337            }
338            drop(tx);
339            left_handle.await.unwrap();
340            right_handle.await.unwrap();
341        }
342    }
343}