1use std::{
4 future::Future,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use foldhash::HashMap as FoldHashMap;
12use futures_core::{future::LocalBoxFuture, ready};
13use pin_project_lite::pin_project;
14
15use crate::{
16 body::EitherBody,
17 dev::{ServiceRequest, ServiceResponse},
18 http::StatusCode,
19 Error, Result,
20};
21
22pub enum ErrorHandlerResponse<B> {
24 Response(ServiceResponse<EitherBody<B>>),
26
27 Future(LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>),
29}
30
31type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
32
33type DefaultHandler<B> = Option<Rc<ErrorHandler<B>>>;
34
35pub struct ErrorHandlers<B> {
183 default_client: DefaultHandler<B>,
184 default_server: DefaultHandler<B>,
185 handlers: Handlers<B>,
186}
187
188type Handlers<B> = Rc<FoldHashMap<StatusCode, Box<ErrorHandler<B>>>>;
189
190impl<B> Default for ErrorHandlers<B> {
191 fn default() -> Self {
192 ErrorHandlers {
193 default_client: Default::default(),
194 default_server: Default::default(),
195 handlers: Default::default(),
196 }
197 }
198}
199
200impl<B> ErrorHandlers<B> {
201 pub fn new() -> Self {
203 ErrorHandlers::default()
204 }
205
206 pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
208 where
209 F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
210 {
211 Rc::get_mut(&mut self.handlers)
212 .unwrap()
213 .insert(status, Box::new(handler));
214 self
215 }
216
217 pub fn default_handler<F>(self, handler: F) -> Self
230 where
231 F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
232 {
233 let handler = Rc::new(handler);
234 let handler2 = Rc::clone(&handler);
235 Self {
236 default_server: Some(handler2),
237 default_client: Some(handler),
238 ..self
239 }
240 }
241
242 pub fn default_handler_client<F>(self, handler: F) -> Self
244 where
245 F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
246 {
247 Self {
248 default_client: Some(Rc::new(handler)),
249 ..self
250 }
251 }
252
253 pub fn default_handler_server<F>(self, handler: F) -> Self
255 where
256 F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
257 {
258 Self {
259 default_server: Some(Rc::new(handler)),
260 ..self
261 }
262 }
263
264 fn get_handler<'a>(
269 status: &StatusCode,
270 default_client: Option<&'a ErrorHandler<B>>,
271 default_server: Option<&'a ErrorHandler<B>>,
272 handlers: &'a Handlers<B>,
273 ) -> Option<&'a ErrorHandler<B>> {
274 handlers
275 .get(status)
276 .map(|h| h.as_ref())
277 .or_else(|| status.is_client_error().then_some(default_client).flatten())
278 .or_else(|| status.is_server_error().then_some(default_server).flatten())
279 }
280}
281
282impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
283where
284 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
285 S::Future: 'static,
286 B: 'static,
287{
288 type Response = ServiceResponse<EitherBody<B>>;
289 type Error = Error;
290 type Transform = ErrorHandlersMiddleware<S, B>;
291 type InitError = ();
292 type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
293
294 fn new_transform(&self, service: S) -> Self::Future {
295 let handlers = Rc::clone(&self.handlers);
296 let default_client = self.default_client.clone();
297 let default_server = self.default_server.clone();
298 Box::pin(async move {
299 Ok(ErrorHandlersMiddleware {
300 service,
301 default_client,
302 default_server,
303 handlers,
304 })
305 })
306 }
307}
308
309#[doc(hidden)]
310pub struct ErrorHandlersMiddleware<S, B> {
311 service: S,
312 default_client: DefaultHandler<B>,
313 default_server: DefaultHandler<B>,
314 handlers: Handlers<B>,
315}
316
317impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
318where
319 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
320 S::Future: 'static,
321 B: 'static,
322{
323 type Response = ServiceResponse<EitherBody<B>>;
324 type Error = Error;
325 type Future = ErrorHandlersFuture<S::Future, B>;
326
327 actix_service::forward_ready!(service);
328
329 fn call(&self, req: ServiceRequest) -> Self::Future {
330 let handlers = Rc::clone(&self.handlers);
331 let default_client = self.default_client.clone();
332 let default_server = self.default_server.clone();
333 let fut = self.service.call(req);
334 ErrorHandlersFuture::ServiceFuture {
335 fut,
336 default_client,
337 default_server,
338 handlers,
339 }
340 }
341}
342
343pin_project! {
344 #[project = ErrorHandlersProj]
345 pub enum ErrorHandlersFuture<Fut, B>
346 where
347 Fut: Future,
348 {
349 ServiceFuture {
350 #[pin]
351 fut: Fut,
352 default_client: DefaultHandler<B>,
353 default_server: DefaultHandler<B>,
354 handlers: Handlers<B>,
355 },
356 ErrorHandlerFuture {
357 fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
358 },
359 }
360}
361
362impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
363where
364 Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
365{
366 type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
367
368 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
369 match self.as_mut().project() {
370 ErrorHandlersProj::ServiceFuture {
371 fut,
372 default_client,
373 default_server,
374 handlers,
375 } => {
376 let res = ready!(fut.poll(cx))?;
377 let status = res.status();
378
379 let handler = ErrorHandlers::get_handler(
380 &status,
381 default_client.as_mut().map(|f| Rc::as_ref(f)),
382 default_server.as_mut().map(|f| Rc::as_ref(f)),
383 handlers,
384 );
385 match handler {
386 Some(handler) => match handler(res)? {
387 ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
388 ErrorHandlerResponse::Future(fut) => {
389 self.as_mut()
390 .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
391
392 self.poll(cx)
393 }
394 },
395 None => Poll::Ready(Ok(res.map_into_left_body())),
396 }
397 }
398
399 ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use actix_service::IntoService;
407 use actix_utils::future::ok;
408 use bytes::Bytes;
409 use futures_util::FutureExt as _;
410
411 use super::*;
412 use crate::{
413 body,
414 http::header::{HeaderValue, CONTENT_TYPE},
415 test::{self, TestRequest},
416 };
417
418 #[actix_rt::test]
419 async fn add_header_error_handler() {
420 #[allow(clippy::unnecessary_wraps)]
421 fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
422 res.response_mut()
423 .headers_mut()
424 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
425
426 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
427 }
428
429 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
430
431 let mw = ErrorHandlers::new()
432 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
433 .new_transform(srv.into_service())
434 .await
435 .unwrap();
436
437 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
438 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
439 }
440
441 #[actix_rt::test]
442 async fn add_header_error_handler_async() {
443 #[allow(clippy::unnecessary_wraps)]
444 fn error_handler<B: 'static>(
445 mut res: ServiceResponse<B>,
446 ) -> Result<ErrorHandlerResponse<B>> {
447 res.response_mut()
448 .headers_mut()
449 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
450
451 Ok(ErrorHandlerResponse::Future(
452 ok(res.map_into_left_body()).boxed_local(),
453 ))
454 }
455
456 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
457
458 let mw = ErrorHandlers::new()
459 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
460 .new_transform(srv.into_service())
461 .await
462 .unwrap();
463
464 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
465 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
466 }
467
468 #[actix_rt::test]
469 async fn changes_body_type() {
470 #[allow(clippy::unnecessary_wraps)]
471 fn error_handler<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
472 let (req, res) = res.into_parts();
473 let res = res.set_body(Bytes::from("sorry, that's no bueno"));
474
475 let res = ServiceResponse::new(req, res)
476 .map_into_boxed_body()
477 .map_into_right_body();
478
479 Ok(ErrorHandlerResponse::Response(res))
480 }
481
482 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
483
484 let mw = ErrorHandlers::new()
485 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
486 .new_transform(srv.into_service())
487 .await
488 .unwrap();
489
490 let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
491 assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
492 }
493
494 #[actix_rt::test]
495 async fn error_thrown() {
496 #[allow(clippy::unnecessary_wraps)]
497 fn error_handler<B>(_res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
498 Err(crate::error::ErrorInternalServerError(
499 "error in error handler",
500 ))
501 }
502
503 let srv = test::status_service(StatusCode::BAD_REQUEST);
504
505 let mw = ErrorHandlers::new()
506 .handler(StatusCode::BAD_REQUEST, error_handler)
507 .new_transform(srv.into_service())
508 .await
509 .unwrap();
510
511 let err = mw
512 .call(TestRequest::default().to_srv_request())
513 .await
514 .unwrap_err();
515 let res = err.error_response();
516
517 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
518 assert_eq!(
519 body::to_bytes(res.into_body()).await.unwrap(),
520 "error in error handler"
521 );
522 }
523
524 #[actix_rt::test]
525 async fn default_error_handler() {
526 #[allow(clippy::unnecessary_wraps)]
527 fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
528 res.response_mut()
529 .headers_mut()
530 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
531 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
532 }
533
534 let make_mw = |status| async move {
535 ErrorHandlers::new()
536 .default_handler(error_handler)
537 .new_transform(test::status_service(status).into_service())
538 .await
539 .unwrap()
540 };
541 let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
542 let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
543
544 let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
545 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
546
547 let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
548 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
549 }
550
551 #[actix_rt::test]
552 async fn default_handlers_separate_client_server() {
553 #[allow(clippy::unnecessary_wraps)]
554 fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
555 res.response_mut()
556 .headers_mut()
557 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
558 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
559 }
560
561 #[allow(clippy::unnecessary_wraps)]
562 fn error_handler_server<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
563 res.response_mut()
564 .headers_mut()
565 .insert(CONTENT_TYPE, HeaderValue::from_static("0002"));
566 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
567 }
568
569 let make_mw = |status| async move {
570 ErrorHandlers::new()
571 .default_handler_server(error_handler_server)
572 .default_handler_client(error_handler_client)
573 .new_transform(test::status_service(status).into_service())
574 .await
575 .unwrap()
576 };
577 let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
578 let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
579
580 let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
581 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
582
583 let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
584 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
585 }
586
587 #[actix_rt::test]
588 async fn default_handlers_specialization() {
589 #[allow(clippy::unnecessary_wraps)]
590 fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
591 res.response_mut()
592 .headers_mut()
593 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
594 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
595 }
596
597 #[allow(clippy::unnecessary_wraps)]
598 fn error_handler_specific<B>(
599 mut res: ServiceResponse<B>,
600 ) -> Result<ErrorHandlerResponse<B>> {
601 res.response_mut()
602 .headers_mut()
603 .insert(CONTENT_TYPE, HeaderValue::from_static("0003"));
604 Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
605 }
606
607 let make_mw = |status| async move {
608 ErrorHandlers::new()
609 .default_handler_client(error_handler_client)
610 .handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific)
611 .new_transform(test::status_service(status).into_service())
612 .await
613 .unwrap()
614 };
615 let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
616 let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await;
617
618 let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
619 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
620
621 let resp = test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await;
622 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003");
623 }
624}