actix_web/middleware/
err_handlers.rs

1//! For middleware documentation, see [`ErrorHandlers`].
2
3use std::{
4    future::Future,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use foldhash::HashMap as FoldHashMap;
12use futures_core::{future::LocalBoxFuture, ready};
13use pin_project_lite::pin_project;
14
15use crate::{
16    body::EitherBody,
17    dev::{ServiceRequest, ServiceResponse},
18    http::StatusCode,
19    Error, Result,
20};
21
22/// Return type for [`ErrorHandlers`] custom handlers.
23pub enum ErrorHandlerResponse<B> {
24    /// Immediate HTTP response.
25    Response(ServiceResponse<EitherBody<B>>),
26
27    /// A future that resolves to an HTTP response.
28    Future(LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>),
29}
30
31type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
32
33type DefaultHandler<B> = Option<Rc<ErrorHandler<B>>>;
34
35/// Middleware for registering custom status code based error handlers.
36///
37/// Register handlers with the [`ErrorHandlers::handler()`] method to register a custom error handler
38/// for a given status code. Handlers can modify existing responses or create completely new ones.
39///
40/// To register a default handler, use the [`ErrorHandlers::default_handler()`] method. This
41/// handler will be used only if a response has an error status code (400-599) that isn't covered by
42/// a more specific handler (set with the [`handler()`][ErrorHandlers::handler] method). See examples
43/// below.
44///
45/// To register a default for only client errors (400-499) or only server errors (500-599), use the
46/// [`ErrorHandlers::default_handler_client()`] and [`ErrorHandlers::default_handler_server()`]
47/// methods, respectively.
48///
49/// Any response with a status code that isn't covered by a specific handler or a default handler
50/// will pass by unchanged by this middleware.
51///
52/// # Examples
53///
54/// Adding a header:
55///
56/// ```
57/// use actix_web::{
58///     dev::ServiceResponse,
59///     http::{header, StatusCode},
60///     middleware::{ErrorHandlerResponse, ErrorHandlers},
61///     web, App, HttpResponse, Result,
62/// };
63///
64/// fn add_error_header<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
65///     res.response_mut().headers_mut().insert(
66///         header::CONTENT_TYPE,
67///         header::HeaderValue::from_static("Error"),
68///     );
69///
70///     // body is unchanged, map to "left" slot
71///     Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
72/// }
73///
74/// let app = App::new()
75///     .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
76///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
77/// ```
78///
79/// Modifying response body:
80///
81/// ```
82/// use actix_web::{
83///     dev::ServiceResponse,
84///     http::{header, StatusCode},
85///     middleware::{ErrorHandlerResponse, ErrorHandlers},
86///     web, App, HttpResponse, Result,
87/// };
88///
89/// fn add_error_body<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
90///     // split service response into request and response components
91///     let (req, res) = res.into_parts();
92///
93///     // set body of response to modified body
94///     let res = res.set_body("An error occurred.");
95///
96///     // modified bodies need to be boxed and placed in the "right" slot
97///     let res = ServiceResponse::new(req, res)
98///         .map_into_boxed_body()
99///         .map_into_right_body();
100///
101///     Ok(ErrorHandlerResponse::Response(res))
102/// }
103///
104/// let app = App::new()
105///     .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_body))
106///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
107/// ```
108///
109/// Registering default handler:
110///
111/// ```
112/// # use actix_web::{
113/// #     dev::ServiceResponse,
114/// #     http::{header, StatusCode},
115/// #     middleware::{ErrorHandlerResponse, ErrorHandlers},
116/// #     web, App, HttpResponse, Result,
117/// # };
118/// fn add_error_header<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
119///     res.response_mut().headers_mut().insert(
120///         header::CONTENT_TYPE,
121///         header::HeaderValue::from_static("Error"),
122///     );
123///
124///     // body is unchanged, map to "left" slot
125///     Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
126/// }
127///
128/// fn handle_bad_request<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
129///     res.response_mut().headers_mut().insert(
130///         header::CONTENT_TYPE,
131///         header::HeaderValue::from_static("Bad Request Error"),
132///     );
133///
134///     // body is unchanged, map to "left" slot
135///     Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
136/// }
137///
138/// // Bad Request errors will hit `handle_bad_request()`, while all other errors will hit
139/// // `add_error_header()`. The order in which the methods are called is not meaningful.
140/// let app = App::new()
141///     .wrap(
142///         ErrorHandlers::new()
143///             .default_handler(add_error_header)
144///             .handler(StatusCode::BAD_REQUEST, handle_bad_request)
145///     )
146///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
147/// ```
148///
149/// You can set default handlers for all client (4xx) or all server (5xx) errors:
150///
151/// ```
152/// # use actix_web::{
153/// #     dev::ServiceResponse,
154/// #     http::{header, StatusCode},
155/// #     middleware::{ErrorHandlerResponse, ErrorHandlers},
156/// #     web, App, HttpResponse, Result,
157/// # };
158/// # fn add_error_header<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
159/// #     res.response_mut().headers_mut().insert(
160/// #         header::CONTENT_TYPE,
161/// #         header::HeaderValue::from_static("Error"),
162/// #     );
163/// #     Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
164/// # }
165/// # fn handle_bad_request<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
166/// #     res.response_mut().headers_mut().insert(
167/// #         header::CONTENT_TYPE,
168/// #         header::HeaderValue::from_static("Bad Request Error"),
169/// #     );
170/// #     Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
171/// # }
172/// // Bad request errors will hit `handle_bad_request()`, other client errors will hit
173/// // `add_error_header()`, and server errors will pass through unchanged
174/// let app = App::new()
175///     .wrap(
176///         ErrorHandlers::new()
177///             .default_handler_client(add_error_header) // or .default_handler_server
178///             .handler(StatusCode::BAD_REQUEST, handle_bad_request)
179///     )
180///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
181/// ```
182pub struct ErrorHandlers<B> {
183    default_client: DefaultHandler<B>,
184    default_server: DefaultHandler<B>,
185    handlers: Handlers<B>,
186}
187
188type Handlers<B> = Rc<FoldHashMap<StatusCode, Box<ErrorHandler<B>>>>;
189
190impl<B> Default for ErrorHandlers<B> {
191    fn default() -> Self {
192        ErrorHandlers {
193            default_client: Default::default(),
194            default_server: Default::default(),
195            handlers: Default::default(),
196        }
197    }
198}
199
200impl<B> ErrorHandlers<B> {
201    /// Construct new `ErrorHandlers` instance.
202    pub fn new() -> Self {
203        ErrorHandlers::default()
204    }
205
206    /// Register error handler for specified status code.
207    pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
208    where
209        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
210    {
211        Rc::get_mut(&mut self.handlers)
212            .unwrap()
213            .insert(status, Box::new(handler));
214        self
215    }
216
217    /// Register a default error handler.
218    ///
219    /// Any request with a status code that hasn't been given a specific other handler (by calling
220    /// [`.handler()`][ErrorHandlers::handler]) will fall back on this.
221    ///
222    /// Note that this will overwrite any default handlers previously set by calling
223    /// [`default_handler_client()`] or [`.default_handler_server()`], but not any set by calling
224    /// [`.handler()`].
225    ///
226    /// [`default_handler_client()`]: ErrorHandlers::default_handler_client
227    /// [`.default_handler_server()`]: ErrorHandlers::default_handler_server
228    /// [`.handler()`]: ErrorHandlers::handler
229    pub fn default_handler<F>(self, handler: F) -> Self
230    where
231        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
232    {
233        let handler = Rc::new(handler);
234        let handler2 = Rc::clone(&handler);
235        Self {
236            default_server: Some(handler2),
237            default_client: Some(handler),
238            ..self
239        }
240    }
241
242    /// Register a handler on which to fall back for client error status codes (400-499).
243    pub fn default_handler_client<F>(self, handler: F) -> Self
244    where
245        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
246    {
247        Self {
248            default_client: Some(Rc::new(handler)),
249            ..self
250        }
251    }
252
253    /// Register a handler on which to fall back for server error status codes (500-599).
254    pub fn default_handler_server<F>(self, handler: F) -> Self
255    where
256        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
257    {
258        Self {
259            default_server: Some(Rc::new(handler)),
260            ..self
261        }
262    }
263
264    /// Selects the most appropriate handler for the given status code.
265    ///
266    /// If the `handlers` map has an entry for that status code, that handler is returned.
267    /// Otherwise, fall back on the appropriate default handler.
268    fn get_handler<'a>(
269        status: &StatusCode,
270        default_client: Option<&'a ErrorHandler<B>>,
271        default_server: Option<&'a ErrorHandler<B>>,
272        handlers: &'a Handlers<B>,
273    ) -> Option<&'a ErrorHandler<B>> {
274        handlers
275            .get(status)
276            .map(|h| h.as_ref())
277            .or_else(|| status.is_client_error().then_some(default_client).flatten())
278            .or_else(|| status.is_server_error().then_some(default_server).flatten())
279    }
280}
281
282impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
283where
284    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
285    S::Future: 'static,
286    B: 'static,
287{
288    type Response = ServiceResponse<EitherBody<B>>;
289    type Error = Error;
290    type Transform = ErrorHandlersMiddleware<S, B>;
291    type InitError = ();
292    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
293
294    fn new_transform(&self, service: S) -> Self::Future {
295        let handlers = Rc::clone(&self.handlers);
296        let default_client = self.default_client.clone();
297        let default_server = self.default_server.clone();
298        Box::pin(async move {
299            Ok(ErrorHandlersMiddleware {
300                service,
301                default_client,
302                default_server,
303                handlers,
304            })
305        })
306    }
307}
308
309#[doc(hidden)]
310pub struct ErrorHandlersMiddleware<S, B> {
311    service: S,
312    default_client: DefaultHandler<B>,
313    default_server: DefaultHandler<B>,
314    handlers: Handlers<B>,
315}
316
317impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
318where
319    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
320    S::Future: 'static,
321    B: 'static,
322{
323    type Response = ServiceResponse<EitherBody<B>>;
324    type Error = Error;
325    type Future = ErrorHandlersFuture<S::Future, B>;
326
327    actix_service::forward_ready!(service);
328
329    fn call(&self, req: ServiceRequest) -> Self::Future {
330        let handlers = Rc::clone(&self.handlers);
331        let default_client = self.default_client.clone();
332        let default_server = self.default_server.clone();
333        let fut = self.service.call(req);
334        ErrorHandlersFuture::ServiceFuture {
335            fut,
336            default_client,
337            default_server,
338            handlers,
339        }
340    }
341}
342
343pin_project! {
344    #[project = ErrorHandlersProj]
345    pub enum ErrorHandlersFuture<Fut, B>
346    where
347        Fut: Future,
348    {
349        ServiceFuture {
350            #[pin]
351            fut: Fut,
352            default_client: DefaultHandler<B>,
353            default_server: DefaultHandler<B>,
354            handlers: Handlers<B>,
355        },
356        ErrorHandlerFuture {
357            fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
358        },
359    }
360}
361
362impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
363where
364    Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
365{
366    type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
367
368    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
369        match self.as_mut().project() {
370            ErrorHandlersProj::ServiceFuture {
371                fut,
372                default_client,
373                default_server,
374                handlers,
375            } => {
376                let res = ready!(fut.poll(cx))?;
377                let status = res.status();
378
379                let handler = ErrorHandlers::get_handler(
380                    &status,
381                    default_client.as_mut().map(|f| Rc::as_ref(f)),
382                    default_server.as_mut().map(|f| Rc::as_ref(f)),
383                    handlers,
384                );
385                match handler {
386                    Some(handler) => match handler(res)? {
387                        ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
388                        ErrorHandlerResponse::Future(fut) => {
389                            self.as_mut()
390                                .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
391
392                            self.poll(cx)
393                        }
394                    },
395                    None => Poll::Ready(Ok(res.map_into_left_body())),
396                }
397            }
398
399            ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use actix_service::IntoService;
407    use actix_utils::future::ok;
408    use bytes::Bytes;
409    use futures_util::FutureExt as _;
410
411    use super::*;
412    use crate::{
413        body,
414        http::header::{HeaderValue, CONTENT_TYPE},
415        test::{self, TestRequest},
416    };
417
418    #[actix_rt::test]
419    async fn add_header_error_handler() {
420        #[allow(clippy::unnecessary_wraps)]
421        fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
422            res.response_mut()
423                .headers_mut()
424                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
425
426            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
427        }
428
429        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
430
431        let mw = ErrorHandlers::new()
432            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
433            .new_transform(srv.into_service())
434            .await
435            .unwrap();
436
437        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
438        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
439    }
440
441    #[actix_rt::test]
442    async fn add_header_error_handler_async() {
443        #[allow(clippy::unnecessary_wraps)]
444        fn error_handler<B: 'static>(
445            mut res: ServiceResponse<B>,
446        ) -> Result<ErrorHandlerResponse<B>> {
447            res.response_mut()
448                .headers_mut()
449                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
450
451            Ok(ErrorHandlerResponse::Future(
452                ok(res.map_into_left_body()).boxed_local(),
453            ))
454        }
455
456        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
457
458        let mw = ErrorHandlers::new()
459            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
460            .new_transform(srv.into_service())
461            .await
462            .unwrap();
463
464        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
465        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
466    }
467
468    #[actix_rt::test]
469    async fn changes_body_type() {
470        #[allow(clippy::unnecessary_wraps)]
471        fn error_handler<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
472            let (req, res) = res.into_parts();
473            let res = res.set_body(Bytes::from("sorry, that's no bueno"));
474
475            let res = ServiceResponse::new(req, res)
476                .map_into_boxed_body()
477                .map_into_right_body();
478
479            Ok(ErrorHandlerResponse::Response(res))
480        }
481
482        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
483
484        let mw = ErrorHandlers::new()
485            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
486            .new_transform(srv.into_service())
487            .await
488            .unwrap();
489
490        let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
491        assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
492    }
493
494    #[actix_rt::test]
495    async fn error_thrown() {
496        #[allow(clippy::unnecessary_wraps)]
497        fn error_handler<B>(_res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
498            Err(crate::error::ErrorInternalServerError(
499                "error in error handler",
500            ))
501        }
502
503        let srv = test::status_service(StatusCode::BAD_REQUEST);
504
505        let mw = ErrorHandlers::new()
506            .handler(StatusCode::BAD_REQUEST, error_handler)
507            .new_transform(srv.into_service())
508            .await
509            .unwrap();
510
511        let err = mw
512            .call(TestRequest::default().to_srv_request())
513            .await
514            .unwrap_err();
515        let res = err.error_response();
516
517        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
518        assert_eq!(
519            body::to_bytes(res.into_body()).await.unwrap(),
520            "error in error handler"
521        );
522    }
523
524    #[actix_rt::test]
525    async fn default_error_handler() {
526        #[allow(clippy::unnecessary_wraps)]
527        fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
528            res.response_mut()
529                .headers_mut()
530                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
531            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
532        }
533
534        let make_mw = |status| async move {
535            ErrorHandlers::new()
536                .default_handler(error_handler)
537                .new_transform(test::status_service(status).into_service())
538                .await
539                .unwrap()
540        };
541        let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
542        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
543
544        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
545        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
546
547        let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
548        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
549    }
550
551    #[actix_rt::test]
552    async fn default_handlers_separate_client_server() {
553        #[allow(clippy::unnecessary_wraps)]
554        fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
555            res.response_mut()
556                .headers_mut()
557                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
558            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
559        }
560
561        #[allow(clippy::unnecessary_wraps)]
562        fn error_handler_server<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
563            res.response_mut()
564                .headers_mut()
565                .insert(CONTENT_TYPE, HeaderValue::from_static("0002"));
566            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
567        }
568
569        let make_mw = |status| async move {
570            ErrorHandlers::new()
571                .default_handler_server(error_handler_server)
572                .default_handler_client(error_handler_client)
573                .new_transform(test::status_service(status).into_service())
574                .await
575                .unwrap()
576        };
577        let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
578        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
579
580        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
581        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
582
583        let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
584        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
585    }
586
587    #[actix_rt::test]
588    async fn default_handlers_specialization() {
589        #[allow(clippy::unnecessary_wraps)]
590        fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
591            res.response_mut()
592                .headers_mut()
593                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
594            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
595        }
596
597        #[allow(clippy::unnecessary_wraps)]
598        fn error_handler_specific<B>(
599            mut res: ServiceResponse<B>,
600        ) -> Result<ErrorHandlerResponse<B>> {
601            res.response_mut()
602                .headers_mut()
603                .insert(CONTENT_TYPE, HeaderValue::from_static("0003"));
604            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
605        }
606
607        let make_mw = |status| async move {
608            ErrorHandlers::new()
609                .default_handler_client(error_handler_client)
610                .handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific)
611                .new_transform(test::status_service(status).into_service())
612                .await
613                .unwrap()
614        };
615        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
616        let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await;
617
618        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
619        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
620
621        let resp = test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await;
622        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003");
623    }
624}