1use std::{future::Future, marker::PhantomData, rc::Rc};
2
3use actix_service::boxed::{self, BoxFuture, RcService};
4use actix_utils::future::{ready, Ready};
5use futures_core::future::LocalBoxFuture;
6
7use crate::{
8 body::MessageBody,
9 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10 Error, FromRequest,
11};
12
13pub fn from_fn<F, Es>(mw_fn: F) -> MiddlewareFn<F, Es> {
84 MiddlewareFn {
85 mw_fn: Rc::new(mw_fn),
86 _phantom: PhantomData,
87 }
88}
89
90#[allow(missing_debug_implementations)]
92pub struct MiddlewareFn<F, Es> {
93 mw_fn: Rc<F>,
94 _phantom: PhantomData<Es>,
95}
96
97impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MiddlewareFn<F, ()>
98where
99 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
100 F: Fn(ServiceRequest, Next<B>) -> Fut + 'static,
101 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
102 B2: MessageBody,
103{
104 type Response = ServiceResponse<B2>;
105 type Error = Error;
106 type Transform = MiddlewareFnService<F, B, ()>;
107 type InitError = ();
108 type Future = Ready<Result<Self::Transform, Self::InitError>>;
109
110 fn new_transform(&self, service: S) -> Self::Future {
111 ready(Ok(MiddlewareFnService {
112 service: boxed::rc_service(service),
113 mw_fn: Rc::clone(&self.mw_fn),
114 _phantom: PhantomData,
115 }))
116 }
117}
118
119#[allow(missing_debug_implementations)]
121pub struct MiddlewareFnService<F, B, Es> {
122 service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
123 mw_fn: Rc<F>,
124 _phantom: PhantomData<(B, Es)>,
125}
126
127impl<F, Fut, B, B2> Service<ServiceRequest> for MiddlewareFnService<F, B, ()>
128where
129 F: Fn(ServiceRequest, Next<B>) -> Fut,
130 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
131 B2: MessageBody,
132{
133 type Response = ServiceResponse<B2>;
134 type Error = Error;
135 type Future = Fut;
136
137 forward_ready!(service);
138
139 fn call(&self, req: ServiceRequest) -> Self::Future {
140 (self.mw_fn)(
141 req,
142 Next::<B> {
143 service: Rc::clone(&self.service),
144 },
145 )
146 }
147}
148
149macro_rules! impl_middleware_fn_service {
150 ($($ext_type:ident),*) => {
151 impl<S, F, Fut, B, B2, $($ext_type),*> Transform<S, ServiceRequest> for MiddlewareFn<F, ($($ext_type),*,)>
152 where
153 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
154 F: Fn($($ext_type),*, ServiceRequest, Next<B>) -> Fut + 'static,
155 $($ext_type: FromRequest + 'static,)*
156 Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
157 B: MessageBody + 'static,
158 B2: MessageBody + 'static,
159 {
160 type Response = ServiceResponse<B2>;
161 type Error = Error;
162 type Transform = MiddlewareFnService<F, B, ($($ext_type,)*)>;
163 type InitError = ();
164 type Future = Ready<Result<Self::Transform, Self::InitError>>;
165
166 fn new_transform(&self, service: S) -> Self::Future {
167 ready(Ok(MiddlewareFnService {
168 service: boxed::rc_service(service),
169 mw_fn: Rc::clone(&self.mw_fn),
170 _phantom: PhantomData,
171 }))
172 }
173 }
174
175 impl<F, $($ext_type),*, Fut, B: 'static, B2> Service<ServiceRequest>
176 for MiddlewareFnService<F, B, ($($ext_type),*,)>
177 where
178 F: Fn(
179 $($ext_type),*,
180 ServiceRequest,
181 Next<B>
182 ) -> Fut + 'static,
183 $($ext_type: FromRequest + 'static,)*
184 Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
185 B2: MessageBody + 'static,
186 {
187 type Response = ServiceResponse<B2>;
188 type Error = Error;
189 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
190
191 forward_ready!(service);
192
193 #[allow(nonstandard_style)]
194 fn call(&self, mut req: ServiceRequest) -> Self::Future {
195 let mw_fn = Rc::clone(&self.mw_fn);
196 let service = Rc::clone(&self.service);
197
198 Box::pin(async move {
199 let ($($ext_type,)*) = req.extract::<($($ext_type,)*)>().await?;
200
201 (mw_fn)($($ext_type),*, req, Next::<B> { service }).await
202 })
203 }
204 }
205 };
206}
207
208impl_middleware_fn_service!(E1);
209impl_middleware_fn_service!(E1, E2);
210impl_middleware_fn_service!(E1, E2, E3);
211impl_middleware_fn_service!(E1, E2, E3, E4);
212impl_middleware_fn_service!(E1, E2, E3, E4, E5);
213impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6);
214impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7);
215impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8);
216impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8, E9);
217
218#[allow(missing_debug_implementations)]
220pub struct Next<B> {
221 service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
222}
223
224impl<B> Next<B> {
225 pub fn call(&self, req: ServiceRequest) -> <Self as Service<ServiceRequest>>::Future {
227 Service::call(self, req)
228 }
229}
230
231impl<B> Service<ServiceRequest> for Next<B> {
232 type Response = ServiceResponse<B>;
233 type Error = Error;
234 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
235
236 forward_ready!(service);
237
238 fn call(&self, req: ServiceRequest) -> Self::Future {
239 self.service.call(req)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::{
247 http::header::{self, HeaderValue},
248 middleware::{Compat, Logger},
249 test, web, App, HttpResponse,
250 };
251
252 async fn noop<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
253 next.call(req).await
254 }
255
256 async fn add_res_header<B>(
257 req: ServiceRequest,
258 next: Next<B>,
259 ) -> Result<ServiceResponse<B>, Error> {
260 let mut res = next.call(req).await?;
261 res.headers_mut()
262 .insert(header::WARNING, HeaderValue::from_static("42"));
263 Ok(res)
264 }
265
266 async fn mutate_body_type(
267 req: ServiceRequest,
268 next: Next<impl MessageBody + 'static>,
269 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
270 let res = next.call(req).await?;
271 Ok(res.map_into_left_body::<()>())
272 }
273
274 struct MyMw(bool);
275
276 impl MyMw {
277 async fn mw_cb(
278 &self,
279 req: ServiceRequest,
280 next: Next<impl MessageBody + 'static>,
281 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
282 let mut res = match self.0 {
283 true => req.into_response("short-circuited").map_into_right_body(),
284 false => next.call(req).await?.map_into_left_body(),
285 };
286 res.headers_mut()
287 .insert(header::WARNING, HeaderValue::from_static("42"));
288 Ok(res)
289 }
290
291 pub fn into_middleware<S, B>(
292 self,
293 ) -> impl Transform<
294 S,
295 ServiceRequest,
296 Response = ServiceResponse<impl MessageBody>,
297 Error = Error,
298 InitError = (),
299 >
300 where
301 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
302 B: MessageBody + 'static,
303 {
304 let this = Rc::new(self);
305 from_fn(move |req, next| {
306 let this = Rc::clone(&this);
307 async move { Self::mw_cb(&this, req, next).await }
308 })
309 }
310 }
311
312 #[actix_rt::test]
313 async fn compat_compat() {
314 let _ = App::new().wrap(Compat::new(from_fn(noop)));
315 let _ = App::new().wrap(Compat::new(from_fn(mutate_body_type)));
316 }
317
318 #[actix_rt::test]
319 async fn permits_different_in_and_out_body_types() {
320 let app = test::init_service(
321 App::new()
322 .wrap(from_fn(mutate_body_type))
323 .wrap(from_fn(add_res_header))
324 .wrap(Logger::default())
325 .wrap(from_fn(noop))
326 .default_service(web::to(HttpResponse::NotFound)),
327 )
328 .await;
329
330 let req = test::TestRequest::default().to_request();
331 let res = test::call_service(&app, req).await;
332 assert!(res.headers().contains_key(header::WARNING));
333 }
334
335 #[actix_rt::test]
336 async fn closure_capture_and_return_from_fn() {
337 let app = test::init_service(
338 App::new()
339 .wrap(Logger::default())
340 .wrap(MyMw(true).into_middleware())
341 .wrap(Logger::default()),
342 )
343 .await;
344
345 let req = test::TestRequest::default().to_request();
346 let res = test::call_service(&app, req).await;
347 assert!(res.headers().contains_key(header::WARNING));
348 }
349}