rama_http/layer/validate_request/
validate_request_header.rs

1use super::{AcceptHeader, BoxValidateRequestFn, ValidateRequest};
2use crate::{Request, Response};
3use rama_core::{Context, Layer, Service};
4use rama_utils::macros::define_inner_service_accessors;
5use std::fmt;
6
7/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
8///
9/// See the [module docs](crate::layer::validate_request) for an example.
10pub struct ValidateRequestHeaderLayer<T> {
11    pub(crate) validate: T,
12}
13
14impl<T: fmt::Debug> fmt::Debug for ValidateRequestHeaderLayer<T> {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        f.debug_struct("ValidateRequestHeaderLayer")
17            .field("validate", &self.validate)
18            .finish()
19    }
20}
21
22impl<T> Clone for ValidateRequestHeaderLayer<T>
23where
24    T: Clone,
25{
26    fn clone(&self) -> Self {
27        Self {
28            validate: self.validate.clone(),
29        }
30    }
31}
32
33impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
34    /// Validate requests have the required Accept header.
35    ///
36    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
37    /// as configured.
38    ///
39    /// # Panics
40    ///
41    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
42    /// See `AcceptHeader::new` for when this method panics.
43    ///
44    /// # Example
45    ///
46    /// ```
47    /// use rama_http::layer::validate_request::{AcceptHeader, ValidateRequestHeaderLayer};
48    ///
49    /// let layer = ValidateRequestHeaderLayer::<AcceptHeader>::accept("application/json");
50    /// ```
51    ///
52    /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
53    pub fn accept(value: &str) -> Self
54    where
55        ResBody: Default,
56    {
57        Self::custom(AcceptHeader::new(value))
58    }
59}
60
61impl<T> ValidateRequestHeaderLayer<T> {
62    /// Validate requests using a custom validator.
63    pub fn custom(validate: T) -> Self {
64        Self { validate }
65    }
66}
67
68impl<F, A> ValidateRequestHeaderLayer<BoxValidateRequestFn<F, A>> {
69    /// Validate requests using a custom validator Fn.
70    pub fn custom_fn(validate: F) -> Self {
71        Self {
72            validate: BoxValidateRequestFn::new(validate),
73        }
74    }
75}
76
77impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
78where
79    T: Clone,
80{
81    type Service = ValidateRequestHeader<S, T>;
82
83    fn layer(&self, inner: S) -> Self::Service {
84        ValidateRequestHeader::new(inner, self.validate.clone())
85    }
86
87    fn into_layer(self, inner: S) -> Self::Service {
88        ValidateRequestHeader::new(inner, self.validate)
89    }
90}
91
92/// Middleware that validates requests.
93///
94/// See the [module docs](crate::layer::validate_request) for an example.
95pub struct ValidateRequestHeader<S, T> {
96    inner: S,
97    pub(crate) validate: T,
98}
99
100impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for ValidateRequestHeader<S, T> {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        f.debug_struct("ValidateRequestHeader")
103            .field("inner", &self.inner)
104            .field("validate", &self.validate)
105            .finish()
106    }
107}
108
109impl<S, T> Clone for ValidateRequestHeader<S, T>
110where
111    S: Clone,
112    T: Clone,
113{
114    fn clone(&self) -> Self {
115        Self {
116            inner: self.inner.clone(),
117            validate: self.validate.clone(),
118        }
119    }
120}
121
122impl<S, T> ValidateRequestHeader<S, T> {
123    fn new(inner: S, validate: T) -> Self {
124        Self::custom(inner, validate)
125    }
126
127    define_inner_service_accessors!();
128}
129
130impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
131    /// Validate requests have the required Accept header.
132    ///
133    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
134    /// as configured.
135    ///
136    /// # Panics
137    ///
138    /// See `AcceptHeader::new` for when this method panics.
139    pub fn accept(inner: S, value: &str) -> Self
140    where
141        ResBody: Default,
142    {
143        Self::custom(inner, AcceptHeader::new(value))
144    }
145}
146
147impl<S, T> ValidateRequestHeader<S, T> {
148    /// Validate requests using a custom validator.
149    pub fn custom(inner: S, validate: T) -> Self {
150        Self { inner, validate }
151    }
152}
153
154impl<S, F, A> ValidateRequestHeader<S, BoxValidateRequestFn<F, A>> {
155    /// Validate requests using a custom validator Fn.
156    pub fn custom_fn(inner: S, validate: F) -> Self {
157        Self {
158            inner,
159            validate: BoxValidateRequestFn::new(validate),
160        }
161    }
162}
163
164impl<ReqBody, ResBody, State, S, V> Service<State, Request<ReqBody>> for ValidateRequestHeader<S, V>
165where
166    ReqBody: Send + 'static,
167    ResBody: Send + 'static,
168    State: Clone + Send + Sync + 'static,
169    V: ValidateRequest<State, ReqBody, ResponseBody = ResBody>,
170    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
171{
172    type Response = Response<ResBody>;
173    type Error = S::Error;
174
175    async fn serve(
176        &self,
177        ctx: Context<State>,
178        req: Request<ReqBody>,
179    ) -> Result<Self::Response, Self::Error> {
180        match self.validate.validate(ctx, req).await {
181            Ok((ctx, req)) => self.inner.serve(ctx, req).await,
182            Err(res) => Ok(res),
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    #[allow(unused_imports)]
190    use super::*;
191
192    use crate::{Body, StatusCode, header};
193    use rama_core::{Layer, error::BoxError, service::service_fn};
194
195    #[tokio::test]
196    async fn valid_accept_header() {
197        let service =
198            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
199
200        let request = Request::get("/")
201            .header(header::ACCEPT, "application/json")
202            .body(Body::empty())
203            .unwrap();
204
205        let res = service.serve(Context::default(), request).await.unwrap();
206
207        assert_eq!(res.status(), StatusCode::OK);
208    }
209
210    #[tokio::test]
211    async fn valid_accept_header_accept_all_json() {
212        let service =
213            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
214
215        let request = Request::get("/")
216            .header(header::ACCEPT, "application/*")
217            .body(Body::empty())
218            .unwrap();
219
220        let res = service.serve(Context::default(), request).await.unwrap();
221
222        assert_eq!(res.status(), StatusCode::OK);
223    }
224
225    #[tokio::test]
226    async fn valid_accept_header_accept_all() {
227        let service =
228            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
229
230        let request = Request::get("/")
231            .header(header::ACCEPT, "*/*")
232            .body(Body::empty())
233            .unwrap();
234
235        let res = service.serve(Context::default(), request).await.unwrap();
236
237        assert_eq!(res.status(), StatusCode::OK);
238    }
239
240    #[tokio::test]
241    async fn invalid_accept_header() {
242        let service =
243            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
244
245        let request = Request::get("/")
246            .header(header::ACCEPT, "invalid")
247            .body(Body::empty())
248            .unwrap();
249
250        let res = service.serve(Context::default(), request).await.unwrap();
251
252        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
253    }
254    #[tokio::test]
255    async fn not_accepted_accept_header_subtype() {
256        let service =
257            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
258
259        let request = Request::get("/")
260            .header(header::ACCEPT, "application/strings")
261            .body(Body::empty())
262            .unwrap();
263
264        let res = service.serve(Context::default(), request).await.unwrap();
265
266        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
267    }
268
269    #[tokio::test]
270    async fn not_accepted_accept_header() {
271        let service =
272            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
273
274        let request = Request::get("/")
275            .header(header::ACCEPT, "text/strings")
276            .body(Body::empty())
277            .unwrap();
278
279        let res = service.serve(Context::default(), request).await.unwrap();
280
281        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
282    }
283
284    #[tokio::test]
285    async fn accepted_multiple_header_value() {
286        let service =
287            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
288
289        let request = Request::get("/")
290            .header(header::ACCEPT, "text/strings")
291            .header(header::ACCEPT, "invalid, application/json")
292            .body(Body::empty())
293            .unwrap();
294
295        let res = service.serve(Context::default(), request).await.unwrap();
296
297        assert_eq!(res.status(), StatusCode::OK);
298    }
299
300    #[tokio::test]
301    async fn accepted_inner_header_value() {
302        let service =
303            ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
304
305        let request = Request::get("/")
306            .header(header::ACCEPT, "text/strings, invalid, application/json")
307            .body(Body::empty())
308            .unwrap();
309
310        let res = service.serve(Context::default(), request).await.unwrap();
311
312        assert_eq!(res.status(), StatusCode::OK);
313    }
314
315    #[tokio::test]
316    async fn accepted_header_with_quotes_valid() {
317        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
318        let service =
319            ValidateRequestHeaderLayer::accept("application/xml").into_layer(service_fn(echo));
320
321        let request = Request::get("/")
322            .header(header::ACCEPT, value)
323            .body(Body::empty())
324            .unwrap();
325
326        let res = service.serve(Context::default(), request).await.unwrap();
327
328        assert_eq!(res.status(), StatusCode::OK);
329    }
330
331    #[tokio::test]
332    async fn accepted_header_with_quotes_invalid() {
333        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
334        let service = ValidateRequestHeaderLayer::accept("text/html").into_layer(service_fn(echo));
335
336        let request = Request::get("/")
337            .header(header::ACCEPT, value)
338            .body(Body::empty())
339            .unwrap();
340
341        let res = service.serve(Context::default(), request).await.unwrap();
342
343        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
344    }
345
346    async fn echo<B>(req: Request<B>) -> Result<Response<B>, BoxError> {
347        Ok(Response::new(req.into_body()))
348    }
349}