1use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7use std::task::Poll;
8
9use crate::object_store::ObjectStore as LanceObjectStore;
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::future::BoxFuture;
13use futures::FutureExt;
14use object_store::MultipartUpload;
15use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
16use rand::Rng;
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18use tokio::task::JoinSet;
19
20use lance_core::{Error, Result};
21
22use crate::traits::Writer;
23use snafu::location;
24
25const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
27
28fn max_upload_parallelism() -> usize {
29 static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
30 *MAX_UPLOAD_PARALLELISM.get_or_init(|| {
31 std::env::var("LANCE_UPLOAD_CONCURRENCY")
32 .ok()
33 .and_then(|s| s.parse::<usize>().ok())
34 .unwrap_or(10)
35 })
36}
37
38fn max_conn_reset_retries() -> u16 {
39 static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
40 *MAX_CONN_RESET_RETRIES.get_or_init(|| {
41 std::env::var("LANCE_CONN_RESET_RETRIES")
42 .ok()
43 .and_then(|s| s.parse::<u16>().ok())
44 .unwrap_or(20)
45 })
46}
47
48fn initial_upload_size() -> usize {
49 static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
50 *LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
51 std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
52 .ok()
53 .and_then(|s| s.parse::<usize>().ok())
54 .inspect(|size| {
55 if *size < INITIAL_UPLOAD_STEP {
56 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
58 } else if *size > 1024 * 1024 * 1024 * 5 {
59 panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
61 }
62 })
63 .unwrap_or(INITIAL_UPLOAD_STEP)
64 })
65}
66
67pub struct ObjectWriter {
75 state: UploadState,
76 path: Arc<Path>,
77 cursor: usize,
78 connection_resets: u16,
79 buffer: Vec<u8>,
80 use_constant_size_upload_parts: bool,
82}
83
84enum UploadState {
85 Started(Arc<dyn ObjectStore>),
88 CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
90 InProgress {
92 part_idx: u16,
93 upload: Box<dyn MultipartUpload>,
94 futures: JoinSet<std::result::Result<(), UploadPutError>>,
95 },
96 PuttingSingle(BoxFuture<'static, OSResult<()>>),
99 Completing(BoxFuture<'static, OSResult<()>>),
101 Done,
103}
104
105impl UploadState {
107 fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
108 let this = std::mem::replace(self, Self::Done);
110 *self = match this {
111 Self::Started(store) => {
112 let fut = async move {
113 store.put(&path, buffer.into()).await?;
114 Ok(())
115 };
116 Self::PuttingSingle(Box::pin(fut))
117 }
118 _ => unreachable!(),
119 }
120 }
121
122 fn in_progress_to_completing(&mut self) {
123 let this = std::mem::replace(self, Self::Done);
125 *self = match this {
126 Self::InProgress {
127 mut upload,
128 futures,
129 ..
130 } => {
131 debug_assert!(futures.is_empty());
132 let fut = async move {
133 upload.complete().await?;
134 Ok(())
135 };
136 Self::Completing(Box::pin(fut))
137 }
138 _ => unreachable!(),
139 };
140 }
141}
142
143impl ObjectWriter {
144 pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
145 Ok(Self {
146 state: UploadState::Started(object_store.inner.clone()),
147 cursor: 0,
148 path: Arc::new(path.clone()),
149 connection_resets: 0,
150 buffer: Vec::with_capacity(initial_upload_size()),
151 use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
152 })
153 }
154
155 fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
158 let new_capacity = if constant_upload_size {
159 initial_upload_size()
161 } else {
162 initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
164 };
165 let new_buffer = Vec::with_capacity(new_capacity);
166 let part = std::mem::replace(buffer, new_buffer);
167 Bytes::from(part)
168 }
169
170 fn put_part(
171 upload: &mut dyn MultipartUpload,
172 buffer: Bytes,
173 part_idx: u16,
174 sleep: Option<std::time::Duration>,
175 ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
176 log::debug!(
177 "MultipartUpload submitting part with {} bytes",
178 buffer.len()
179 );
180 let fut = upload.put_part(buffer.clone().into());
181 Box::pin(async move {
182 if let Some(sleep) = sleep {
183 tokio::time::sleep(sleep).await;
184 }
185 fut.await.map_err(|source| UploadPutError {
186 part_idx,
187 buffer,
188 source,
189 })?;
190 Ok(())
191 })
192 }
193
194 fn poll_tasks(
195 mut self: Pin<&mut Self>,
196 cx: &mut std::task::Context<'_>,
197 ) -> std::result::Result<(), io::Error> {
198 let mut_self = &mut *self;
199 loop {
200 match &mut mut_self.state {
201 UploadState::Started(_) | UploadState::Done => break,
202 UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
203 Poll::Ready(Ok(mut upload)) => {
204 let mut futures = JoinSet::new();
205
206 let data = Self::next_part_buffer(
207 &mut mut_self.buffer,
208 0,
209 mut_self.use_constant_size_upload_parts,
210 );
211 futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
212
213 mut_self.state = UploadState::InProgress {
214 part_idx: 1, futures,
216 upload,
217 };
218 }
219 Poll::Ready(Err(e)) => {
220 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
221 }
222 Poll::Pending => break,
223 },
224 UploadState::InProgress {
225 upload, futures, ..
226 } => {
227 while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
228 match res {
229 Ok(Ok(())) => {}
230 Err(err) => {
231 return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
232 }
233 Ok(Err(UploadPutError {
234 source: OSError::Generic { source, .. },
235 part_idx,
236 buffer,
237 })) if source
238 .to_string()
239 .to_lowercase()
240 .contains("connection reset by peer") =>
241 {
242 if mut_self.connection_resets < max_conn_reset_retries() {
243 mut_self.connection_resets += 1;
245
246 let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
248 let sleep_time =
249 std::time::Duration::from_millis(sleep_time_ms);
250
251 futures.spawn(Self::put_part(
252 upload.as_mut(),
253 buffer,
254 part_idx,
255 Some(sleep_time),
256 ));
257 } else {
258 return Err(io::Error::new(
259 io::ErrorKind::ConnectionReset,
260 Box::new(ConnectionResetError {
261 message: format!(
262 "Hit max retries ({}) for connection reset",
263 max_conn_reset_retries()
264 ),
265 source,
266 }),
267 ));
268 }
269 }
270 Ok(Err(err)) => return Err(err.source.into()),
271 }
272 }
273 break;
274 }
275 UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
276 match fut.poll_unpin(cx) {
277 Poll::Ready(Ok(())) => mut_self.state = UploadState::Done,
278 Poll::Ready(Err(e)) => {
279 return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
280 }
281 Poll::Pending => break,
282 }
283 }
284 }
285 }
286 Ok(())
287 }
288
289 pub async fn shutdown(&mut self) -> Result<()> {
290 AsyncWriteExt::shutdown(self).await.map_err(|e| {
291 Error::io(
292 format!("failed to shutdown object writer for {}: {}", self.path, e),
293 location!(),
295 )
296 })
297 }
298}
299
300impl Drop for ObjectWriter {
301 fn drop(&mut self) {
302 if matches!(self.state, UploadState::InProgress { .. }) {
304 let state = std::mem::replace(&mut self.state, UploadState::Done);
306 if let UploadState::InProgress { mut upload, .. } = state {
307 tokio::task::spawn(async move {
308 let _ = upload.abort().await;
309 });
310 }
311 }
312 }
313}
314
315struct UploadPutError {
319 part_idx: u16,
320 buffer: Bytes,
321 source: OSError,
322}
323
324#[derive(Debug)]
325struct ConnectionResetError {
326 message: String,
327 source: Box<dyn std::error::Error + Send + Sync>,
328}
329
330impl std::error::Error for ConnectionResetError {}
331
332impl std::fmt::Display for ConnectionResetError {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 write!(f, "{}: {}", self.message, self.source)
335 }
336}
337
338impl AsyncWrite for ObjectWriter {
339 fn poll_write(
340 mut self: std::pin::Pin<&mut Self>,
341 cx: &mut std::task::Context<'_>,
342 buf: &[u8],
343 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
344 self.as_mut().poll_tasks(cx)?;
345
346 let remaining_capacity = self.buffer.capacity() - self.buffer.len();
348 let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
349 self.buffer.extend_from_slice(&buf[..bytes_to_write]);
350 self.cursor += bytes_to_write;
351
352 let mut_self = &mut *self;
355
356 if mut_self.buffer.capacity() == mut_self.buffer.len() {
358 match &mut mut_self.state {
359 UploadState::Started(store) => {
360 let path = mut_self.path.clone();
361 let store = store.clone();
362 let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
363 self.state = UploadState::CreatingUpload(fut);
364 }
365 UploadState::InProgress {
366 upload,
367 part_idx,
368 futures,
369 ..
370 } => {
371 if futures.len() < max_upload_parallelism() {
373 let data = Self::next_part_buffer(
374 &mut mut_self.buffer,
375 *part_idx,
376 mut_self.use_constant_size_upload_parts,
377 );
378 futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
379 *part_idx += 1;
380 }
381 }
382 _ => {}
383 }
384 }
385
386 self.poll_tasks(cx)?;
387
388 match bytes_to_write {
389 0 => Poll::Pending,
390 _ => Poll::Ready(Ok(bytes_to_write)),
391 }
392 }
393
394 fn poll_flush(
395 mut self: std::pin::Pin<&mut Self>,
396 cx: &mut std::task::Context<'_>,
397 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
398 self.as_mut().poll_tasks(cx)?;
399
400 match &self.state {
401 UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())),
402 UploadState::CreatingUpload(_)
403 | UploadState::Completing(_)
404 | UploadState::PuttingSingle(_) => Poll::Pending,
405 UploadState::InProgress { futures, .. } => {
406 if futures.is_empty() {
407 Poll::Ready(Ok(()))
408 } else {
409 Poll::Pending
410 }
411 }
412 }
413 }
414
415 fn poll_shutdown(
416 mut self: std::pin::Pin<&mut Self>,
417 cx: &mut std::task::Context<'_>,
418 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
419 loop {
420 self.as_mut().poll_tasks(cx)?;
421
422 let mut_self = &mut *self;
425 match &mut mut_self.state {
426 UploadState::Done => return Poll::Ready(Ok(())),
427 UploadState::CreatingUpload(_)
428 | UploadState::PuttingSingle(_)
429 | UploadState::Completing(_) => return Poll::Pending,
430 UploadState::Started(_) => {
431 let part = std::mem::take(&mut mut_self.buffer);
433 let path = mut_self.path.clone();
434 self.state.started_to_completing(path, part);
435 }
436 UploadState::InProgress {
437 upload,
438 futures,
439 part_idx,
440 } => {
441 if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
443 let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
445 futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
446 continue;
449 }
450
451 if futures.is_empty() {
453 self.state.in_progress_to_completing();
454 } else {
455 return Poll::Pending;
456 }
457 }
458 }
459 }
460 }
461}
462
463#[async_trait]
464impl Writer for ObjectWriter {
465 async fn tell(&mut self) -> Result<usize> {
466 Ok(self.cursor)
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use tokio::io::AsyncWriteExt;
473
474 use super::*;
475
476 #[tokio::test]
477 async fn test_write() {
478 let store = LanceObjectStore::memory();
479
480 let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
481 .await
482 .unwrap();
483 assert_eq!(object_writer.tell().await.unwrap(), 0);
484
485 let buf = vec![0; 256];
486 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
487 assert_eq!(object_writer.tell().await.unwrap(), 256);
488
489 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
490 assert_eq!(object_writer.tell().await.unwrap(), 512);
491
492 assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
493 assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
494
495 object_writer.shutdown().await.unwrap();
496 }
497}