1use std::{
10 fmt::{self, Debug},
11 future::Future,
12 pin::Pin,
13 task::{self, Poll},
14};
15
16use pin_project_lite::pin_project;
17use tokio::sync::watch;
18
19pub struct GracefulShutdown {
21 tx: watch::Sender<()>,
22}
23
24impl GracefulShutdown {
25 pub fn new() -> Self {
27 let (tx, _) = watch::channel(());
28 Self { tx }
29 }
30
31 pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
33 let mut rx = self.tx.subscribe();
34 GracefulConnectionFuture::new(conn, async move {
35 let _ = rx.changed().await;
36 rx
38 })
39 }
40
41 pub async fn shutdown(self) {
46 let Self { tx } = self;
47
48 let _ = tx.send(());
50 tx.closed().await;
52 }
53}
54
55impl Debug for GracefulShutdown {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("GracefulShutdown").finish()
58 }
59}
60
61impl Default for GracefulShutdown {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67pin_project! {
68 struct GracefulConnectionFuture<C, F: Future> {
69 #[pin]
70 conn: C,
71 #[pin]
72 cancel: F,
73 #[pin]
74 cancelled_guard: Option<F::Output>,
76 }
77}
78
79impl<C, F: Future> GracefulConnectionFuture<C, F> {
80 fn new(conn: C, cancel: F) -> Self {
81 Self {
82 conn,
83 cancel,
84 cancelled_guard: None,
85 }
86 }
87}
88
89impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("GracefulConnectionFuture").finish()
92 }
93}
94
95impl<C, F> Future for GracefulConnectionFuture<C, F>
96where
97 C: GracefulConnection,
98 F: Future,
99{
100 type Output = C::Output;
101
102 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
103 let mut this = self.project();
104 if this.cancelled_guard.is_none() {
105 if let Poll::Ready(guard) = this.cancel.poll(cx) {
106 this.cancelled_guard.set(Some(guard));
107 this.conn.as_mut().graceful_shutdown();
108 }
109 }
110 this.conn.poll(cx)
111 }
112}
113
114pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
117 type Error;
119
120 fn graceful_shutdown(self: Pin<&mut Self>);
122}
123
124#[cfg(feature = "http1")]
125impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
126where
127 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
128 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
129 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
130 B: hyper::body::Body + 'static,
131 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
132{
133 type Error = hyper::Error;
134
135 fn graceful_shutdown(self: Pin<&mut Self>) {
136 hyper::server::conn::http1::Connection::graceful_shutdown(self);
137 }
138}
139
140#[cfg(feature = "http2")]
141impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
142where
143 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
144 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
145 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
146 B: hyper::body::Body + 'static,
147 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
148 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
149{
150 type Error = hyper::Error;
151
152 fn graceful_shutdown(self: Pin<&mut Self>) {
153 hyper::server::conn::http2::Connection::graceful_shutdown(self);
154 }
155}
156
157#[cfg(feature = "server-auto")]
158impl<I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E>
159where
160 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
161 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
162 S::Future: 'static,
163 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
164 B: hyper::body::Body + 'static,
165 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
166 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
167{
168 type Error = Box<dyn std::error::Error + Send + Sync>;
169
170 fn graceful_shutdown(self: Pin<&mut Self>) {
171 crate::server::conn::auto::Connection::graceful_shutdown(self);
172 }
173}
174
175#[cfg(feature = "server-auto")]
176impl<I, B, S, E> GracefulConnection
177 for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
178where
179 S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
180 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181 S::Future: 'static,
182 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
183 B: hyper::body::Body + 'static,
184 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
185 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
186{
187 type Error = Box<dyn std::error::Error + Send + Sync>;
188
189 fn graceful_shutdown(self: Pin<&mut Self>) {
190 crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
191 }
192}
193
194mod private {
195 pub trait Sealed {}
196
197 #[cfg(feature = "http1")]
198 impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
199 where
200 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
201 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
202 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
203 B: hyper::body::Body + 'static,
204 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
205 {
206 }
207
208 #[cfg(feature = "http1")]
209 impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
210 where
211 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
212 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
213 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
214 B: hyper::body::Body + 'static,
215 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
216 {
217 }
218
219 #[cfg(feature = "http2")]
220 impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
221 where
222 S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
223 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
224 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
225 B: hyper::body::Body + 'static,
226 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
227 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
228 {
229 }
230
231 #[cfg(feature = "server-auto")]
232 impl<I, B, S, E> Sealed for crate::server::conn::auto::Connection<'_, I, S, E>
233 where
234 S: hyper::service::Service<
235 http::Request<hyper::body::Incoming>,
236 Response = http::Response<B>,
237 >,
238 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
239 S::Future: 'static,
240 I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
241 B: hyper::body::Body + 'static,
242 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
244 {
245 }
246
247 #[cfg(feature = "server-auto")]
248 impl<I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E>
249 where
250 S: hyper::service::Service<
251 http::Request<hyper::body::Incoming>,
252 Response = http::Response<B>,
253 >,
254 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
255 S::Future: 'static,
256 I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
257 B: hyper::body::Body + 'static,
258 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
259 E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
260 {
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267 use pin_project_lite::pin_project;
268 use std::sync::atomic::{AtomicUsize, Ordering};
269 use std::sync::Arc;
270
271 pin_project! {
272 #[derive(Debug)]
273 struct DummyConnection<F> {
274 #[pin]
275 future: F,
276 shutdown_counter: Arc<AtomicUsize>,
277 }
278 }
279
280 impl<F> private::Sealed for DummyConnection<F> {}
281
282 impl<F: Future> GracefulConnection for DummyConnection<F> {
283 type Error = ();
284
285 fn graceful_shutdown(self: Pin<&mut Self>) {
286 self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
287 }
288 }
289
290 impl<F: Future> Future for DummyConnection<F> {
291 type Output = Result<(), ()>;
292
293 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
294 match self.project().future.poll(cx) {
295 Poll::Ready(_) => Poll::Ready(Ok(())),
296 Poll::Pending => Poll::Pending,
297 }
298 }
299 }
300
301 #[cfg(not(miri))]
302 #[tokio::test]
303 async fn test_graceful_shutdown_ok() {
304 let graceful = GracefulShutdown::new();
305 let shutdown_counter = Arc::new(AtomicUsize::new(0));
306 let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
307
308 for i in 1..=3 {
309 let mut dummy_rx = dummy_tx.subscribe();
310 let shutdown_counter = shutdown_counter.clone();
311
312 let future = async move {
313 tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
314 let _ = dummy_rx.recv().await;
315 };
316 let dummy_conn = DummyConnection {
317 future,
318 shutdown_counter,
319 };
320 let conn = graceful.watch(dummy_conn);
321 tokio::spawn(async move {
322 conn.await.unwrap();
323 });
324 }
325
326 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
327 let _ = dummy_tx.send(());
328
329 tokio::select! {
330 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
331 panic!("timeout")
332 },
333 _ = graceful.shutdown() => {
334 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
335 }
336 }
337 }
338
339 #[cfg(not(miri))]
340 #[tokio::test]
341 async fn test_graceful_shutdown_delayed_ok() {
342 let graceful = GracefulShutdown::new();
343 let shutdown_counter = Arc::new(AtomicUsize::new(0));
344
345 for i in 1..=3 {
346 let shutdown_counter = shutdown_counter.clone();
347
348 let future = async move {
350 tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
351 };
352 let dummy_conn = DummyConnection {
353 future,
354 shutdown_counter,
355 };
356 let conn = graceful.watch(dummy_conn);
357 tokio::spawn(async move {
358 conn.await.unwrap();
359 });
360 }
361
362 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
363
364 tokio::select! {
365 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
366 panic!("timeout")
367 },
368 _ = graceful.shutdown() => {
369 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
370 }
371 }
372 }
373
374 #[cfg(not(miri))]
375 #[tokio::test]
376 async fn test_graceful_shutdown_multi_per_watcher_ok() {
377 let graceful = GracefulShutdown::new();
378 let shutdown_counter = Arc::new(AtomicUsize::new(0));
379
380 for i in 1..=3 {
381 let shutdown_counter = shutdown_counter.clone();
382
383 let mut futures = Vec::new();
384 for u in 1..=i {
385 let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
386 let dummy_conn = DummyConnection {
387 future,
388 shutdown_counter: shutdown_counter.clone(),
389 };
390 let conn = graceful.watch(dummy_conn);
391 futures.push(conn);
392 }
393 tokio::spawn(async move {
394 futures_util::future::join_all(futures).await;
395 });
396 }
397
398 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
399
400 tokio::select! {
401 _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
402 panic!("timeout")
403 },
404 _ = graceful.shutdown() => {
405 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
406 }
407 }
408 }
409
410 #[cfg(not(miri))]
411 #[tokio::test]
412 async fn test_graceful_shutdown_timeout() {
413 let graceful = GracefulShutdown::new();
414 let shutdown_counter = Arc::new(AtomicUsize::new(0));
415
416 for i in 1..=3 {
417 let shutdown_counter = shutdown_counter.clone();
418
419 let future = async move {
420 if i == 1 {
421 std::future::pending::<()>().await
422 } else {
423 std::future::ready(()).await
424 }
425 };
426 let dummy_conn = DummyConnection {
427 future,
428 shutdown_counter,
429 };
430 let conn = graceful.watch(dummy_conn);
431 tokio::spawn(async move {
432 conn.await.unwrap();
433 });
434 }
435
436 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
437
438 tokio::select! {
439 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
440 assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
441 },
442 _ = graceful.shutdown() => {
443 panic!("shutdown should not be completed: as not all our conns finish")
444 }
445 }
446 }
447}