actix_web/middleware/
condition.rs

1//! For middleware documentation, see [`Condition`].
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures_core::{future::LocalBoxFuture, ready};
10use futures_util::FutureExt as _;
11use pin_project_lite::pin_project;
12
13use crate::{
14    body::EitherBody,
15    dev::{Service, ServiceResponse, Transform},
16};
17
18/// Middleware for conditionally enabling other middleware.
19///
20/// # Examples
21/// ```
22/// use actix_web::middleware::{Condition, NormalizePath};
23/// use actix_web::App;
24///
25/// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok();
26/// let app = App::new()
27///     .wrap(Condition::new(enable_normalize, NormalizePath::default()));
28/// ```
29pub struct Condition<T> {
30    transformer: T,
31    enable: bool,
32}
33
34impl<T> Condition<T> {
35    pub fn new(enable: bool, transformer: T) -> Self {
36        Self {
37            transformer,
38            enable,
39        }
40    }
41}
42
43impl<S, T, Req, BE, BD, Err> Transform<S, Req> for Condition<T>
44where
45    S: Service<Req, Response = ServiceResponse<BD>, Error = Err> + 'static,
46    T: Transform<S, Req, Response = ServiceResponse<BE>, Error = Err>,
47    T::Future: 'static,
48    T::InitError: 'static,
49    T::Transform: 'static,
50{
51    type Response = ServiceResponse<EitherBody<BE, BD>>;
52    type Error = Err;
53    type Transform = ConditionMiddleware<T::Transform, S>;
54    type InitError = T::InitError;
55    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
56
57    fn new_transform(&self, service: S) -> Self::Future {
58        if self.enable {
59            let fut = self.transformer.new_transform(service);
60            async move {
61                let wrapped_svc = fut.await?;
62                Ok(ConditionMiddleware::Enable(wrapped_svc))
63            }
64            .boxed_local()
65        } else {
66            async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local()
67        }
68    }
69}
70
71pub enum ConditionMiddleware<E, D> {
72    Enable(E),
73    Disable(D),
74}
75
76impl<E, D, Req, BE, BD, Err> Service<Req> for ConditionMiddleware<E, D>
77where
78    E: Service<Req, Response = ServiceResponse<BE>, Error = Err>,
79    D: Service<Req, Response = ServiceResponse<BD>, Error = Err>,
80{
81    type Response = ServiceResponse<EitherBody<BE, BD>>;
82    type Error = Err;
83    type Future = ConditionMiddlewareFuture<E::Future, D::Future>;
84
85    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        match self {
87            ConditionMiddleware::Enable(service) => service.poll_ready(cx),
88            ConditionMiddleware::Disable(service) => service.poll_ready(cx),
89        }
90    }
91
92    fn call(&self, req: Req) -> Self::Future {
93        match self {
94            ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Enabled {
95                fut: service.call(req),
96            },
97            ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Disabled {
98                fut: service.call(req),
99            },
100        }
101    }
102}
103
104pin_project! {
105    #[doc(hidden)]
106    #[project = ConditionProj]
107    pub enum ConditionMiddlewareFuture<E, D> {
108        Enabled { #[pin] fut: E, },
109        Disabled { #[pin] fut: D, },
110    }
111}
112
113impl<E, D, BE, BD, Err> Future for ConditionMiddlewareFuture<E, D>
114where
115    E: Future<Output = Result<ServiceResponse<BE>, Err>>,
116    D: Future<Output = Result<ServiceResponse<BD>, Err>>,
117{
118    type Output = Result<ServiceResponse<EitherBody<BE, BD>>, Err>;
119
120    #[inline]
121    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122        let res = match self.project() {
123            ConditionProj::Enabled { fut } => ready!(fut.poll(cx))?.map_into_left_body(),
124            ConditionProj::Disabled { fut } => ready!(fut.poll(cx))?.map_into_right_body(),
125        };
126
127        Poll::Ready(Ok(res))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use actix_service::IntoService as _;
134
135    use super::*;
136    use crate::{
137        body::BoxBody,
138        dev::ServiceRequest,
139        error::Result,
140        http::{
141            header::{HeaderValue, CONTENT_TYPE},
142            StatusCode,
143        },
144        middleware::{self, ErrorHandlerResponse, ErrorHandlers, Identity},
145        test::{self, TestRequest},
146        web::Bytes,
147        HttpResponse,
148    };
149
150    #[allow(clippy::unnecessary_wraps)]
151    fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
152        res.response_mut()
153            .headers_mut()
154            .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
155
156        Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
157    }
158
159    #[test]
160    fn compat_with_builtin_middleware() {
161        let _ = Condition::new(true, middleware::Compat::new(Identity));
162        let _ = Condition::new(true, middleware::Logger::default());
163        let _ = Condition::new(true, middleware::Compress::default());
164        let _ = Condition::new(true, middleware::NormalizePath::trim());
165        let _ = Condition::new(true, middleware::DefaultHeaders::new());
166        let _ = Condition::new(true, middleware::ErrorHandlers::<BoxBody>::new());
167        let _ = Condition::new(true, middleware::ErrorHandlers::<Bytes>::new());
168    }
169
170    #[actix_rt::test]
171    async fn test_handler_enabled() {
172        let srv = |req: ServiceRequest| async move {
173            let resp = HttpResponse::InternalServerError().message_body(String::new())?;
174            Ok(req.into_response(resp))
175        };
176
177        let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
178
179        let mw = Condition::new(true, mw)
180            .new_transform(srv.into_service())
181            .await
182            .unwrap();
183
184        let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
185            test::call_service(&mw, TestRequest::default().to_srv_request()).await;
186        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
187    }
188
189    #[actix_rt::test]
190    async fn test_handler_disabled() {
191        let srv = |req: ServiceRequest| async move {
192            let resp = HttpResponse::InternalServerError().message_body(String::new())?;
193            Ok(req.into_response(resp))
194        };
195
196        let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
197
198        let mw = Condition::new(false, mw)
199            .new_transform(srv.into_service())
200            .await
201            .unwrap();
202
203        let resp: ServiceResponse<EitherBody<EitherBody<_, _>, String>> =
204            test::call_service(&mw, TestRequest::default().to_srv_request()).await;
205        assert_eq!(resp.headers().get(CONTENT_TYPE), None);
206    }
207}