actix_multipart/form/
text.rs

1//! Deserializes a field from plain text.
2
3use std::{str, sync::Arc};
4
5use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
6use derive_more::{Deref, DerefMut, Display, Error};
7use futures_core::future::LocalBoxFuture;
8use serde::de::DeserializeOwned;
9
10use super::FieldErrorHandler;
11use crate::{
12    form::{bytes::Bytes, FieldReader, Limits},
13    Field, MultipartError,
14};
15
16/// Deserialize from plain text.
17///
18/// Internally this uses [`serde_plain`] for deserialization, which supports primitive types
19/// including strings, numbers, and simple enums.
20#[derive(Debug, Deref, DerefMut)]
21pub struct Text<T: DeserializeOwned>(pub T);
22
23impl<T: DeserializeOwned> Text<T> {
24    /// Unwraps into inner value.
25    pub fn into_inner(self) -> T {
26        self.0
27    }
28}
29
30impl<'t, T> FieldReader<'t> for Text<T>
31where
32    T: DeserializeOwned + 'static,
33{
34    type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
35
36    fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
37        Box::pin(async move {
38            let config = TextConfig::from_req(req);
39
40            if config.validate_content_type {
41                let valid = if let Some(mime) = field.content_type() {
42                    mime.subtype() == mime::PLAIN || mime.suffix() == Some(mime::PLAIN)
43                } else {
44                    // https://datatracker.ietf.org/doc/html/rfc7578#section-4.4
45                    // content type defaults to text/plain, so None should be considered valid
46                    true
47                };
48
49                if !valid {
50                    return Err(MultipartError::Field {
51                        name: field.form_field_name,
52                        source: config.map_error(req, TextError::ContentType),
53                    });
54                }
55            }
56
57            let form_field_name = field.form_field_name.clone();
58
59            let bytes = Bytes::read_field(req, field, limits).await?;
60
61            let text = str::from_utf8(&bytes.data).map_err(|err| MultipartError::Field {
62                name: form_field_name.clone(),
63                source: config.map_error(req, TextError::Utf8Error(err)),
64            })?;
65
66            Ok(Text(serde_plain::from_str(text).map_err(|err| {
67                MultipartError::Field {
68                    name: form_field_name,
69                    source: config.map_error(req, TextError::Deserialize(err)),
70                }
71            })?))
72        })
73    }
74}
75
76#[derive(Debug, Display, Error)]
77#[non_exhaustive]
78pub enum TextError {
79    /// UTF-8 decoding error.
80    #[display(fmt = "UTF-8 decoding error: {}", _0)]
81    Utf8Error(str::Utf8Error),
82
83    /// Deserialize error.
84    #[display(fmt = "Plain text deserialize error: {}", _0)]
85    Deserialize(serde_plain::Error),
86
87    /// Content type error.
88    #[display(fmt = "Content type error")]
89    ContentType,
90}
91
92impl ResponseError for TextError {
93    fn status_code(&self) -> StatusCode {
94        StatusCode::BAD_REQUEST
95    }
96}
97
98/// Configuration for the [`Text`] field reader.
99#[derive(Clone)]
100pub struct TextConfig {
101    err_handler: FieldErrorHandler<TextError>,
102    validate_content_type: bool,
103}
104
105impl TextConfig {
106    /// Sets custom error handler.
107    pub fn error_handler<F>(mut self, f: F) -> Self
108    where
109        F: Fn(TextError, &HttpRequest) -> Error + Send + Sync + 'static,
110    {
111        self.err_handler = Some(Arc::new(f));
112        self
113    }
114
115    /// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
116    /// back to the default payload config.
117    fn from_req(req: &HttpRequest) -> &Self {
118        req.app_data::<Self>()
119            .or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
120            .unwrap_or(&DEFAULT_CONFIG)
121    }
122
123    fn map_error(&self, req: &HttpRequest, err: TextError) -> Error {
124        if let Some(ref err_handler) = self.err_handler {
125            (err_handler)(err, req)
126        } else {
127            err.into()
128        }
129    }
130
131    /// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
132    ///
133    /// Note that an empty `Content-Type` is also accepted, as the multipart specification defines
134    /// `text/plain` as the default for text fields.
135    pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
136        self.validate_content_type = validate_content_type;
137        self
138    }
139}
140
141const DEFAULT_CONFIG: TextConfig = TextConfig {
142    err_handler: None,
143    validate_content_type: true,
144};
145
146impl Default for TextConfig {
147    fn default() -> Self {
148        DEFAULT_CONFIG
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use std::io::Cursor;
155
156    use actix_multipart_rfc7578::client::multipart;
157    use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
158
159    use crate::form::{
160        tests::send_form,
161        text::{Text, TextConfig},
162        MultipartForm,
163    };
164
165    #[derive(MultipartForm)]
166    struct TextForm {
167        number: Text<i32>,
168    }
169
170    async fn test_text_route(form: MultipartForm<TextForm>) -> impl Responder {
171        assert_eq!(*form.number, 1025);
172        HttpResponse::Ok().finish()
173    }
174
175    #[actix_rt::test]
176    async fn test_content_type_validation() {
177        let srv = actix_test::start(|| {
178            App::new()
179                .route("/", web::post().to(test_text_route))
180                .app_data(TextConfig::default().validate_content_type(true))
181        });
182
183        // Deny because wrong content type
184        let bytes = Cursor::new("1025");
185        let mut form = multipart::Form::default();
186        form.add_reader_file_with_mime("number", bytes, "", mime::APPLICATION_OCTET_STREAM);
187        let response = send_form(&srv, form, "/").await;
188        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
189
190        // Allow because correct content type
191        let bytes = Cursor::new("1025");
192        let mut form = multipart::Form::default();
193        form.add_reader_file_with_mime("number", bytes, "", mime::TEXT_PLAIN);
194        let response = send_form(&srv, form, "/").await;
195        assert_eq!(response.status(), StatusCode::OK);
196    }
197}