1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
use crate::Trigger; use futures_core::{ready, stream::Stream}; use pin_project::pin_project; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::sync::watch; /// A stream combinator which takes elements from a stream until a future resolves. /// /// This structure is produced by the [`StreamExt::take_until`] method. #[pin_project] #[derive(Clone, Debug)] pub struct TakeUntil<S, F> { #[pin] stream: S, #[pin] until: F, free: bool, } /// This `Stream` extension trait provides a `take_until` method that terminates the stream once /// the given future resolves. pub trait StreamExt: Stream { /// Take elements from this stream until the given future resolves. /// /// This function will take elements from this stream until the given future resolves with /// `true`. Once it resolves, the stream will yield `None`, and produce no further elements. /// /// If the future resolves with `false`, the stream will be allowed to continue indefinitely. /// /// ``` /// use stream_cancel::StreamExt; /// use futures::prelude::*; /// use tokio::prelude::*; /// /// #[tokio::main] /// async fn main() { /// let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); /// let (tx, rx) = tokio::sync::oneshot::channel(); /// /// tokio::spawn(async move { /// let mut incoming = listener.incoming().take_until(rx.map(|_| true)); /// while let Some(mut s) = incoming.next().await.transpose().unwrap() { /// tokio::spawn(async move { /// let (mut r, mut w) = s.split(); /// println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap()); /// }); /// } /// }); /// /// // tell the listener to stop accepting new connections /// tx.send(()).unwrap(); /// // the spawned async block will terminate cleanly, allowing main to return /// } /// ``` fn take_until<U>(self, until: U) -> TakeUntil<Self, U> where U: Future<Output = bool>, Self: Sized, { TakeUntil { stream: self, until, free: false, } } } impl<S> StreamExt for S where S: Stream {} impl<S, F> Stream for TakeUntil<S, F> where S: Stream, F: Future<Output = bool>, { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let this = self.project(); if !*this.free { if let Poll::Ready(terminate) = this.until.poll(cx) { if terminate { // future resolved -- terminate stream return Poll::Ready(None); } // to provide a mechanism for the creator to let the stream run forever, // we interpret this as "run forever". *this.free = true; } } this.stream.poll_next(cx) } } /// A `Tripwire` is a convenient mechanism for implementing graceful shutdown over many /// asynchronous streams. A `Tripwire` is a `Future` that is `Clone`, and that can be passed to /// [`StreamExt::take_until`]. All `Tripwire` clones are associated with a single [`Trigger`], /// which is then used to signal that all the associated streams should be terminated. /// /// The `Tripwire` future resolves to `true` if the stream should be considered closed, and `false` /// if the `Trigger` has been disabled. /// /// `Tripwire` is internally implemented using a `Shared<oneshot::Receiver<()>>`, with the /// `Trigger` holding the associated `oneshot::Sender`. There is very little magic. #[pin_project] #[derive(Clone, Debug)] pub struct Tripwire(#[pin] watch::Receiver<bool>); impl Tripwire { /// Make a new `Tripwire` and an associated [`Trigger`]. pub fn new() -> (Trigger, Self) { let (tx, rx) = watch::channel(false); (Trigger(Some(tx)), Tripwire(rx)) } } impl Future for Tripwire { type Output = bool; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let mut this = self.project().0; loop { match this.as_mut().poll_recv_ref(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => { // channel was closed -- we return whatever the latest value was } Poll::Ready(Some(v)) if *v => { // value change to true, and we should exit return Poll::Ready(true); } Poll::Ready(Some(_)) => { // value is currently false, we need to poll again continue; } } return Poll::Ready(*this.borrow()); } } } /// Map any Future<Output = Result<T, E>> to a Future<Output = bool> /// /// The output is `true` if the `Result` was `Ok`, and `false` otherwise. #[pin_project] struct ResultTrueFalse<F>(#[pin] F); impl<F, T, E> Future for ResultTrueFalse<F> where F: Future<Output = Result<T, E>>, { type Output = bool; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { ready!(self.project().0.poll(cx)).is_ok().into() } }