1use bytes::Bytes;
88use futures_util::future::{CatchUnwind, FutureExt};
89use http::{HeaderValue, Request, Response, StatusCode};
90use http_body::Body;
91use http_body_util::BodyExt;
92use pin_project_lite::pin_project;
93use std::{
94 any::Any,
95 future::Future,
96 panic::AssertUnwindSafe,
97 pin::Pin,
98 task::{ready, Context, Poll},
99};
100use tower_layer::Layer;
101use tower_service::Service;
102
103use crate::{
104 body::{Full, UnsyncBoxBody},
105 BoxError,
106};
107
108#[derive(Debug, Clone, Copy, Default)]
113pub struct CatchPanicLayer<T> {
114 panic_handler: T,
115}
116
117impl CatchPanicLayer<DefaultResponseForPanic> {
118 pub fn new() -> Self {
120 CatchPanicLayer {
121 panic_handler: DefaultResponseForPanic,
122 }
123 }
124}
125
126impl<T> CatchPanicLayer<T> {
127 pub fn custom(panic_handler: T) -> Self
129 where
130 T: ResponseForPanic,
131 {
132 Self { panic_handler }
133 }
134}
135
136impl<T, S> Layer<S> for CatchPanicLayer<T>
137where
138 T: Clone,
139{
140 type Service = CatchPanic<S, T>;
141
142 fn layer(&self, inner: S) -> Self::Service {
143 CatchPanic {
144 inner,
145 panic_handler: self.panic_handler.clone(),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Copy)]
154pub struct CatchPanic<S, T> {
155 inner: S,
156 panic_handler: T,
157}
158
159impl<S> CatchPanic<S, DefaultResponseForPanic> {
160 pub fn new(inner: S) -> Self {
162 Self {
163 inner,
164 panic_handler: DefaultResponseForPanic,
165 }
166 }
167}
168
169impl<S, T> CatchPanic<S, T> {
170 define_inner_service_accessors!();
171
172 pub fn custom(inner: S, panic_handler: T) -> Self
174 where
175 T: ResponseForPanic,
176 {
177 Self {
178 inner,
179 panic_handler,
180 }
181 }
182}
183
184impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for CatchPanic<S, T>
185where
186 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
187 ResBody: Body<Data = Bytes> + Send + 'static,
188 ResBody::Error: Into<BoxError>,
189 T: ResponseForPanic + Clone,
190 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
191 <T::ResponseBody as Body>::Error: Into<BoxError>,
192{
193 type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
194 type Error = S::Error;
195 type Future = ResponseFuture<S::Future, T>;
196
197 #[inline]
198 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199 self.inner.poll_ready(cx)
200 }
201
202 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
203 match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) {
204 Ok(future) => ResponseFuture {
205 kind: Kind::Future {
206 future: AssertUnwindSafe(future).catch_unwind(),
207 panic_handler: Some(self.panic_handler.clone()),
208 },
209 },
210 Err(panic_err) => ResponseFuture {
211 kind: Kind::Panicked {
212 panic_err: Some(panic_err),
213 panic_handler: Some(self.panic_handler.clone()),
214 },
215 },
216 }
217 }
218}
219
220pin_project! {
221 pub struct ResponseFuture<F, T> {
223 #[pin]
224 kind: Kind<F, T>,
225 }
226}
227
228pin_project! {
229 #[project = KindProj]
230 enum Kind<F, T> {
231 Panicked {
232 panic_err: Option<Box<dyn Any + Send + 'static>>,
233 panic_handler: Option<T>,
234 },
235 Future {
236 #[pin]
237 future: CatchUnwind<AssertUnwindSafe<F>>,
238 panic_handler: Option<T>,
239 }
240 }
241}
242
243impl<F, ResBody, E, T> Future for ResponseFuture<F, T>
244where
245 F: Future<Output = Result<Response<ResBody>, E>>,
246 ResBody: Body<Data = Bytes> + Send + 'static,
247 ResBody::Error: Into<BoxError>,
248 T: ResponseForPanic,
249 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
250 <T::ResponseBody as Body>::Error: Into<BoxError>,
251{
252 type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, E>;
253
254 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255 match self.project().kind.project() {
256 KindProj::Panicked {
257 panic_err,
258 panic_handler,
259 } => {
260 let panic_handler = panic_handler
261 .take()
262 .expect("future polled after completion");
263 let panic_err = panic_err.take().expect("future polled after completion");
264 Poll::Ready(Ok(response_for_panic(panic_handler, panic_err)))
265 }
266 KindProj::Future {
267 future,
268 panic_handler,
269 } => match ready!(future.poll(cx)) {
270 Ok(Ok(res)) => {
271 Poll::Ready(Ok(res.map(|body| {
272 UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())
273 })))
274 }
275 Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)),
276 Err(panic_err) => Poll::Ready(Ok(response_for_panic(
277 panic_handler
278 .take()
279 .expect("future polled after completion"),
280 panic_err,
281 ))),
282 },
283 }
284 }
285}
286
287fn response_for_panic<T>(
288 mut panic_handler: T,
289 err: Box<dyn Any + Send + 'static>,
290) -> Response<UnsyncBoxBody<Bytes, BoxError>>
291where
292 T: ResponseForPanic,
293 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
294 <T::ResponseBody as Body>::Error: Into<BoxError>,
295{
296 panic_handler
297 .response_for_panic(err)
298 .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()))
299}
300
301pub trait ResponseForPanic: Clone {
303 type ResponseBody;
305
306 fn response_for_panic(
308 &mut self,
309 err: Box<dyn Any + Send + 'static>,
310 ) -> Response<Self::ResponseBody>;
311}
312
313impl<F, B> ResponseForPanic for F
314where
315 F: FnMut(Box<dyn Any + Send + 'static>) -> Response<B> + Clone,
316{
317 type ResponseBody = B;
318
319 fn response_for_panic(
320 &mut self,
321 err: Box<dyn Any + Send + 'static>,
322 ) -> Response<Self::ResponseBody> {
323 self(err)
324 }
325}
326
327#[derive(Debug, Default, Clone, Copy)]
332#[non_exhaustive]
333pub struct DefaultResponseForPanic;
334
335impl ResponseForPanic for DefaultResponseForPanic {
336 type ResponseBody = Full;
337
338 fn response_for_panic(
339 &mut self,
340 err: Box<dyn Any + Send + 'static>,
341 ) -> Response<Self::ResponseBody> {
342 if let Some(s) = err.downcast_ref::<String>() {
343 tracing::error!("Service panicked: {}", s);
344 } else if let Some(s) = err.downcast_ref::<&str>() {
345 tracing::error!("Service panicked: {}", s);
346 } else {
347 tracing::error!(
348 "Service panicked but `CatchPanic` was unable to downcast the panic info"
349 );
350 };
351
352 let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked")));
353 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
354
355 #[allow(clippy::declare_interior_mutable_const)]
356 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
357 res.headers_mut()
358 .insert(http::header::CONTENT_TYPE, TEXT_PLAIN);
359
360 res
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 #![allow(unreachable_code)]
367
368 use super::*;
369 use crate::test_helpers::Body;
370 use http::Response;
371 use std::convert::Infallible;
372 use tower::{ServiceBuilder, ServiceExt};
373
374 #[tokio::test]
375 async fn panic_before_returning_future() {
376 let svc = ServiceBuilder::new()
377 .layer(CatchPanicLayer::new())
378 .service_fn(|_: Request<Body>| {
379 panic!("service panic");
380 async { Ok::<_, Infallible>(Response::new(Body::empty())) }
381 });
382
383 let req = Request::new(Body::empty());
384
385 let res = svc.oneshot(req).await.unwrap();
386
387 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
388 let body = crate::test_helpers::to_bytes(res).await.unwrap();
389 assert_eq!(&body[..], b"Service panicked");
390 }
391
392 #[tokio::test]
393 async fn panic_in_future() {
394 let svc = ServiceBuilder::new()
395 .layer(CatchPanicLayer::new())
396 .service_fn(|_: Request<Body>| async {
397 panic!("future panic");
398 Ok::<_, Infallible>(Response::new(Body::empty()))
399 });
400
401 let req = Request::new(Body::empty());
402
403 let res = svc.oneshot(req).await.unwrap();
404
405 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
406 let body = crate::test_helpers::to_bytes(res).await.unwrap();
407 assert_eq!(&body[..], b"Service panicked");
408 }
409}