multipart_rs/
reader.rs

1use std::{
2    pin::Pin,
3    str,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, Bytes, BytesMut};
8use futures_core::{stream::BoxStream, Stream};
9use futures_util::StreamExt;
10use mediatype::{MediaType, ReadParams};
11
12use crate::{error::MultipartError, multipart_type::MultipartType};
13
14#[derive(PartialEq, Debug)]
15enum InnerState {
16    /// Stream eof
17    Eof,
18
19    /// Skip data until first boundary
20    FirstBoundary,
21
22    /// Reading boundary
23    Boundary,
24
25    /// Reading Headers,
26    Headers,
27}
28
29pub struct MultipartItem {
30    /// Headers
31    pub headers: Vec<(String, String)>,
32
33    /// Data
34    pub data: BytesMut,
35}
36
37impl MultipartItem {
38    pub fn get_mime_type(&self) -> Result<MediaType, MultipartError> {
39        let content_type = self
40            .headers
41            .iter()
42            .find(|(key, _)| key.to_lowercase() == "content-type");
43
44        if content_type.is_none() {
45            return Err(MultipartError::InvalidContentType);
46        }
47
48        let ct = MediaType::parse(content_type.unwrap().1.as_str())
49            .map_err(MultipartError::ContentTypeParsingError)?;
50
51        Ok(ct)
52    }
53
54    pub fn get_file_name(&self) -> Option<String> {
55        let content_disposition = self
56            .headers
57            .iter()
58            .find(|(key, _)| key.to_lowercase() == "content-disposition")?;
59
60        let cd = &content_disposition.1;
61        let parts: Vec<&str> = cd.split(";").collect();
62        let filename = parts
63            .iter()
64            .find(|p| p.trim().starts_with("filename="))
65            .map(|p| p.trim().split("=").collect::<Vec<&str>>()[1].to_string());
66
67        filename
68    }
69}
70
71pub struct MultipartReader<'a, E> {
72    pub boundary: String,
73    pub multipart_type: MultipartType,
74    /// Inner state
75    state: InnerState,
76    stream: BoxStream<'a, Result<Bytes, E>>,
77    buf: BytesMut,
78    pending_item: Option<MultipartItem>,
79}
80
81impl<'a, E> MultipartReader<'a, E> {
82    pub fn from_stream_with_boundary_and_type<S>(
83        stream: S,
84        boundary: &str,
85        multipart_type: MultipartType,
86    ) -> Result<MultipartReader<'a, E>, MultipartError>
87    where
88        S: Stream<Item = Result<Bytes, E>> + 'a + Send,
89    {
90        Ok(MultipartReader {
91            stream: stream.boxed(),
92            boundary: boundary.to_string(),
93            multipart_type,
94            state: InnerState::FirstBoundary,
95            pending_item: None,
96            buf: BytesMut::new(),
97        })
98    }
99
100    pub fn from_data_with_boundary_and_type(
101        data: &[u8],
102        boundary: &str,
103        multipart_type: MultipartType,
104    ) -> Result<MultipartReader<'a, E>, MultipartError>
105    where
106        E: std::error::Error + 'a + Send,
107    {
108        let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
109        MultipartReader::from_stream_with_boundary_and_type(stream, boundary, multipart_type)
110    }
111
112    pub fn from_stream_with_headers<S>(
113        stream: S,
114        headers: &[(String, String)],
115    ) -> Result<MultipartReader<'a, E>, MultipartError>
116    where
117        S: Stream<Item = Result<Bytes, E>> + 'a + Send,
118        E: std::error::Error,
119    {
120        // Search for the content-type header
121        let content_type = headers
122            .iter()
123            .find(|(key, _)| key.to_lowercase() == "content-type");
124
125        if content_type.is_none() {
126            return Err(MultipartError::NoContentType);
127        }
128        let ct = MediaType::parse(content_type.unwrap().1.as_str())
129            .map_err(MultipartError::ContentTypeParsingError)?;
130
131        let boundary = ct
132            .get_param(mediatype::names::BOUNDARY)
133            .ok_or(MultipartError::InvalidBoundary)?;
134
135        if ct.ty != mediatype::names::MULTIPART {
136            return Err(MultipartError::InvalidContentType);
137        }
138
139        let multipart_type = ct
140            .subty
141            .as_str()
142            .parse::<MultipartType>()
143            .map_err(|_| MultipartError::InvalidMultipartType)?;
144
145        Ok(MultipartReader {
146            stream: stream.boxed(),
147            boundary: boundary.to_string(),
148            multipart_type,
149            state: InnerState::FirstBoundary,
150            pending_item: None,
151            buf: BytesMut::new(),
152        })
153    }
154
155    pub fn from_data_with_headers(
156        data: &[u8],
157        headers: &[(String, String)],
158    ) -> Result<MultipartReader<'a, E>, MultipartError>
159    where
160        E: std::error::Error + 'a + Send,
161    {
162        let stream = futures_util::stream::iter(vec![Ok(Bytes::copy_from_slice(data))]);
163        MultipartReader::from_stream_with_headers(stream, headers)
164    }
165
166    fn is_final_boundary(&self, data: &[u8]) -> bool {
167        let boundary = format!("--{}--", self.boundary);
168        data.starts_with(boundary.as_bytes())
169    }
170
171    // TODO: make this RFC compliant
172    fn is_boundary(&self, data: &[u8]) -> bool {
173        let boundary = format!("--{}", self.boundary);
174        data.starts_with(boundary.as_bytes())
175    }
176}
177
178impl<'a, E> Stream for MultipartReader<'a, E> {
179    type Item = Result<MultipartItem, MultipartError>;
180
181    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182        let this = self.get_mut();
183        let finder = memchr::memmem::Finder::new("\r\n");
184
185        loop {
186            while let Some(idx) = finder.find(&this.buf) {
187                match this.state {
188                    InnerState::FirstBoundary => {
189                        // Check if the last line was a boundary
190                        if this.is_boundary(&this.buf[..idx]) {
191                            this.state = InnerState::Headers;
192                        };
193                    }
194                    InnerState::Boundary => {
195                        // Check if the last line was a boundary
196                        if this.is_boundary(&this.buf[..idx]) {
197                            let final_boundary = this.is_final_boundary(&this.buf[..idx]);
198
199                            // If we have a pending item, return it
200                            if let Some(mut item) = this.pending_item.take() {
201                                // Remove last 2 bytes from the data (which were a newline sequence)
202                                item.data.truncate(item.data.len() - 2);
203                                // Skip to the next line
204                                this.buf.advance(2 + idx);
205                                if final_boundary {
206                                    this.state = InnerState::Eof;
207                                } else {
208                                    this.state = InnerState::Headers;
209                                }
210                                return std::task::Poll::Ready(Some(Ok(item)));
211                            }
212
213                            this.state = InnerState::Headers;
214                            this.pending_item = Some(MultipartItem {
215                                headers: vec![],
216                                data: BytesMut::new(),
217                            });
218                        };
219
220                        // Add the data to the pending item
221                        this.pending_item
222                            .as_mut()
223                            .unwrap()
224                            .data
225                            .extend(&this.buf[..idx + 2])
226                    }
227                    InnerState::Headers => {
228                        // Check if we have a pending item or we should create one
229                        if this.pending_item.is_none() {
230                            this.pending_item = Some(MultipartItem {
231                                headers: vec![],
232                                data: BytesMut::new(),
233                            });
234                        }
235
236                        // Read the header line and split it into key and value
237                        let header = match str::from_utf8(&this.buf[..idx]) {
238                            Ok(h) => h,
239                            Err(_) => {
240                                this.state = InnerState::Eof;
241                                return std::task::Poll::Ready(Some(Err(
242                                    MultipartError::InvalidItemHeader,
243                                )));
244                            }
245                        };
246
247                        // This is no header anymore, we are at the end of the headers
248                        if header.trim().is_empty() {
249                            this.buf.advance(2 + idx);
250                            this.state = InnerState::Boundary;
251                            continue;
252                        }
253
254                        let header_parts: Vec<&str> = header.split(": ").collect();
255                        if header_parts.len() != 2 {
256                            this.state = InnerState::Eof;
257                            return std::task::Poll::Ready(Some(Err(
258                                MultipartError::InvalidItemHeader,
259                            )));
260                        }
261
262                        // Add header entry to the pending item
263                        this.pending_item
264                            .as_mut()
265                            .unwrap()
266                            .headers
267                            .push((header_parts[0].to_string(), header_parts[1].to_string()));
268                    }
269                    InnerState::Eof => {
270                        return std::task::Poll::Ready(None);
271                    }
272                }
273
274                // Skip to the next line
275                this.buf.advance(2 + idx);
276            }
277
278            // Read more data from the stream
279            match Pin::new(&mut this.stream).poll_next(cx) {
280                Poll::Ready(Some(Ok(data))) => {
281                    this.buf.extend_from_slice(&data);
282                }
283                Poll::Ready(None) => {
284                    this.state = InnerState::Eof;
285                    return std::task::Poll::Ready(None);
286                }
287                Poll::Ready(Some(Err(_e))) => {
288                    this.state = InnerState::Eof;
289                    return std::task::Poll::Ready(Some(Err(MultipartError::PollingDataFailed)));
290                }
291                Poll::Pending => {
292                    return std::task::Poll::Pending;
293                }
294            };
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[futures_test::test]
304    async fn valid_request() {
305        let headermap = vec![(
306            "Content-Type".to_string(),
307            "multipart/form-data; boundary=974767299852498929531610575".to_string(),
308        )];
309        // Lines must end with CRLF
310        let data = b"--974767299852498929531610575\r
311Content-Disposition: form-data; name=\"text\"\r
312\r
313text default\r
314--974767299852498929531610575\r
315Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r
316Content-Type: text/plain\r
317\r
318Content of a.txt.\r
319\r\n--974767299852498929531610575\r
320Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r
321Content-Type: text/html\r
322\r
323<!DOCTYPE html><title>Content of a.html.</title>\r
324\r
325--974767299852498929531610575--\r\n";
326
327        assert!(
328            MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).is_ok()
329        );
330        assert!(
331            MultipartReader::<std::io::Error>::from_data_with_boundary_and_type(
332                data,
333                "974767299852498929531610575",
334                MultipartType::FormData
335            )
336            .is_ok()
337        );
338
339        // Poll all the items from the reader
340        let mut reader =
341            MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
342        assert_eq!(reader.multipart_type, MultipartType::FormData);
343        let mut items = vec![];
344
345        loop {
346            match reader.next().await {
347                Some(Ok(item)) => items.push(item),
348                None => break,
349                Some(Err(e)) => panic!("Error: {:?}", e),
350            }
351        }
352
353        assert_eq!(items.len(), 3);
354    }
355
356    #[futures_test::test]
357    async fn valid_request_extra_type() {
358        let headermap = vec![(
359            "Content-Type".to_string(),
360            "multipart/related; type=\"application/dicom\"; boundary=974767299852498929531610575"
361                .to_string(),
362        )];
363        // Lines must end with CRLF
364        let data = b"--974767299852498929531610575\r
365Content-Disposition: form-data; name=\"text\"\r
366\r
367text default\r
368--974767299852498929531610575\r
369Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r
370Content-Type: text/plain\r
371\r
372Content of a.txt.\r
373\r\n--974767299852498929531610575\r
374Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r
375Content-Type: text/html\r
376\r
377<!DOCTYPE html><title>Content of a.html.</title>\r
378\r
379--974767299852498929531610575--\r\n";
380
381        MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
382        MultipartReader::<std::io::Error>::from_data_with_boundary_and_type(
383            data,
384            "974767299852498929531610575",
385            MultipartType::FormData,
386        )
387        .unwrap();
388
389        // Poll all the items from the reader
390        let mut reader =
391            MultipartReader::<std::io::Error>::from_data_with_headers(data, &headermap).unwrap();
392        assert_eq!(reader.multipart_type, MultipartType::Related);
393        let mut items = vec![];
394
395        loop {
396            match reader.next().await {
397                Some(Ok(item)) => items.push(item),
398                None => break,
399                Some(Err(e)) => panic!("Error: {:?}", e),
400            }
401        }
402
403        assert_eq!(items.len(), 3);
404    }
405}