actix_web/middleware/
from_fn.rs

1use std::{future::Future, marker::PhantomData, rc::Rc};
2
3use actix_service::boxed::{self, BoxFuture, RcService};
4use actix_utils::future::{ready, Ready};
5use futures_core::future::LocalBoxFuture;
6
7use crate::{
8    body::MessageBody,
9    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10    Error, FromRequest,
11};
12
13/// Wraps an async function to be used as a middleware.
14///
15/// # Examples
16///
17/// The wrapped function should have the following form:
18///
19/// ```
20/// # use actix_web::{
21/// #     App, Error,
22/// #     body::MessageBody,
23/// #     dev::{ServiceRequest, ServiceResponse, Service as _},
24/// # };
25/// use actix_web::middleware::{self, Next};
26///
27/// async fn my_mw(
28///     req: ServiceRequest,
29///     next: Next<impl MessageBody>,
30/// ) -> Result<ServiceResponse<impl MessageBody>, Error> {
31///     // pre-processing
32///     next.call(req).await
33///     // post-processing
34/// }
35/// # App::new().wrap(middleware::from_fn(my_mw));
36/// ```
37///
38/// Then use in an app builder like this:
39///
40/// ```
41/// use actix_web::{
42///     App, Error,
43///     dev::{ServiceRequest, ServiceResponse, Service as _},
44/// };
45/// use actix_web::middleware::from_fn;
46/// # use actix_web::middleware::Next;
47/// # async fn my_mw<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
48/// #     next.call(req).await
49/// # }
50///
51/// App::new()
52///     .wrap(from_fn(my_mw))
53/// # ;
54/// ```
55///
56/// It is also possible to write a middleware that automatically uses extractors, similar to request
57/// handlers, by declaring them as the first parameters. As usual, **take care with extractors that
58/// consume the body stream**, since handlers will no longer be able to read it again without
59/// putting the body "back" into the request object within your middleware.
60///
61/// ```
62/// # use std::collections::HashMap;
63/// # use actix_web::{
64/// #     App, Error,
65/// #     body::MessageBody,
66/// #     dev::{ServiceRequest, ServiceResponse},
67/// #     http::header::{Accept, Date},
68/// #     web::{Header, Query},
69/// # };
70/// use actix_web::middleware::Next;
71///
72/// async fn my_extracting_mw(
73///     accept: Header<Accept>,
74///     query: Query<HashMap<String, String>>,
75///     req: ServiceRequest,
76///     next: Next<impl MessageBody>,
77/// ) -> Result<ServiceResponse<impl MessageBody>, Error> {
78///     // pre-processing
79///     next.call(req).await
80///     // post-processing
81/// }
82/// # App::new().wrap(actix_web::middleware::from_fn(my_extracting_mw));
83pub fn from_fn<F, Es>(mw_fn: F) -> MiddlewareFn<F, Es> {
84    MiddlewareFn {
85        mw_fn: Rc::new(mw_fn),
86        _phantom: PhantomData,
87    }
88}
89
90/// Middleware transform for [`from_fn`].
91#[allow(missing_debug_implementations)]
92pub struct MiddlewareFn<F, Es> {
93    mw_fn: Rc<F>,
94    _phantom: PhantomData<Es>,
95}
96
97impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MiddlewareFn<F, ()>
98where
99    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
100    F: Fn(ServiceRequest, Next<B>) -> Fut + 'static,
101    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
102    B2: MessageBody,
103{
104    type Response = ServiceResponse<B2>;
105    type Error = Error;
106    type Transform = MiddlewareFnService<F, B, ()>;
107    type InitError = ();
108    type Future = Ready<Result<Self::Transform, Self::InitError>>;
109
110    fn new_transform(&self, service: S) -> Self::Future {
111        ready(Ok(MiddlewareFnService {
112            service: boxed::rc_service(service),
113            mw_fn: Rc::clone(&self.mw_fn),
114            _phantom: PhantomData,
115        }))
116    }
117}
118
119/// Middleware service for [`from_fn`].
120#[allow(missing_debug_implementations)]
121pub struct MiddlewareFnService<F, B, Es> {
122    service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
123    mw_fn: Rc<F>,
124    _phantom: PhantomData<(B, Es)>,
125}
126
127impl<F, Fut, B, B2> Service<ServiceRequest> for MiddlewareFnService<F, B, ()>
128where
129    F: Fn(ServiceRequest, Next<B>) -> Fut,
130    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
131    B2: MessageBody,
132{
133    type Response = ServiceResponse<B2>;
134    type Error = Error;
135    type Future = Fut;
136
137    forward_ready!(service);
138
139    fn call(&self, req: ServiceRequest) -> Self::Future {
140        (self.mw_fn)(
141            req,
142            Next::<B> {
143                service: Rc::clone(&self.service),
144            },
145        )
146    }
147}
148
149macro_rules! impl_middleware_fn_service {
150    ($($ext_type:ident),*) => {
151        impl<S, F, Fut, B, B2, $($ext_type),*> Transform<S, ServiceRequest> for MiddlewareFn<F, ($($ext_type),*,)>
152        where
153            S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
154            F: Fn($($ext_type),*, ServiceRequest, Next<B>) -> Fut + 'static,
155            $($ext_type: FromRequest + 'static,)*
156            Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
157            B: MessageBody + 'static,
158            B2: MessageBody + 'static,
159        {
160            type Response = ServiceResponse<B2>;
161            type Error = Error;
162            type Transform = MiddlewareFnService<F, B, ($($ext_type,)*)>;
163            type InitError = ();
164            type Future = Ready<Result<Self::Transform, Self::InitError>>;
165
166            fn new_transform(&self, service: S) -> Self::Future {
167                ready(Ok(MiddlewareFnService {
168                    service: boxed::rc_service(service),
169                    mw_fn: Rc::clone(&self.mw_fn),
170                    _phantom: PhantomData,
171                }))
172            }
173        }
174
175        impl<F, $($ext_type),*, Fut, B: 'static, B2> Service<ServiceRequest>
176            for MiddlewareFnService<F, B, ($($ext_type),*,)>
177        where
178            F: Fn(
179                $($ext_type),*,
180                ServiceRequest,
181                Next<B>
182            ) -> Fut + 'static,
183            $($ext_type: FromRequest + 'static,)*
184            Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
185            B2: MessageBody + 'static,
186        {
187            type Response = ServiceResponse<B2>;
188            type Error = Error;
189            type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
190
191            forward_ready!(service);
192
193            #[allow(nonstandard_style)]
194            fn call(&self, mut req: ServiceRequest) -> Self::Future {
195                let mw_fn = Rc::clone(&self.mw_fn);
196                let service = Rc::clone(&self.service);
197
198                Box::pin(async move {
199                    let ($($ext_type,)*) = req.extract::<($($ext_type,)*)>().await?;
200
201                    (mw_fn)($($ext_type),*, req, Next::<B> { service }).await
202                })
203            }
204        }
205    };
206}
207
208impl_middleware_fn_service!(E1);
209impl_middleware_fn_service!(E1, E2);
210impl_middleware_fn_service!(E1, E2, E3);
211impl_middleware_fn_service!(E1, E2, E3, E4);
212impl_middleware_fn_service!(E1, E2, E3, E4, E5);
213impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6);
214impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7);
215impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8);
216impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8, E9);
217
218/// Wraps the "next" service in the middleware chain.
219#[allow(missing_debug_implementations)]
220pub struct Next<B> {
221    service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
222}
223
224impl<B> Next<B> {
225    /// Equivalent to `Service::call(self, req)`.
226    pub fn call(&self, req: ServiceRequest) -> <Self as Service<ServiceRequest>>::Future {
227        Service::call(self, req)
228    }
229}
230
231impl<B> Service<ServiceRequest> for Next<B> {
232    type Response = ServiceResponse<B>;
233    type Error = Error;
234    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
235
236    forward_ready!(service);
237
238    fn call(&self, req: ServiceRequest) -> Self::Future {
239        self.service.call(req)
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::{
247        http::header::{self, HeaderValue},
248        middleware::{Compat, Logger},
249        test, web, App, HttpResponse,
250    };
251
252    async fn noop<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
253        next.call(req).await
254    }
255
256    async fn add_res_header<B>(
257        req: ServiceRequest,
258        next: Next<B>,
259    ) -> Result<ServiceResponse<B>, Error> {
260        let mut res = next.call(req).await?;
261        res.headers_mut()
262            .insert(header::WARNING, HeaderValue::from_static("42"));
263        Ok(res)
264    }
265
266    async fn mutate_body_type(
267        req: ServiceRequest,
268        next: Next<impl MessageBody + 'static>,
269    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
270        let res = next.call(req).await?;
271        Ok(res.map_into_left_body::<()>())
272    }
273
274    struct MyMw(bool);
275
276    impl MyMw {
277        async fn mw_cb(
278            &self,
279            req: ServiceRequest,
280            next: Next<impl MessageBody + 'static>,
281        ) -> Result<ServiceResponse<impl MessageBody>, Error> {
282            let mut res = match self.0 {
283                true => req.into_response("short-circuited").map_into_right_body(),
284                false => next.call(req).await?.map_into_left_body(),
285            };
286            res.headers_mut()
287                .insert(header::WARNING, HeaderValue::from_static("42"));
288            Ok(res)
289        }
290
291        pub fn into_middleware<S, B>(
292            self,
293        ) -> impl Transform<
294            S,
295            ServiceRequest,
296            Response = ServiceResponse<impl MessageBody>,
297            Error = Error,
298            InitError = (),
299        >
300        where
301            S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
302            B: MessageBody + 'static,
303        {
304            let this = Rc::new(self);
305            from_fn(move |req, next| {
306                let this = Rc::clone(&this);
307                async move { Self::mw_cb(&this, req, next).await }
308            })
309        }
310    }
311
312    #[actix_rt::test]
313    async fn compat_compat() {
314        let _ = App::new().wrap(Compat::new(from_fn(noop)));
315        let _ = App::new().wrap(Compat::new(from_fn(mutate_body_type)));
316    }
317
318    #[actix_rt::test]
319    async fn permits_different_in_and_out_body_types() {
320        let app = test::init_service(
321            App::new()
322                .wrap(from_fn(mutate_body_type))
323                .wrap(from_fn(add_res_header))
324                .wrap(Logger::default())
325                .wrap(from_fn(noop))
326                .default_service(web::to(HttpResponse::NotFound)),
327        )
328        .await;
329
330        let req = test::TestRequest::default().to_request();
331        let res = test::call_service(&app, req).await;
332        assert!(res.headers().contains_key(header::WARNING));
333    }
334
335    #[actix_rt::test]
336    async fn closure_capture_and_return_from_fn() {
337        let app = test::init_service(
338            App::new()
339                .wrap(Logger::default())
340                .wrap(MyMw(true).into_middleware())
341                .wrap(Logger::default()),
342        )
343        .await;
344
345        let req = test::TestRequest::default().to_request();
346        let res = test::call_service(&app, req).await;
347        assert!(res.headers().contains_key(header::WARNING));
348    }
349}