tower_http/
validate_request.rs

1//! Middleware that validates requests.
2//!
3//! # Example
4//!
5//! ```
6//! use tower_http::validate_request::ValidateRequestHeaderLayer;
7//! use http::{Request, Response, StatusCode, header::ACCEPT};
8//! use http_body_util::Full;
9//! use bytes::Bytes;
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//!
12//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
13//!     Ok(Response::new(Full::default()))
14//! }
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), BoxError> {
18//! let mut service = ServiceBuilder::new()
19//!     // Require the `Accept` header to be `application/json`, `*/*` or `application/*`
20//!     .layer(ValidateRequestHeaderLayer::accept("application/json"))
21//!     .service_fn(handle);
22//!
23//! // Requests with the correct value are allowed through
24//! let request = Request::builder()
25//!     .header(ACCEPT, "application/json")
26//!     .body(Full::default())
27//!     .unwrap();
28//!
29//! let response = service
30//!     .ready()
31//!     .await?
32//!     .call(request)
33//!     .await?;
34//!
35//! assert_eq!(StatusCode::OK, response.status());
36//!
37//! // Requests with an invalid value get a `406 Not Acceptable` response
38//! let request = Request::builder()
39//!     .header(ACCEPT, "text/strings")
40//!     .body(Full::default())
41//!     .unwrap();
42//!
43//! let response = service
44//!     .ready()
45//!     .await?
46//!     .call(request)
47//!     .await?;
48//!
49//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status());
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! Custom validation can be made by implementing [`ValidateRequest`]:
55//!
56//! ```
57//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
58//! use http::{Request, Response, StatusCode, header::ACCEPT};
59//! use http_body_util::Full;
60//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
61//! use bytes::Bytes;
62//!
63//! #[derive(Clone, Copy)]
64//! pub struct MyHeader { /* ...  */ }
65//!
66//! impl<B> ValidateRequest<B> for MyHeader {
67//!     type ResponseBody = Full<Bytes>;
68//!
69//!     fn validate(
70//!         &mut self,
71//!         request: &mut Request<B>,
72//!     ) -> Result<(), Response<Self::ResponseBody>> {
73//!         // validate the request...
74//!         # unimplemented!()
75//!     }
76//! }
77//!
78//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
79//!     Ok(Response::new(Full::default()))
80//! }
81//!
82//!
83//! # #[tokio::main]
84//! # async fn main() -> Result<(), BoxError> {
85//! let service = ServiceBuilder::new()
86//!     // Validate requests using `MyHeader`
87//!     .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ }))
88//!     .service_fn(handle);
89//! # Ok(())
90//! # }
91//! ```
92//!
93//! Or using a closure:
94//!
95//! ```
96//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
97//! use http::{Request, Response, StatusCode, header::ACCEPT};
98//! use bytes::Bytes;
99//! use http_body_util::Full;
100//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
101//!
102//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
103//!     # todo!();
104//!     // ...
105//! }
106//!
107//! # #[tokio::main]
108//! # async fn main() -> Result<(), BoxError> {
109//! let service = ServiceBuilder::new()
110//!     .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request<Full<Bytes>>| {
111//!         // Validate the request
112//!         # Ok::<_, Response<Full<Bytes>>>(())
113//!     }))
114//!     .service_fn(handle);
115//! # Ok(())
116//! # }
117//! ```
118
119use http::{header, Request, Response, StatusCode};
120use mime::{Mime, MimeIter};
121use pin_project_lite::pin_project;
122use std::{
123    fmt,
124    future::Future,
125    marker::PhantomData,
126    pin::Pin,
127    sync::Arc,
128    task::{Context, Poll},
129};
130use tower_layer::Layer;
131use tower_service::Service;
132
133/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
134///
135/// See the [module docs](crate::validate_request) for an example.
136#[derive(Debug, Clone)]
137pub struct ValidateRequestHeaderLayer<T> {
138    validate: T,
139}
140
141impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
142    /// Validate requests have the required Accept header.
143    ///
144    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
145    /// as configured.
146    ///
147    /// # Panics
148    ///
149    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
150    /// See `AcceptHeader::new` for when this method panics.
151    ///
152    /// # Example
153    ///
154    /// ```
155    /// use http_body_util::Full;
156    /// use bytes::Bytes;
157    /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer};
158    ///
159    /// let layer = ValidateRequestHeaderLayer::<AcceptHeader<Full<Bytes>>>::accept("application/json");
160    /// ```
161    ///
162    /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
163    pub fn accept(value: &str) -> Self
164    where
165        ResBody: Default,
166    {
167        Self::custom(AcceptHeader::new(value))
168    }
169}
170
171impl<T> ValidateRequestHeaderLayer<T> {
172    /// Validate requests using a custom method.
173    pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
174        Self { validate }
175    }
176}
177
178impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
179where
180    T: Clone,
181{
182    type Service = ValidateRequestHeader<S, T>;
183
184    fn layer(&self, inner: S) -> Self::Service {
185        ValidateRequestHeader::new(inner, self.validate.clone())
186    }
187}
188
189/// Middleware that validates requests.
190///
191/// See the [module docs](crate::validate_request) for an example.
192#[derive(Clone, Debug)]
193pub struct ValidateRequestHeader<S, T> {
194    inner: S,
195    validate: T,
196}
197
198impl<S, T> ValidateRequestHeader<S, T> {
199    fn new(inner: S, validate: T) -> Self {
200        Self::custom(inner, validate)
201    }
202
203    define_inner_service_accessors!();
204}
205
206impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
207    /// Validate requests have the required Accept header.
208    ///
209    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
210    /// as configured.
211    ///
212    /// # Panics
213    ///
214    /// See `AcceptHeader::new` for when this method panics.
215    pub fn accept(inner: S, value: &str) -> Self
216    where
217        ResBody: Default,
218    {
219        Self::custom(inner, AcceptHeader::new(value))
220    }
221}
222
223impl<S, T> ValidateRequestHeader<S, T> {
224    /// Validate requests using a custom method.
225    pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
226        Self { inner, validate }
227    }
228}
229
230impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
231where
232    V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
233    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
234{
235    type Response = Response<ResBody>;
236    type Error = S::Error;
237    type Future = ResponseFuture<S::Future, ResBody>;
238
239    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240        self.inner.poll_ready(cx)
241    }
242
243    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
244        match self.validate.validate(&mut req) {
245            Ok(_) => ResponseFuture::future(self.inner.call(req)),
246            Err(res) => ResponseFuture::invalid_header_value(res),
247        }
248    }
249}
250
251pin_project! {
252    /// Response future for [`ValidateRequestHeader`].
253    pub struct ResponseFuture<F, B> {
254        #[pin]
255        kind: Kind<F, B>,
256    }
257}
258
259impl<F, B> ResponseFuture<F, B> {
260    fn future(future: F) -> Self {
261        Self {
262            kind: Kind::Future { future },
263        }
264    }
265
266    fn invalid_header_value(res: Response<B>) -> Self {
267        Self {
268            kind: Kind::Error {
269                response: Some(res),
270            },
271        }
272    }
273}
274
275pin_project! {
276    #[project = KindProj]
277    enum Kind<F, B> {
278        Future {
279            #[pin]
280            future: F,
281        },
282        Error {
283            response: Option<Response<B>>,
284        },
285    }
286}
287
288impl<F, B, E> Future for ResponseFuture<F, B>
289where
290    F: Future<Output = Result<Response<B>, E>>,
291{
292    type Output = F::Output;
293
294    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295        match self.project().kind.project() {
296            KindProj::Future { future } => future.poll(cx),
297            KindProj::Error { response } => {
298                let response = response.take().expect("future polled after completion");
299                Poll::Ready(Ok(response))
300            }
301        }
302    }
303}
304
305/// Trait for validating requests.
306pub trait ValidateRequest<B> {
307    /// The body type used for responses to unvalidated requests.
308    type ResponseBody;
309
310    /// Validate the request.
311    ///
312    /// If `Ok(())` is returned then the request is allowed through, otherwise not.
313    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
314}
315
316impl<B, F, ResBody> ValidateRequest<B> for F
317where
318    F: FnMut(&mut Request<B>) -> Result<(), Response<ResBody>>,
319{
320    type ResponseBody = ResBody;
321
322    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
323        self(request)
324    }
325}
326
327/// Type that performs validation of the Accept header.
328pub struct AcceptHeader<ResBody> {
329    header_value: Arc<Mime>,
330    _ty: PhantomData<fn() -> ResBody>,
331}
332
333impl<ResBody> AcceptHeader<ResBody> {
334    /// Create a new `AcceptHeader`.
335    ///
336    /// # Panics
337    ///
338    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
339    fn new(header_value: &str) -> Self
340    where
341        ResBody: Default,
342    {
343        Self {
344            header_value: Arc::new(
345                header_value
346                    .parse::<Mime>()
347                    .expect("value is not a valid header value"),
348            ),
349            _ty: PhantomData,
350        }
351    }
352}
353
354impl<ResBody> Clone for AcceptHeader<ResBody> {
355    fn clone(&self) -> Self {
356        Self {
357            header_value: self.header_value.clone(),
358            _ty: PhantomData,
359        }
360    }
361}
362
363impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        f.debug_struct("AcceptHeader")
366            .field("header_value", &self.header_value)
367            .finish()
368    }
369}
370
371impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
372where
373    ResBody: Default,
374{
375    type ResponseBody = ResBody;
376
377    fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
378        if !req.headers().contains_key(header::ACCEPT) {
379            return Ok(());
380        }
381        if req
382            .headers()
383            .get_all(header::ACCEPT)
384            .into_iter()
385            .filter_map(|header| header.to_str().ok())
386            .any(|h| {
387                MimeIter::new(h)
388                    .map(|mim| {
389                        if let Ok(mim) = mim {
390                            let typ = self.header_value.type_();
391                            let subtype = self.header_value.subtype();
392                            match (mim.type_(), mim.subtype()) {
393                                (t, s) if t == typ && s == subtype => true,
394                                (t, mime::STAR) if t == typ => true,
395                                (mime::STAR, mime::STAR) => true,
396                                _ => false,
397                            }
398                        } else {
399                            false
400                        }
401                    })
402                    .reduce(|acc, mim| acc || mim)
403                    .unwrap_or(false)
404            })
405        {
406            return Ok(());
407        }
408        let mut res = Response::new(ResBody::default());
409        *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
410        Err(res)
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    #[allow(unused_imports)]
417    use super::*;
418    use crate::test_helpers::Body;
419    use http::header;
420    use tower::{BoxError, ServiceBuilder, ServiceExt};
421
422    #[tokio::test]
423    async fn valid_accept_header() {
424        let mut service = ServiceBuilder::new()
425            .layer(ValidateRequestHeaderLayer::accept("application/json"))
426            .service_fn(echo);
427
428        let request = Request::get("/")
429            .header(header::ACCEPT, "application/json")
430            .body(Body::empty())
431            .unwrap();
432
433        let res = service.ready().await.unwrap().call(request).await.unwrap();
434
435        assert_eq!(res.status(), StatusCode::OK);
436    }
437
438    #[tokio::test]
439    async fn valid_accept_header_accept_all_json() {
440        let mut service = ServiceBuilder::new()
441            .layer(ValidateRequestHeaderLayer::accept("application/json"))
442            .service_fn(echo);
443
444        let request = Request::get("/")
445            .header(header::ACCEPT, "application/*")
446            .body(Body::empty())
447            .unwrap();
448
449        let res = service.ready().await.unwrap().call(request).await.unwrap();
450
451        assert_eq!(res.status(), StatusCode::OK);
452    }
453
454    #[tokio::test]
455    async fn valid_accept_header_accept_all() {
456        let mut service = ServiceBuilder::new()
457            .layer(ValidateRequestHeaderLayer::accept("application/json"))
458            .service_fn(echo);
459
460        let request = Request::get("/")
461            .header(header::ACCEPT, "*/*")
462            .body(Body::empty())
463            .unwrap();
464
465        let res = service.ready().await.unwrap().call(request).await.unwrap();
466
467        assert_eq!(res.status(), StatusCode::OK);
468    }
469
470    #[tokio::test]
471    async fn invalid_accept_header() {
472        let mut service = ServiceBuilder::new()
473            .layer(ValidateRequestHeaderLayer::accept("application/json"))
474            .service_fn(echo);
475
476        let request = Request::get("/")
477            .header(header::ACCEPT, "invalid")
478            .body(Body::empty())
479            .unwrap();
480
481        let res = service.ready().await.unwrap().call(request).await.unwrap();
482
483        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
484    }
485    #[tokio::test]
486    async fn not_accepted_accept_header_subtype() {
487        let mut service = ServiceBuilder::new()
488            .layer(ValidateRequestHeaderLayer::accept("application/json"))
489            .service_fn(echo);
490
491        let request = Request::get("/")
492            .header(header::ACCEPT, "application/strings")
493            .body(Body::empty())
494            .unwrap();
495
496        let res = service.ready().await.unwrap().call(request).await.unwrap();
497
498        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
499    }
500
501    #[tokio::test]
502    async fn not_accepted_accept_header() {
503        let mut service = ServiceBuilder::new()
504            .layer(ValidateRequestHeaderLayer::accept("application/json"))
505            .service_fn(echo);
506
507        let request = Request::get("/")
508            .header(header::ACCEPT, "text/strings")
509            .body(Body::empty())
510            .unwrap();
511
512        let res = service.ready().await.unwrap().call(request).await.unwrap();
513
514        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
515    }
516
517    #[tokio::test]
518    async fn accepted_multiple_header_value() {
519        let mut service = ServiceBuilder::new()
520            .layer(ValidateRequestHeaderLayer::accept("application/json"))
521            .service_fn(echo);
522
523        let request = Request::get("/")
524            .header(header::ACCEPT, "text/strings")
525            .header(header::ACCEPT, "invalid, application/json")
526            .body(Body::empty())
527            .unwrap();
528
529        let res = service.ready().await.unwrap().call(request).await.unwrap();
530
531        assert_eq!(res.status(), StatusCode::OK);
532    }
533
534    #[tokio::test]
535    async fn accepted_inner_header_value() {
536        let mut service = ServiceBuilder::new()
537            .layer(ValidateRequestHeaderLayer::accept("application/json"))
538            .service_fn(echo);
539
540        let request = Request::get("/")
541            .header(header::ACCEPT, "text/strings, invalid, application/json")
542            .body(Body::empty())
543            .unwrap();
544
545        let res = service.ready().await.unwrap().call(request).await.unwrap();
546
547        assert_eq!(res.status(), StatusCode::OK);
548    }
549
550    #[tokio::test]
551    async fn accepted_header_with_quotes_valid() {
552        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
553        let mut service = ServiceBuilder::new()
554            .layer(ValidateRequestHeaderLayer::accept("application/xml"))
555            .service_fn(echo);
556
557        let request = Request::get("/")
558            .header(header::ACCEPT, value)
559            .body(Body::empty())
560            .unwrap();
561
562        let res = service.ready().await.unwrap().call(request).await.unwrap();
563
564        assert_eq!(res.status(), StatusCode::OK);
565    }
566
567    #[tokio::test]
568    async fn accepted_header_with_quotes_invalid() {
569        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
570        let mut service = ServiceBuilder::new()
571            .layer(ValidateRequestHeaderLayer::accept("text/html"))
572            .service_fn(echo);
573
574        let request = Request::get("/")
575            .header(header::ACCEPT, value)
576            .body(Body::empty())
577            .unwrap();
578
579        let res = service.ready().await.unwrap().call(request).await.unwrap();
580
581        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
582    }
583
584    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
585        Ok(Response::new(req.into_body()))
586    }
587}