axum_extra/
typed_header.rs

1//! Extractor and response for typed headers.
2
3use axum::{
4    extract::{FromRequestParts, OptionalFromRequestParts},
5    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
6};
7use headers::{Header, HeaderMapExt};
8use http::{request::Parts, StatusCode};
9use std::convert::Infallible;
10
11/// Extractor and response that works with typed header values from [`headers`].
12///
13/// # As extractor
14///
15/// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than
16/// removing all headers with the `HeaderMap` extractor.
17///
18/// ```rust,no_run
19/// use axum::{
20///     routing::get,
21///     Router,
22/// };
23/// use headers::UserAgent;
24/// use axum_extra::TypedHeader;
25///
26/// async fn users_teams_show(
27///     TypedHeader(user_agent): TypedHeader<UserAgent>,
28/// ) {
29///     // ...
30/// }
31///
32/// let app = Router::new().route("/users/{user_id}/team/{team_id}", get(users_teams_show));
33/// # let _: Router = app;
34/// ```
35///
36/// # As response
37///
38/// ```rust
39/// use axum::{
40///     response::IntoResponse,
41/// };
42/// use headers::ContentType;
43/// use axum_extra::TypedHeader;
44///
45/// async fn handler() -> (TypedHeader<ContentType>, &'static str) {
46///     (
47///         TypedHeader(ContentType::text_utf8()),
48///         "Hello, World!",
49///     )
50/// }
51/// ```
52#[cfg(feature = "typed-header")]
53#[derive(Debug, Clone, Copy)]
54#[must_use]
55pub struct TypedHeader<T>(pub T);
56
57impl<T, S> FromRequestParts<S> for TypedHeader<T>
58where
59    T: Header,
60    S: Send + Sync,
61{
62    type Rejection = TypedHeaderRejection;
63
64    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
65        let mut values = parts.headers.get_all(T::name()).iter();
66        let is_missing = values.size_hint() == (0, Some(0));
67        T::decode(&mut values)
68            .map(Self)
69            .map_err(|err| TypedHeaderRejection {
70                name: T::name(),
71                reason: if is_missing {
72                    // Report a more precise rejection for the missing header case.
73                    TypedHeaderRejectionReason::Missing
74                } else {
75                    TypedHeaderRejectionReason::Error(err)
76                },
77            })
78    }
79}
80
81impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
82where
83    T: Header,
84    S: Send + Sync,
85{
86    type Rejection = TypedHeaderRejection;
87
88    async fn from_request_parts(
89        parts: &mut Parts,
90        _state: &S,
91    ) -> Result<Option<Self>, Self::Rejection> {
92        let mut values = parts.headers.get_all(T::name()).iter();
93        let is_missing = values.size_hint() == (0, Some(0));
94        match T::decode(&mut values) {
95            Ok(res) => Ok(Some(Self(res))),
96            Err(_) if is_missing => Ok(None),
97            Err(err) => Err(TypedHeaderRejection {
98                name: T::name(),
99                reason: TypedHeaderRejectionReason::Error(err),
100            }),
101        }
102    }
103}
104
105axum_core::__impl_deref!(TypedHeader);
106
107impl<T> IntoResponseParts for TypedHeader<T>
108where
109    T: Header,
110{
111    type Error = Infallible;
112
113    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
114        res.headers_mut().typed_insert(self.0);
115        Ok(res)
116    }
117}
118
119impl<T> IntoResponse for TypedHeader<T>
120where
121    T: Header,
122{
123    fn into_response(self) -> Response {
124        let mut res = ().into_response();
125        res.headers_mut().typed_insert(self.0);
126        res
127    }
128}
129
130/// Rejection used for [`TypedHeader`].
131#[cfg(feature = "typed-header")]
132#[derive(Debug)]
133pub struct TypedHeaderRejection {
134    name: &'static http::header::HeaderName,
135    reason: TypedHeaderRejectionReason,
136}
137
138impl TypedHeaderRejection {
139    /// Name of the header that caused the rejection
140    pub fn name(&self) -> &http::header::HeaderName {
141        self.name
142    }
143
144    /// Reason why the header extraction has failed
145    pub fn reason(&self) -> &TypedHeaderRejectionReason {
146        &self.reason
147    }
148
149    /// Returns `true` if the typed header rejection reason is [`Missing`].
150    ///
151    /// [`Missing`]: TypedHeaderRejectionReason::Missing
152    #[must_use]
153    pub fn is_missing(&self) -> bool {
154        self.reason.is_missing()
155    }
156}
157
158/// Additional information regarding a [`TypedHeaderRejection`]
159#[cfg(feature = "typed-header")]
160#[derive(Debug)]
161#[non_exhaustive]
162pub enum TypedHeaderRejectionReason {
163    /// The header was missing from the HTTP request
164    Missing,
165    /// An error occurred when parsing the header from the HTTP request
166    Error(headers::Error),
167}
168
169impl TypedHeaderRejectionReason {
170    /// Returns `true` if the typed header rejection reason is [`Missing`].
171    ///
172    /// [`Missing`]: TypedHeaderRejectionReason::Missing
173    #[must_use]
174    pub fn is_missing(&self) -> bool {
175        matches!(self, Self::Missing)
176    }
177}
178
179impl IntoResponse for TypedHeaderRejection {
180    fn into_response(self) -> Response {
181        let status = StatusCode::BAD_REQUEST;
182        let body = self.to_string();
183        axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
184        (status, body).into_response()
185    }
186}
187
188impl std::fmt::Display for TypedHeaderRejection {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match &self.reason {
191            TypedHeaderRejectionReason::Missing => {
192                write!(f, "Header of type `{}` was missing", self.name)
193            }
194            TypedHeaderRejectionReason::Error(err) => {
195                write!(f, "{err} ({})", self.name)
196            }
197        }
198    }
199}
200
201impl std::error::Error for TypedHeaderRejection {
202    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
203        match &self.reason {
204            TypedHeaderRejectionReason::Error(err) => Some(err),
205            TypedHeaderRejectionReason::Missing => None,
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::test_helpers::*;
214    use axum::{routing::get, Router};
215
216    #[tokio::test]
217    async fn typed_header() {
218        async fn handle(
219            TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
220            TypedHeader(cookies): TypedHeader<headers::Cookie>,
221        ) -> impl IntoResponse {
222            let user_agent = user_agent.as_str();
223            let cookies = cookies.iter().collect::<Vec<_>>();
224            format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
225        }
226
227        let app = Router::new().route("/", get(handle));
228
229        let client = TestClient::new(app);
230
231        let res = client
232            .get("/")
233            .header("user-agent", "foobar")
234            .header("cookie", "a=1; b=2")
235            .header("cookie", "c=3")
236            .await;
237        let body = res.text().await;
238        assert_eq!(
239            body,
240            r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
241        );
242
243        let res = client.get("/").header("user-agent", "foobar").await;
244        let body = res.text().await;
245        assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
246
247        let res = client.get("/").header("cookie", "a=1").await;
248        let body = res.text().await;
249        assert_eq!(body, "Header of type `user-agent` was missing");
250    }
251}