taskcluster_upload/
lib.rs

1/*! Support for uploading data to the Taskcluster object server.
2
3This crate provides a set of functions to perform an object-service upload.
4These functions negotiate an upload method with the object service, and then perform the upload, following all of the Taskcluster recommended practices.
5
6Each function takes the necessary metadata for the upload, a handle to the data to be uploaded, and a [taskcluster::Object] client.
7The data to be uploaded can come in a variety of forms, described below.
8The client must be configured with the necessary credentials to access the object service.
9
10## Convenience Functions
11
12Most uses of this crate can utilize [upload_from_buf] or [upload_from_file], providing the data in the form of a buffer and a [tokio::fs::File], respectively.
13
14## Factories
15
16An upload may be retried, in which case the upload function must have access to the object data from the beginning.
17This is accomplished with the [`AsyncReaderFactory`](crate::AsyncReaderFactory) trait, which defines a `get_reader` method to generate a fresh [tokio::io::AsyncRead] for each attempt.
18Users for whom the supplied convenience functions are inadequate can add their own implementation of this trait.
19
20 */
21use anyhow::{bail, Context as ErrorContext, Result};
22use base64::Engine;
23use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_LENGTH};
24use reqwest::Body;
25use serde::Deserialize;
26use serde_json::{json, Value};
27use std::collections::HashMap;
28use taskcluster::chrono::{DateTime, Utc};
29use taskcluster::retry::{Backoff, Retry};
30use taskcluster::Object;
31use tokio::fs::File;
32use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, SeekFrom};
33use tokio_util::codec::{BytesCodec, FramedRead};
34
35mod factory;
36mod hashing;
37mod service;
38
39pub use factory::{AsyncReaderFactory, CursorReaderFactory, FileReaderFactory};
40use service::ObjectService;
41
42const DATA_INLINE_MAX_SIZE: u64 = 8192;
43
44/// Upload an object from an in-memory buffer.
45pub async fn upload_from_buf(
46    project_id: &str,
47    name: &str,
48    content_type: &str,
49    expires: &DateTime<Utc>,
50    data: &[u8],
51    retry: &Retry,
52    object_service: &Object,
53    upload_id: &str,
54) -> Result<()> {
55    upload_with_factory(
56        project_id,
57        name,
58        content_type,
59        data.len() as u64,
60        expires,
61        CursorReaderFactory::new(data),
62        retry,
63        object_service,
64        upload_id,
65    )
66    .await
67}
68
69/// Upload an object from a File.  The file must be open in read mode and must be clone-able (that
70/// is, [File::try_clone()] must succeed) in order to support retried uploads.
71pub async fn upload_from_file(
72    project_id: &str,
73    name: &str,
74    content_type: &str,
75    expires: &DateTime<Utc>,
76    mut file: File,
77    retry: &Retry,
78    object_service: &Object,
79    upload_id: &str,
80) -> Result<()> {
81    let content_length = file.seek(SeekFrom::End(0)).await?;
82    upload_with_factory(
83        project_id,
84        name,
85        content_type,
86        content_length,
87        expires,
88        FileReaderFactory::new(file),
89        retry,
90        object_service,
91        upload_id,
92    )
93    .await
94}
95
96/// Upload an object using an AsyncReaderFactory.  This is useful for advanced cases where one of
97/// the convenience functions is not adequate.
98pub async fn upload_with_factory<ARF: AsyncReaderFactory>(
99    project_id: &str,
100    name: &str,
101    content_type: &str,
102    content_length: u64,
103    expires: &DateTime<Utc>,
104    reader_factory: ARF,
105    retry: &Retry,
106    object_service: &Object,
107    upload_id: &str,
108) -> Result<()> {
109    upload_impl(
110        project_id,
111        name,
112        content_type,
113        content_length,
114        expires,
115        reader_factory,
116        retry,
117        object_service,
118        &upload_id,
119    )
120    .await
121}
122
123/// Internal implementation of downloads, using the ObjectService trait to allow
124/// injecting a fake dependency
125async fn upload_impl<O: ObjectService, ARF: AsyncReaderFactory>(
126    project_id: &str,
127    name: &str,
128    content_type: &str,
129    content_length: u64,
130    expires: &DateTime<Utc>,
131    reader_factory: ARF,
132    retry: &Retry,
133    object_service: &O,
134    upload_id: &str,
135) -> Result<()> {
136    let mut reader_factory = hashing::HasherAsyncReaderFactory::new(reader_factory);
137    let mut proposed_upload_methods = json!({});
138
139    // if the data is short enough, try a data-inline upload
140    if content_length < DATA_INLINE_MAX_SIZE {
141        let mut buf = vec![];
142        let mut reader = reader_factory.get_reader().await?;
143        reader.read_to_end(&mut buf).await?;
144        let data_b64 = base64::engine::general_purpose::STANDARD.encode(buf);
145        proposed_upload_methods["dataInline"] = json!({
146            "contentType": content_type,
147            "objectData": data_b64,
148        });
149    }
150
151    // in any case, try a put-url upload
152    proposed_upload_methods["putUrl"] = json!({
153        "contentType": content_type,
154        "contentLength": content_length,
155    });
156
157    // send the request to the object service
158    let create_upload_res = object_service
159        .createUpload(
160            name,
161            &json!({
162                "expires": expires,
163                "projectId": project_id,
164                "uploadId": upload_id,
165                "proposedUploadMethods": proposed_upload_methods,
166            }),
167        )
168        .await?;
169
170    let mut backoff = Backoff::new(retry);
171    let mut attempts = 0u32;
172    loop {
173        // actually upload the data
174        let res: Result<()> = if create_upload_res
175            .pointer("/uploadMethod/dataInline")
176            .is_some()
177        {
178            Ok(()) // nothing to do - data is already in place
179        } else if let Some(method) = create_upload_res.pointer("/uploadMethod/putUrl") {
180            let reader = reader_factory.get_reader().await?;
181            simple_upload(reader, content_length, method.clone()).await
182        } else {
183            bail!("Could not negotiate an upload method") // not retriable
184        };
185
186        attempts += 1;
187        match &res {
188            Ok(_) => break,
189            Err(err) => {
190                if let Some(reqerr) = err.downcast_ref::<reqwest::Error>() {
191                    if reqerr
192                        .status()
193                        .map(|s| s.is_client_error())
194                        .unwrap_or(false)
195                    {
196                        return res;
197                    }
198                }
199            }
200        }
201
202        match backoff.next_backoff() {
203            Some(duration) => tokio::time::sleep(duration).await,
204            None => return res.context(format!("Download failed after {} attempts", attempts)),
205        }
206    }
207
208    let hashes = reader_factory.hashes(content_length);
209
210    // finish the upload
211    object_service
212        .finishUpload(
213            name,
214            &json!({
215                "projectId": project_id,
216                "uploadId": upload_id,
217                "hashes": hashes,
218            }),
219        )
220        .await?;
221
222    Ok(())
223}
224
225/// Perform a simple upload, given the `method` property of the response from createUpload.
226async fn simple_upload(
227    reader: Box<dyn AsyncRead + Sync + Send + Unpin + 'static>,
228    content_length: u64,
229    upload_method: Value,
230) -> Result<()> {
231    #[derive(Deserialize)]
232    struct Method {
233        url: String,
234        headers: HashMap<String, String>,
235    }
236
237    let upload_method: Method = serde_json::from_value(upload_method.clone())?;
238    let client = reqwest::Client::new();
239
240    let mut req = client.put(&upload_method.url);
241
242    let mut req_headers = HeaderMap::new();
243    for (k, v) in upload_method.headers.iter() {
244        req_headers.insert(
245            HeaderName::from_bytes(k.as_bytes())?,
246            HeaderValue::from_str(v)?,
247        );
248    }
249
250    if !req_headers.contains_key(CONTENT_LENGTH) {
251        req_headers.insert(CONTENT_LENGTH, content_length.into());
252    }
253
254    req = req.headers(req_headers);
255
256    let stream = FramedRead::new(reader, BytesCodec::new());
257    req = req.body(Body::wrap_stream(stream));
258
259    req.send().await?.error_for_status()?;
260
261    Ok(())
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267    use anyhow::Error;
268    use async_trait::async_trait;
269    use httptest::{
270        matchers::{all_of, contains, request, ExecutionContext, Matcher},
271        responders::status_code,
272        Expectation,
273    };
274    use ring::rand::{SecureRandom, SystemRandom};
275    use serde_json::json;
276    use std::fmt;
277    use std::sync::Mutex;
278    use taskcluster::chrono::Duration;
279
280    /// Event logger, used to log events in the fake ObjectService implementations
281    #[derive(Default)]
282    pub(crate) struct Logger {
283        logged: Mutex<Vec<String>>,
284    }
285
286    impl Logger {
287        pub(crate) fn log<S: Into<String>>(&self, message: S) {
288            self.logged.lock().unwrap().push(message.into())
289        }
290
291        pub(crate) fn assert(&self, expected: Vec<String>) {
292            assert_eq!(*self.logged.lock().unwrap(), expected);
293        }
294    }
295
296    /// Log the matched value with `dbg!()` and always match.
297    pub(crate) struct Dbg;
298    impl<IN> Matcher<IN> for Dbg
299    where
300        IN: fmt::Debug + ?Sized,
301    {
302        fn matches(&mut self, input: &IN, _ctx: &mut ExecutionContext) -> bool {
303            dbg!(input);
304            true
305        }
306
307        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308            write!(f, "Dbg()")
309        }
310    }
311
312    /// Fake implementation of the Object service, that only supports DataInline
313    #[derive(Default)]
314    struct DataInlineOnly {
315        logger: Logger,
316    }
317
318    #[async_trait]
319    impl ObjectService for DataInlineOnly {
320        async fn createUpload(
321            &self,
322            name: &str,
323            payload: &Value,
324        ) -> std::result::Result<Value, Error> {
325            let expires: DateTime<Utc> =
326                serde_json::from_value(payload["expires"].clone()).unwrap();
327            self.logger.log(format!(
328                "create {} {} {} {}",
329                name,
330                expires,
331                payload["projectId"].as_str().unwrap(),
332                payload["uploadId"].as_str().unwrap()
333            ));
334            if let Some(di) = payload.pointer("/proposedUploadMethods/dataInline") {
335                self.logger.log(format!(
336                    "dataInline {} {}",
337                    di["contentType"].as_str().unwrap(),
338                    di["objectData"].as_str().unwrap()
339                ));
340                Ok(json!({
341                    "expires": payload["expires"],
342                    "projectId": payload["projectId"],
343                    "uploadId": payload["uploadId"],
344                    "uploadMethod": {
345                        "dataInline": true,
346                    },
347                }))
348            } else {
349                Ok(json!({
350                    "expires": payload["expires"],
351                    "projectId": payload["projectId"],
352                    "uploadId": payload["uploadId"],
353                    "uploadMethod": {},
354                }))
355            }
356        }
357
358        async fn finishUpload(
359            &self,
360            name: &str,
361            payload: &Value,
362        ) -> std::result::Result<(), Error> {
363            assert_eq!(name, "some/object");
364            self.logger.log(format!(
365                "finish {} {} {}",
366                name,
367                payload["projectId"].as_str().unwrap(),
368                payload["uploadId"].as_str().unwrap(),
369            ));
370            Ok(())
371        }
372    }
373
374    /// Fake implementation of the Object service, that only supports PutUrl
375    struct PutUrlOnly {
376        logger: Logger,
377        server: httptest::Server,
378    }
379
380    impl PutUrlOnly {
381        fn new(server: httptest::Server) -> Self {
382            Self {
383                logger: Logger::default(),
384                server,
385            }
386        }
387    }
388
389    #[async_trait]
390    impl ObjectService for PutUrlOnly {
391        async fn createUpload(
392            &self,
393            name: &str,
394            payload: &Value,
395        ) -> std::result::Result<Value, Error> {
396            let expires: DateTime<Utc> =
397                serde_json::from_value(payload["expires"].clone()).unwrap();
398            self.logger.log(format!(
399                "create {} {} {} {}",
400                name,
401                expires,
402                payload["projectId"].as_str().unwrap(),
403                payload["uploadId"].as_str().unwrap()
404            ));
405            if let Some(pu) = payload.pointer("/proposedUploadMethods/putUrl") {
406                self.logger.log(format!(
407                    "putUrl {} {}",
408                    pu["contentType"].as_str().unwrap(),
409                    pu["contentLength"]
410                ));
411                Ok(json!({
412                    "expires": payload["expires"],
413                    "projectId": payload["projectId"],
414                    "uploadId": payload["uploadId"],
415                    "uploadMethod": {
416                        "putUrl": {
417                            "expires": payload["expires"],
418                            "url": self.server.url_str("/data"),
419                            "headers": {
420                                "Content-Type": pu["contentType"],
421                                "Content-Length": pu["contentLength"].to_string(),
422                                "X-Test-Header": "good",
423                            },
424                        },
425                    },
426                }))
427            } else {
428                Ok(json!({
429                    "expires": payload["expires"],
430                    "projectId": payload["projectId"],
431                    "uploadId": payload["uploadId"],
432                    "uploadMethod": {},
433                }))
434            }
435        }
436
437        async fn finishUpload(
438            &self,
439            name: &str,
440            payload: &Value,
441        ) -> std::result::Result<(), Error> {
442            assert_eq!(name, "some/object");
443            self.logger.log(format!(
444                "finish {} {} {}",
445                name,
446                payload["projectId"].as_str().unwrap(),
447                payload["uploadId"].as_str().unwrap(),
448            ));
449            Ok(())
450        }
451    }
452
453    async fn upload<O: ObjectService>(
454        object_service: &O,
455        upload_id: String,
456        expires: &DateTime<Utc>,
457        data: &[u8],
458    ) -> Result<()> {
459        upload_impl(
460            "proj",
461            "some/object",
462            "application/binary",
463            data.len() as u64,
464            expires,
465            CursorReaderFactory::new(data),
466            &Retry::default(),
467            object_service,
468            &upload_id,
469        )
470        .await
471    }
472
473    #[tokio::test]
474    async fn small_data_inline_upload() -> Result<()> {
475        let upload_id = slugid::v4();
476        let expires = Utc::now() + Duration::hours(1);
477
478        let object_service = DataInlineOnly {
479            ..Default::default()
480        };
481
482        upload(&object_service, upload_id.clone(), &expires, b"hello world").await?;
483
484        object_service.logger.assert(vec![
485            format!("create some/object {} proj {}", expires, upload_id),
486            format!(
487                "dataInline application/binary {}",
488                base64::engine::general_purpose::STANDARD.encode(b"hello world")
489            ),
490            format!("finish some/object proj {}", upload_id),
491        ]);
492
493        Ok(())
494    }
495
496    #[tokio::test]
497    async fn large_data_inline_upload() -> Result<()> {
498        let upload_id = slugid::v4();
499        let expires = Utc::now() + Duration::hours(1);
500
501        let object_service = DataInlineOnly {
502            ..Default::default()
503        };
504
505        let mut data = vec![0u8; 10000];
506        SystemRandom::new().fill(&mut data).unwrap();
507        let res = upload(&object_service, upload_id.clone(), &expires, &data).await;
508
509        // negotiation fails..
510        assert!(res.is_err());
511
512        Ok(())
513    }
514
515    #[tokio::test]
516    async fn put_url() -> Result<()> {
517        let upload_id = slugid::v4();
518        let expires = Utc::now() + Duration::hours(1);
519
520        let server = httptest::Server::run();
521        server.expect(
522            Expectation::matching(all_of![
523                Dbg,
524                request::method_path("PUT", "/data"),
525                request::body("hello, world"),
526                request::headers(all_of![
527                    // reqwest normalizes header names to lower-case
528                    contains(("content-type", "application/binary")),
529                    contains(("content-length", "12")),
530                    contains(("x-test-header", "good")),
531                ]),
532            ])
533            .times(1)
534            .respond_with(status_code(200)),
535        );
536
537        let object_service = PutUrlOnly::new(server);
538
539        upload(
540            &object_service,
541            upload_id.clone(),
542            &expires,
543            b"hello, world",
544        )
545        .await?;
546
547        object_service.logger.assert(vec![
548            format!("create some/object {} proj {}", expires, upload_id),
549            format!("putUrl application/binary {}", 12),
550            format!("finish some/object proj {}", upload_id),
551        ]);
552
553        drop(object_service); // ..and with it, server, which refs data
554
555        Ok(())
556    }
557}