tower_http/
catch_panic.rs

1//! Convert panics into responses.
2//!
3//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result`
4//! whenever possible.
5//!
6//! # Example
7//!
8//! ```rust
9//! use http::{Request, Response, header::HeaderName};
10//! use std::convert::Infallible;
11//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
12//! use tower_http::catch_panic::CatchPanicLayer;
13//! use http_body_util::Full;
14//! use bytes::Bytes;
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
19//!     panic!("something went wrong...")
20//! }
21//!
22//! let mut svc = ServiceBuilder::new()
23//!     // Catch panics and convert them into responses.
24//!     .layer(CatchPanicLayer::new())
25//!     .service_fn(handle);
26//!
27//! // Call the service.
28//! let request = Request::new(Full::default());
29//!
30//! let response = svc.ready().await?.call(request).await?;
31//!
32//! assert_eq!(response.status(), 500);
33//! #
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! Using a custom panic handler:
39//!
40//! ```rust
41//! use http::{Request, StatusCode, Response, header::{self, HeaderName}};
42//! use std::{any::Any, convert::Infallible};
43//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
44//! use tower_http::catch_panic::CatchPanicLayer;
45//! use bytes::Bytes;
46//! use http_body_util::Full;
47//!
48//! # #[tokio::main]
49//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
50//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
51//!     panic!("something went wrong...")
52//! }
53//!
54//! fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
55//!     let details = if let Some(s) = err.downcast_ref::<String>() {
56//!         s.clone()
57//!     } else if let Some(s) = err.downcast_ref::<&str>() {
58//!         s.to_string()
59//!     } else {
60//!         "Unknown panic message".to_string()
61//!     };
62//!
63//!     let body = serde_json::json!({
64//!         "error": {
65//!             "kind": "panic",
66//!             "details": details,
67//!         }
68//!     });
69//!     let body = serde_json::to_string(&body).unwrap();
70//!
71//!     Response::builder()
72//!         .status(StatusCode::INTERNAL_SERVER_ERROR)
73//!         .header(header::CONTENT_TYPE, "application/json")
74//!         .body(Full::from(body))
75//!         .unwrap()
76//! }
77//!
78//! let svc = ServiceBuilder::new()
79//!     // Use `handle_panic` to create the response.
80//!     .layer(CatchPanicLayer::custom(handle_panic))
81//!     .service_fn(handle);
82//! #
83//! # Ok(())
84//! # }
85//! ```
86
87use bytes::Bytes;
88use futures_util::future::{CatchUnwind, FutureExt};
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::BodyExt;
92use pin_project_lite::pin_project;
93use std::{
94    any::Any,
95    future::Future,
96    panic::AssertUnwindSafe,
97    pin::Pin,
98    task::{ready, Context, Poll},
99};
100use tower_layer::Layer;
101use tower_service::Service;
102
103use crate::{
104    body::{Full, UnsyncBoxBody},
105    BoxError,
106};
107
108/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into
109/// `500 Internal Server` responses.
110///
111/// See the [module docs](self) for an example.
112#[derive(Debug, Clone, Copy, Default)]
113pub struct CatchPanicLayer<T> {
114    panic_handler: T,
115}
116
117impl CatchPanicLayer<DefaultResponseForPanic> {
118    /// Create a new `CatchPanicLayer` with the default panic handler.
119    pub fn new() -> Self {
120        CatchPanicLayer {
121            panic_handler: DefaultResponseForPanic,
122        }
123    }
124}
125
126impl<T> CatchPanicLayer<T> {
127    /// Create a new `CatchPanicLayer` with a custom panic handler.
128    pub fn custom(panic_handler: T) -> Self
129    where
130        T: ResponseForPanic,
131    {
132        Self { panic_handler }
133    }
134}
135
136impl<T, S> Layer<S> for CatchPanicLayer<T>
137where
138    T: Clone,
139{
140    type Service = CatchPanic<S, T>;
141
142    fn layer(&self, inner: S) -> Self::Service {
143        CatchPanic {
144            inner,
145            panic_handler: self.panic_handler.clone(),
146        }
147    }
148}
149
150/// Middleware that catches panics and converts them into `500 Internal Server` responses.
151///
152/// See the [module docs](self) for an example.
153#[derive(Debug, Clone, Copy)]
154pub struct CatchPanic<S, T> {
155    inner: S,
156    panic_handler: T,
157}
158
159impl<S> CatchPanic<S, DefaultResponseForPanic> {
160    /// Create a new `CatchPanic` with the default panic handler.
161    pub fn new(inner: S) -> Self {
162        Self {
163            inner,
164            panic_handler: DefaultResponseForPanic,
165        }
166    }
167}
168
169impl<S, T> CatchPanic<S, T> {
170    define_inner_service_accessors!();
171
172    /// Create a new `CatchPanic` with a custom panic handler.
173    pub fn custom(inner: S, panic_handler: T) -> Self
174    where
175        T: ResponseForPanic,
176    {
177        Self {
178            inner,
179            panic_handler,
180        }
181    }
182}
183
184impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
185where
186    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
187    ResBody: Body<Data = Bytes> + Send + 'static,
188    ResBody::Error: Into<BoxError>,
189    T: ResponseForPanic + Clone,
190    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
191    <T::ResponseBody as Body>::Error: Into<BoxError>,
192{
193    type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
194    type Error = S::Error;
195    type Future = ResponseFuture<S::Future, T>;
196
197    #[inline]
198    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        self.inner.poll_ready(cx)
200    }
201
202    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
203        match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
204            Ok(future) => ResponseFuture {
205                kind: Kind::Future {
206                    future: AssertUnwindSafe(future).catch_unwind(),
207                    panic_handler: Some(self.panic_handler.clone()),
208                },
209            },
210            Err(panic_err) => ResponseFuture {
211                kind: Kind::Panicked {
212                    panic_err: Some(panic_err),
213                    panic_handler: Some(self.panic_handler.clone()),
214                },
215            },
216        }
217    }
218}
219
220pin_project! {
221    /// Response future for [`CatchPanic`].
222    pub struct ResponseFuture<F, T> {
223        #[pin]
224        kind: Kind<F, T>,
225    }
226}
227
228pin_project! {
229    #[project = KindProj]
230    enum Kind<F, T> {
231        Panicked {
232            panic_err: Option<Box<dyn Any + Send + 'static>>,
233            panic_handler: Option<T>,
234        },
235        Future {
236            #[pin]
237            future: CatchUnwind<AssertUnwindSafe<F>>,
238            panic_handler: Option<T>,
239        }
240    }
241}
242
243impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
244where
245    F: Future<Output = Result<Response<ResBody>, E>>,
246    ResBody: Body<Data = Bytes> + Send + 'static,
247    ResBody::Error: Into<BoxError>,
248    T: ResponseForPanic,
249    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
250    <T::ResponseBody as Body>::Error: Into<BoxError>,
251{
252    type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
253
254    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255        match self.project().kind.project() {
256            KindProj::Panicked {
257                panic_err,
258                panic_handler,
259            } => {
260                let panic_handler = panic_handler
261                    .take()
262                    .expect("future polled after completion");
263                let panic_err = panic_err.take().expect("future polled after completion");
264                Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
265            }
266            KindProj::Future {
267                future,
268                panic_handler,
269            } => match ready!(future.poll(cx)) {
270                Ok(Ok(res)) => {
271                    Poll::Ready(Ok(res.map(|body| {
272                        UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())
273                    })))
274                }
275                Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
276                Err(panic_err) => Poll::Ready(Ok(response_for_panic(
277                    panic_handler
278                        .take()
279                        .expect("future polled after completion"),
280                    panic_err,
281                ))),
282            },
283        }
284    }
285}
286
287fn response_for_panic<T>(
288    mut panic_handler: T,
289    err: Box<dyn Any + Send + 'static>,
290) -> Response<UnsyncBoxBody<Bytes, BoxError>>
291where
292    T: ResponseForPanic,
293    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
294    <T::ResponseBody as Body>::Error: Into<BoxError>,
295{
296    panic_handler
297        .response_for_panic(err)
298        .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()))
299}
300
301/// Trait for creating responses from panics.
302pub trait ResponseForPanic: Clone {
303    /// The body type used for responses to panics.
304    type ResponseBody;
305
306    /// Create a response from the panic error.
307    fn response_for_panic(
308        &mut self,
309        err: Box<dyn Any + Send + 'static>,
310    ) -> Response<Self::ResponseBody>;
311}
312
313impl<F, B> ResponseForPanic for F
314where
315    F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
316{
317    type ResponseBody = B;
318
319    fn response_for_panic(
320        &mut self,
321        err: Box<dyn Any + Send + 'static>,
322    ) -> Response<Self::ResponseBody> {
323        self(err)
324    }
325}
326
327/// The default `ResponseForPanic` used by `CatchPanic`.
328///
329/// It will log the panic message and return a `500 Internal Server` error response with an empty
330/// body.
331#[derive(Debug, Default, Clone, Copy)]
332#[non_exhaustive]
333pub struct DefaultResponseForPanic;
334
335impl ResponseForPanic for DefaultResponseForPanic {
336    type ResponseBody = Full;
337
338    fn response_for_panic(
339        &mut self,
340        err: Box<dyn Any + Send + 'static>,
341    ) -> Response<Self::ResponseBody> {
342        if let Some(s) = err.downcast_ref::<String>() {
343            tracing::error!("Service panicked: {}", s);
344        } else if let Some(s) = err.downcast_ref::<&str>() {
345            tracing::error!("Service panicked: {}", s);
346        } else {
347            tracing::error!(
348                "Service panicked but `CatchPanic` was unable to downcast the panic info"
349            );
350        };
351
352        let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked")));
353        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
354
355        #[allow(clippy::declare_interior_mutable_const)]
356        const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
357        res.headers_mut()
358            .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
359
360        res
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    #![allow(unreachable_code)]
367
368    use super::*;
369    use crate::test_helpers::Body;
370    use http::Response;
371    use std::convert::Infallible;
372    use tower::{ServiceBuilder, ServiceExt};
373
374    #[tokio::test]
375    async fn panic_before_returning_future() {
376        let svc = ServiceBuilder::new()
377            .layer(CatchPanicLayer::new())
378            .service_fn(|_: Request<Body>| {
379                panic!("service panic");
380                async { Ok::<_, Infallible>(Response::new(Body::empty())) }
381            });
382
383        let req = Request::new(Body::empty());
384
385        let res = svc.oneshot(req).await.unwrap();
386
387        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
388        let body = crate::test_helpers::to_bytes(res).await.unwrap();
389        assert_eq!(&body[..], b"Service panicked");
390    }
391
392    #[tokio::test]
393    async fn panic_in_future() {
394        let svc = ServiceBuilder::new()
395            .layer(CatchPanicLayer::new())
396            .service_fn(|_: Request<Body>| async {
397                panic!("future panic");
398                Ok::<_, Infallible>(Response::new(Body::empty()))
399            });
400
401        let req = Request::new(Body::empty());
402
403        let res = svc.oneshot(req).await.unwrap();
404
405        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
406        let body = crate::test_helpers::to_bytes(res).await.unwrap();
407        assert_eq!(&body[..], b"Service panicked");
408    }
409}