1#![warn(
101 clippy::all,
102 clippy::dbg_macro,
103 clippy::todo,
104 clippy::empty_enum,
105 clippy::enum_glob_use,
106 clippy::mem_forget,
107 clippy::unused_self,
108 clippy::filter_map_next,
109 clippy::needless_continue,
110 clippy::needless_borrow,
111 clippy::match_wildcard_for_single_variants,
112 clippy::if_let_mutex,
113 clippy::mismatched_target_os,
114 clippy::await_holding_lock,
115 clippy::match_on_vec_items,
116 clippy::imprecise_flops,
117 clippy::suboptimal_flops,
118 clippy::lossy_float_literal,
119 clippy::rest_pat_in_fully_bound_structs,
120 clippy::fn_params_excessive_bools,
121 clippy::exit,
122 clippy::inefficient_to_string,
123 clippy::linkedlist,
124 clippy::macro_use_imports,
125 clippy::option_option,
126 clippy::verbose_file_reads,
127 clippy::unnested_or_patterns,
128 rust_2018_idioms,
129 future_incompatible,
130 nonstandard_style,
131 missing_debug_implementations,
132 missing_docs
133)]
134#![deny(unreachable_pub, private_in_public)]
135#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
136#![forbid(unsafe_code)]
137#![cfg_attr(docsrs, feature(doc_cfg))]
138#![cfg_attr(test, allow(clippy::float_cmp))]
139
140use axum::{
141 body::{box_body, BoxBody, Bytes, Full, HttpBody},
142 extract::{FromRequest, RequestParts},
143 http::{Request, Response, StatusCode},
144 response::IntoResponse,
145 BoxError,
146};
147use pin_project_lite::pin_project;
148use std::{
149 convert::Infallible,
150 fmt,
151 future::Future,
152 marker::PhantomData,
153 pin::Pin,
154 task::{Context, Poll},
155};
156use tower::ServiceExt;
157use tower_layer::Layer;
158use tower_service::Service;
159
160pub struct HandleErrorLayer<F, T> {
165 f: F,
166 _extractor: PhantomData<fn() -> T>,
167}
168
169impl<F, T> HandleErrorLayer<F, T> {
170 pub fn new(f: F) -> Self {
172 Self {
173 f,
174 _extractor: PhantomData,
175 }
176 }
177}
178
179impl<F, T> Clone for HandleErrorLayer<F, T>
180where
181 F: Clone,
182{
183 fn clone(&self) -> Self {
184 Self {
185 f: self.f.clone(),
186 _extractor: PhantomData,
187 }
188 }
189}
190
191impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 f.debug_struct("HandleErrorLayer")
194 .field("f", &format_args!("{}", std::any::type_name::<F>()))
195 .finish()
196 }
197}
198
199impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
200where
201 F: Clone,
202{
203 type Service = HandleError<S, F, T>;
204
205 fn layer(&self, inner: S) -> Self::Service {
206 HandleError::new(inner, self.f.clone())
207 }
208}
209
210pub struct HandleError<S, F, T> {
214 inner: S,
215 f: F,
216 _extractor: PhantomData<fn() -> T>,
217}
218
219impl<S, F, T> HandleError<S, F, T> {
220 pub fn new(inner: S, f: F) -> Self {
222 Self {
223 inner,
224 f,
225 _extractor: PhantomData,
226 }
227 }
228}
229
230impl<S, F, T> Clone for HandleError<S, F, T>
231where
232 S: Clone,
233 F: Clone,
234{
235 fn clone(&self) -> Self {
236 Self {
237 inner: self.inner.clone(),
238 f: self.f.clone(),
239 _extractor: PhantomData,
240 }
241 }
242}
243
244impl<S, F, E> fmt::Debug for HandleError<S, F, E>
245where
246 S: fmt::Debug,
247{
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("HandleError")
250 .field("inner", &self.inner)
251 .field("f", &format_args!("{}", std::any::type_name::<F>()))
252 .finish()
253 }
254}
255
256impl<S, F, ReqBody, ResBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()>
257where
258 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
259 S::Error: Send,
260 S::Future: Send,
261 F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
262 Fut: Future<Output = Res> + Send,
263 Res: IntoResponse,
264 ReqBody: Send + 'static,
265 ResBody: HttpBody<Data = Bytes> + Send + 'static,
266 ResBody::Error: Into<BoxError>,
267{
268 type Response = Response<BoxBody>;
269 type Error = Infallible;
270 type Future = ResponseFuture;
271
272 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273 Poll::Ready(Ok(()))
274 }
275
276 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
277 let f = self.f.clone();
278
279 let clone = self.inner.clone();
280 let inner = std::mem::replace(&mut self.inner, clone);
281
282 let future = Box::pin(async move {
283 match inner.oneshot(req).await {
284 Ok(res) => Ok(res.map(box_body)),
285 Err(err) => Ok(f(err).await.into_response().map(box_body)),
286 }
287 });
288
289 ResponseFuture { future }
290 }
291}
292
293#[allow(unused_macros)]
294macro_rules! impl_service {
295 ( $($ty:ident),* $(,)? ) => {
296 impl<S, F, ReqBody, ResBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>>
297 for HandleError<S, F, ($($ty,)*)>
298 where
299 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
300 S::Error: Send,
301 S::Future: Send,
302 F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
303 Fut: Future<Output = Res> + Send,
304 Res: IntoResponse,
305 $( $ty: FromRequest<ReqBody> + Send,)*
306 ReqBody: Send + 'static,
307 ResBody: HttpBody<Data = Bytes> + Send + 'static,
308 ResBody::Error: Into<BoxError>,
309 {
310 type Response = Response<BoxBody>;
311 type Error = Infallible;
312
313 type Future = ResponseFuture;
314
315 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
316 Poll::Ready(Ok(()))
317 }
318
319 #[allow(non_snake_case)]
320 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
321 let f = self.f.clone();
322
323 let clone = self.inner.clone();
324 let inner = std::mem::replace(&mut self.inner, clone);
325
326 let future = Box::pin(async move {
327 let mut req = RequestParts::new(req);
328
329 $(
330 let $ty = match $ty::from_request(&mut req).await {
331 Ok(value) => value,
332 Err(rejection) => return Ok(rejection.into_response().map(box_body)),
333 };
334 )*
335
336 let req = match req.try_into_request() {
337 Ok(req) => req,
338 Err(err) => {
339 return Ok(Response::builder()
340 .status(StatusCode::INTERNAL_SERVER_ERROR)
341 .body(box_body(Full::from(err.to_string())))
342 .unwrap());
343 }
344 };
345
346 match inner.oneshot(req).await {
347 Ok(res) => Ok(res.map(box_body)),
348 Err(err) => Ok(f($($ty),*, err).await.into_response().map(box_body)),
349 }
350 });
351
352 ResponseFuture { future }
353 }
354 }
355 }
356}
357
358impl_service!(T1);
359impl_service!(T1, T2);
360impl_service!(T1, T2, T3);
361impl_service!(T1, T2, T3, T4);
362impl_service!(T1, T2, T3, T4, T5);
363impl_service!(T1, T2, T3, T4, T5, T6);
364impl_service!(T1, T2, T3, T4, T5, T6, T7);
365impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
366impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
367impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
368impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
369impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
370impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
371impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
372impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
373impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
374
375pin_project! {
376 pub struct ResponseFuture {
378 #[pin]
379 future: Pin<Box<dyn Future<Output = Result<Response<BoxBody>, Infallible>> + Send + 'static>>,
380 }
381}
382
383impl Future for ResponseFuture {
384 type Output = Result<Response<BoxBody>, Infallible>;
385
386 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
387 self.project().future.poll(cx)
388 }
389}
390
391pub trait HandleErrorExt<B>: Service<Request<B>> + Sized {
396 fn handle_error<F>(self, f: F) -> HandleError<Self, F, B> {
398 HandleError::new(self, f)
399 }
400}
401
402impl<B, S> HandleErrorExt<B> for S where S: Service<Request<B>> {}