stream_cancel/combinator.rs
1use crate::Trigger;
2use futures_core::{ready, stream::Stream};
3use pin_project::pin_project;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::sync::watch;
9
10/// A stream combinator which takes elements from a stream until a future resolves.
11///
12/// This structure is produced by the [`StreamExt::take_until_if`] method.
13#[pin_project]
14#[derive(Clone, Debug)]
15pub struct TakeUntilIf<S, F> {
16 #[pin]
17 stream: S,
18 #[pin]
19 until: F,
20 free: bool,
21}
22
23impl<S, F> TakeUntilIf<S, F> {
24 /// Consumes this combinator, returning the underlying stream.
25 pub fn into_inner(self) -> S {
26 self.stream
27 }
28}
29
30/// This `Stream` extension trait provides a `take_until_if` method that terminates the stream once
31/// the given future resolves.
32pub trait StreamExt: Stream {
33 /// Take elements from this stream until the given future resolves.
34 ///
35 /// This function takes elements from this stream until the given future resolves with
36 /// `true`. Once it resolves, the stream yields `None`, and produces no further elements.
37 ///
38 /// If the future resolves with `false`, the stream is allowed to continue indefinitely.
39 ///
40 /// This method is essentially a wrapper around `futures_util::stream::StreamExt::take_until`
41 /// that ascribes particular semantics to the output of the provided future.
42 ///
43 /// ```
44 /// use stream_cancel::StreamExt;
45 /// use futures::prelude::*;
46 /// use tokio_stream::wrappers::TcpListenerStream;
47 ///
48 /// #[tokio::main]
49 /// async fn main() {
50 /// let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
51 /// let (tx, rx) = tokio::sync::oneshot::channel();
52 ///
53 /// tokio::spawn(async move {
54 /// let mut incoming = TcpListenerStream::new(listener).take_until_if(rx.map(|_| true));
55 /// while let Some(mut s) = incoming.next().await.transpose().unwrap() {
56 /// tokio::spawn(async move {
57 /// let (mut r, mut w) = s.split();
58 /// println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
59 /// });
60 /// }
61 /// });
62 ///
63 /// // tell the listener to stop accepting new connections
64 /// tx.send(()).unwrap();
65 /// // the spawned async block will terminate cleanly, allowing main to return
66 /// }
67 /// ```
68 fn take_until_if<U>(self, until: U) -> TakeUntilIf<Self, U>
69 where
70 U: Future<Output = bool>,
71 Self: Sized,
72 {
73 TakeUntilIf {
74 stream: self,
75 until,
76 free: false,
77 }
78 }
79}
80
81impl<S> StreamExt for S where S: Stream {}
82
83impl<S, F> Stream for TakeUntilIf<S, F>
84where
85 S: Stream,
86 F: Future<Output = bool>,
87{
88 type Item = S::Item;
89
90 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 let this = self.project();
92 if !*this.free {
93 if let Poll::Ready(terminate) = this.until.poll(cx) {
94 if terminate {
95 // future resolved -- terminate stream
96 return Poll::Ready(None);
97 }
98 // to provide a mechanism for the creator to let the stream run forever,
99 // we interpret this as "run forever".
100 *this.free = true;
101 }
102 }
103
104 this.stream.poll_next(cx)
105 }
106}
107
108/// A `Tripwire` is a convenient mechanism for implementing graceful shutdown over many
109/// asynchronous streams. A `Tripwire` is a `Future` that is `Clone`, and that can be passed to
110/// [`StreamExt::take_until_if`]. All `Tripwire` clones are associated with a single [`Trigger`],
111/// which is then used to signal that all the associated streams should be terminated.
112///
113/// The `Tripwire` future resolves to `true` if the stream should be considered closed, and `false`
114/// if the `Trigger` has been disabled.
115///
116/// `Tripwire` is internally implemented using a `Shared<oneshot::Receiver<()>>`, with the
117/// `Trigger` holding the associated `oneshot::Sender`. There is very little magic.
118#[pin_project]
119pub struct Tripwire {
120 watch: watch::Receiver<bool>,
121
122 // TODO: existential type
123 #[pin]
124 fut: Option<Pin<Box<dyn Future<Output = bool> + Send + Sync>>>,
125}
126
127#[cfg(test)]
128static_assertions::assert_impl_all!(Tripwire: Sync, Send);
129
130impl Clone for Tripwire {
131 fn clone(&self) -> Self {
132 Self {
133 watch: self.watch.clone(),
134 fut: None,
135 }
136 }
137}
138
139impl fmt::Debug for Tripwire {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 f.debug_tuple("Tripwire").field(&self.watch).finish()
142 }
143}
144
145impl Tripwire {
146 /// Make a new `Tripwire` and an associated [`Trigger`].
147 pub fn new() -> (Trigger, Self) {
148 let (tx, rx) = watch::channel(false);
149 (
150 Trigger(Some(tx)),
151 Tripwire {
152 watch: rx,
153 fut: None,
154 },
155 )
156 }
157}
158
159impl Future for Tripwire {
160 type Output = bool;
161 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162 let mut this = self.project();
163 if this.fut.is_none() {
164 let mut watch = this.watch.clone();
165 this.fut.set(Some(Box::pin(async move {
166 while !*watch.borrow() {
167 // value is currently false; wait for it to change
168 if let Err(_) = watch.changed().await {
169 // channel was closed -- we return whatever the latest value was
170 return *watch.borrow();
171 }
172 }
173 // value change to true, and we should exit
174 true
175 })));
176 }
177
178 // Safety: we never move the value inside the option.
179 // If the Tripwire is pinned, the Option is pinned, and the future inside is as well.
180 unsafe { this.fut.map_unchecked_mut(|f| f.as_mut().unwrap()) }
181 .as_mut()
182 .poll(cx)
183 }
184}
185
186/// Map any Future<Output = Result<T, E>> to a Future<Output = bool>
187///
188/// The output is `true` if the `Result` was `Ok`, and `false` otherwise.
189#[pin_project]
190struct ResultTrueFalse<F>(#[pin] F);
191
192impl<F, T, E> Future for ResultTrueFalse<F>
193where
194 F: Future<Output = Result<T, E>>,
195{
196 type Output = bool;
197 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
198 ready!(self.project().0.poll(cx)).is_ok().into()
199 }
200}