use std::io;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::task::Poll;
use crate::object_store::ObjectStore as LanceObjectStore;
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::FutureExt;
use object_store::MultipartUpload;
use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult};
use rand::Rng;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::task::JoinSet;
use lance_core::{Error, Result};
use crate::traits::Writer;
use snafu::{location, Location};
const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5;
fn max_upload_parallelism() -> usize {
static MAX_UPLOAD_PARALLELISM: OnceLock<usize> = OnceLock::new();
*MAX_UPLOAD_PARALLELISM.get_or_init(|| {
std::env::var("LANCE_UPLOAD_CONCURRENCY")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(10)
})
}
fn max_conn_reset_retries() -> u16 {
static MAX_CONN_RESET_RETRIES: OnceLock<u16> = OnceLock::new();
*MAX_CONN_RESET_RETRIES.get_or_init(|| {
std::env::var("LANCE_CONN_RESET_RETRIES")
.ok()
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(20)
})
}
fn initial_upload_size() -> usize {
static LANCE_INITIAL_UPLOAD_SIZE: OnceLock<usize> = OnceLock::new();
*LANCE_INITIAL_UPLOAD_SIZE.get_or_init(|| {
std::env::var("LANCE_INITIAL_UPLOAD_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.inspect(|size| {
if *size < INITIAL_UPLOAD_STEP {
panic!("LANCE_INITIAL_UPLOAD_SIZE must be at least 5MB");
} else if *size > 1024 * 1024 * 1024 * 5 {
panic!("LANCE_INITIAL_UPLOAD_SIZE must be at most 5GB");
}
})
.unwrap_or(INITIAL_UPLOAD_STEP)
})
}
pub struct ObjectWriter {
state: UploadState,
path: Arc<Path>,
cursor: usize,
connection_resets: u16,
buffer: Vec<u8>,
use_constant_size_upload_parts: bool,
}
enum UploadState {
Started(Arc<dyn ObjectStore>),
CreatingUpload(BoxFuture<'static, OSResult<Box<dyn MultipartUpload>>>),
InProgress {
part_idx: u16,
upload: Box<dyn MultipartUpload>,
futures: JoinSet<std::result::Result<(), UploadPutError>>,
},
PuttingSingle(BoxFuture<'static, OSResult<()>>),
Completing(BoxFuture<'static, OSResult<()>>),
Done,
}
impl UploadState {
fn started_to_completing(&mut self, path: Arc<Path>, buffer: Vec<u8>) {
let this = std::mem::replace(self, Self::Done);
*self = match this {
Self::Started(store) => {
let fut = async move {
store.put(&path, buffer.into()).await?;
Ok(())
};
Self::PuttingSingle(Box::pin(fut))
}
_ => unreachable!(),
}
}
fn in_progress_to_completing(&mut self) {
let this = std::mem::replace(self, Self::Done);
*self = match this {
Self::InProgress {
mut upload,
futures,
..
} => {
debug_assert!(futures.is_empty());
let fut = async move {
upload.complete().await?;
Ok(())
};
Self::Completing(Box::pin(fut))
}
_ => unreachable!(),
};
}
}
impl ObjectWriter {
pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result<Self> {
Ok(Self {
state: UploadState::Started(object_store.inner.clone()),
cursor: 0,
path: Arc::new(path.clone()),
connection_resets: 0,
buffer: Vec::with_capacity(initial_upload_size()),
use_constant_size_upload_parts: object_store.use_constant_size_upload_parts,
})
}
fn next_part_buffer(buffer: &mut Vec<u8>, part_idx: u16, constant_upload_size: bool) -> Bytes {
let new_capacity = if constant_upload_size {
initial_upload_size()
} else {
initial_upload_size().max(((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_STEP)
};
let new_buffer = Vec::with_capacity(new_capacity);
let part = std::mem::replace(buffer, new_buffer);
Bytes::from(part)
}
fn put_part(
upload: &mut dyn MultipartUpload,
buffer: Bytes,
part_idx: u16,
sleep: Option<std::time::Duration>,
) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> {
log::debug!(
"MultipartUpload submitting part with {} bytes",
buffer.len()
);
let fut = upload.put_part(buffer.clone().into());
Box::pin(async move {
if let Some(sleep) = sleep {
tokio::time::sleep(sleep).await;
}
fut.await.map_err(|source| UploadPutError {
part_idx,
buffer,
source,
})?;
Ok(())
})
}
fn poll_tasks(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::result::Result<(), io::Error> {
let mut_self = &mut *self;
loop {
match &mut mut_self.state {
UploadState::Started(_) | UploadState::Done => break,
UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(mut upload)) => {
let mut futures = JoinSet::new();
let data = Self::next_part_buffer(
&mut mut_self.buffer,
0,
mut_self.use_constant_size_upload_parts,
);
futures.spawn(Self::put_part(upload.as_mut(), data, 0, None));
mut_self.state = UploadState::InProgress {
part_idx: 1, futures,
upload,
};
}
Poll::Ready(Err(e)) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
}
Poll::Pending => break,
},
UploadState::InProgress {
upload, futures, ..
} => {
while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) {
match res {
Ok(Ok(())) => {}
Err(err) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, err))
}
Ok(Err(UploadPutError {
source: OSError::Generic { source, .. },
part_idx,
buffer,
})) if source
.to_string()
.to_lowercase()
.contains("connection reset by peer") =>
{
if mut_self.connection_resets < max_conn_reset_retries() {
mut_self.connection_resets += 1;
let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000);
let sleep_time =
std::time::Duration::from_millis(sleep_time_ms);
futures.spawn(Self::put_part(
upload.as_mut(),
buffer,
part_idx,
Some(sleep_time),
));
} else {
return Err(io::Error::new(
io::ErrorKind::ConnectionReset,
Box::new(ConnectionResetError {
message: format!(
"Hit max retries ({}) for connection reset",
max_conn_reset_retries()
),
source,
}),
));
}
}
Ok(Err(err)) => return Err(err.source.into()),
}
}
break;
}
UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => {
match fut.poll_unpin(cx) {
Poll::Ready(Ok(())) => mut_self.state = UploadState::Done,
Poll::Ready(Err(e)) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, e))
}
Poll::Pending => break,
}
}
}
}
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
AsyncWriteExt::shutdown(self).await.map_err(|e| {
Error::io(
format!("failed to shutdown object writer for {}: {}", self.path, e),
location!(),
)
})
}
}
impl Drop for ObjectWriter {
fn drop(&mut self) {
if matches!(self.state, UploadState::InProgress { .. }) {
let state = std::mem::replace(&mut self.state, UploadState::Done);
if let UploadState::InProgress { mut upload, .. } = state {
tokio::task::spawn(async move {
let _ = upload.abort().await;
});
}
}
}
}
struct UploadPutError {
part_idx: u16,
buffer: Bytes,
source: OSError,
}
#[derive(Debug)]
struct ConnectionResetError {
message: String,
source: Box<dyn std::error::Error + Send + Sync>,
}
impl std::error::Error for ConnectionResetError {}
impl std::fmt::Display for ConnectionResetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.message, self.source)
}
}
impl AsyncWrite for ObjectWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
self.as_mut().poll_tasks(cx)?;
let remaining_capacity = self.buffer.capacity() - self.buffer.len();
let bytes_to_write = std::cmp::min(remaining_capacity, buf.len());
self.buffer.extend_from_slice(&buf[..bytes_to_write]);
self.cursor += bytes_to_write;
let mut_self = &mut *self;
if mut_self.buffer.capacity() == mut_self.buffer.len() {
match &mut mut_self.state {
UploadState::Started(store) => {
let path = mut_self.path.clone();
let store = store.clone();
let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await });
self.state = UploadState::CreatingUpload(fut);
}
UploadState::InProgress {
upload,
part_idx,
futures,
..
} => {
if futures.len() < max_upload_parallelism() {
let data = Self::next_part_buffer(
&mut mut_self.buffer,
*part_idx,
mut_self.use_constant_size_upload_parts,
);
futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
*part_idx += 1;
}
}
_ => {}
}
}
self.poll_tasks(cx)?;
match bytes_to_write {
0 => Poll::Pending,
_ => Poll::Ready(Ok(bytes_to_write)),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
self.as_mut().poll_tasks(cx)?;
match &self.state {
UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())),
UploadState::CreatingUpload(_)
| UploadState::Completing(_)
| UploadState::PuttingSingle(_) => Poll::Pending,
UploadState::InProgress { futures, .. } => {
if futures.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
loop {
self.as_mut().poll_tasks(cx)?;
let mut_self = &mut *self;
match &mut mut_self.state {
UploadState::Done => return Poll::Ready(Ok(())),
UploadState::CreatingUpload(_)
| UploadState::PuttingSingle(_)
| UploadState::Completing(_) => return Poll::Pending,
UploadState::Started(_) => {
let part = std::mem::take(&mut mut_self.buffer);
let path = mut_self.path.clone();
self.state.started_to_completing(path, part);
}
UploadState::InProgress {
upload,
futures,
part_idx,
} => {
if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() {
let data = Bytes::from(std::mem::take(&mut mut_self.buffer));
futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None));
continue;
}
if futures.is_empty() {
self.state.in_progress_to_completing();
} else {
return Poll::Pending;
}
}
}
}
}
}
#[async_trait]
impl Writer for ObjectWriter {
async fn tell(&mut self) -> Result<usize> {
Ok(self.cursor)
}
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncWriteExt;
use super::*;
#[tokio::test]
async fn test_write() {
let store = LanceObjectStore::memory();
let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo"))
.await
.unwrap();
assert_eq!(object_writer.tell().await.unwrap(), 0);
let buf = vec![0; 256];
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 512);
assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256);
assert_eq!(object_writer.tell().await.unwrap(), 256 * 3);
object_writer.shutdown().await.unwrap();
}
}