axum_extra/extract/
json_deserializer.rs1use axum::extract::{FromRequest, Request};
2use axum_core::__composite_rejection as composite_rejection;
3use axum_core::__define_rejection as define_rejection;
4use axum_core::extract::rejection::BytesRejection;
5use bytes::Bytes;
6use http::{header, HeaderMap};
7use serde::Deserialize;
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone, Default)]
80#[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
81pub struct JsonDeserializer<T> {
82 bytes: Bytes,
83 _marker: PhantomData<T>,
84}
85
86impl<T, S> FromRequest<S> for JsonDeserializer<T>
87where
88 T: Deserialize<'static>,
89 S: Send + Sync,
90{
91 type Rejection = JsonDeserializerRejection;
92
93 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
94 if json_content_type(req.headers()) {
95 let bytes = Bytes::from_request(req, state).await?;
96 Ok(Self {
97 bytes,
98 _marker: PhantomData,
99 })
100 } else {
101 Err(MissingJsonContentType.into())
102 }
103 }
104}
105
106impl<'de, 'a: 'de, T> JsonDeserializer<T>
107where
108 T: Deserialize<'de>,
109{
110 pub fn deserialize(&'a self) -> Result<T, JsonDeserializerRejection> {
113 let deserializer = &mut serde_json::Deserializer::from_slice(&self.bytes);
114
115 let value = match serde_path_to_error::deserialize(deserializer) {
116 Ok(value) => value,
117 Err(err) => {
118 let rejection = match err.inner().classify() {
119 serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
120 serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
121 JsonSyntaxError::from_err(err).into()
122 }
123 serde_json::error::Category::Io => {
124 if cfg!(debug_assertions) {
125 unreachable!()
128 } else {
129 JsonSyntaxError::from_err(err).into()
130 }
131 }
132 };
133 return Err(rejection);
134 }
135 };
136
137 Ok(value)
138 }
139}
140
141define_rejection! {
142 #[status = UNPROCESSABLE_ENTITY]
143 #[body = "Failed to deserialize the JSON body into the target type"]
144 #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
145 pub struct JsonDataError(Error);
150}
151
152define_rejection! {
153 #[status = BAD_REQUEST]
154 #[body = "Failed to parse the request body as JSON"]
155 #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
156 pub struct JsonSyntaxError(Error);
160}
161
162define_rejection! {
163 #[status = UNSUPPORTED_MEDIA_TYPE]
164 #[body = "Expected request with `Content-Type: application/json`"]
165 #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
166 pub struct MissingJsonContentType;
169}
170
171composite_rejection! {
172 #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
177 pub enum JsonDeserializerRejection {
178 JsonDataError,
179 JsonSyntaxError,
180 MissingJsonContentType,
181 BytesRejection,
182 }
183}
184
185fn json_content_type(headers: &HeaderMap) -> bool {
186 let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
187 content_type
188 } else {
189 return false;
190 };
191
192 let content_type = if let Ok(content_type) = content_type.to_str() {
193 content_type
194 } else {
195 return false;
196 };
197
198 let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
199 mime
200 } else {
201 return false;
202 };
203
204 let is_json_content_type = mime.type_() == "application"
205 && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"));
206
207 is_json_content_type
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::test_helpers::*;
214 use axum::{
215 response::{IntoResponse, Response},
216 routing::post,
217 Router,
218 };
219 use http::StatusCode;
220 use serde::Deserialize;
221 use serde_json::{json, Value};
222 use std::borrow::Cow;
223
224 #[tokio::test]
225 async fn deserialize_body() {
226 #[derive(Debug, Deserialize)]
227 struct Input<'a> {
228 #[serde(borrow)]
229 foo: Cow<'a, str>,
230 }
231
232 async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
233 match deserializer.deserialize() {
234 Ok(input) => {
235 assert!(matches!(input.foo, Cow::Borrowed(_)));
236 input.foo.into_owned().into_response()
237 }
238 Err(e) => e.into_response(),
239 }
240 }
241
242 let app = Router::new().route("/", post(handler));
243
244 let client = TestClient::new(app);
245 let res = client.post("/").json(&json!({ "foo": "bar" })).await;
246 let body = res.text().await;
247
248 assert_eq!(body, "bar");
249 }
250
251 #[tokio::test]
252 async fn deserialize_body_escaped_to_cow() {
253 #[derive(Debug, Deserialize)]
254 struct Input<'a> {
255 #[serde(borrow)]
256 foo: Cow<'a, str>,
257 }
258
259 async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
260 match deserializer.deserialize() {
261 Ok(Input { foo }) => {
262 let Cow::Owned(foo) = foo else {
263 panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters")
264 };
265
266 foo.into_response()
267 }
268 Err(e) => e.into_response(),
269 }
270 }
271
272 let app = Router::new().route("/", post(handler));
273
274 let client = TestClient::new(app);
275
276 let res = client.post("/").json(&json!({ "foo": "\"bar\"" })).await;
278
279 let body = res.text().await;
280
281 assert_eq!(body, r#""bar""#);
282 }
283
284 #[tokio::test]
285 async fn deserialize_body_escaped_to_str() {
286 #[derive(Debug, Deserialize)]
287 struct Input<'a> {
288 foo: &'a str,
291 }
292
293 async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
294 match deserializer.deserialize() {
295 Ok(Input { foo }) => foo.to_owned().into_response(),
296 Err(e) => e.into_response(),
297 }
298 }
299
300 let app = Router::new().route("/", post(handler));
301
302 let client = TestClient::new(app);
303
304 let res = client.post("/").json(&json!({ "foo": "good" })).await;
305 let body = res.text().await;
306 assert_eq!(body, "good");
307
308 let res = client.post("/").json(&json!({ "foo": "\"bad\"" })).await;
309 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
310 let body_text = res.text().await;
311 assert_eq!(
312 body_text,
313 "Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16"
314 );
315 }
316
317 #[tokio::test]
318 async fn consume_body_to_json_requires_json_content_type() {
319 #[derive(Debug, Deserialize)]
320 struct Input<'a> {
321 #[allow(dead_code)]
322 foo: Cow<'a, str>,
323 }
324
325 async fn handler(_deserializer: JsonDeserializer<Input<'_>>) -> Response {
326 panic!("This handler should not be called")
327 }
328
329 let app = Router::new().route("/", post(handler));
330
331 let client = TestClient::new(app);
332 let res = client.post("/").body(r#"{ "foo": "bar" }"#).await;
333
334 let status = res.status();
335
336 assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
337 }
338
339 #[tokio::test]
340 async fn json_content_types() {
341 async fn valid_json_content_type(content_type: &str) -> bool {
342 println!("testing {content_type:?}");
343
344 async fn handler(_deserializer: JsonDeserializer<Value>) -> Response {
345 StatusCode::OK.into_response()
346 }
347
348 let app = Router::new().route("/", post(handler));
349
350 let res = TestClient::new(app)
351 .post("/")
352 .header("content-type", content_type)
353 .body("{}")
354 .await;
355
356 res.status() == StatusCode::OK
357 }
358
359 assert!(valid_json_content_type("application/json").await);
360 assert!(valid_json_content_type("application/json; charset=utf-8").await);
361 assert!(valid_json_content_type("application/json;charset=utf-8").await);
362 assert!(valid_json_content_type("application/cloudevents+json").await);
363 assert!(!valid_json_content_type("text/json").await);
364 }
365
366 #[tokio::test]
367 async fn invalid_json_syntax() {
368 async fn handler(deserializer: JsonDeserializer<Value>) -> Response {
369 match deserializer.deserialize() {
370 Ok(_) => panic!("Should have matched `Err`"),
371 Err(e) => e.into_response(),
372 }
373 }
374
375 let app = Router::new().route("/", post(handler));
376
377 let client = TestClient::new(app);
378 let res = client
379 .post("/")
380 .body("{")
381 .header("content-type", "application/json")
382 .await;
383
384 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
385 }
386
387 #[derive(Deserialize)]
388 struct Foo {
389 #[allow(dead_code)]
390 a: i32,
391 #[allow(dead_code)]
392 b: Vec<Bar>,
393 }
394
395 #[derive(Deserialize)]
396 struct Bar {
397 #[allow(dead_code)]
398 x: i32,
399 #[allow(dead_code)]
400 y: i32,
401 }
402
403 #[tokio::test]
404 async fn invalid_json_data() {
405 async fn handler(deserializer: JsonDeserializer<Foo>) -> Response {
406 match deserializer.deserialize() {
407 Ok(_) => panic!("Should have matched `Err`"),
408 Err(e) => e.into_response(),
409 }
410 }
411
412 let app = Router::new().route("/", post(handler));
413
414 let client = TestClient::new(app);
415 let res = client
416 .post("/")
417 .body("{\"a\": 1, \"b\": [{\"x\": 2}]}")
418 .header("content-type", "application/json")
419 .await;
420
421 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
422 let body_text = res.text().await;
423 assert_eq!(
424 body_text,
425 "Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23"
426 );
427 }
428}