axum_extra/extract/
json_deserializer.rs

1use axum::extract::{FromRequest, Request};
2use axum_core::__composite_rejection as composite_rejection;
3use axum_core::__define_rejection as define_rejection;
4use axum_core::extract::rejection::BytesRejection;
5use bytes::Bytes;
6use http::{header, HeaderMap};
7use serde::Deserialize;
8use std::marker::PhantomData;
9
10/// JSON Extractor for zero-copy deserialization.
11///
12/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`][serde::Deserialize].
13/// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called.
14/// If the type implements [`serde::de::DeserializeOwned`], the [`Json`](axum::Json) extractor should
15/// be preferred.
16///
17/// The request will be rejected (and a [`JsonDeserializerRejection`] will be returned) if:
18///
19/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
20/// - Buffering the request body fails.
21///
22/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if:
23///
24/// - The body doesn't contain syntactically valid JSON.
25/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type.
26/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`).
27///
28/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the
29/// input contains escaped characters. Use `Cow<'a, str>` or `Cow<'a, [u8]>`, with the
30/// `#[serde(borrow)]` attribute, to allow serde to fall back to an owned type when encountering
31/// escaped characters.
32///
33/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be
34/// *last* if there are multiple extractors in a handler.
35/// See ["the order of extractors"][order-of-extractors]
36///
37/// [order-of-extractors]: axum::extract#the-order-of-extractors
38///
39/// See [`JsonDeserializerRejection`] for more details.
40///
41/// # Example
42///
43/// ```rust,no_run
44/// use axum::{
45///     routing::post,
46///     Router,
47///     response::{IntoResponse, Response}
48/// };
49/// use axum_extra::extract::JsonDeserializer;
50/// use serde::Deserialize;
51/// use std::borrow::Cow;
52/// use http::StatusCode;
53///
54/// #[derive(Deserialize)]
55/// struct Data<'a> {
56///     #[serde(borrow)]
57///     borrow_text: Cow<'a, str>,
58///     #[serde(borrow)]
59///     borrow_bytes: Cow<'a, [u8]>,
60///     borrow_dangerous: &'a str,
61///     not_borrowed: String,
62/// }
63///
64/// async fn upload(deserializer: JsonDeserializer<Data<'_>>) -> Response {
65///     let data = match deserializer.deserialize() {
66///         Ok(data) => data,
67///         Err(e) => return e.into_response(),
68///     };
69///
70///     // payload is a `Data` with borrowed data from `deserializer`,
71///     // which owns the request body (`Bytes`).
72///
73///     StatusCode::OK.into_response()
74/// }
75///
76/// let app = Router::new().route("/upload", post(upload));
77/// # let _: Router = app;
78/// ```
79#[derive(Debug, Clone, Default)]
80#[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
81pub struct JsonDeserializer<T> {
82    bytes: Bytes,
83    _marker: PhantomData<T>,
84}
85
86impl<T, S> FromRequest<S> for JsonDeserializer<T>
87where
88    T: Deserialize<'static>,
89    S: Send + Sync,
90{
91    type Rejection = JsonDeserializerRejection;
92
93    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
94        if json_content_type(req.headers()) {
95            let bytes = Bytes::from_request(req, state).await?;
96            Ok(Self {
97                bytes,
98                _marker: PhantomData,
99            })
100        } else {
101            Err(MissingJsonContentType.into())
102        }
103    }
104}
105
106impl<'de, 'a: 'de, T> JsonDeserializer<T>
107where
108    T: Deserialize<'de>,
109{
110    /// Deserialize the request body into the target type.
111    /// See [`JsonDeserializer`] for more details.
112    pub fn deserialize(&'a self) -> Result<T, JsonDeserializerRejection> {
113        let deserializer = &mut serde_json::Deserializer::from_slice(&self.bytes);
114
115        let value = match serde_path_to_error::deserialize(deserializer) {
116            Ok(value) => value,
117            Err(err) => {
118                let rejection = match err.inner().classify() {
119                    serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
120                    serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
121                        JsonSyntaxError::from_err(err).into()
122                    }
123                    serde_json::error::Category::Io => {
124                        if cfg!(debug_assertions) {
125                            // we don't use `serde_json::from_reader` and instead always buffer
126                            // bodies first, so we shouldn't encounter any IO errors
127                            unreachable!()
128                        } else {
129                            JsonSyntaxError::from_err(err).into()
130                        }
131                    }
132                };
133                return Err(rejection);
134            }
135        };
136
137        Ok(value)
138    }
139}
140
141define_rejection! {
142    #[status = UNPROCESSABLE_ENTITY]
143    #[body = "Failed to deserialize the JSON body into the target type"]
144    #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
145    /// Rejection type for [`JsonDeserializer`].
146    ///
147    /// This rejection is used if the request body is syntactically valid JSON but couldn't be
148    /// deserialized into the target type.
149    pub struct JsonDataError(Error);
150}
151
152define_rejection! {
153    #[status = BAD_REQUEST]
154    #[body = "Failed to parse the request body as JSON"]
155    #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
156    /// Rejection type for [`JsonDeserializer`].
157    ///
158    /// This rejection is used if the request body didn't contain syntactically valid JSON.
159    pub struct JsonSyntaxError(Error);
160}
161
162define_rejection! {
163    #[status = UNSUPPORTED_MEDIA_TYPE]
164    #[body = "Expected request with `Content-Type: application/json`"]
165    #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
166    /// Rejection type for [`JsonDeserializer`] used if the `Content-Type`
167    /// header is missing.
168    pub struct MissingJsonContentType;
169}
170
171composite_rejection! {
172    /// Rejection used for [`JsonDeserializer`].
173    ///
174    /// Contains one variant for each way the [`JsonDeserializer`] extractor
175    /// can fail.
176    #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))]
177    pub enum JsonDeserializerRejection {
178        JsonDataError,
179        JsonSyntaxError,
180        MissingJsonContentType,
181        BytesRejection,
182    }
183}
184
185fn json_content_type(headers: &HeaderMap) -> bool {
186    let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
187        content_type
188    } else {
189        return false;
190    };
191
192    let content_type = if let Ok(content_type) = content_type.to_str() {
193        content_type
194    } else {
195        return false;
196    };
197
198    let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
199        mime
200    } else {
201        return false;
202    };
203
204    let is_json_content_type = mime.type_() == "application"
205        && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"));
206
207    is_json_content_type
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::test_helpers::*;
214    use axum::{
215        response::{IntoResponse, Response},
216        routing::post,
217        Router,
218    };
219    use http::StatusCode;
220    use serde::Deserialize;
221    use serde_json::{json, Value};
222    use std::borrow::Cow;
223
224    #[tokio::test]
225    async fn deserialize_body() {
226        #[derive(Debug, Deserialize)]
227        struct Input<'a> {
228            #[serde(borrow)]
229            foo: Cow<'a, str>,
230        }
231
232        async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
233            match deserializer.deserialize() {
234                Ok(input) => {
235                    assert!(matches!(input.foo, Cow::Borrowed(_)));
236                    input.foo.into_owned().into_response()
237                }
238                Err(e) => e.into_response(),
239            }
240        }
241
242        let app = Router::new().route("/", post(handler));
243
244        let client = TestClient::new(app);
245        let res = client.post("/").json(&json!({ "foo": "bar" })).await;
246        let body = res.text().await;
247
248        assert_eq!(body, "bar");
249    }
250
251    #[tokio::test]
252    async fn deserialize_body_escaped_to_cow() {
253        #[derive(Debug, Deserialize)]
254        struct Input<'a> {
255            #[serde(borrow)]
256            foo: Cow<'a, str>,
257        }
258
259        async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
260            match deserializer.deserialize() {
261                Ok(Input { foo }) => {
262                    let Cow::Owned(foo) = foo else {
263                        panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters")
264                    };
265
266                    foo.into_response()
267                }
268                Err(e) => e.into_response(),
269            }
270        }
271
272        let app = Router::new().route("/", post(handler));
273
274        let client = TestClient::new(app);
275
276        // The escaped characters prevent serde_json from borrowing.
277        let res = client.post("/").json(&json!({ "foo": "\"bar\"" })).await;
278
279        let body = res.text().await;
280
281        assert_eq!(body, r#""bar""#);
282    }
283
284    #[tokio::test]
285    async fn deserialize_body_escaped_to_str() {
286        #[derive(Debug, Deserialize)]
287        struct Input<'a> {
288            // Explicit `#[serde(borrow)]` attribute is not required for `&str` or &[u8].
289            // See: https://serde.rs/lifetimes.html#borrowing-data-in-a-derived-impl
290            foo: &'a str,
291        }
292
293        async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
294            match deserializer.deserialize() {
295                Ok(Input { foo }) => foo.to_owned().into_response(),
296                Err(e) => e.into_response(),
297            }
298        }
299
300        let app = Router::new().route("/", post(handler));
301
302        let client = TestClient::new(app);
303
304        let res = client.post("/").json(&json!({ "foo": "good" })).await;
305        let body = res.text().await;
306        assert_eq!(body, "good");
307
308        let res = client.post("/").json(&json!({ "foo": "\"bad\"" })).await;
309        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
310        let body_text = res.text().await;
311        assert_eq!(
312            body_text,
313            "Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16"
314        );
315    }
316
317    #[tokio::test]
318    async fn consume_body_to_json_requires_json_content_type() {
319        #[derive(Debug, Deserialize)]
320        struct Input<'a> {
321            #[allow(dead_code)]
322            foo: Cow<'a, str>,
323        }
324
325        async fn handler(_deserializer: JsonDeserializer<Input<'_>>) -> Response {
326            panic!("This handler should not be called")
327        }
328
329        let app = Router::new().route("/", post(handler));
330
331        let client = TestClient::new(app);
332        let res = client.post("/").body(r#"{ "foo": "bar" }"#).await;
333
334        let status = res.status();
335
336        assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE);
337    }
338
339    #[tokio::test]
340    async fn json_content_types() {
341        async fn valid_json_content_type(content_type: &str) -> bool {
342            println!("testing {content_type:?}");
343
344            async fn handler(_deserializer: JsonDeserializer<Value>) -> Response {
345                StatusCode::OK.into_response()
346            }
347
348            let app = Router::new().route("/", post(handler));
349
350            let res = TestClient::new(app)
351                .post("/")
352                .header("content-type", content_type)
353                .body("{}")
354                .await;
355
356            res.status() == StatusCode::OK
357        }
358
359        assert!(valid_json_content_type("application/json").await);
360        assert!(valid_json_content_type("application/json; charset=utf-8").await);
361        assert!(valid_json_content_type("application/json;charset=utf-8").await);
362        assert!(valid_json_content_type("application/cloudevents+json").await);
363        assert!(!valid_json_content_type("text/json").await);
364    }
365
366    #[tokio::test]
367    async fn invalid_json_syntax() {
368        async fn handler(deserializer: JsonDeserializer<Value>) -> Response {
369            match deserializer.deserialize() {
370                Ok(_) => panic!("Should have matched `Err`"),
371                Err(e) => e.into_response(),
372            }
373        }
374
375        let app = Router::new().route("/", post(handler));
376
377        let client = TestClient::new(app);
378        let res = client
379            .post("/")
380            .body("{")
381            .header("content-type", "application/json")
382            .await;
383
384        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
385    }
386
387    #[derive(Deserialize)]
388    struct Foo {
389        #[allow(dead_code)]
390        a: i32,
391        #[allow(dead_code)]
392        b: Vec<Bar>,
393    }
394
395    #[derive(Deserialize)]
396    struct Bar {
397        #[allow(dead_code)]
398        x: i32,
399        #[allow(dead_code)]
400        y: i32,
401    }
402
403    #[tokio::test]
404    async fn invalid_json_data() {
405        async fn handler(deserializer: JsonDeserializer<Foo>) -> Response {
406            match deserializer.deserialize() {
407                Ok(_) => panic!("Should have matched `Err`"),
408                Err(e) => e.into_response(),
409            }
410        }
411
412        let app = Router::new().route("/", post(handler));
413
414        let client = TestClient::new(app);
415        let res = client
416            .post("/")
417            .body("{\"a\": 1, \"b\": [{\"x\": 2}]}")
418            .header("content-type", "application/json")
419            .await;
420
421        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
422        let body_text = res.text().await;
423        assert_eq!(
424            body_text,
425            "Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23"
426        );
427    }
428}