axum_handle_error_extract/
lib.rs

1//! Error handling layer for axum that supports extractors and async functions.
2//!
3//! This crate provides [`HandleErrorLayer`] which works similarly to
4//! [`axum::error_handling::HandleErrorLayer`] except that it supports
5//! extractors and async functions:
6//!
7//! ```rust
8//! use axum::{
9//!     Router,
10//!     BoxError,
11//!     response::IntoResponse,
12//!     http::{StatusCode, Method, Uri},
13//!     routing::get,
14//! };
15//! use tower::{ServiceBuilder, timeout::error::Elapsed};
16//! use std::time::Duration;
17//! use axum_handle_error_extract::HandleErrorLayer;
18//!
19//! let app = Router::new()
20//!     .route("/", get(|| async {}))
21//!     .layer(
22//!         ServiceBuilder::new()
23//!             // timeouts produces errors, so we handle those with `handle_error`
24//!             .layer(HandleErrorLayer::new(handle_error))
25//!             .timeout(Duration::from_secs(10))
26//!     );
27//!
28//! // our handler take can 0 to 16 extractors and the final argument must
29//! // always be the error produced by the middleware
30//! async fn handle_error(
31//!     method: Method,
32//!     uri: Uri,
33//!     error: BoxError,
34//! ) -> impl IntoResponse {
35//!     if error.is::<Elapsed>() {
36//!         (
37//!             StatusCode::REQUEST_TIMEOUT,
38//!             format!("{} {} took too long", method, uri),
39//!         )
40//!     } else {
41//!         (
42//!             StatusCode::INTERNAL_SERVER_ERROR,
43//!             format!("{} {} failed: {}", method, uri, error),
44//!         )
45//!     }
46//! }
47//! # async {
48//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
49//! # };
50//! ```
51//!
52//! Not running any extractors is also supported:
53//!
54//! ```rust
55//! use axum::{
56//!     Router,
57//!     BoxError,
58//!     response::IntoResponse,
59//!     http::StatusCode,
60//!     routing::get,
61//! };
62//! use tower::{ServiceBuilder, timeout::error::Elapsed};
63//! use std::time::Duration;
64//! use axum_handle_error_extract::HandleErrorLayer;
65//!
66//! let app = Router::new()
67//!     .route("/", get(|| async {}))
68//!     .layer(
69//!         ServiceBuilder::new()
70//!             .layer(HandleErrorLayer::new(handle_error))
71//!             .timeout(Duration::from_secs(10))
72//!     );
73//!
74//! // this function just takes the error
75//! async fn handle_error(error: BoxError) -> impl IntoResponse {
76//!     if error.is::<Elapsed>() {
77//!         (
78//!             StatusCode::REQUEST_TIMEOUT,
79//!             "Request timeout".to_string(),
80//!         )
81//!     } else {
82//!         (
83//!             StatusCode::INTERNAL_SERVER_ERROR,
84//!             format!("Unhandled internal error: {}", error),
85//!         )
86//!     }
87//! }
88//! # async {
89//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
90//! # };
91//! ```
92//!
93//! See [`axum::error_handling`] for more details on axum's error handling model and
94//! [`axum::extract`] for more details on extractors.
95//!
96//! # The future
97//!
98//! In axum 0.4 this will replace the current [`axum::error_handling::HandleErrorLayer`].
99
100#![warn(
101    clippy::all,
102    clippy::dbg_macro,
103    clippy::todo,
104    clippy::empty_enum,
105    clippy::enum_glob_use,
106    clippy::mem_forget,
107    clippy::unused_self,
108    clippy::filter_map_next,
109    clippy::needless_continue,
110    clippy::needless_borrow,
111    clippy::match_wildcard_for_single_variants,
112    clippy::if_let_mutex,
113    clippy::mismatched_target_os,
114    clippy::await_holding_lock,
115    clippy::match_on_vec_items,
116    clippy::imprecise_flops,
117    clippy::suboptimal_flops,
118    clippy::lossy_float_literal,
119    clippy::rest_pat_in_fully_bound_structs,
120    clippy::fn_params_excessive_bools,
121    clippy::exit,
122    clippy::inefficient_to_string,
123    clippy::linkedlist,
124    clippy::macro_use_imports,
125    clippy::option_option,
126    clippy::verbose_file_reads,
127    clippy::unnested_or_patterns,
128    rust_2018_idioms,
129    future_incompatible,
130    nonstandard_style,
131    missing_debug_implementations,
132    missing_docs
133)]
134#![deny(unreachable_pub, private_in_public)]
135#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
136#![forbid(unsafe_code)]
137#![cfg_attr(docsrs, feature(doc_cfg))]
138#![cfg_attr(test, allow(clippy::float_cmp))]
139
140use axum::{
141    body::{box_body, BoxBody, Bytes, Full, HttpBody},
142    extract::{FromRequest, RequestParts},
143    http::{Request, Response, StatusCode},
144    response::IntoResponse,
145    BoxError,
146};
147use pin_project_lite::pin_project;
148use std::{
149    convert::Infallible,
150    fmt,
151    future::Future,
152    marker::PhantomData,
153    pin::Pin,
154    task::{Context, Poll},
155};
156use tower::ServiceExt;
157use tower_layer::Layer;
158use tower_service::Service;
159
160/// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter
161/// that handles errors by converting them into responses.
162///
163/// See [module docs](self) for more details on axum's error handling model.
164pub struct HandleErrorLayer<F, T> {
165    f: F,
166    _extractor: PhantomData<fn() -> T>,
167}
168
169impl<F, T> HandleErrorLayer<F, T> {
170    /// Create a new `HandleErrorLayer`.
171    pub fn new(f: F) -> Self {
172        Self {
173            f,
174            _extractor: PhantomData,
175        }
176    }
177}
178
179impl<F, T> Clone for HandleErrorLayer<F, T>
180where
181    F: Clone,
182{
183    fn clone(&self) -> Self {
184        Self {
185            f: self.f.clone(),
186            _extractor: PhantomData,
187        }
188    }
189}
190
191impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        f.debug_struct("HandleErrorLayer")
194            .field("f", &format_args!("{}", std::any::type_name::<F>()))
195            .finish()
196    }
197}
198
199impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
200where
201    F: Clone,
202{
203    type Service = HandleError<S, F, T>;
204
205    fn layer(&self, inner: S) -> Self::Service {
206        HandleError::new(inner, self.f.clone())
207    }
208}
209
210/// A [`Service`] adapter that handles errors by converting them into responses.
211///
212/// See [module docs](self) for more details on axum's error handling model.
213pub struct HandleError<S, F, T> {
214    inner: S,
215    f: F,
216    _extractor: PhantomData<fn() -> T>,
217}
218
219impl<S, F, T> HandleError<S, F, T> {
220    /// Create a new `HandleError`.
221    pub fn new(inner: S, f: F) -> Self {
222        Self {
223            inner,
224            f,
225            _extractor: PhantomData,
226        }
227    }
228}
229
230impl<S, F, T> Clone for HandleError<S, F, T>
231where
232    S: Clone,
233    F: Clone,
234{
235    fn clone(&self) -> Self {
236        Self {
237            inner: self.inner.clone(),
238            f: self.f.clone(),
239            _extractor: PhantomData,
240        }
241    }
242}
243
244impl<S, F, E> fmt::Debug for HandleError<S, F, E>
245where
246    S: fmt::Debug,
247{
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        f.debug_struct("HandleError")
250            .field("inner", &self.inner)
251            .field("f", &format_args!("{}", std::any::type_name::<F>()))
252            .finish()
253    }
254}
255
256impl<S, F, ReqBody, ResBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()>
257where
258    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
259    S::Error: Send,
260    S::Future: Send,
261    F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
262    Fut: Future<Output = Res> + Send,
263    Res: IntoResponse,
264    ReqBody: Send + 'static,
265    ResBody: HttpBody<Data = Bytes> + Send + 'static,
266    ResBody::Error: Into<BoxError>,
267{
268    type Response = Response<BoxBody>;
269    type Error = Infallible;
270    type Future = ResponseFuture;
271
272    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273        Poll::Ready(Ok(()))
274    }
275
276    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
277        let f = self.f.clone();
278
279        let clone = self.inner.clone();
280        let inner = std::mem::replace(&mut self.inner, clone);
281
282        let future = Box::pin(async move {
283            match inner.oneshot(req).await {
284                Ok(res) => Ok(res.map(box_body)),
285                Err(err) => Ok(f(err).await.into_response().map(box_body)),
286            }
287        });
288
289        ResponseFuture { future }
290    }
291}
292
293#[allow(unused_macros)]
294macro_rules! impl_service {
295    ( $($ty:ident),* $(,)? ) => {
296        impl<S, F, ReqBody, ResBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>>
297            for HandleError<S, F, ($($ty,)*)>
298        where
299            S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
300            S::Error: Send,
301            S::Future: Send,
302            F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
303            Fut: Future<Output = Res> + Send,
304            Res: IntoResponse,
305            $( $ty: FromRequest<ReqBody> + Send,)*
306            ReqBody: Send + 'static,
307            ResBody: HttpBody<Data = Bytes> + Send + 'static,
308            ResBody::Error: Into<BoxError>,
309        {
310            type Response = Response<BoxBody>;
311            type Error = Infallible;
312
313            type Future = ResponseFuture;
314
315            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
316                Poll::Ready(Ok(()))
317            }
318
319            #[allow(non_snake_case)]
320            fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
321                let f = self.f.clone();
322
323                let clone = self.inner.clone();
324                let inner = std::mem::replace(&mut self.inner, clone);
325
326                let future = Box::pin(async move {
327                    let mut req = RequestParts::new(req);
328
329                    $(
330                        let $ty = match $ty::from_request(&mut req).await {
331                            Ok(value) => value,
332                            Err(rejection) => return Ok(rejection.into_response().map(box_body)),
333                        };
334                    )*
335
336                    let req = match req.try_into_request() {
337                        Ok(req) => req,
338                        Err(err) => {
339                            return Ok(Response::builder()
340                                .status(StatusCode::INTERNAL_SERVER_ERROR)
341                                .body(box_body(Full::from(err.to_string())))
342                                .unwrap());
343                        }
344                    };
345
346                    match inner.oneshot(req).await {
347                        Ok(res) => Ok(res.map(box_body)),
348                        Err(err) => Ok(f($($ty),*, err).await.into_response().map(box_body)),
349                    }
350                });
351
352                ResponseFuture { future }
353            }
354        }
355    }
356}
357
358impl_service!(T1);
359impl_service!(T1, T2);
360impl_service!(T1, T2, T3);
361impl_service!(T1, T2, T3, T4);
362impl_service!(T1, T2, T3, T4, T5);
363impl_service!(T1, T2, T3, T4, T5, T6);
364impl_service!(T1, T2, T3, T4, T5, T6, T7);
365impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
366impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
367impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
368impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
369impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
370impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
371impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
372impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
373impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
374
375pin_project! {
376    /// Response future for [`HandleError`].
377    pub struct ResponseFuture {
378        #[pin]
379        future: Pin<Box<dyn Future<Output = Result<Response<BoxBody>, Infallible>> + Send + 'static>>,
380    }
381}
382
383impl Future for ResponseFuture {
384    type Output = Result<Response<BoxBody>, Infallible>;
385
386    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
387        self.project().future.poll(cx)
388    }
389}
390
391/// Extension trait to [`Service`] for handling errors by mapping them to
392/// responses.
393///
394/// See [module docs](self) for more details on axum's error handling model.
395pub trait HandleErrorExt<B>: Service<Request<B>> + Sized {
396    /// Apply a [`HandleError`] middleware.
397    fn handle_error<F>(self, f: F) -> HandleError<Self, F, B> {
398        HandleError::new(self, f)
399    }
400}
401
402impl<B, S> HandleErrorExt<B> for S where S: Service<Request<B>> {}