actix_multipart/form/
json.rs

1//! Deserializes a field as JSON.
2
3use std::sync::Arc;
4
5use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
6use derive_more::{Deref, DerefMut, Display, Error};
7use futures_core::future::LocalBoxFuture;
8use serde::de::DeserializeOwned;
9
10use super::FieldErrorHandler;
11use crate::{
12    form::{bytes::Bytes, FieldReader, Limits},
13    Field, MultipartError,
14};
15
16/// Deserialize from JSON.
17#[derive(Debug, Deref, DerefMut)]
18pub struct Json<T: DeserializeOwned>(pub T);
19
20impl<T: DeserializeOwned> Json<T> {
21    pub fn into_inner(self) -> T {
22        self.0
23    }
24}
25
26impl<'t, T> FieldReader<'t> for Json<T>
27where
28    T: DeserializeOwned + 'static,
29{
30    type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
31
32    fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
33        Box::pin(async move {
34            let config = JsonConfig::from_req(req);
35
36            if config.validate_content_type {
37                let valid = if let Some(mime) = field.content_type() {
38                    mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
39                } else {
40                    false
41                };
42
43                if !valid {
44                    return Err(MultipartError::Field {
45                        name: field.form_field_name,
46                        source: config.map_error(req, JsonFieldError::ContentType),
47                    });
48                }
49            }
50
51            let form_field_name = field.form_field_name.clone();
52
53            let bytes = Bytes::read_field(req, field, limits).await?;
54
55            Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err(
56                |err| MultipartError::Field {
57                    name: form_field_name,
58                    source: config.map_error(req, JsonFieldError::Deserialize(err)),
59                },
60            )?))
61        })
62    }
63}
64
65#[derive(Debug, Display, Error)]
66#[non_exhaustive]
67pub enum JsonFieldError {
68    /// Deserialize error.
69    #[display(fmt = "Json deserialize error: {}", _0)]
70    Deserialize(serde_json::Error),
71
72    /// Content type error.
73    #[display(fmt = "Content type error")]
74    ContentType,
75}
76
77impl ResponseError for JsonFieldError {
78    fn status_code(&self) -> StatusCode {
79        StatusCode::BAD_REQUEST
80    }
81}
82
83/// Configuration for the [`Json`] field reader.
84#[derive(Clone)]
85pub struct JsonConfig {
86    err_handler: FieldErrorHandler<JsonFieldError>,
87    validate_content_type: bool,
88}
89
90const DEFAULT_CONFIG: JsonConfig = JsonConfig {
91    err_handler: None,
92    validate_content_type: true,
93};
94
95impl JsonConfig {
96    pub fn error_handler<F>(mut self, f: F) -> Self
97    where
98        F: Fn(JsonFieldError, &HttpRequest) -> Error + Send + Sync + 'static,
99    {
100        self.err_handler = Some(Arc::new(f));
101        self
102    }
103
104    /// Extract payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
105    /// back to the default payload config.
106    fn from_req(req: &HttpRequest) -> &Self {
107        req.app_data::<Self>()
108            .or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
109            .unwrap_or(&DEFAULT_CONFIG)
110    }
111
112    fn map_error(&self, req: &HttpRequest, err: JsonFieldError) -> Error {
113        if let Some(err_handler) = self.err_handler.as_ref() {
114            (*err_handler)(err, req)
115        } else {
116            err.into()
117        }
118    }
119
120    /// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
121    pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
122        self.validate_content_type = validate_content_type;
123        self
124    }
125}
126
127impl Default for JsonConfig {
128    fn default() -> Self {
129        DEFAULT_CONFIG
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::collections::HashMap;
136
137    use actix_web::{http::StatusCode, web, web::Bytes, App, HttpResponse, Responder};
138
139    use crate::form::{
140        json::{Json, JsonConfig},
141        MultipartForm,
142    };
143
144    #[derive(MultipartForm)]
145    struct JsonForm {
146        json: Json<HashMap<String, String>>,
147    }
148
149    async fn test_json_route(form: MultipartForm<JsonForm>) -> impl Responder {
150        let mut expected = HashMap::new();
151        expected.insert("key1".to_owned(), "value1".to_owned());
152        expected.insert("key2".to_owned(), "value2".to_owned());
153        assert_eq!(&*form.json, &expected);
154        HttpResponse::Ok().finish()
155    }
156
157    const TEST_JSON: &str = r#"{"key1": "value1", "key2": "value2"}"#;
158
159    #[actix_rt::test]
160    async fn test_json_without_content_type() {
161        let srv = actix_test::start(|| {
162            App::new()
163                .route("/", web::post().to(test_json_route))
164                .app_data(JsonConfig::default().validate_content_type(false))
165        });
166
167        let (body, headers) = crate::test::create_form_data_payload_and_headers(
168            "json",
169            None,
170            None,
171            Bytes::from_static(TEST_JSON.as_bytes()),
172        );
173        let mut req = srv.post("/");
174        *req.headers_mut() = headers;
175        let res = req.send_body(body).await.unwrap();
176        assert_eq!(res.status(), StatusCode::OK);
177    }
178
179    #[actix_rt::test]
180    async fn test_content_type_validation() {
181        let srv = actix_test::start(|| {
182            App::new()
183                .route("/", web::post().to(test_json_route))
184                .app_data(JsonConfig::default().validate_content_type(true))
185        });
186
187        // Deny because wrong content type
188        let (body, headers) = crate::test::create_form_data_payload_and_headers(
189            "json",
190            None,
191            Some(mime::APPLICATION_OCTET_STREAM),
192            Bytes::from_static(TEST_JSON.as_bytes()),
193        );
194        let mut req = srv.post("/");
195        *req.headers_mut() = headers;
196        let res = req.send_body(body).await.unwrap();
197        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
198
199        // Allow because correct content type
200        let (body, headers) = crate::test::create_form_data_payload_and_headers(
201            "json",
202            None,
203            Some(mime::APPLICATION_JSON),
204            Bytes::from_static(TEST_JSON.as_bytes()),
205        );
206        let mut req = srv.post("/");
207        *req.headers_mut() = headers;
208        let res = req.send_body(body).await.unwrap();
209        assert_eq!(res.status(), StatusCode::OK);
210    }
211}