1use std::{
2 pin::Pin,
3 str,
4 task::{Context, Poll},
5};
6
7use bytes::{Buf, Bytes, BytesMut};
8use futures_core::{stream::BoxStream, Stream};
9use futures_util::StreamExt;
10use mediatype::{MediaType, ReadParams};
11
12use crate::{error::MultipartError, multipart_type::MultipartType};
13
14#[derive(PartialEq, Debug)]
15enum InnerState {
16 Eof,
18
19 FirstBoundary,
21
22 Boundary,
24
25 Headers,
27}
28
29pub struct MultipartItem {
30 pub headers: Vec<(String, String)>,
32
33 pub data: BytesMut,
35}
36
37impl MultipartItem {
38 pub fn get_mime_type(&self) -> Result<MediaType, MultipartError> {
39 let content_type = self
40 .headers
41 .iter()
42 .find(|(key, _)| key.to_lowercase() == "content-type");
43
44 if content_type.is_none() {
45 return Err(MultipartError::InvalidContentType);
46 }
47
48 let ct = MediaType::parse(content_type.unwrap().1.as_str())
49 .map_err(MultipartError::ContentTypeParsingError)?;
50
51 Ok(ct)
52 }
53
54 pub fn get_file_name(&self) -> Option<String> {
55 let content_disposition = self
56 .headers
57 .iter()
58 .find(|(key, _)| key.to_lowercase() == "content-disposition")?;
59
60 let cd = &content_disposition.1;
61 let parts: Vec<&str> = cd.split(";").collect();
62 let filename = parts
63 .iter()
64 .find(|p| p.trim().starts_with("filename="))
65 .map(|p| p.trim().split("=").collect::<Vec<&str>>()[1].to_string());
66
67 filename
68 }
69}
70
71pub struct MultipartReader<'a, E> {
72 pub boundary: String,
73 pub multipart_type: MultipartType,
74 state: InnerState,
76 stream: BoxStream<'a, Result<Bytes, E>>,
77 buf: BytesMut,
78 pending_item: Option<MultipartItem>,
79}
80
81impl<'a, E> MultipartReader<'a, E> {
82 pub fn from_stream_with_boundary_and_type<S>(
83 stream: S,
84 boundary: &str,
85 multipart_type: MultipartType,
86 ) -> Result<MultipartReader<'a, E>, MultipartError>
87 where
88 S: Stream<Item = Result<Bytes, E>> + 'a + Send,
89 {
90 Ok(MultipartReader {
91 stream: stream.boxed(),
92 boundary: boundary.to_string(),
93 multipart_type,
94 state: InnerState::FirstBoundary,
95 pending_item: None,
96 buf: BytesMut::new(),
97 })
98 }
99
100 pub fn from_data_with_boundary_and_type(
101 data: &[u8],
102 boundary: &str,
103 multipart_type: MultipartType,
104 ) -> Result<MultipartReader<'a, E>, MultipartError>
105 where
106 E: std::error::Error + 'a + Send,
107 {
108 let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
109 MultipartReader::from_stream_with_boundary_and_type(stream, boundary, multipart_type)
110 }
111
112 pub fn from_stream_with_headers<S>(
113 stream: S,
114 headers: &[(String, String)],
115 ) -> Result<MultipartReader<'a, E>, MultipartError>
116 where
117 S: Stream<Item = Result<Bytes, E>> + 'a + Send,
118 E: std::error::Error,
119 {
120 let content_type = headers
122 .iter()
123 .find(|(key, _)| key.to_lowercase() == "content-type");
124
125 if content_type.is_none() {
126 return Err(MultipartError::NoContentType);
127 }
128 let ct = MediaType::parse(content_type.unwrap().1.as_str())
129 .map_err(MultipartError::ContentTypeParsingError)?;
130
131 let boundary = ct
132 .get_param(mediatype::names::BOUNDARY)
133 .ok_or(MultipartError::InvalidBoundary)?;
134
135 if ct.ty != mediatype::names::MULTIPART {
136 return Err(MultipartError::InvalidContentType);
137 }
138
139 let multipart_type = ct
140 .subty
141 .as_str()
142 .parse::<MultipartType>()
143 .map_err(|_| MultipartError::InvalidMultipartType)?;
144
145 Ok(MultipartReader {
146 stream: stream.boxed(),
147 boundary: boundary.to_string(),
148 multipart_type,
149 state: InnerState::FirstBoundary,
150 pending_item: None,
151 buf: BytesMut::new(),
152 })
153 }
154
155 pub fn from_data_with_headers(
156 data: &[u8],
157 headers: &[(String, String)],
158 ) -> Result<MultipartReader<'a, E>, MultipartError>
159 where
160 E: std::error::Error + 'a + Send,
161 {
162 let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
163 MultipartReader::from_stream_with_headers(stream, headers)
164 }
165
166 fn is_final_boundary(&self, data: &[u8]) -> bool {
167 let boundary = format!("--{}--", self.boundary);
168 data.starts_with(boundary.as_bytes())
169 }
170
171 fn is_boundary(&self, data: &[u8]) -> bool {
173 let boundary = format!("--{}", self.boundary);
174 data.starts_with(boundary.as_bytes())
175 }
176}
177
178impl<'a, E> Stream for MultipartReader<'a, E> {
179 type Item = Result<MultipartItem, MultipartError>;
180
181 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 let this = self.get_mut();
183 let finder = memchr::memmem::Finder::new("\r\n");
184
185 loop {
186 while let Some(idx) = finder.find(&this.buf) {
187 match this.state {
188 InnerState::FirstBoundary => {
189 if this.is_boundary(&this.buf[..idx]) {
191 this.state = InnerState::Headers;
192 };
193 }
194 InnerState::Boundary => {
195 if this.is_boundary(&this.buf[..idx]) {
197 let final_boundary = this.is_final_boundary(&this.buf[..idx]);
198
199 if let Some(mut item) = this.pending_item.take() {
201 item.data.truncate(item.data.len() - 2);
203 this.buf.advance(2 + idx);
205 if final_boundary {
206 this.state = InnerState::Eof;
207 } else {
208 this.state = InnerState::Headers;
209 }
210 return std::task::Poll::Ready(Some(Ok(item)));
211 }
212
213 this.state = InnerState::Headers;
214 this.pending_item = Some(MultipartItem {
215 headers: vec![],
216 data: BytesMut::new(),
217 });
218 };
219
220 this.pending_item
222 .as_mut()
223 .unwrap()
224 .data
225 .extend(&this.buf[..idx + 2])
226 }
227 InnerState::Headers => {
228 if this.pending_item.is_none() {
230 this.pending_item = Some(MultipartItem {
231 headers: vec![],
232 data: BytesMut::new(),
233 });
234 }
235
236 let header = match str::from_utf8(&this.buf[..idx]) {
238 Ok(h) => h,
239 Err(_) => {
240 this.state = InnerState::Eof;
241 return std::task::Poll::Ready(Some(Err(
242 MultipartError::InvalidItemHeader,
243 )));
244 }
245 };
246
247 if header.trim().is_empty() {
249 this.buf.advance(2 + idx);
250 this.state = InnerState::Boundary;
251 continue;
252 }
253
254 let header_parts: Vec<&str> = header.split(": ").collect();
255 if header_parts.len() != 2 {
256 this.state = InnerState::Eof;
257 return std::task::Poll::Ready(Some(Err(
258 MultipartError::InvalidItemHeader,
259 )));
260 }
261
262 this.pending_item
264 .as_mut()
265 .unwrap()
266 .headers
267 .push((header_parts[0].to_string(), header_parts[1].to_string()));
268 }
269 InnerState::Eof => {
270 return std::task::Poll::Ready(None);
271 }
272 }
273
274 this.buf.advance(2 + idx);
276 }
277
278 match Pin::new(&mut this.stream).poll_next(cx) {
280 Poll::Ready(Some(Ok(data))) => {
281 this.buf.extend_from_slice(&data);
282 }
283 Poll::Ready(None) => {
284 this.state = InnerState::Eof;
285 return std::task::Poll::Ready(None);
286 }
287 Poll::Ready(Some(Err(_e))) => {
288 this.state = InnerState::Eof;
289 return std::task::Poll::Ready(Some(Err(MultipartError::PollingDataFailed)));
290 }
291 Poll::Pending => {
292 return std::task::Poll::Pending;
293 }
294 };
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[futures_test::test]
304 async fn valid_request() {
305 let headermap = vec![(
306 "Content-Type".to_string(),
307 "multipart/form-data; boundary=974767299852498929531610575".to_string(),
308 )];
309 let data = b"--974767299852498929531610575\r
311Content-Disposition: form-data; name=\"text\"\r
312\r
313text default\r
314--974767299852498929531610575\r
315Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r
316Content-Type: text/plain\r
317\r
318Content of a.txt.\r
319\r\n--974767299852498929531610575\r
320Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r
321Content-Type: text/html\r
322\r
323<!DOCTYPE html><title>Content of a.html.</title>\r
324\r
325--974767299852498929531610575--\r\n";
326
327 assert!(
328 MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).is_ok()
329 );
330 assert!(
331 MultipartReader::<std::io::Error>::from_data_with_boundary_and_type(
332 data,
333 "974767299852498929531610575",
334 MultipartType::FormData
335 )
336 .is_ok()
337 );
338
339 let mut reader =
341 MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
342 assert_eq!(reader.multipart_type, MultipartType::FormData);
343 let mut items = vec![];
344
345 loop {
346 match reader.next().await {
347 Some(Ok(item)) => items.push(item),
348 None => break,
349 Some(Err(e)) => panic!("Error: {:?}", e),
350 }
351 }
352
353 assert_eq!(items.len(), 3);
354 }
355
356 #[futures_test::test]
357 async fn valid_request_extra_type() {
358 let headermap = vec![(
359 "Content-Type".to_string(),
360 "multipart/related; type=\"application/dicom\"; boundary=974767299852498929531610575"
361 .to_string(),
362 )];
363 let data = b"--974767299852498929531610575\r
365Content-Disposition: form-data; name=\"text\"\r
366\r
367text default\r
368--974767299852498929531610575\r
369Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r
370Content-Type: text/plain\r
371\r
372Content of a.txt.\r
373\r\n--974767299852498929531610575\r
374Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r
375Content-Type: text/html\r
376\r
377<!DOCTYPE html><title>Content of a.html.</title>\r
378\r
379--974767299852498929531610575--\r\n";
380
381 MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
382 MultipartReader::<std::io::Error>::from_data_with_boundary_and_type(
383 data,
384 "974767299852498929531610575",
385 MultipartType::FormData,
386 )
387 .unwrap();
388
389 let mut reader =
391 MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
392 assert_eq!(reader.multipart_type, MultipartType::Related);
393 let mut items = vec![];
394
395 loop {
396 match reader.next().await {
397 Some(Ok(item)) => items.push(item),
398 None => break,
399 Some(Err(e)) => panic!("Error: {:?}", e),
400 }
401 }
402
403 assert_eq!(items.len(), 3);
404 }
405}