axum_extra/
typed_header.rs1use axum::{
4 extract::{FromRequestParts, OptionalFromRequestParts},
5 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
6};
7use headers::{Header, HeaderMapExt};
8use http::{request::Parts, StatusCode};
9use std::convert::Infallible;
10
11#[cfg(feature = "typed-header")]
53#[derive(Debug, Clone, Copy)]
54#[must_use]
55pub struct TypedHeader<T>(pub T);
56
57impl<T, S> FromRequestParts<S> for TypedHeader<T>
58where
59 T: Header,
60 S: Send + Sync,
61{
62 type Rejection = TypedHeaderRejection;
63
64 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
65 let mut values = parts.headers.get_all(T::name()).iter();
66 let is_missing = values.size_hint() == (0, Some(0));
67 T::decode(&mut values)
68 .map(Self)
69 .map_err(|err| TypedHeaderRejection {
70 name: T::name(),
71 reason: if is_missing {
72 TypedHeaderRejectionReason::Missing
74 } else {
75 TypedHeaderRejectionReason::Error(err)
76 },
77 })
78 }
79}
80
81impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
82where
83 T: Header,
84 S: Send + Sync,
85{
86 type Rejection = TypedHeaderRejection;
87
88 async fn from_request_parts(
89 parts: &mut Parts,
90 _state: &S,
91 ) -> Result<Option<Self>, Self::Rejection> {
92 let mut values = parts.headers.get_all(T::name()).iter();
93 let is_missing = values.size_hint() == (0, Some(0));
94 match T::decode(&mut values) {
95 Ok(res) => Ok(Some(Self(res))),
96 Err(_) if is_missing => Ok(None),
97 Err(err) => Err(TypedHeaderRejection {
98 name: T::name(),
99 reason: TypedHeaderRejectionReason::Error(err),
100 }),
101 }
102 }
103}
104
105axum_core::__impl_deref!(TypedHeader);
106
107impl<T> IntoResponseParts for TypedHeader<T>
108where
109 T: Header,
110{
111 type Error = Infallible;
112
113 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
114 res.headers_mut().typed_insert(self.0);
115 Ok(res)
116 }
117}
118
119impl<T> IntoResponse for TypedHeader<T>
120where
121 T: Header,
122{
123 fn into_response(self) -> Response {
124 let mut res = ().into_response();
125 res.headers_mut().typed_insert(self.0);
126 res
127 }
128}
129
130#[cfg(feature = "typed-header")]
132#[derive(Debug)]
133pub struct TypedHeaderRejection {
134 name: &'static http::header::HeaderName,
135 reason: TypedHeaderRejectionReason,
136}
137
138impl TypedHeaderRejection {
139 pub fn name(&self) -> &http::header::HeaderName {
141 self.name
142 }
143
144 pub fn reason(&self) -> &TypedHeaderRejectionReason {
146 &self.reason
147 }
148
149 #[must_use]
153 pub fn is_missing(&self) -> bool {
154 self.reason.is_missing()
155 }
156}
157
158#[cfg(feature = "typed-header")]
160#[derive(Debug)]
161#[non_exhaustive]
162pub enum TypedHeaderRejectionReason {
163 Missing,
165 Error(headers::Error),
167}
168
169impl TypedHeaderRejectionReason {
170 #[must_use]
174 pub fn is_missing(&self) -> bool {
175 matches!(self, Self::Missing)
176 }
177}
178
179impl IntoResponse for TypedHeaderRejection {
180 fn into_response(self) -> Response {
181 let status = StatusCode::BAD_REQUEST;
182 let body = self.to_string();
183 axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
184 (status, body).into_response()
185 }
186}
187
188impl std::fmt::Display for TypedHeaderRejection {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 match &self.reason {
191 TypedHeaderRejectionReason::Missing => {
192 write!(f, "Header of type `{}` was missing", self.name)
193 }
194 TypedHeaderRejectionReason::Error(err) => {
195 write!(f, "{err} ({})", self.name)
196 }
197 }
198 }
199}
200
201impl std::error::Error for TypedHeaderRejection {
202 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
203 match &self.reason {
204 TypedHeaderRejectionReason::Error(err) => Some(err),
205 TypedHeaderRejectionReason::Missing => None,
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::test_helpers::*;
214 use axum::{routing::get, Router};
215
216 #[tokio::test]
217 async fn typed_header() {
218 async fn handle(
219 TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
220 TypedHeader(cookies): TypedHeader<headers::Cookie>,
221 ) -> impl IntoResponse {
222 let user_agent = user_agent.as_str();
223 let cookies = cookies.iter().collect::<Vec<_>>();
224 format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
225 }
226
227 let app = Router::new().route("/", get(handle));
228
229 let client = TestClient::new(app);
230
231 let res = client
232 .get("/")
233 .header("user-agent", "foobar")
234 .header("cookie", "a=1; b=2")
235 .header("cookie", "c=3")
236 .await;
237 let body = res.text().await;
238 assert_eq!(
239 body,
240 r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
241 );
242
243 let res = client.get("/").header("user-agent", "foobar").await;
244 let body = res.text().await;
245 assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
246
247 let res = client.get("/").header("cookie", "a=1").await;
248 let body = res.text().await;
249 assert_eq!(body, "Header of type `user-agent` was missing");
250 }
251}