lance_io/
object_writer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21
22use crate::traits::Writer;
23use snafu::location;
24
25/// Start at 5MB.
26const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
27
28fn max_upload_parallelism() -> usize {
29    static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
30    *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
31        std::env::var("LANCE_UPLOAD_CONCURRENCY")
32            .ok()
33            .and_then(|s| s.parse::<usize>().ok())
34            .unwrap_or(10)
35    })
36}
37
38fn max_conn_reset_retries() -> u16 {
39    static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
40    *MAX_CONN_RESET_RETRIES.get_or_init(|| {
41        std::env::var("LANCE_CONN_RESET_RETRIES")
42            .ok()
43            .and_then(|s| s.parse::<u16>().ok())
44            .unwrap_or(20)
45    })
46}
47
48fn initial_upload_size() -> usize {
49    static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
50    *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
51        std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
52            .ok()
53            .and_then(|s| s.parse::<usize>().ok())
54            .inspect(|size| {
55                if *size < INITIAL_UPLOAD_STEP {
56                    // Minimum part size in GCS and S3
57                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
58                } else if *size > 1024 * 1024 * 1024 * 5 {
59                    // Maximum part size in GCS and S3
60                    panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
61                }
62            })
63            .unwrap_or(INITIAL_UPLOAD_STEP)
64    })
65}
66
67/// Writer to an object in an object store.
68///
69/// If the object is small enough, the writer will upload the object in a single
70/// PUT request. If the object is larger, the writer will create a multipart
71/// upload and upload parts in parallel.
72///
73/// This implements the `AsyncWrite` trait.
74pub struct ObjectWriter {
75    state: UploadState,
76    path: Arc<Path>,
77    cursor: usize,
78    connection_resets: u16,
79    buffer: Vec<u8>,
80    // TODO: use constant size to support R2
81    use_constant_size_upload_parts: bool,
82}
83
84enum UploadState {
85    /// The writer has been opened but no data has been written yet. Will be in
86    /// this state until the buffer is full or the writer is shut down.
87    Started(Arc<dyn ObjectStore>),
88    /// The writer is in the process of creating a multipart upload.
89    CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
90    /// The writer is in the process of uploading parts.
91    InProgress {
92        part_idx: u16,
93        upload: Box<dyn MultipartUpload>,
94        futures: JoinSet<std::result::Result<(), UploadPutError>>,
95    },
96    /// The writer is in the process of uploading data in a single PUT request.
97    /// This happens when shutdown is called before the buffer is full.
98    PuttingSingle(BoxFuture<'static, OSResult<()>>),
99    /// The writer is in the process of completing the multipart upload.
100    Completing(BoxFuture<'static, OSResult<()>>),
101    /// The writer has been shut down and all data has been written.
102    Done,
103}
104
105/// Methods for state transitions.
106impl UploadState {
107    fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
108        // To get owned self, we temporarily swap with Done.
109        let this = std::mem::replace(self, Self::Done);
110        *self = match this {
111            Self::Started(store) => {
112                let fut = async move {
113                    store.put(&path, buffer.into()).await?;
114                    Ok(())
115                };
116                Self::PuttingSingle(Box::pin(fut))
117            }
118            _ => unreachable!(),
119        }
120    }
121
122    fn in_progress_to_completing(&mut self) {
123        // To get owned self, we temporarily swap with Done.
124        let this = std::mem::replace(self, Self::Done);
125        *self = match this {
126            Self::InProgress {
127                mut upload,
128                futures,
129                ..
130            } => {
131                debug_assert!(futures.is_empty());
132                let fut = async move {
133                    upload.complete().await?;
134                    Ok(())
135                };
136                Self::Completing(Box::pin(fut))
137            }
138            _ => unreachable!(),
139        };
140    }
141}
142
143impl ObjectWriter {
144    pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
145        Ok(Self {
146            state: UploadState::Started(object_store.inner.clone()),
147            cursor: 0,
148            path: Arc::new(path.clone()),
149            connection_resets: 0,
150            buffer: Vec::with_capacity(initial_upload_size()),
151            use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
152        })
153    }
154
155    /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`.
156    /// The new capacity of `buffer` is determined by the current part index.
157    fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
158        let new_capacity = if constant_upload_size {
159            // The store does not support variable part sizes, so use the initial size.
160            initial_upload_size()
161        } else {
162            // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB.
163            initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
164        };
165        let new_buffer = Vec::with_capacity(new_capacity);
166        let part = std::mem::replace(buffer, new_buffer);
167        Bytes::from(part)
168    }
169
170    fn put_part(
171        upload: &mut dyn MultipartUpload,
172        buffer: Bytes,
173        part_idx: u16,
174        sleep: Option<std::time::Duration>,
175    ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
176        log::debug!(
177            "MultipartUpload submitting part with {} bytes",
178            buffer.len()
179        );
180        let fut = upload.put_part(buffer.clone().into());
181        Box::pin(async move {
182            if let Some(sleep) = sleep {
183                tokio::time::sleep(sleep).await;
184            }
185            fut.await.map_err(|source| UploadPutError {
186                part_idx,
187                buffer,
188                source,
189            })?;
190            Ok(())
191        })
192    }
193
194    fn poll_tasks(
195        mut self: Pin<&mut Self>,
196        cx: &mut std::task::Context<'_>,
197    ) -> std::result::Result<(), io::Error> {
198        let mut_self = &mut *self;
199        loop {
200            match &mut mut_self.state {
201                UploadState::Started(_) | UploadState::Done => break,
202                UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
203                    Poll::Ready(Ok(mut upload)) => {
204                        let mut futures = JoinSet::new();
205
206                        let data = Self::next_part_buffer(
207                            &mut mut_self.buffer,
208                            0,
209                            mut_self.use_constant_size_upload_parts,
210                        );
211                        futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
212
213                        mut_self.state = UploadState::InProgress {
214                            part_idx: 1, // We just used 0
215                            futures,
216                            upload,
217                        };
218                    }
219                    Poll::Ready(Err(e)) => {
220                        return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
221                    }
222                    Poll::Pending => break,
223                },
224                UploadState::InProgress {
225                    upload, futures, ..
226                } => {
227                    while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
228                        match res {
229                            Ok(Ok(())) => {}
230                            Err(err) => {
231                                return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
232                            }
233                            Ok(Err(UploadPutError {
234                                source: OSError::Generic { source, .. },
235                                part_idx,
236                                buffer,
237                            })) if source
238                                .to_string()
239                                .to_lowercase()
240                                .contains("connection reset by peer") =>
241                            {
242                                if mut_self.connection_resets < max_conn_reset_retries() {
243                                    // Retry, but only up to max_conn_reset_retries of them.
244                                    mut_self.connection_resets += 1;
245
246                                    // Resubmit with random jitter
247                                    let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
248                                    let sleep_time =
249                                        std::time::Duration::from_millis(sleep_time_ms);
250
251                                    futures.spawn(Self::put_part(
252                                        upload.as_mut(),
253                                        buffer,
254                                        part_idx,
255                                        Some(sleep_time),
256                                    ));
257                                } else {
258                                    return Err(io::Error::new(
259                                        io::ErrorKind::ConnectionReset,
260                                        Box::new(ConnectionResetError {
261                                            message: format!(
262                                                "Hit max retries ({}) for connection reset",
263                                                max_conn_reset_retries()
264                                            ),
265                                            source,
266                                        }),
267                                    ));
268                                }
269                            }
270                            Ok(Err(err)) => return Err(err.source.into()),
271                        }
272                    }
273                    break;
274                }
275                UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
276                    match fut.poll_unpin(cx) {
277                        Poll::Ready(Ok(())) => mut_self.state = UploadState::Done,
278                        Poll::Ready(Err(e)) => {
279                            return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
280                        }
281                        Poll::Pending => break,
282                    }
283                }
284            }
285        }
286        Ok(())
287    }
288
289    pub async fn shutdown(&mut self) -> Result<()> {
290        AsyncWriteExt::shutdown(self).await.map_err(|e| {
291            Error::io(
292                format!("failed to shutdown object writer for {}: {}", self.path, e),
293                // and wrap it in here.
294                location!(),
295            )
296        })
297    }
298}
299
300impl Drop for ObjectWriter {
301    fn drop(&mut self) {
302        // If there is a multipart upload started but not finished, we should abort it.
303        if matches!(self.state, UploadState::InProgress { .. }) {
304            // Take ownership of the state.
305            let state = std::mem::replace(&mut self.state, UploadState::Done);
306            if let UploadState::InProgress { mut upload, .. } = state {
307                tokio::task::spawn(async move {
308                    let _ = upload.abort().await;
309                });
310            }
311        }
312    }
313}
314
315/// Returned error from trying to upload a part.
316/// Has the part_idx and buffer so we can pass
317/// them to the retry logic.
318struct UploadPutError {
319    part_idx: u16,
320    buffer: Bytes,
321    source: OSError,
322}
323
324#[derive(Debug)]
325struct ConnectionResetError {
326    message: String,
327    source: Box<dyn std::error::Error + Send + Sync>,
328}
329
330impl std::error::Error for ConnectionResetError {}
331
332impl std::fmt::Display for ConnectionResetError {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        write!(f, "{}: {}", self.message, self.source)
335    }
336}
337
338impl AsyncWrite for ObjectWriter {
339    fn poll_write(
340        mut self: std::pin::Pin<&mut Self>,
341        cx: &mut std::task::Context<'_>,
342        buf: &[u8],
343    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
344        self.as_mut().poll_tasks(cx)?;
345
346        // Fill buffer up to remaining capacity.
347        let remaining_capacity = self.buffer.capacity() - self.buffer.len();
348        let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
349        self.buffer.extend_from_slice(&buf[..bytes_to_write]);
350        self.cursor += bytes_to_write;
351
352        // Rust needs a little help to borrow self mutably and immutably at the same time
353        // through a Pin.
354        let mut_self = &mut *self;
355
356        // Instantiate next request, if available.
357        if mut_self.buffer.capacity() == mut_self.buffer.len() {
358            match &mut mut_self.state {
359                UploadState::Started(store) => {
360                    let path = mut_self.path.clone();
361                    let store = store.clone();
362                    let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
363                    self.state = UploadState::CreatingUpload(fut);
364                }
365                UploadState::InProgress {
366                    upload,
367                    part_idx,
368                    futures,
369                    ..
370                } => {
371                    // TODO: Make max concurrency configurable from storage options.
372                    if futures.len() < max_upload_parallelism() {
373                        let data = Self::next_part_buffer(
374                            &mut mut_self.buffer,
375                            *part_idx,
376                            mut_self.use_constant_size_upload_parts,
377                        );
378                        futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
379                        *part_idx += 1;
380                    }
381                }
382                _ => {}
383            }
384        }
385
386        self.poll_tasks(cx)?;
387
388        match bytes_to_write {
389            0 => Poll::Pending,
390            _ => Poll::Ready(Ok(bytes_to_write)),
391        }
392    }
393
394    fn poll_flush(
395        mut self: std::pin::Pin<&mut Self>,
396        cx: &mut std::task::Context<'_>,
397    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
398        self.as_mut().poll_tasks(cx)?;
399
400        match &self.state {
401            UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())),
402            UploadState::CreatingUpload(_)
403            | UploadState::Completing(_)
404            | UploadState::PuttingSingle(_) => Poll::Pending,
405            UploadState::InProgress { futures, .. } => {
406                if futures.is_empty() {
407                    Poll::Ready(Ok(()))
408                } else {
409                    Poll::Pending
410                }
411            }
412        }
413    }
414
415    fn poll_shutdown(
416        mut self: std::pin::Pin<&mut Self>,
417        cx: &mut std::task::Context<'_>,
418    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
419        loop {
420            self.as_mut().poll_tasks(cx)?;
421
422            // Rust needs a little help to borrow self mutably and immutably at the same time
423            // through a Pin.
424            let mut_self = &mut *self;
425            match &mut mut_self.state {
426                UploadState::Done => return Poll::Ready(Ok(())),
427                UploadState::CreatingUpload(_)
428                | UploadState::PuttingSingle(_)
429                | UploadState::Completing(_) => return Poll::Pending,
430                UploadState::Started(_) => {
431                    // If we didn't start a multipart upload, we can just do a single put.
432                    let part = std::mem::take(&mut mut_self.buffer);
433                    let path = mut_self.path.clone();
434                    self.state.started_to_completing(path, part);
435                }
436                UploadState::InProgress {
437                    upload,
438                    futures,
439                    part_idx,
440                } => {
441                    // Flush final batch
442                    if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
443                        // We can just use `take` since we don't need the buffer anymore.
444                        let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
445                        futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
446                        // We need to go back to beginning of loop to poll the
447                        // new feature and get the waker registered on the ctx.
448                        continue;
449                    }
450
451                    // We handle the transition from in progress to completing here.
452                    if futures.is_empty() {
453                        self.state.in_progress_to_completing();
454                    } else {
455                        return Poll::Pending;
456                    }
457                }
458            }
459        }
460    }
461}
462
463#[async_trait]
464impl Writer for ObjectWriter {
465    async fn tell(&mut self) -> Result<usize> {
466        Ok(self.cursor)
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use tokio::io::AsyncWriteExt;
473
474    use super::*;
475
476    #[tokio::test]
477    async fn test_write() {
478        let store = LanceObjectStore::memory();
479
480        let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
481            .await
482            .unwrap();
483        assert_eq!(object_writer.tell().await.unwrap(), 0);
484
485        let buf = vec![0; 256];
486        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
487        assert_eq!(object_writer.tell().await.unwrap(), 256);
488
489        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
490        assert_eq!(object_writer.tell().await.unwrap(), 512);
491
492        assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
493        assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
494
495        object_writer.shutdown().await.unwrap();
496    }
497}