1use anyhow::{bail, Context as ErrorContext, Result};
22use base64::Engine;
23use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_LENGTH};
24use reqwest::Body;
25use serde::Deserialize;
26use serde_json::{json, Value};
27use std::collections::HashMap;
28use taskcluster::chrono::{DateTime, Utc};
29use taskcluster::retry::{Backoff, Retry};
30use taskcluster::Object;
31use tokio::fs::File;
32use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, SeekFrom};
33use tokio_util::codec::{BytesCodec, FramedRead};
34
35mod factory;
36mod hashing;
37mod service;
38
39pub use factory::{AsyncReaderFactory, CursorReaderFactory, FileReaderFactory};
40use service::ObjectService;
41
42const DATA_INLINE_MAX_SIZE: u64 = 8192;
43
44pub async fn upload_from_buf(
46 project_id: &str,
47 name: &str,
48 content_type: &str,
49 expires: &DateTime<Utc>,
50 data: &[u8],
51 retry: &Retry,
52 object_service: &Object,
53 upload_id: &str,
54) -> Result<()> {
55 upload_with_factory(
56 project_id,
57 name,
58 content_type,
59 data.len() as u64,
60 expires,
61 CursorReaderFactory::new(data),
62 retry,
63 object_service,
64 upload_id,
65 )
66 .await
67}
68
69pub async fn upload_from_file(
72 project_id: &str,
73 name: &str,
74 content_type: &str,
75 expires: &DateTime<Utc>,
76 mut file: File,
77 retry: &Retry,
78 object_service: &Object,
79 upload_id: &str,
80) -> Result<()> {
81 let content_length = file.seek(SeekFrom::End(0)).await?;
82 upload_with_factory(
83 project_id,
84 name,
85 content_type,
86 content_length,
87 expires,
88 FileReaderFactory::new(file),
89 retry,
90 object_service,
91 upload_id,
92 )
93 .await
94}
95
96pub async fn upload_with_factory<ARF: AsyncReaderFactory>(
99 project_id: &str,
100 name: &str,
101 content_type: &str,
102 content_length: u64,
103 expires: &DateTime<Utc>,
104 reader_factory: ARF,
105 retry: &Retry,
106 object_service: &Object,
107 upload_id: &str,
108) -> Result<()> {
109 upload_impl(
110 project_id,
111 name,
112 content_type,
113 content_length,
114 expires,
115 reader_factory,
116 retry,
117 object_service,
118 &upload_id,
119 )
120 .await
121}
122
123async fn upload_impl<O: ObjectService, ARF: AsyncReaderFactory>(
126 project_id: &str,
127 name: &str,
128 content_type: &str,
129 content_length: u64,
130 expires: &DateTime<Utc>,
131 reader_factory: ARF,
132 retry: &Retry,
133 object_service: &O,
134 upload_id: &str,
135) -> Result<()> {
136 let mut reader_factory = hashing::HasherAsyncReaderFactory::new(reader_factory);
137 let mut proposed_upload_methods = json!({});
138
139 if content_length < DATA_INLINE_MAX_SIZE {
141 let mut buf = vec![];
142 let mut reader = reader_factory.get_reader().await?;
143 reader.read_to_end(&mut buf).await?;
144 let data_b64 = base64::engine::general_purpose::STANDARD.encode(buf);
145 proposed_upload_methods["dataInline"] = json!({
146 "contentType": content_type,
147 "objectData": data_b64,
148 });
149 }
150
151 proposed_upload_methods["putUrl"] = json!({
153 "contentType": content_type,
154 "contentLength": content_length,
155 });
156
157 let create_upload_res = object_service
159 .createUpload(
160 name,
161 &json!({
162 "expires": expires,
163 "projectId": project_id,
164 "uploadId": upload_id,
165 "proposedUploadMethods": proposed_upload_methods,
166 }),
167 )
168 .await?;
169
170 let mut backoff = Backoff::new(retry);
171 let mut attempts = 0u32;
172 loop {
173 let res: Result<()> = if create_upload_res
175 .pointer("/uploadMethod/dataInline")
176 .is_some()
177 {
178 Ok(()) } else if let Some(method) = create_upload_res.pointer("/uploadMethod/putUrl") {
180 let reader = reader_factory.get_reader().await?;
181 simple_upload(reader, content_length, method.clone()).await
182 } else {
183 bail!("Could not negotiate an upload method") };
185
186 attempts += 1;
187 match &res {
188 Ok(_) => break,
189 Err(err) => {
190 if let Some(reqerr) = err.downcast_ref::<reqwest::Error>() {
191 if reqerr
192 .status()
193 .map(|s| s.is_client_error())
194 .unwrap_or(false)
195 {
196 return res;
197 }
198 }
199 }
200 }
201
202 match backoff.next_backoff() {
203 Some(duration) => tokio::time::sleep(duration).await,
204 None => return res.context(format!("Download failed after {} attempts", attempts)),
205 }
206 }
207
208 let hashes = reader_factory.hashes(content_length);
209
210 object_service
212 .finishUpload(
213 name,
214 &json!({
215 "projectId": project_id,
216 "uploadId": upload_id,
217 "hashes": hashes,
218 }),
219 )
220 .await?;
221
222 Ok(())
223}
224
225async fn simple_upload(
227 reader: Box<dyn AsyncRead + Sync + Send + Unpin + 'static>,
228 content_length: u64,
229 upload_method: Value,
230) -> Result<()> {
231 #[derive(Deserialize)]
232 struct Method {
233 url: String,
234 headers: HashMap<String, String>,
235 }
236
237 let upload_method: Method = serde_json::from_value(upload_method.clone())?;
238 let client = reqwest::Client::new();
239
240 let mut req = client.put(&upload_method.url);
241
242 let mut req_headers = HeaderMap::new();
243 for (k, v) in upload_method.headers.iter() {
244 req_headers.insert(
245 HeaderName::from_bytes(k.as_bytes())?,
246 HeaderValue::from_str(v)?,
247 );
248 }
249
250 if !req_headers.contains_key(CONTENT_LENGTH) {
251 req_headers.insert(CONTENT_LENGTH, content_length.into());
252 }
253
254 req = req.headers(req_headers);
255
256 let stream = FramedRead::new(reader, BytesCodec::new());
257 req = req.body(Body::wrap_stream(stream));
258
259 req.send().await?.error_for_status()?;
260
261 Ok(())
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267 use anyhow::Error;
268 use async_trait::async_trait;
269 use httptest::{
270 matchers::{all_of, contains, request, ExecutionContext, Matcher},
271 responders::status_code,
272 Expectation,
273 };
274 use ring::rand::{SecureRandom, SystemRandom};
275 use serde_json::json;
276 use std::fmt;
277 use std::sync::Mutex;
278 use taskcluster::chrono::Duration;
279
280 #[derive(Default)]
282 pub(crate) struct Logger {
283 logged: Mutex<Vec<String>>,
284 }
285
286 impl Logger {
287 pub(crate) fn log<S: Into<String>>(&self, message: S) {
288 self.logged.lock().unwrap().push(message.into())
289 }
290
291 pub(crate) fn assert(&self, expected: Vec<String>) {
292 assert_eq!(*self.logged.lock().unwrap(), expected);
293 }
294 }
295
296 pub(crate) struct Dbg;
298 impl<IN> Matcher<IN> for Dbg
299 where
300 IN: fmt::Debug + ?Sized,
301 {
302 fn matches(&mut self, input: &IN, _ctx: &mut ExecutionContext) -> bool {
303 dbg!(input);
304 true
305 }
306
307 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308 write!(f, "Dbg()")
309 }
310 }
311
312 #[derive(Default)]
314 struct DataInlineOnly {
315 logger: Logger,
316 }
317
318 #[async_trait]
319 impl ObjectService for DataInlineOnly {
320 async fn createUpload(
321 &self,
322 name: &str,
323 payload: &Value,
324 ) -> std::result::Result<Value, Error> {
325 let expires: DateTime<Utc> =
326 serde_json::from_value(payload["expires"].clone()).unwrap();
327 self.logger.log(format!(
328 "create {} {} {} {}",
329 name,
330 expires,
331 payload["projectId"].as_str().unwrap(),
332 payload["uploadId"].as_str().unwrap()
333 ));
334 if let Some(di) = payload.pointer("/proposedUploadMethods/dataInline") {
335 self.logger.log(format!(
336 "dataInline {} {}",
337 di["contentType"].as_str().unwrap(),
338 di["objectData"].as_str().unwrap()
339 ));
340 Ok(json!({
341 "expires": payload["expires"],
342 "projectId": payload["projectId"],
343 "uploadId": payload["uploadId"],
344 "uploadMethod": {
345 "dataInline": true,
346 },
347 }))
348 } else {
349 Ok(json!({
350 "expires": payload["expires"],
351 "projectId": payload["projectId"],
352 "uploadId": payload["uploadId"],
353 "uploadMethod": {},
354 }))
355 }
356 }
357
358 async fn finishUpload(
359 &self,
360 name: &str,
361 payload: &Value,
362 ) -> std::result::Result<(), Error> {
363 assert_eq!(name, "some/object");
364 self.logger.log(format!(
365 "finish {} {} {}",
366 name,
367 payload["projectId"].as_str().unwrap(),
368 payload["uploadId"].as_str().unwrap(),
369 ));
370 Ok(())
371 }
372 }
373
374 struct PutUrlOnly {
376 logger: Logger,
377 server: httptest::Server,
378 }
379
380 impl PutUrlOnly {
381 fn new(server: httptest::Server) -> Self {
382 Self {
383 logger: Logger::default(),
384 server,
385 }
386 }
387 }
388
389 #[async_trait]
390 impl ObjectService for PutUrlOnly {
391 async fn createUpload(
392 &self,
393 name: &str,
394 payload: &Value,
395 ) -> std::result::Result<Value, Error> {
396 let expires: DateTime<Utc> =
397 serde_json::from_value(payload["expires"].clone()).unwrap();
398 self.logger.log(format!(
399 "create {} {} {} {}",
400 name,
401 expires,
402 payload["projectId"].as_str().unwrap(),
403 payload["uploadId"].as_str().unwrap()
404 ));
405 if let Some(pu) = payload.pointer("/proposedUploadMethods/putUrl") {
406 self.logger.log(format!(
407 "putUrl {} {}",
408 pu["contentType"].as_str().unwrap(),
409 pu["contentLength"]
410 ));
411 Ok(json!({
412 "expires": payload["expires"],
413 "projectId": payload["projectId"],
414 "uploadId": payload["uploadId"],
415 "uploadMethod": {
416 "putUrl": {
417 "expires": payload["expires"],
418 "url": self.server.url_str("/data"),
419 "headers": {
420 "Content-Type": pu["contentType"],
421 "Content-Length": pu["contentLength"].to_string(),
422 "X-Test-Header": "good",
423 },
424 },
425 },
426 }))
427 } else {
428 Ok(json!({
429 "expires": payload["expires"],
430 "projectId": payload["projectId"],
431 "uploadId": payload["uploadId"],
432 "uploadMethod": {},
433 }))
434 }
435 }
436
437 async fn finishUpload(
438 &self,
439 name: &str,
440 payload: &Value,
441 ) -> std::result::Result<(), Error> {
442 assert_eq!(name, "some/object");
443 self.logger.log(format!(
444 "finish {} {} {}",
445 name,
446 payload["projectId"].as_str().unwrap(),
447 payload["uploadId"].as_str().unwrap(),
448 ));
449 Ok(())
450 }
451 }
452
453 async fn upload<O: ObjectService>(
454 object_service: &O,
455 upload_id: String,
456 expires: &DateTime<Utc>,
457 data: &[u8],
458 ) -> Result<()> {
459 upload_impl(
460 "proj",
461 "some/object",
462 "application/binary",
463 data.len() as u64,
464 expires,
465 CursorReaderFactory::new(data),
466 &Retry::default(),
467 object_service,
468 &upload_id,
469 )
470 .await
471 }
472
473 #[tokio::test]
474 async fn small_data_inline_upload() -> Result<()> {
475 let upload_id = slugid::v4();
476 let expires = Utc::now() + Duration::hours(1);
477
478 let object_service = DataInlineOnly {
479 ..Default::default()
480 };
481
482 upload(&object_service, upload_id.clone(), &expires, b"hello world").await?;
483
484 object_service.logger.assert(vec![
485 format!("create some/object {} proj {}", expires, upload_id),
486 format!(
487 "dataInline application/binary {}",
488 base64::engine::general_purpose::STANDARD.encode(b"hello world")
489 ),
490 format!("finish some/object proj {}", upload_id),
491 ]);
492
493 Ok(())
494 }
495
496 #[tokio::test]
497 async fn large_data_inline_upload() -> Result<()> {
498 let upload_id = slugid::v4();
499 let expires = Utc::now() + Duration::hours(1);
500
501 let object_service = DataInlineOnly {
502 ..Default::default()
503 };
504
505 let mut data = vec![0u8; 10000];
506 SystemRandom::new().fill(&mut data).unwrap();
507 let res = upload(&object_service, upload_id.clone(), &expires, &data).await;
508
509 assert!(res.is_err());
511
512 Ok(())
513 }
514
515 #[tokio::test]
516 async fn put_url() -> Result<()> {
517 let upload_id = slugid::v4();
518 let expires = Utc::now() + Duration::hours(1);
519
520 let server = httptest::Server::run();
521 server.expect(
522 Expectation::matching(all_of![
523 Dbg,
524 request::method_path("PUT", "/data"),
525 request::body("hello, world"),
526 request::headers(all_of![
527 contains(("content-type", "application/binary")),
529 contains(("content-length", "12")),
530 contains(("x-test-header", "good")),
531 ]),
532 ])
533 .times(1)
534 .respond_with(status_code(200)),
535 );
536
537 let object_service = PutUrlOnly::new(server);
538
539 upload(
540 &object_service,
541 upload_id.clone(),
542 &expires,
543 b"hello, world",
544 )
545 .await?;
546
547 object_service.logger.assert(vec![
548 format!("create some/object {} proj {}", expires, upload_id),
549 format!("putUrl application/binary {}", 12),
550 format!("finish some/object proj {}", upload_id),
551 ]);
552
553 drop(object_service); Ok(())
556 }
557}