async_graphql/http/
multipart.rs

1use std::{
2    collections::HashMap,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures_util::{io::AsyncRead, stream::Stream};
9use multer::{Constraints, Multipart, SizeLimit};
10use pin_project_lite::pin_project;
11
12use crate::{BatchRequest, ParseRequestError, UploadValue};
13
14/// Options for `receive_multipart`.
15#[derive(Default, Clone, Copy)]
16#[non_exhaustive]
17pub struct MultipartOptions {
18    /// The maximum file size.
19    pub max_file_size: Option<usize>,
20    /// The maximum number of files.
21    pub max_num_files: Option<usize>,
22}
23
24impl MultipartOptions {
25    /// Set maximum file size.
26    #[must_use]
27    pub fn max_file_size(self, size: usize) -> Self {
28        MultipartOptions {
29            max_file_size: Some(size),
30            ..self
31        }
32    }
33
34    /// Set maximum number of files.
35    #[must_use]
36    pub fn max_num_files(self, n: usize) -> Self {
37        MultipartOptions {
38            max_num_files: Some(n),
39            ..self
40        }
41    }
42}
43
44pub(super) async fn receive_batch_multipart(
45    body: impl AsyncRead + Send,
46    boundary: impl Into<String>,
47    opts: MultipartOptions,
48) -> Result<BatchRequest, ParseRequestError> {
49    let mut multipart = Multipart::with_constraints(
50        ReaderStream::new(body),
51        boundary,
52        Constraints::new().size_limit({
53            let mut limit = SizeLimit::new();
54            if let (Some(max_file_size), Some(max_num_files)) =
55                (opts.max_file_size, opts.max_num_files)
56            {
57                limit = limit.whole_stream((max_file_size * max_num_files) as u64);
58            }
59            if let Some(max_file_size) = opts.max_file_size {
60                limit = limit.per_field(max_file_size as u64);
61            }
62            limit
63        }),
64    );
65
66    let mut request = None;
67    let mut map = None;
68    let mut files = Vec::new();
69
70    while let Some(field) = multipart.next_field().await? {
71        // in multipart, each field / file can actually have a own Content-Type.
72        // We use this to determine the encoding of the graphql query
73        let content_type = field
74            .content_type()
75            // default to json
76            .unwrap_or(&mime::APPLICATION_JSON)
77            .clone();
78        match field.name() {
79            Some("operations") => {
80                let body = field.bytes().await?;
81                request = Some(
82                    super::receive_batch_body_no_multipart(&content_type, body.as_ref()).await?,
83                )
84            }
85            Some("map") => {
86                let map_bytes = field.bytes().await?;
87
88                match (content_type.type_(), content_type.subtype()) {
89                    // cbor is in application/octet-stream.
90                    // TODO: wait for mime to add application/cbor and match against that too
91                    // Note: we actually differ here from the inoffical spec for this:
92                    // (https://github.com/jaydenseric/graphql-multipart-request-spec#multipart-form-field-structure)
93                    // It says: "map: A JSON encoded map of where files occurred in the operations.
94                    // For each file, the key is the file multipart form field name and the value is
95                    // an array of operations paths." However, I think, that
96                    // since we accept CBOR as operation, which is valid, we should also accept it
97                    // as the mapping for the files.
98                    #[cfg(feature = "cbor")]
99                    (mime::OCTET_STREAM, _) | (mime::APPLICATION, mime::OCTET_STREAM) => {
100                        map = Some(
101                            serde_cbor::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
102                                .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
103                        );
104                    }
105                    // default to json
106                    _ => {
107                        map = Some(
108                            serde_json::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
109                                .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
110                        );
111                    }
112                }
113            }
114            _ => {
115                if let Some(name) = field.name().map(ToString::to_string) {
116                    if let Some(filename) = field.file_name().map(ToString::to_string) {
117                        let content_type = field.content_type().map(ToString::to_string);
118
119                        #[cfg(feature = "tempfile")]
120                        let content = {
121                            let mut field = field;
122
123                            #[cfg(feature = "unblock")]
124                            {
125                                use std::io::SeekFrom;
126
127                                use blocking::Unblock;
128                                use futures_util::{AsyncSeekExt, AsyncWriteExt};
129
130                                let mut file = Unblock::new(
131                                    tempfile::tempfile().map_err(ParseRequestError::Io)?,
132                                );
133                                while let Some(chunk) = field.chunk().await? {
134                                    file.write_all(&chunk)
135                                        .await
136                                        .map_err(ParseRequestError::Io)?;
137                                }
138                                file.seek(SeekFrom::Start(0))
139                                    .await
140                                    .map_err(ParseRequestError::Io)?;
141                                file.into_inner().await
142                            }
143
144                            #[cfg(not(feature = "unblock"))]
145                            {
146                                use std::io::{Seek, Write};
147
148                                let mut file =
149                                    tempfile::tempfile().map_err(ParseRequestError::Io)?;
150                                while let Some(chunk) = field.chunk().await? {
151                                    file.write_all(&chunk).map_err(ParseRequestError::Io)?;
152                                }
153                                file.rewind()?;
154                                file
155                            }
156                        };
157
158                        #[cfg(not(feature = "tempfile"))]
159                        let content = field.bytes().await?;
160
161                        files.push((name, filename, content_type, content));
162                    }
163                }
164            }
165        }
166    }
167
168    let mut request: BatchRequest = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
169    let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
170
171    for (name, filename, content_type, file) in files {
172        if let Some(var_paths) = map.remove(&name) {
173            let upload = UploadValue {
174                filename,
175                content_type,
176                content: file,
177            };
178
179            for var_path in var_paths {
180                match &mut request {
181                    BatchRequest::Single(request) => {
182                        request.set_upload(&var_path, upload.try_clone()?);
183                    }
184                    BatchRequest::Batch(requests) => {
185                        let mut s = var_path.splitn(2, '.');
186                        let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
187                        let path = s.next();
188
189                        if let (Some(idx), Some(path)) = (idx, path) {
190                            if let Some(request) = requests.get_mut(idx) {
191                                request.set_upload(path, upload.try_clone()?);
192                            }
193                        }
194                    }
195                }
196            }
197        }
198    }
199
200    if !map.is_empty() {
201        return Err(ParseRequestError::MissingFiles);
202    }
203
204    Ok(request)
205}
206
207pin_project! {
208    pub(crate) struct ReaderStream<T> {
209        buf: [u8; 2048],
210        #[pin]
211        reader: T,
212    }
213}
214
215impl<T> ReaderStream<T> {
216    pub(crate) fn new(reader: T) -> Self {
217        Self {
218            buf: [0; 2048],
219            reader,
220        }
221    }
222}
223
224impl<T: AsyncRead> Stream for ReaderStream<T> {
225    type Item = io::Result<Vec<u8>>;
226
227    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
228        let this = self.project();
229
230        Poll::Ready(
231            match futures_util::ready!(this.reader.poll_read(cx, this.buf)?) {
232                0 => None,
233                size => Some(Ok(this.buf[..size].to_vec())),
234            },
235        )
236    }
237}