stream_cancel/lib.rs
1//! This crate provides multiple mechanisms for interrupting a `Stream`.
2//!
3//! # Stream combinator
4//!
5//! The extension trait [`StreamExt`] provides a single new `Stream` combinator: `take_until_if`.
6//! [`StreamExt::take_until_if`] continues yielding elements from the underlying `Stream` until a
7//! `Future` resolves, and at that moment immediately yields `None` and stops producing further
8//! elements.
9//!
10//! For convenience, the crate also includes the [`Tripwire`] type, which produces a cloneable
11//! `Future` that can then be passed to `take_until_if`. When a new `Tripwire` is created, an
12//! associated [`Trigger`] is also returned, which interrupts the `Stream` when it is dropped.
13//!
14//!
15//! ```
16//! use stream_cancel::{StreamExt, Tripwire};
17//! use futures::prelude::*;
18//! use tokio_stream::wrappers::TcpListenerStream;
19//!
20//! #[tokio::main]
21//! async fn main() {
22//! let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
23//! let (trigger, tripwire) = Tripwire::new();
24//!
25//! tokio::spawn(async move {
26//! let mut incoming = TcpListenerStream::new(listener).take_until_if(tripwire);
27//! while let Some(mut s) = incoming.next().await.transpose().unwrap() {
28//! tokio::spawn(async move {
29//! let (mut r, mut w) = s.split();
30//! println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
31//! });
32//! }
33//! });
34//!
35//! // tell the listener to stop accepting new connections
36//! drop(trigger);
37//! // the spawned async block will terminate cleanly, allowing main to return
38//! }
39//! ```
40//!
41//! # Stream wrapper
42//!
43//! Any stream can be wrapped in a [`Valved`], which enables it to be remotely terminated through
44//! an associated [`Trigger`]. This can be useful to implement graceful shutdown on "infinite"
45//! streams like a `TcpListener`. Once [`Trigger::cancel`] is called on the handle for a given
46//! stream's [`Valved`], the stream will yield `None` to indicate that it has terminated.
47//!
48//! ```
49//! use stream_cancel::Valved;
50//! use futures::prelude::*;
51//! use tokio_stream::wrappers::TcpListenerStream;
52//! use std::thread;
53//!
54//! #[tokio::main]
55//! async fn main() {
56//! let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
57//! let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
58//!
59//! tokio::spawn(async move {
60//! let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
61//! exit_tx.send(exit).unwrap();
62//! while let Some(mut s) = incoming.next().await.transpose().unwrap() {
63//! tokio::spawn(async move {
64//! let (mut r, mut w) = s.split();
65//! println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
66//! });
67//! }
68//! });
69//!
70//! let exit = exit_rx.await;
71//!
72//! // the server thread will normally never exit, since more connections
73//! // can always arrive. however, with a Valved, we can turn off the
74//! // stream of incoming connections to initiate a graceful shutdown
75//! drop(exit);
76//! }
77//! ```
78//!
79//! You can share the same [`Trigger`] between multiple streams by first creating a [`Valve`],
80//! and then wrapping multiple streams using [`Valve::Wrap`]:
81//!
82//! ```
83//! use stream_cancel::Valve;
84//! use futures::prelude::*;
85//! use tokio_stream::wrappers::TcpListenerStream;
86//!
87//! #[tokio::main]
88//! async fn main() {
89//! let (exit, valve) = Valve::new();
90//! let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
91//! let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
92//!
93//! tokio::spawn(async move {
94//! let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
95//! let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
96//!
97//! use futures_util::stream::select;
98//! let mut incoming = select(incoming1, incoming2);
99//! while let Some(mut s) = incoming.next().await.transpose().unwrap() {
100//! tokio::spawn(async move {
101//! let (mut r, mut w) = s.split();
102//! println!("copied {} bytes", tokio::io::copy(&mut r, &mut w).await.unwrap());
103//! });
104//! }
105//! });
106//!
107//! // the runtime will not become idle until both incoming1 and incoming2 have stopped
108//! // (due to the select). this checks that they are indeed both interrupted when the
109//! // valve is closed.
110//! drop(exit);
111//! }
112//! ```
113
114#![deny(missing_docs)]
115#![warn(rust_2018_idioms)]
116
117use tokio::sync::watch;
118
119mod combinator;
120mod wrapper;
121
122pub use crate::combinator::{StreamExt, TakeUntilIf, Tripwire};
123pub use crate::wrapper::{Valve, Valved};
124
125/// A handle to a set of cancellable streams.
126///
127/// If the `Trigger` is dropped, any streams associated with it are interrupted (this is equivalent
128/// to calling [`Trigger::cancel`]. To override this behavior, call [`Trigger::disable`].
129#[derive(Debug)]
130pub struct Trigger(Option<watch::Sender<bool>>);
131
132impl Trigger {
133 /// Cancel all associated streams, and make them immediately yield `None`.
134 pub fn cancel(self) {
135 drop(self);
136 }
137
138 /// Disable the `Trigger`, and leave all associated streams running to completion.
139 pub fn disable(mut self) {
140 let _ = self.0.take();
141 drop(self);
142 }
143}
144
145impl Drop for Trigger {
146 fn drop(&mut self) {
147 if let Some(tx) = self.0.take() {
148 // Send may fail when all associated rx'es are dropped already
149 // so code here cannot panic on error
150 let _ = tx.send(true);
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use futures::prelude::*;
159 use futures_util::stream::select;
160 use tokio::io::{AsyncReadExt, AsyncWriteExt};
161 use tokio_stream::wrappers::TcpListenerStream;
162
163 #[test]
164 fn tokio_run() {
165 use std::thread;
166
167 let rt = tokio::runtime::Runtime::new().unwrap();
168 let listener = rt
169 .block_on(tokio::net::TcpListener::bind("127.0.0.1:0"))
170 .unwrap();
171 let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
172 let server = thread::spawn(move || {
173 let (tx, rx) = tokio::sync::oneshot::channel();
174
175 // start a tokio echo server
176 rt.block_on(async move {
177 let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
178 exit_tx.send(exit).unwrap();
179 while let Some(mut s) = incoming.next().await.transpose().unwrap() {
180 tokio::spawn(async move {
181 let (mut r, mut w) = s.split();
182 tokio::io::copy(&mut r, &mut w).await.unwrap();
183 });
184 }
185 tx.send(()).unwrap();
186 });
187 let _ = rt.block_on(rx).unwrap();
188 });
189
190 let exit = futures::executor::block_on(exit_rx);
191
192 // the server thread will normally never exit, since more connections
193 // can always arrive. however, with a Valved, we can turn off the
194 // stream of incoming connections to initiate a graceful shutdown
195 drop(exit);
196 server.join().unwrap();
197 }
198
199 #[tokio::test]
200 async fn tokio_rt_on_idle() {
201 let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
202
203 tokio::spawn(async move {
204 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
205 let (exit, mut incoming) = Valved::new(TcpListenerStream::new(listener));
206 exit_tx.send(exit).unwrap();
207 while let Some(mut s) = incoming.next().await.transpose().unwrap() {
208 tokio::spawn(async move {
209 let (mut r, mut w) = s.split();
210 tokio::io::copy(&mut r, &mut w).await.unwrap();
211 });
212 }
213 });
214
215 let exit = exit_rx.await;
216 drop(exit);
217 }
218
219 #[tokio::test]
220 async fn multi_interrupt() {
221 let (exit, valve) = Valve::new();
222 tokio::spawn(async move {
223 let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
224 let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
225 let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
226 let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
227
228 let mut incoming = select(incoming1, incoming2);
229 while let Some(mut s) = incoming.next().await.transpose().unwrap() {
230 tokio::spawn(async move {
231 let (mut r, mut w) = s.split();
232 tokio::io::copy(&mut r, &mut w).await.unwrap();
233 });
234 }
235 });
236
237 // the runtime will not become idle until both incoming1 and incoming2 have stopped (due to
238 // the select). this checks that they are indeed both interrupted when the valve is closed.
239 drop(exit);
240 }
241
242 #[tokio::test]
243 async fn yields_many() {
244 use std::sync::{
245 atomic::{AtomicUsize, Ordering},
246 Arc,
247 };
248
249 let (exit, valve) = Valve::new();
250 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
251 let addr = listener.local_addr().unwrap();
252
253 let reqs = Arc::new(AtomicUsize::new(0));
254 let got = reqs.clone();
255 tokio::spawn(async move {
256 let mut incoming = valve.wrap(TcpListenerStream::new(listener));
257 while let Some(mut s) = incoming.next().await.transpose().unwrap() {
258 reqs.fetch_add(1, Ordering::SeqCst);
259 tokio::spawn(async move {
260 let (mut r, mut w) = s.split();
261 tokio::io::copy(&mut r, &mut w).await.unwrap();
262 });
263 }
264 });
265
266 let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
267 s.write_all(b"hello").await.unwrap();
268 let mut buf = [0; 5];
269 s.read_exact(&mut buf[..]).await.unwrap();
270 assert_eq!(&buf, b"hello");
271 drop(s);
272
273 let mut s = tokio::net::TcpStream::connect(&addr).await.unwrap();
274 s.write_all(b"world").await.unwrap();
275 let mut buf = [0; 5];
276 s.read_exact(&mut buf[..]).await.unwrap();
277 assert_eq!(&buf, b"world");
278 drop(s);
279
280 assert_eq!(got.load(Ordering::SeqCst), 2);
281
282 drop(exit);
283 }
284
285 #[tokio::test]
286 async fn yields_some() {
287 use std::sync::{
288 atomic::{AtomicUsize, Ordering},
289 Arc,
290 };
291
292 let (exit, valve) = Valve::new();
293 let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
294 let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
295 let addr1 = listener1.local_addr().unwrap();
296 let addr2 = listener2.local_addr().unwrap();
297
298 let reqs = Arc::new(AtomicUsize::new(0));
299 let got = reqs.clone();
300
301 tokio::spawn(async move {
302 let incoming1 = valve.wrap(TcpListenerStream::new(listener1));
303 let incoming2 = valve.wrap(TcpListenerStream::new(listener2));
304 let mut incoming = select(incoming1, incoming2);
305 while let Some(mut s) = incoming.next().await.transpose().unwrap() {
306 reqs.fetch_add(1, Ordering::SeqCst);
307 tokio::spawn(async move {
308 let (mut r, mut w) = s.split();
309 tokio::io::copy(&mut r, &mut w).await.unwrap();
310 });
311 }
312 });
313
314 let mut s = tokio::net::TcpStream::connect(&addr1).await.unwrap();
315 s.write_all(b"hello").await.unwrap();
316 let mut buf = [0; 5];
317 s.read_exact(&mut buf[..]).await.unwrap();
318 assert_eq!(&buf, b"hello");
319 drop(s);
320
321 let mut s = tokio::net::TcpStream::connect(&addr2).await.unwrap();
322 s.write_all(b"world").await.unwrap();
323 let mut buf = [0; 5];
324 s.read_exact(&mut buf[..]).await.unwrap();
325 assert_eq!(&buf, b"world");
326 drop(s);
327
328 assert_eq!(got.load(Ordering::SeqCst), 2);
329
330 drop(exit);
331 }
332}