stream_cancel/
lib.rs

1//! This crate provides multiple mechanisms for interrupting a `Stream`.
2//!
3//! # Stream combinator
4//!
5//! The extension trait [`StreamExt`] provides a single new `Stream` combinator: `take_until_if`.
6//! [`StreamExt::take_until_if`] continues yielding elements from the underlying `Stream` until a
7//! `Future` resolves, and at that moment immediately yields `None` and stops producing further
8//! elements.
9//!
10//! For convenience, the crate also includes the [`Tripwire`] type, which produces a cloneable
11//! `Future` that can then be passed to `take_until_if`. When a new `Tripwire` is created, an
12//! associated [`Trigger`] is also returned, which interrupts the `Stream` when it is dropped.
13//!
14//!
15//! ```
16//! use stream_cancel::{StreamExt, Tripwire};
17//! use futures::prelude::*;
18//! use tokio_stream::wrappers::TcpListenerStream;
19//!
20//! #[tokio::main]
21//! async fn main() {
22//!     let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
23//!     let (trigger, tripwire) = Tripwire::new();
24//!
25//!     tokio::spawn(async move {
26//!         let mut incoming = TcpListenerStream::new(listener).take_until_if(tripwire);
27//!         while let Some(mut s) = incoming.next().await.transpose().unwrap() {
28//!             tokio::spawn(async move {
29//!                 let (mut r, mut w) = s.split();
30//!                 println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
31//!             });
32//!         }
33//!     });
34//!
35//!     // tell the listener to stop accepting new connections
36//!     drop(trigger);
37//!     // the spawned async block will terminate cleanly, allowing main to return
38//! }
39//! ```
40//!
41//! # Stream wrapper
42//!
43//! Any stream can be wrapped in a [`Valved`], which enables it to be remotely terminated through
44//! an associated [`Trigger`]. This can be useful to implement graceful shutdown on "infinite"
45//! streams like a `TcpListener`. Once [`Trigger::cancel`] is called on the handle for a given
46//! stream's [`Valved`], the stream will yield `None` to indicate that it has terminated.
47//!
48//! ```
49//! use stream_cancel::Valved;
50//! use futures::prelude::*;
51//! use tokio_stream::wrappers::TcpListenerStream;
52//! use std::thread;
53//!
54//! #[tokio::main]
55//! async fn main() {
56//!     let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
57//!     let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
58//!
59//!     tokio::spawn(async move {
60//!         let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
61//!         exit_tx.send(exit).unwrap();
62//!         while let Some(mut s) = incoming.next().await.transpose().unwrap() {
63//!             tokio::spawn(async move {
64//!                 let (mut r, mut w) = s.split();
65//!                 println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
66//!             });
67//!         }
68//!     });
69//!
70//!     let exit = exit_rx.await;
71//!
72//!     // the server thread will normally never exit, since more connections
73//!     // can always arrive. however, with a Valved, we can turn off the
74//!     // stream of incoming connections to initiate a graceful shutdown
75//!     drop(exit);
76//! }
77//! ```
78//!
79//! You can share the same [`Trigger`] between multiple streams by first creating a [`Valve`],
80//! and then wrapping multiple streams using [`Valve::Wrap`]:
81//!
82//! ```
83//! use stream_cancel::Valve;
84//! use futures::prelude::*;
85//! use tokio_stream::wrappers::TcpListenerStream;
86//!
87//! #[tokio::main]
88//! async fn main() {
89//!     let (exit, valve) = Valve::new();
90//!     let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
91//!     let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
92//!
93//!     tokio::spawn(async move {
94//!         let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
95//!         let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
96//!
97//!         use futures_util::stream::select;
98//!         let mut incoming = select(incoming1, incoming2);
99//!         while let Some(mut s) = incoming.next().await.transpose().unwrap() {
100//!             tokio::spawn(async move {
101//!                 let (mut r, mut w) = s.split();
102//!                 println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
103//!             });
104//!         }
105//!     });
106//!
107//!     // the runtime will not become idle until both incoming1 and incoming2 have stopped
108//!     // (due to the select). this checks that they are indeed both interrupted when the
109//!     // valve is closed.
110//!     drop(exit);
111//! }
112//! ```
113
114#![deny(missing_docs)]
115#![warn(rust_2018_idioms)]
116
117use tokio::sync::watch;
118
119mod combinator;
120mod wrapper;
121
122pub use crate::combinator::{StreamExt, TakeUntilIf, Tripwire};
123pub use crate::wrapper::{Valve, Valved};
124
125/// A handle to a set of cancellable streams.
126///
127/// If the `Trigger` is dropped, any streams associated with it are interrupted (this is equivalent
128/// to calling [`Trigger::cancel`]. To override this behavior, call [`Trigger::disable`].
129#[derive(Debug)]
130pub struct Trigger(Option<watch::Sender<bool>>);
131
132impl Trigger {
133    /// Cancel all associated streams, and make them immediately yield `None`.
134    pub fn cancel(self) {
135        drop(self);
136    }
137
138    /// Disable the `Trigger`, and leave all associated streams running to completion.
139    pub fn disable(mut self) {
140        let _ = self.0.take();
141        drop(self);
142    }
143}
144
145impl Drop for Trigger {
146    fn drop(&mut self) {
147        if let Some(tx) = self.0.take() {
148            // Send may fail when all associated rx'es are dropped already
149            // so code here cannot panic on error
150            let _ = tx.send(true);
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use futures::prelude::*;
159    use futures_util::stream::select;
160    use tokio::io::{AsyncReadExt, AsyncWriteExt};
161    use tokio_stream::wrappers::TcpListenerStream;
162
163    #[test]
164    fn tokio_run() {
165        use std::thread;
166
167        let rt = tokio::runtime::Runtime::new().unwrap();
168        let listener = rt
169            .block_on(tokio::net::TcpListener::bind("127.0.0.1:0"))
170            .unwrap();
171        let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
172        let server = thread::spawn(move || {
173            let (tx, rx) = tokio::sync::oneshot::channel();
174
175            // start a tokio echo server
176            rt.block_on(async move {
177                let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
178                exit_tx.send(exit).unwrap();
179                while let Some(mut s) = incoming.next().await.transpose().unwrap() {
180                    tokio::spawn(async move {
181                        let (mut r, mut w) = s.split();
182                        tokio::io::copy(&mut r, &mut w).await.unwrap();
183                    });
184                }
185                tx.send(()).unwrap();
186            });
187            let _ = rt.block_on(rx).unwrap();
188        });
189
190        let exit = futures::executor::block_on(exit_rx);
191
192        // the server thread will normally never exit, since more connections
193        // can always arrive. however, with a Valved, we can turn off the
194        // stream of incoming connections to initiate a graceful shutdown
195        drop(exit);
196        server.join().unwrap();
197    }
198
199    #[tokio::test]
200    async fn tokio_rt_on_idle() {
201        let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
202
203        tokio::spawn(async move {
204            let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
205            let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
206            exit_tx.send(exit).unwrap();
207            while let Some(mut s) = incoming.next().await.transpose().unwrap() {
208                tokio::spawn(async move {
209                    let (mut r, mut w) = s.split();
210                    tokio::io::copy(&mut r, &mut w).await.unwrap();
211                });
212            }
213        });
214
215        let exit = exit_rx.await;
216        drop(exit);
217    }
218
219    #[tokio::test]
220    async fn multi_interrupt() {
221        let (exit, valve) = Valve::new();
222        tokio::spawn(async move {
223            let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
224            let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
225            let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
226            let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
227
228            let mut incoming = select(incoming1, incoming2);
229            while let Some(mut s) = incoming.next().await.transpose().unwrap() {
230                tokio::spawn(async move {
231                    let (mut r, mut w) = s.split();
232                    tokio::io::copy(&mut r, &mut w).await.unwrap();
233                });
234            }
235        });
236
237        // the runtime will not become idle until both incoming1 and incoming2 have stopped (due to
238        // the select). this checks that they are indeed both interrupted when the valve is closed.
239        drop(exit);
240    }
241
242    #[tokio::test]
243    async fn yields_many() {
244        use std::sync::{
245            atomic::{AtomicUsize, Ordering},
246            Arc,
247        };
248
249        let (exit, valve) = Valve::new();
250        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
251        let addr = listener.local_addr().unwrap();
252
253        let reqs = Arc::new(AtomicUsize::new(0));
254        let got = reqs.clone();
255        tokio::spawn(async move {
256            let mut incoming = valve.wrap(TcpListenerStream::new(listener));
257            while let Some(mut s) = incoming.next().await.transpose().unwrap() {
258                reqs.fetch_add(1, Ordering::SeqCst);
259                tokio::spawn(async move {
260                    let (mut r, mut w) = s.split();
261                    tokio::io::copy(&mut r, &mut w).await.unwrap();
262                });
263            }
264        });
265
266        let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
267        s.write_all(b"hello").await.unwrap();
268        let mut buf = [0; 5];
269        s.read_exact(&mut buf[..]).await.unwrap();
270        assert_eq!(&buf, b"hello");
271        drop(s);
272
273        let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
274        s.write_all(b"world").await.unwrap();
275        let mut buf = [0; 5];
276        s.read_exact(&mut buf[..]).await.unwrap();
277        assert_eq!(&buf, b"world");
278        drop(s);
279
280        assert_eq!(got.load(Ordering::SeqCst), 2);
281
282        drop(exit);
283    }
284
285    #[tokio::test]
286    async fn yields_some() {
287        use std::sync::{
288            atomic::{AtomicUsize, Ordering},
289            Arc,
290        };
291
292        let (exit, valve) = Valve::new();
293        let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
294        let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
295        let addr1 = listener1.local_addr().unwrap();
296        let addr2 = listener2.local_addr().unwrap();
297
298        let reqs = Arc::new(AtomicUsize::new(0));
299        let got = reqs.clone();
300
301        tokio::spawn(async move {
302            let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
303            let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
304            let mut incoming = select(incoming1, incoming2);
305            while let Some(mut s) = incoming.next().await.transpose().unwrap() {
306                reqs.fetch_add(1, Ordering::SeqCst);
307                tokio::spawn(async move {
308                    let (mut r, mut w) = s.split();
309                    tokio::io::copy(&mut r, &mut w).await.unwrap();
310                });
311            }
312        });
313
314        let mut s = tokio::net::TcpStream::connect(&addr1).await.unwrap();
315        s.write_all(b"hello").await.unwrap();
316        let mut buf = [0; 5];
317        s.read_exact(&mut buf[..]).await.unwrap();
318        assert_eq!(&buf, b"hello");
319        drop(s);
320
321        let mut s = tokio::net::TcpStream::connect(&addr2).await.unwrap();
322        s.write_all(b"world").await.unwrap();
323        let mut buf = [0; 5];
324        s.read_exact(&mut buf[..]).await.unwrap();
325        assert_eq!(&buf, b"world");
326        drop(s);
327
328        assert_eq!(got.load(Ordering::SeqCst), 2);
329
330        drop(exit);
331    }
332}