async_graphql/http/
multipart.rs1use std::{
2 collections::HashMap,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures_util::{io::AsyncRead, stream::Stream};
9use multer::{Constraints, Multipart, SizeLimit};
10use pin_project_lite::pin_project;
11
12use crate::{BatchRequest, ParseRequestError, UploadValue};
13
14#[derive(Default, Clone, Copy)]
16#[non_exhaustive]
17pub struct MultipartOptions {
18 pub max_file_size: Option<usize>,
20 pub max_num_files: Option<usize>,
22}
23
24impl MultipartOptions {
25 #[must_use]
27 pub fn max_file_size(self, size: usize) -> Self {
28 MultipartOptions {
29 max_file_size: Some(size),
30 ..self
31 }
32 }
33
34 #[must_use]
36 pub fn max_num_files(self, n: usize) -> Self {
37 MultipartOptions {
38 max_num_files: Some(n),
39 ..self
40 }
41 }
42}
43
44pub(super) async fn receive_batch_multipart(
45 body: impl AsyncRead + Send,
46 boundary: impl Into<String>,
47 opts: MultipartOptions,
48) -> Result<BatchRequest, ParseRequestError> {
49 let mut multipart = Multipart::with_constraints(
50 ReaderStream::new(body),
51 boundary,
52 Constraints::new().size_limit({
53 let mut limit = SizeLimit::new();
54 if let (Some(max_file_size), Some(max_num_files)) =
55 (opts.max_file_size, opts.max_num_files)
56 {
57 limit = limit.whole_stream((max_file_size * max_num_files) as u64);
58 }
59 if let Some(max_file_size) = opts.max_file_size {
60 limit = limit.per_field(max_file_size as u64);
61 }
62 limit
63 }),
64 );
65
66 let mut request = None;
67 let mut map = None;
68 let mut files = Vec::new();
69
70 while let Some(field) = multipart.next_field().await? {
71 let content_type = field
74 .content_type()
75 .unwrap_or(&mime::APPLICATION_JSON)
77 .clone();
78 match field.name() {
79 Some("operations") => {
80 let body = field.bytes().await?;
81 request = Some(
82 super::receive_batch_body_no_multipart(&content_type, body.as_ref()).await?,
83 )
84 }
85 Some("map") => {
86 let map_bytes = field.bytes().await?;
87
88 match (content_type.type_(), content_type.subtype()) {
89 #[cfg(feature = "cbor")]
99 (mime::OCTET_STREAM, _) | (mime::APPLICATION, mime::OCTET_STREAM) => {
100 map = Some(
101 serde_cbor::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
102 .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
103 );
104 }
105 _ => {
107 map = Some(
108 serde_json::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
109 .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
110 );
111 }
112 }
113 }
114 _ => {
115 if let Some(name) = field.name().map(ToString::to_string) {
116 if let Some(filename) = field.file_name().map(ToString::to_string) {
117 let content_type = field.content_type().map(ToString::to_string);
118
119 #[cfg(feature = "tempfile")]
120 let content = {
121 let mut field = field;
122
123 #[cfg(feature = "unblock")]
124 {
125 use std::io::SeekFrom;
126
127 use blocking::Unblock;
128 use futures_util::{AsyncSeekExt, AsyncWriteExt};
129
130 let mut file = Unblock::new(
131 tempfile::tempfile().map_err(ParseRequestError::Io)?,
132 );
133 while let Some(chunk) = field.chunk().await? {
134 file.write_all(&chunk)
135 .await
136 .map_err(ParseRequestError::Io)?;
137 }
138 file.seek(SeekFrom::Start(0))
139 .await
140 .map_err(ParseRequestError::Io)?;
141 file.into_inner().await
142 }
143
144 #[cfg(not(feature = "unblock"))]
145 {
146 use std::io::{Seek, Write};
147
148 let mut file =
149 tempfile::tempfile().map_err(ParseRequestError::Io)?;
150 while let Some(chunk) = field.chunk().await? {
151 file.write_all(&chunk).map_err(ParseRequestError::Io)?;
152 }
153 file.rewind()?;
154 file
155 }
156 };
157
158 #[cfg(not(feature = "tempfile"))]
159 let content = field.bytes().await?;
160
161 files.push((name, filename, content_type, content));
162 }
163 }
164 }
165 }
166 }
167
168 let mut request: BatchRequest = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
169 let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
170
171 for (name, filename, content_type, file) in files {
172 if let Some(var_paths) = map.remove(&name) {
173 let upload = UploadValue {
174 filename,
175 content_type,
176 content: file,
177 };
178
179 for var_path in var_paths {
180 match &mut request {
181 BatchRequest::Single(request) => {
182 request.set_upload(&var_path, upload.try_clone()?);
183 }
184 BatchRequest::Batch(requests) => {
185 let mut s = var_path.splitn(2, '.');
186 let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
187 let path = s.next();
188
189 if let (Some(idx), Some(path)) = (idx, path) {
190 if let Some(request) = requests.get_mut(idx) {
191 request.set_upload(path, upload.try_clone()?);
192 }
193 }
194 }
195 }
196 }
197 }
198 }
199
200 if !map.is_empty() {
201 return Err(ParseRequestError::MissingFiles);
202 }
203
204 Ok(request)
205}
206
207pin_project! {
208 pub(crate) struct ReaderStream<T> {
209 buf: [u8; 2048],
210 #[pin]
211 reader: T,
212 }
213}
214
215impl<T> ReaderStream<T> {
216 pub(crate) fn new(reader: T) -> Self {
217 Self {
218 buf: [0; 2048],
219 reader,
220 }
221 }
222}
223
224impl<T: AsyncRead> Stream for ReaderStream<T> {
225 type Item = io::Result<Vec<u8>>;
226
227 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
228 let this = self.project();
229
230 Poll::Ready(
231 match futures_util::ready!(this.reader.poll_read(cx, this.buf)?) {
232 0 => None,
233 size => Some(Ok(this.buf[..size].to_vec())),
234 },
235 )
236 }
237}