rama_http/layer/validate_request/
accept_header.rs

1use super::ValidateRequest;
2use crate::{
3    Request, Response, StatusCode,
4    dep::mime::{Mime, MimeIter},
5    header,
6};
7use rama_core::Context;
8use std::{fmt, marker::PhantomData, sync::Arc};
9
10/// Type that performs validation of the Accept header.
11pub struct AcceptHeader<ResBody = crate::Body> {
12    header_value: Arc<Mime>,
13    _ty: PhantomData<fn() -> ResBody>,
14}
15
16impl<ResBody> AcceptHeader<ResBody> {
17    /// Create a new `AcceptHeader`.
18    ///
19    /// # Panics
20    ///
21    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
22    pub(super) fn new(header_value: &str) -> Self
23    where
24        ResBody: Default,
25    {
26        Self {
27            header_value: Arc::new(
28                header_value
29                    .parse::<Mime>()
30                    .expect("value is not a valid header value"),
31            ),
32            _ty: PhantomData,
33        }
34    }
35}
36
37impl<ResBody> Clone for AcceptHeader<ResBody> {
38    fn clone(&self) -> Self {
39        Self {
40            header_value: self.header_value.clone(),
41            _ty: PhantomData,
42        }
43    }
44}
45
46impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        f.debug_struct("AcceptHeader")
49            .field("header_value", &self.header_value)
50            .finish()
51    }
52}
53
54impl<S, B, ResBody> ValidateRequest<S, B> for AcceptHeader<ResBody>
55where
56    S: Clone + Send + Sync + 'static,
57    B: Send + Sync + 'static,
58    ResBody: Default + Send + 'static,
59{
60    type ResponseBody = ResBody;
61
62    async fn validate(
63        &self,
64        ctx: Context<S>,
65        req: Request<B>,
66    ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
67        if !req.headers().contains_key(header::ACCEPT) {
68            return Ok((ctx, req));
69        }
70        if req
71            .headers()
72            .get_all(header::ACCEPT)
73            .into_iter()
74            .filter_map(|header| header.to_str().ok())
75            .any(|h| {
76                MimeIter::new(h)
77                    .map(|mim| {
78                        if let Ok(mim) = mim {
79                            let typ = self.header_value.type_();
80                            let subtype = self.header_value.subtype();
81                            match (mim.type_(), mim.subtype()) {
82                                (t, s) if t == typ && s == subtype => true,
83                                (t, mime::STAR) if t == typ => true,
84                                (mime::STAR, mime::STAR) => true,
85                                _ => false,
86                            }
87                        } else {
88                            false
89                        }
90                    })
91                    .reduce(|acc, mim| acc || mim)
92                    .unwrap_or(false)
93            })
94        {
95            return Ok((ctx, req));
96        }
97        let mut res = Response::new(ResBody::default());
98        *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
99        Err(res)
100    }
101}