1use std::{
4 any::Any,
5 collections::HashMap,
6 future::{ready, Future},
7 sync::Arc,
8};
9
10use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest};
11use derive_more::{Deref, DerefMut};
12use futures_core::future::LocalBoxFuture;
13use futures_util::{TryFutureExt as _, TryStreamExt as _};
14
15use crate::{Field, Multipart, MultipartError};
16
17pub mod bytes;
18pub mod json;
19#[cfg(feature = "tempfile")]
20pub mod tempfile;
21pub mod text;
22
23#[cfg(feature = "derive")]
24pub use actix_multipart_derive::MultipartForm;
25
26type FieldErrorHandler<T> = Option<Arc<dyn Fn(T, &HttpRequest) -> Error + Send + Sync>>;
27
28pub trait FieldReader<'t>: Sized + Any {
32 type Future: Future<Output = Result<Self, MultipartError>>;
34
35 fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future;
45}
46
47#[doc(hidden)]
49#[derive(Default, Deref, DerefMut)]
50pub struct State(pub HashMap<String, Box<dyn Any>>);
51
52#[doc(hidden)]
54pub trait FieldGroupReader<'t>: Sized + Any {
55 type Future: Future<Output = Result<(), MultipartError>>;
56
57 fn handle_field(
59 req: &'t HttpRequest,
60 field: Field,
61 limits: &'t mut Limits,
62 state: &'t mut State,
63 duplicate_field: DuplicateField,
64 ) -> Self::Future;
65
66 fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError>;
68}
69
70impl<'t, T> FieldGroupReader<'t> for Option<T>
71where
72 T: FieldReader<'t>,
73{
74 type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
75
76 fn handle_field(
77 req: &'t HttpRequest,
78 field: Field,
79 limits: &'t mut Limits,
80 state: &'t mut State,
81 duplicate_field: DuplicateField,
82 ) -> Self::Future {
83 if state.contains_key(&field.form_field_name) {
84 match duplicate_field {
85 DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
86
87 DuplicateField::Deny => {
88 return Box::pin(ready(Err(MultipartError::DuplicateField(
89 field.form_field_name,
90 ))))
91 }
92
93 DuplicateField::Replace => {}
94 }
95 }
96
97 Box::pin(async move {
98 let field_name = field.form_field_name.clone();
99 let t = T::read_field(req, field, limits).await?;
100 state.insert(field_name, Box::new(t));
101 Ok(())
102 })
103 }
104
105 fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
106 Ok(state.remove(name).map(|m| *m.downcast::<T>().unwrap()))
107 }
108}
109
110impl<'t, T> FieldGroupReader<'t> for Vec<T>
111where
112 T: FieldReader<'t>,
113{
114 type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
115
116 fn handle_field(
117 req: &'t HttpRequest,
118 field: Field,
119 limits: &'t mut Limits,
120 state: &'t mut State,
121 _duplicate_field: DuplicateField,
122 ) -> Self::Future {
123 Box::pin(async move {
124 let vec = state
127 .entry(field.form_field_name.clone())
128 .or_insert_with(|| Box::<Vec<T>>::default())
129 .downcast_mut::<Vec<T>>()
130 .unwrap();
131
132 let item = T::read_field(req, field, limits).await?;
133 vec.push(item);
134
135 Ok(())
136 })
137 }
138
139 fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
140 Ok(state
141 .remove(name)
142 .map(|m| *m.downcast::<Vec<T>>().unwrap())
143 .unwrap_or_default())
144 }
145}
146
147impl<'t, T> FieldGroupReader<'t> for T
148where
149 T: FieldReader<'t>,
150{
151 type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
152
153 fn handle_field(
154 req: &'t HttpRequest,
155 field: Field,
156 limits: &'t mut Limits,
157 state: &'t mut State,
158 duplicate_field: DuplicateField,
159 ) -> Self::Future {
160 if state.contains_key(&field.form_field_name) {
161 match duplicate_field {
162 DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
163
164 DuplicateField::Deny => {
165 return Box::pin(ready(Err(MultipartError::DuplicateField(
166 field.form_field_name,
167 ))))
168 }
169
170 DuplicateField::Replace => {}
171 }
172 }
173
174 Box::pin(async move {
175 let field_name = field.form_field_name.clone();
176 let t = T::read_field(req, field, limits).await?;
177 state.insert(field_name, Box::new(t));
178 Ok(())
179 })
180 }
181
182 fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
183 state
184 .remove(name)
185 .map(|m| *m.downcast::<T>().unwrap())
186 .ok_or_else(|| MultipartError::MissingField(name.to_owned()))
187 }
188}
189
190pub trait MultipartCollect: Sized {
194 fn limit(field_name: &str) -> Option<usize>;
197
198 fn handle_field<'t>(
201 req: &'t HttpRequest,
202 field: Field,
203 limits: &'t mut Limits,
204 state: &'t mut State,
205 ) -> LocalBoxFuture<'t, Result<(), MultipartError>>;
206
207 fn from_state(state: State) -> Result<Self, MultipartError>;
210}
211
212#[doc(hidden)]
213pub enum DuplicateField {
214 Ignore,
216
217 Deny,
219
220 Replace,
222}
223
224pub struct Limits {
226 pub total_limit_remaining: usize,
227 pub memory_limit_remaining: usize,
228 pub field_limit_remaining: Option<usize>,
229}
230
231impl Limits {
232 pub fn new(total_limit: usize, memory_limit: usize) -> Self {
233 Self {
234 total_limit_remaining: total_limit,
235 memory_limit_remaining: memory_limit,
236 field_limit_remaining: None,
237 }
238 }
239
240 pub fn try_consume_limits(
248 &mut self,
249 bytes: usize,
250 in_memory: bool,
251 ) -> Result<(), MultipartError> {
252 self.total_limit_remaining = self
253 .total_limit_remaining
254 .checked_sub(bytes)
255 .ok_or(MultipartError::Payload(PayloadError::Overflow))?;
256
257 if in_memory {
258 self.memory_limit_remaining = self
259 .memory_limit_remaining
260 .checked_sub(bytes)
261 .ok_or(MultipartError::Payload(PayloadError::Overflow))?;
262 }
263
264 if let Some(field_limit) = self.field_limit_remaining {
265 self.field_limit_remaining = Some(
266 field_limit
267 .checked_sub(bytes)
268 .ok_or(MultipartError::Payload(PayloadError::Overflow))?,
269 );
270 }
271
272 Ok(())
273 }
274}
275
276#[derive(Deref, DerefMut)]
287pub struct MultipartForm<T: MultipartCollect>(pub T);
288
289impl<T: MultipartCollect> MultipartForm<T> {
290 pub fn into_inner(self) -> T {
292 self.0
293 }
294}
295
296impl<T> FromRequest for MultipartForm<T>
297where
298 T: MultipartCollect + 'static,
299{
300 type Error = Error;
301 type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
302
303 #[inline]
304 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
305 let mut multipart = Multipart::from_req(req, payload);
306
307 let content_type = match multipart.content_type_or_bail() {
308 Ok(content_type) => content_type,
309 Err(err) => return Box::pin(ready(Err(err.into()))),
310 };
311
312 if content_type.subtype() != mime::FORM_DATA {
313 return Box::pin(ready(Err(MultipartError::ContentTypeIncompatible.into())));
315 };
316
317 let config = MultipartFormConfig::from_req(req);
318 let mut limits = Limits::new(config.total_limit, config.memory_limit);
319
320 let req = req.clone();
321 let req2 = req.clone();
322 let err_handler = config.err_handler.clone();
323
324 Box::pin(
325 async move {
326 let mut state = State::default();
327
328 let mut field_limits = HashMap::<String, Option<usize>>::new();
330
331 while let Some(field) = multipart.try_next().await? {
332 debug_assert!(
333 !field.form_field_name.is_empty(),
334 "multipart form fields should have names",
335 );
336
337 let entry = field_limits
339 .entry(field.form_field_name.clone())
340 .or_insert_with(|| T::limit(&field.form_field_name));
341
342 limits.field_limit_remaining.clone_from(entry);
343
344 T::handle_field(&req, field, &mut limits, &mut state).await?;
345
346 *entry = limits.field_limit_remaining;
348 }
349
350 let inner = T::from_state(state)?;
351 Ok(MultipartForm(inner))
352 }
353 .map_err(move |err| {
354 if let Some(handler) = err_handler {
355 (*handler)(err, &req2)
356 } else {
357 err.into()
358 }
359 }),
360 )
361 }
362}
363
364type MultipartFormErrorHandler =
365 Option<Arc<dyn Fn(MultipartError, &HttpRequest) -> Error + Send + Sync>>;
366
367#[derive(Clone)]
371pub struct MultipartFormConfig {
372 total_limit: usize,
373 memory_limit: usize,
374 err_handler: MultipartFormErrorHandler,
375}
376
377impl MultipartFormConfig {
378 pub fn total_limit(mut self, total_limit: usize) -> Self {
380 self.total_limit = total_limit;
381 self
382 }
383
384 pub fn memory_limit(mut self, memory_limit: usize) -> Self {
386 self.memory_limit = memory_limit;
387 self
388 }
389
390 pub fn error_handler<F>(mut self, f: F) -> Self
392 where
393 F: Fn(MultipartError, &HttpRequest) -> Error + Send + Sync + 'static,
394 {
395 self.err_handler = Some(Arc::new(f));
396 self
397 }
398
399 fn from_req(req: &HttpRequest) -> &Self {
402 req.app_data::<Self>()
403 .or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
404 .unwrap_or(&DEFAULT_CONFIG)
405 }
406}
407
408const DEFAULT_CONFIG: MultipartFormConfig = MultipartFormConfig {
409 total_limit: 52_428_800, memory_limit: 2_097_152, err_handler: None,
412};
413
414impl Default for MultipartFormConfig {
415 fn default() -> Self {
416 DEFAULT_CONFIG
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use actix_http::encoding::Decoder;
423 use actix_multipart_rfc7578::client::multipart;
424 use actix_test::TestServer;
425 use actix_web::{
426 dev::Payload, http::StatusCode, web, App, HttpRequest, HttpResponse, Resource, Responder,
427 };
428 use awc::{Client, ClientResponse};
429 use futures_core::future::LocalBoxFuture;
430 use futures_util::TryStreamExt as _;
431
432 use super::MultipartForm;
433 use crate::{
434 form::{
435 bytes::Bytes, tempfile::TempFile, text::Text, FieldReader, Limits, MultipartFormConfig,
436 },
437 Field, MultipartError,
438 };
439
440 pub async fn send_form(
441 srv: &TestServer,
442 form: multipart::Form<'static>,
443 uri: &'static str,
444 ) -> ClientResponse<Decoder<Payload>> {
445 Client::default()
446 .post(srv.url(uri))
447 .content_type(form.content_type())
448 .send_body(multipart::Body::from(form))
449 .await
450 .unwrap()
451 }
452
453 #[derive(MultipartForm)]
455 struct TestOptions {
456 field1: Option<Text<String>>,
457 field2: Option<Text<String>>,
458 }
459
460 async fn test_options_route(form: MultipartForm<TestOptions>) -> impl Responder {
461 assert!(form.field1.is_some());
462 assert!(form.field2.is_none());
463 HttpResponse::Ok().finish()
464 }
465
466 #[actix_rt::test]
467 async fn test_options() {
468 let srv = actix_test::start(|| App::new().route("/", web::post().to(test_options_route)));
469
470 let mut form = multipart::Form::default();
471 form.add_text("field1", "value");
472
473 let response = send_form(&srv, form, "/").await;
474 assert_eq!(response.status(), StatusCode::OK);
475 }
476
477 #[derive(MultipartForm)]
479 struct TestVec {
480 list1: Vec<Text<String>>,
481 list2: Vec<Text<String>>,
482 }
483
484 async fn test_vec_route(form: MultipartForm<TestVec>) -> impl Responder {
485 let form = form.into_inner();
486 let strings = form
487 .list1
488 .into_iter()
489 .map(|s| s.into_inner())
490 .collect::<Vec<_>>();
491 assert_eq!(strings, vec!["value1", "value2", "value3"]);
492 assert_eq!(form.list2.len(), 0);
493 HttpResponse::Ok().finish()
494 }
495
496 #[actix_rt::test]
497 async fn test_vec() {
498 let srv = actix_test::start(|| App::new().route("/", web::post().to(test_vec_route)));
499
500 let mut form = multipart::Form::default();
501 form.add_text("list1", "value1");
502 form.add_text("list1", "value2");
503 form.add_text("list1", "value3");
504
505 let response = send_form(&srv, form, "/").await;
506 assert_eq!(response.status(), StatusCode::OK);
507 }
508
509 #[derive(MultipartForm)]
511 struct TestFieldRenaming {
512 #[multipart(rename = "renamed")]
513 field1: Text<String>,
514 #[multipart(rename = "field1")]
515 field2: Text<String>,
516 field3: Text<String>,
517 }
518
519 async fn test_field_renaming_route(form: MultipartForm<TestFieldRenaming>) -> impl Responder {
520 assert_eq!(&*form.field1, "renamed");
521 assert_eq!(&*form.field2, "field1");
522 assert_eq!(&*form.field3, "field3");
523 HttpResponse::Ok().finish()
524 }
525
526 #[actix_rt::test]
527 async fn test_field_renaming() {
528 let srv =
529 actix_test::start(|| App::new().route("/", web::post().to(test_field_renaming_route)));
530
531 let mut form = multipart::Form::default();
532 form.add_text("renamed", "renamed");
533 form.add_text("field1", "field1");
534 form.add_text("field3", "field3");
535
536 let response = send_form(&srv, form, "/").await;
537 assert_eq!(response.status(), StatusCode::OK);
538 }
539
540 #[derive(MultipartForm)]
542 #[multipart(deny_unknown_fields)]
543 struct TestDenyUnknown {}
544
545 #[derive(MultipartForm)]
546 struct TestAllowUnknown {}
547
548 async fn test_deny_unknown_route(_: MultipartForm<TestDenyUnknown>) -> impl Responder {
549 HttpResponse::Ok().finish()
550 }
551
552 async fn test_allow_unknown_route(_: MultipartForm<TestAllowUnknown>) -> impl Responder {
553 HttpResponse::Ok().finish()
554 }
555
556 #[actix_rt::test]
557 async fn test_deny_unknown() {
558 let srv = actix_test::start(|| {
559 App::new()
560 .route("/deny", web::post().to(test_deny_unknown_route))
561 .route("/allow", web::post().to(test_allow_unknown_route))
562 });
563
564 let mut form = multipart::Form::default();
565 form.add_text("unknown", "value");
566 let response = send_form(&srv, form, "/deny").await;
567 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
568
569 let mut form = multipart::Form::default();
570 form.add_text("unknown", "value");
571 let response = send_form(&srv, form, "/allow").await;
572 assert_eq!(response.status(), StatusCode::OK);
573 }
574
575 #[derive(MultipartForm)]
577 #[multipart(duplicate_field = "deny")]
578 struct TestDuplicateDeny {
579 _field: Text<String>,
580 }
581
582 #[derive(MultipartForm)]
583 #[multipart(duplicate_field = "replace")]
584 struct TestDuplicateReplace {
585 field: Text<String>,
586 }
587
588 #[derive(MultipartForm)]
589 #[multipart(duplicate_field = "ignore")]
590 struct TestDuplicateIgnore {
591 field: Text<String>,
592 }
593
594 async fn test_duplicate_deny_route(_: MultipartForm<TestDuplicateDeny>) -> impl Responder {
595 HttpResponse::Ok().finish()
596 }
597
598 async fn test_duplicate_replace_route(
599 form: MultipartForm<TestDuplicateReplace>,
600 ) -> impl Responder {
601 assert_eq!(&*form.field, "second_value");
602 HttpResponse::Ok().finish()
603 }
604
605 async fn test_duplicate_ignore_route(
606 form: MultipartForm<TestDuplicateIgnore>,
607 ) -> impl Responder {
608 assert_eq!(&*form.field, "first_value");
609 HttpResponse::Ok().finish()
610 }
611
612 #[actix_rt::test]
613 async fn test_duplicate_field() {
614 let srv = actix_test::start(|| {
615 App::new()
616 .route("/deny", web::post().to(test_duplicate_deny_route))
617 .route("/replace", web::post().to(test_duplicate_replace_route))
618 .route("/ignore", web::post().to(test_duplicate_ignore_route))
619 });
620
621 let mut form = multipart::Form::default();
622 form.add_text("_field", "first_value");
623 form.add_text("_field", "second_value");
624 let response = send_form(&srv, form, "/deny").await;
625 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
626
627 let mut form = multipart::Form::default();
628 form.add_text("field", "first_value");
629 form.add_text("field", "second_value");
630 let response = send_form(&srv, form, "/replace").await;
631 assert_eq!(response.status(), StatusCode::OK);
632
633 let mut form = multipart::Form::default();
634 form.add_text("field", "first_value");
635 form.add_text("field", "second_value");
636 let response = send_form(&srv, form, "/ignore").await;
637 assert_eq!(response.status(), StatusCode::OK);
638 }
639
640 #[derive(MultipartForm)]
642 struct TestMemoryUploadLimits {
643 field: Bytes,
644 }
645
646 #[derive(MultipartForm)]
647 struct TestFileUploadLimits {
648 field: TempFile,
649 }
650
651 async fn test_upload_limits_memory(
652 form: MultipartForm<TestMemoryUploadLimits>,
653 ) -> impl Responder {
654 assert!(!form.field.data.is_empty());
655 HttpResponse::Ok().finish()
656 }
657
658 async fn test_upload_limits_file(form: MultipartForm<TestFileUploadLimits>) -> impl Responder {
659 assert!(form.field.size > 0);
660 HttpResponse::Ok().finish()
661 }
662
663 #[actix_rt::test]
664 async fn test_memory_limits() {
665 let srv = actix_test::start(|| {
666 App::new()
667 .route("/text", web::post().to(test_upload_limits_memory))
668 .route("/file", web::post().to(test_upload_limits_file))
669 .app_data(
670 MultipartFormConfig::default()
671 .memory_limit(20)
672 .total_limit(usize::MAX),
673 )
674 });
675
676 let mut form = multipart::Form::default();
678 form.add_text("field", "this string is 28 bytes long");
679 let response = send_form(&srv, form, "/text").await;
680 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
681
682 let mut form = multipart::Form::default();
684 form.add_text("field", "this string is 28 bytes long");
685 let response = send_form(&srv, form, "/file").await;
686 assert_eq!(response.status(), StatusCode::OK);
687 }
688
689 #[actix_rt::test]
690 async fn test_total_limit() {
691 let srv = actix_test::start(|| {
692 App::new()
693 .route("/text", web::post().to(test_upload_limits_memory))
694 .route("/file", web::post().to(test_upload_limits_file))
695 .app_data(
696 MultipartFormConfig::default()
697 .memory_limit(usize::MAX)
698 .total_limit(20),
699 )
700 });
701
702 let mut form = multipart::Form::default();
704 form.add_text("field", "7 bytes");
705 let response = send_form(&srv, form, "/text").await;
706 assert_eq!(response.status(), StatusCode::OK);
707
708 let mut form = multipart::Form::default();
710 form.add_text("field", "this string is 28 bytes long");
711 let response = send_form(&srv, form, "/text").await;
712 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
713
714 let mut form = multipart::Form::default();
716 form.add_text("field", "this string is 28 bytes long");
717 let response = send_form(&srv, form, "/file").await;
718 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
719 }
720
721 #[derive(MultipartForm)]
722 struct TestFieldLevelLimits {
723 #[multipart(limit = "30B")]
724 field: Vec<Bytes>,
725 }
726
727 async fn test_field_level_limits_route(
728 form: MultipartForm<TestFieldLevelLimits>,
729 ) -> impl Responder {
730 assert!(!form.field.is_empty());
731 HttpResponse::Ok().finish()
732 }
733
734 #[actix_rt::test]
735 async fn test_field_level_limits() {
736 let srv = actix_test::start(|| {
737 App::new()
738 .route("/", web::post().to(test_field_level_limits_route))
739 .app_data(
740 MultipartFormConfig::default()
741 .memory_limit(usize::MAX)
742 .total_limit(usize::MAX),
743 )
744 });
745
746 let mut form = multipart::Form::default();
748 form.add_text("field", "this string is 28 bytes long");
749 let response = send_form(&srv, form, "/").await;
750 assert_eq!(response.status(), StatusCode::OK);
751
752 let mut form = multipart::Form::default();
754 form.add_text("field", "this string is more than 30 bytes long");
755 let response = send_form(&srv, form, "/").await;
756 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
757
758 let mut form = multipart::Form::default();
760 form.add_text("field", "7 bytes");
761 form.add_text("field", "7 bytes");
762 let response = send_form(&srv, form, "/").await;
763 assert_eq!(response.status(), StatusCode::OK);
764
765 let mut form = multipart::Form::default();
767 form.add_text("field", "this string is 28 bytes long");
768 form.add_text("field", "this string is 28 bytes long");
769 let response = send_form(&srv, form, "/").await;
770 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
771 }
772
773 #[actix_rt::test]
774 async fn non_multipart_form_data() {
775 #[derive(MultipartForm)]
776 struct TestNonMultipartFormData {
777 #[allow(unused)]
778 #[multipart(limit = "30B")]
779 foo: Text<String>,
780 }
781
782 async fn non_multipart_form_data_route(
783 _form: MultipartForm<TestNonMultipartFormData>,
784 ) -> String {
785 unreachable!("request is sent with multipart/mixed");
786 }
787
788 let srv = actix_test::start(|| {
789 App::new().route("/", web::post().to(non_multipart_form_data_route))
790 });
791
792 let mut form = multipart::Form::default();
793 form.add_text("foo", "foo");
794
795 let ct = form.content_type().replacen("/form-data", "/mixed", 1);
797
798 let res = Client::default()
799 .post(srv.url("/"))
800 .content_type(ct)
801 .send_body(multipart::Body::from(form))
802 .await
803 .unwrap();
804
805 assert_eq!(res.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
806 }
807
808 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: Connect(Disconnected)")]
809 #[actix_web::test]
810 async fn field_try_next_panic() {
811 #[derive(Debug)]
812 struct NullSink;
813
814 impl<'t> FieldReader<'t> for NullSink {
815 type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
816
817 fn read_field(
818 _: &'t HttpRequest,
819 mut field: Field,
820 _limits: &'t mut Limits,
821 ) -> Self::Future {
822 Box::pin(async move {
823 while let Some(_chunk) = field.try_next().await? {}
825
826 let _post = field.try_next().await;
828
829 Ok(Self)
830 })
831 }
832 }
833
834 #[allow(dead_code)]
835 #[derive(MultipartForm)]
836 struct NullSinkForm {
837 foo: NullSink,
838 }
839
840 async fn null_sink(_form: MultipartForm<NullSinkForm>) -> impl Responder {
841 "unreachable"
842 }
843
844 let srv = actix_test::start(|| App::new().service(Resource::new("/").post(null_sink)));
845
846 let mut form = multipart::Form::default();
847 form.add_text("foo", "data is not important to this test");
848
849 let _res = send_form(&srv, form, "/").await;
851 }
852}