hyper_util/server/
graceful.rs

1//! Utility to gracefully shutdown a server.
2//!
3//! This module provides a [`GracefulShutdown`] type,
4//! which can be used to gracefully shutdown a server.
5//!
6//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs>
7//! for an example of how to use this.
8
9use std::{
10    fmt::{self, Debug},
11    future::Future,
12    pin::Pin,
13    task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19/// A graceful shutdown utility
20pub struct GracefulShutdown {
21    tx: watch::Sender<()>,
22}
23
24impl GracefulShutdown {
25    /// Create a new graceful shutdown helper.
26    pub fn new() -> Self {
27        let (tx, _) = watch::channel(());
28        Self { tx }
29    }
30
31    /// Wrap a future for graceful shutdown watching.
32    pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
33        let mut rx = self.tx.subscribe();
34        GracefulConnectionFuture::new(conn, async move {
35            let _ = rx.changed().await;
36            // hold onto the rx until the watched future is completed
37            rx
38        })
39    }
40
41    /// Signal shutdown for all watched connections.
42    ///
43    /// This returns a `Future` which will complete once all watched
44    /// connections have shutdown.
45    pub async fn shutdown(self) {
46        let Self { tx } = self;
47
48        // signal all the watched futures about the change
49        let _ = tx.send(());
50        // and then wait for all of them to complete
51        tx.closed().await;
52    }
53}
54
55impl Debug for GracefulShutdown {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("GracefulShutdown").finish()
58    }
59}
60
61impl Default for GracefulShutdown {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67pin_project! {
68    struct GracefulConnectionFuture<C, F: Future> {
69        #[pin]
70        conn: C,
71        #[pin]
72        cancel: F,
73        #[pin]
74        // If cancelled, this is held until the inner conn is done.
75        cancelled_guard: Option<F::Output>,
76    }
77}
78
79impl<C, F: Future> GracefulConnectionFuture<C, F> {
80    fn new(conn: C, cancel: F) -> Self {
81        Self {
82            conn,
83            cancel,
84            cancelled_guard: None,
85        }
86    }
87}
88
89impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("GracefulConnectionFuture").finish()
92    }
93}
94
95impl<C, F> Future for GracefulConnectionFuture<C, F>
96where
97    C: GracefulConnection,
98    F: Future,
99{
100    type Output = C::Output;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
103        let mut this = self.project();
104        if this.cancelled_guard.is_none() {
105            if let Poll::Ready(guard) = this.cancel.poll(cx) {
106                this.cancelled_guard.set(Some(guard));
107                this.conn.as_mut().graceful_shutdown();
108            }
109        }
110        this.conn.poll(cx)
111    }
112}
113
114/// An internal utility trait as an umbrella target for all (hyper) connection
115/// types that the [`GracefulShutdown`] can watch.
116pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
117    /// The error type returned by the connection when used as a future.
118    type Error;
119
120    /// Start a graceful shutdown process for this connection.
121    fn graceful_shutdown(self: Pin<&mut Self>);
122}
123
124#[cfg(feature = "http1")]
125impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
126where
127    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
128    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
129    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
130    B: hyper::body::Body + 'static,
131    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
132{
133    type Error = hyper::Error;
134
135    fn graceful_shutdown(self: Pin<&mut Self>) {
136        hyper::server::conn::http1::Connection::graceful_shutdown(self);
137    }
138}
139
140#[cfg(feature = "http2")]
141impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
142where
143    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
144    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
145    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
146    B: hyper::body::Body + 'static,
147    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
148    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
149{
150    type Error = hyper::Error;
151
152    fn graceful_shutdown(self: Pin<&mut Self>) {
153        hyper::server::conn::http2::Connection::graceful_shutdown(self);
154    }
155}
156
157#[cfg(feature = "server-auto")]
158impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
159where
160    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
161    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
162    S::Future: 'static,
163    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
164    B: hyper::body::Body + 'static,
165    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
166    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
167{
168    type Error = Box<dyn std::error::Error + Send + Sync>;
169
170    fn graceful_shutdown(self: Pin<&mut Self>) {
171        crate::server::conn::auto::Connection::graceful_shutdown(self);
172    }
173}
174
175#[cfg(feature = "server-auto")]
176impl<I, B, S, E> GracefulConnection
177    for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
178where
179    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
180    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181    S::Future: 'static,
182    I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
183    B: hyper::body::Body + 'static,
184    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
185    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
186{
187    type Error = Box<dyn std::error::Error + Send + Sync>;
188
189    fn graceful_shutdown(self: Pin<&mut Self>) {
190        crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
191    }
192}
193
194mod private {
195    pub trait Sealed {}
196
197    #[cfg(feature = "http1")]
198    impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
199    where
200        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
201        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
202        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
203        B: hyper::body::Body + 'static,
204        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
205    {
206    }
207
208    #[cfg(feature = "http1")]
209    impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
210    where
211        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
212        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
213        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
214        B: hyper::body::Body + 'static,
215        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
216    {
217    }
218
219    #[cfg(feature = "http2")]
220    impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
221    where
222        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
223        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
224        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
225        B: hyper::body::Body + 'static,
226        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
227        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
228    {
229    }
230
231    #[cfg(feature = "server-auto")]
232    impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
233    where
234        S: hyper::service::Service<
235            http::Request<hyper::body::Incoming>,
236            Response = http::Response<B>,
237        >,
238        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
239        S::Future: 'static,
240        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
241        B: hyper::body::Body + 'static,
242        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
244    {
245    }
246
247    #[cfg(feature = "server-auto")]
248    impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
249    where
250        S: hyper::service::Service<
251            http::Request<hyper::body::Incoming>,
252            Response = http::Response<B>,
253        >,
254        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
255        S::Future: 'static,
256        I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
257        B: hyper::body::Body + 'static,
258        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
259        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
260    {
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267    use pin_project_lite::pin_project;
268    use std::sync::atomic::{AtomicUsize, Ordering};
269    use std::sync::Arc;
270
271    pin_project! {
272        #[derive(Debug)]
273        struct DummyConnection<F> {
274            #[pin]
275            future: F,
276            shutdown_counter: Arc<AtomicUsize>,
277        }
278    }
279
280    impl<F> private::Sealed for DummyConnection<F> {}
281
282    impl<F: Future> GracefulConnection for DummyConnection<F> {
283        type Error = ();
284
285        fn graceful_shutdown(self: Pin<&mut Self>) {
286            self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
287        }
288    }
289
290    impl<F: Future> Future for DummyConnection<F> {
291        type Output = Result<(), ()>;
292
293        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
294            match self.project().future.poll(cx) {
295                Poll::Ready(_) => Poll::Ready(Ok(())),
296                Poll::Pending => Poll::Pending,
297            }
298        }
299    }
300
301    #[cfg(not(miri))]
302    #[tokio::test]
303    async fn test_graceful_shutdown_ok() {
304        let graceful = GracefulShutdown::new();
305        let shutdown_counter = Arc::new(AtomicUsize::new(0));
306        let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
307
308        for i in 1..=3 {
309            let mut dummy_rx = dummy_tx.subscribe();
310            let shutdown_counter = shutdown_counter.clone();
311
312            let future = async move {
313                tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
314                let _ = dummy_rx.recv().await;
315            };
316            let dummy_conn = DummyConnection {
317                future,
318                shutdown_counter,
319            };
320            let conn = graceful.watch(dummy_conn);
321            tokio::spawn(async move {
322                conn.await.unwrap();
323            });
324        }
325
326        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
327        let _ = dummy_tx.send(());
328
329        tokio::select! {
330            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
331                panic!("timeout")
332            },
333            _ = graceful.shutdown() => {
334                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
335            }
336        }
337    }
338
339    #[cfg(not(miri))]
340    #[tokio::test]
341    async fn test_graceful_shutdown_delayed_ok() {
342        let graceful = GracefulShutdown::new();
343        let shutdown_counter = Arc::new(AtomicUsize::new(0));
344
345        for i in 1..=3 {
346            let shutdown_counter = shutdown_counter.clone();
347
348            //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await;
349            let future = async move {
350                tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
351            };
352            let dummy_conn = DummyConnection {
353                future,
354                shutdown_counter,
355            };
356            let conn = graceful.watch(dummy_conn);
357            tokio::spawn(async move {
358                conn.await.unwrap();
359            });
360        }
361
362        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
363
364        tokio::select! {
365            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
366                panic!("timeout")
367            },
368            _ = graceful.shutdown() => {
369                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
370            }
371        }
372    }
373
374    #[cfg(not(miri))]
375    #[tokio::test]
376    async fn test_graceful_shutdown_multi_per_watcher_ok() {
377        let graceful = GracefulShutdown::new();
378        let shutdown_counter = Arc::new(AtomicUsize::new(0));
379
380        for i in 1..=3 {
381            let shutdown_counter = shutdown_counter.clone();
382
383            let mut futures = Vec::new();
384            for u in 1..=i {
385                let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
386                let dummy_conn = DummyConnection {
387                    future,
388                    shutdown_counter: shutdown_counter.clone(),
389                };
390                let conn = graceful.watch(dummy_conn);
391                futures.push(conn);
392            }
393            tokio::spawn(async move {
394                futures_util::future::join_all(futures).await;
395            });
396        }
397
398        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
399
400        tokio::select! {
401            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
402                panic!("timeout")
403            },
404            _ = graceful.shutdown() => {
405                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
406            }
407        }
408    }
409
410    #[cfg(not(miri))]
411    #[tokio::test]
412    async fn test_graceful_shutdown_timeout() {
413        let graceful = GracefulShutdown::new();
414        let shutdown_counter = Arc::new(AtomicUsize::new(0));
415
416        for i in 1..=3 {
417            let shutdown_counter = shutdown_counter.clone();
418
419            let future = async move {
420                if i == 1 {
421                    std::future::pending::<()>().await
422                } else {
423                    std::future::ready(()).await
424                }
425            };
426            let dummy_conn = DummyConnection {
427                future,
428                shutdown_counter,
429            };
430            let conn = graceful.watch(dummy_conn);
431            tokio::spawn(async move {
432                conn.await.unwrap();
433            });
434        }
435
436        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
437
438        tokio::select! {
439            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
440                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
441            },
442            _ = graceful.shutdown() => {
443                panic!("shutdown should not be completed: as not all our conns finish")
444            }
445        }
446    }
447}