stream_cancel/
combinator.rs

1use crate::Trigger;
2use futures_core::{ready, stream::Stream};
3use pin_project::pin_project;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::sync::watch;
9
10/// A stream combinator which takes elements from a stream until a future resolves.
11///
12/// This structure is produced by the [`StreamExt::take_until_if`] method.
13#[pin_project]
14#[derive(Clone, Debug)]
15pub struct TakeUntilIf<S, F> {
16    #[pin]
17    stream: S,
18    #[pin]
19    until: F,
20    free: bool,
21}
22
23impl<S, F> TakeUntilIf<S, F> {
24    /// Consumes this combinator, returning the underlying stream.
25    pub fn into_inner(self) -> S {
26        self.stream
27    }
28}
29
30/// This `Stream` extension trait provides a `take_until_if` method that terminates the stream once
31/// the given future resolves.
32pub trait StreamExt: Stream {
33    /// Take elements from this stream until the given future resolves.
34    ///
35    /// This function takes elements from this stream until the given future resolves with
36    /// `true`. Once it resolves, the stream yields `None`, and produces no further elements.
37    ///
38    /// If the future resolves with `false`, the stream is allowed to continue indefinitely.
39    ///
40    /// This method is essentially a wrapper around `futures_util::stream::StreamExt::take_until`
41    /// that ascribes particular semantics to the output of the provided future.
42    ///
43    /// ```
44    /// use stream_cancel::StreamExt;
45    /// use futures::prelude::*;
46    /// use tokio_stream::wrappers::TcpListenerStream;
47    ///
48    /// #[tokio::main]
49    /// async fn main() {
50    ///     let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
51    ///     let (tx, rx) = tokio::sync::oneshot::channel();
52    ///
53    ///     tokio::spawn(async move {
54    ///         let mut incoming = TcpListenerStream::new(listener).take_until_if(rx.map(|_| true));
55    ///         while let Some(mut s) = incoming.next().await.transpose().unwrap() {
56    ///             tokio::spawn(async move {
57    ///                 let (mut r, mut w) = s.split();
58    ///                 println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
59    ///             });
60    ///         }
61    ///     });
62    ///
63    ///     // tell the listener to stop accepting new connections
64    ///     tx.send(()).unwrap();
65    ///     // the spawned async block will terminate cleanly, allowing main to return
66    /// }
67    /// ```
68    fn take_until_if<U>(self, until: U) -> TakeUntilIf<Self, U>
69    where
70        U: Future<Output = bool>,
71        Self: Sized,
72    {
73        TakeUntilIf {
74            stream: self,
75            until,
76            free: false,
77        }
78    }
79}
80
81impl<S> StreamExt for S where S: Stream {}
82
83impl<S, F> Stream for TakeUntilIf<S, F>
84where
85    S: Stream,
86    F: Future<Output = bool>,
87{
88    type Item = S::Item;
89
90    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91        let this = self.project();
92        if !*this.free {
93            if let Poll::Ready(terminate) = this.until.poll(cx) {
94                if terminate {
95                    // future resolved -- terminate stream
96                    return Poll::Ready(None);
97                }
98                // to provide a mechanism for the creator to let the stream run forever,
99                // we interpret this as "run forever".
100                *this.free = true;
101            }
102        }
103
104        this.stream.poll_next(cx)
105    }
106}
107
108/// A `Tripwire` is a convenient mechanism for implementing graceful shutdown over many
109/// asynchronous streams. A `Tripwire` is a `Future` that is `Clone`, and that can be passed to
110/// [`StreamExt::take_until_if`]. All `Tripwire` clones are associated with a single [`Trigger`],
111/// which is then used to signal that all the associated streams should be terminated.
112///
113/// The `Tripwire` future resolves to `true` if the stream should be considered closed, and `false`
114/// if the `Trigger` has been disabled.
115///
116/// `Tripwire` is internally implemented using a `Shared<oneshot::Receiver<()>>`, with the
117/// `Trigger` holding the associated `oneshot::Sender`. There is very little magic.
118#[pin_project]
119pub struct Tripwire {
120    watch: watch::Receiver<bool>,
121
122    // TODO: existential type
123    #[pin]
124    fut: Option<Pin<Box<dyn Future<Output = bool> + Send + Sync>>>,
125}
126
127#[cfg(test)]
128static_assertions::assert_impl_all!(Tripwire: Sync, Send);
129
130impl Clone for Tripwire {
131    fn clone(&self) -> Self {
132        Self {
133            watch: self.watch.clone(),
134            fut: None,
135        }
136    }
137}
138
139impl fmt::Debug for Tripwire {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        f.debug_tuple("Tripwire").field(&self.watch).finish()
142    }
143}
144
145impl Tripwire {
146    /// Make a new `Tripwire` and an associated [`Trigger`].
147    pub fn new() -> (Trigger, Self) {
148        let (tx, rx) = watch::channel(false);
149        (
150            Trigger(Some(tx)),
151            Tripwire {
152                watch: rx,
153                fut: None,
154            },
155        )
156    }
157}
158
159impl Future for Tripwire {
160    type Output = bool;
161    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162        let mut this = self.project();
163        if this.fut.is_none() {
164            let mut watch = this.watch.clone();
165            this.fut.set(Some(Box::pin(async move {
166                while !*watch.borrow() {
167                    // value is currently false; wait for it to change
168                    if let Err(_) = watch.changed().await {
169                        // channel was closed -- we return whatever the latest value was
170                        return *watch.borrow();
171                    }
172                }
173                // value change to true, and we should exit
174                true
175            })));
176        }
177
178        // Safety: we never move the value inside the option.
179        // If the Tripwire is pinned, the Option is pinned, and the future inside is as well.
180        unsafe { this.fut.map_unchecked_mut(|f| f.as_mut().unwrap()) }
181            .as_mut()
182            .poll(cx)
183    }
184}
185
186/// Map any Future<Output = Result<T, E>> to a Future<Output = bool>
187///
188/// The output is `true` if the `Result` was `Ok`, and `false` otherwise.
189#[pin_project]
190struct ResultTrueFalse<F>(#[pin] F);
191
192impl<F, T, E> Future for ResultTrueFalse<F>
193where
194    F: Future<Output = Result<T, E>>,
195{
196    type Output = bool;
197    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
198        ready!(self.project().0.poll(cx)).is_ok().into()
199    }
200}