lance_io/
object_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend [object_store::ObjectStore] functionalities
5
6use std::collections::HashMap;
7use std::ops::Range;
8use std::path::PathBuf;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13use async_trait::async_trait;
14use aws_config::default_provider::credentials::DefaultCredentialsChain;
15use aws_credential_types::provider::ProvideCredentials;
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use deepsize::DeepSizeOf;
19use futures::{future, stream::BoxStream, StreamExt, TryStreamExt};
20use lance_core::utils::parse::str_is_truthy;
21use lance_core::utils::tokio::get_num_compute_intensive_cpus;
22use object_store::aws::{
23    AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider,
24};
25use object_store::gcp::{GcpCredential, GoogleCloudStorageBuilder};
26use object_store::{
27    aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem,
28    memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult,
29};
30use object_store::{
31    parse_url_opts, ClientOptions, DynObjectStore, RetryConfig, StaticCredentialProvider,
32};
33use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore};
34use shellexpand::tilde;
35use snafu::location;
36use tokio::io::AsyncWriteExt;
37use tokio::sync::RwLock;
38use url::Url;
39
40use super::local::LocalObjectReader;
41mod tracing;
42use self::tracing::ObjectStoreTracingExt;
43use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader};
44use lance_core::{Error, Result};
45
46// Local disks tend to do fine with a few threads
47// Note: the number of threads here also impacts the number of files
48// we need to read in some situations.  So keeping this at 8 keeps the
49// RAM on our scanner down.
50pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8;
51// Cloud disks often need many many threads to saturate the network
52pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64;
53
54pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3;
55
56#[async_trait]
57pub trait ObjectStoreExt {
58    /// Returns true if the file exists.
59    async fn exists(&self, path: &Path) -> Result<bool>;
60
61    /// Read all files (start from base directory) recursively
62    ///
63    /// unmodified_since can be specified to only return files that have not been modified since the given time.
64    async fn read_dir_all<'a>(
65        &'a self,
66        dir_path: impl Into<&Path> + Send,
67        unmodified_since: Option<DateTime<Utc>>,
68    ) -> Result<BoxStream<'a, Result<ObjectMeta>>>;
69}
70
71#[async_trait]
72impl<O: OSObjectStore + ?Sized> ObjectStoreExt for O {
73    async fn read_dir_all<'a>(
74        &'a self,
75        dir_path: impl Into<&Path> + Send,
76        unmodified_since: Option<DateTime<Utc>>,
77    ) -> Result<BoxStream<'a, Result<ObjectMeta>>> {
78        let mut output = self.list(Some(dir_path.into()));
79        if let Some(unmodified_since_val) = unmodified_since {
80            output = output
81                .try_filter(move |file| future::ready(file.last_modified < unmodified_since_val))
82                .boxed();
83        }
84        Ok(output.map_err(|e| e.into()).boxed())
85    }
86
87    async fn exists(&self, path: &Path) -> Result<bool> {
88        match self.head(path).await {
89            Ok(_) => Ok(true),
90            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
91            Err(e) => Err(e.into()),
92        }
93    }
94}
95
96/// Wraps [ObjectStore](object_store::ObjectStore)
97#[derive(Debug, Clone)]
98pub struct ObjectStore {
99    // Inner object store
100    pub inner: Arc<dyn OSObjectStore>,
101    scheme: String,
102    block_size: usize,
103    /// Whether to use constant size upload parts for multipart uploads. This
104    /// is only necessary for Cloudflare R2.
105    pub use_constant_size_upload_parts: bool,
106    /// Whether we can assume that the list of files is lexically ordered. This
107    /// is true for object stores, but not for local filesystems.
108    pub list_is_lexically_ordered: bool,
109    io_parallelism: usize,
110    /// Number of times to retry a failed download
111    download_retry_count: usize,
112}
113
114impl DeepSizeOf for ObjectStore {
115    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
116        // We aren't counting `inner` here which is problematic but an ObjectStore
117        // shouldn't be too big.  The only exception might be the write cache but, if
118        // the writer cache has data, it means we're using it somewhere else that isn't
119        // a cache and so that doesn't really count.
120        self.scheme.deep_size_of_children(context) + self.block_size.deep_size_of_children(context)
121    }
122}
123
124impl std::fmt::Display for ObjectStore {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "ObjectStore({})", self.scheme)
127    }
128}
129
130pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
131    fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
132}
133
134#[derive(Default, Debug)]
135pub struct ObjectStoreRegistry {
136    providers: HashMap<String, Arc<dyn ObjectStoreProvider>>,
137}
138
139impl ObjectStoreRegistry {
140    pub fn insert(&mut self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
141        self.providers.insert(scheme.into(), provider);
142    }
143}
144
145const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
146
147/// Adapt an AWS SDK cred into object_store credentials
148#[derive(Debug)]
149pub struct AwsCredentialAdapter {
150    pub inner: Arc<dyn ProvideCredentials>,
151
152    // RefCell can't be shared across threads, so we use HashMap
153    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
154
155    // The amount of time before expiry to refresh credentials
156    credentials_refresh_offset: Duration,
157}
158
159impl AwsCredentialAdapter {
160    pub fn new(
161        provider: Arc<dyn ProvideCredentials>,
162        credentials_refresh_offset: Duration,
163    ) -> Self {
164        Self {
165            inner: provider,
166            cache: Arc::new(RwLock::new(HashMap::new())),
167            credentials_refresh_offset,
168        }
169    }
170}
171
172#[async_trait]
173impl CredentialProvider for AwsCredentialAdapter {
174    type Credential = ObjectStoreAwsCredential;
175
176    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
177        let cached_creds = {
178            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
179            let expired = cache_value
180                .clone()
181                .map(|cred| {
182                    cred.expiry()
183                        .map(|exp| {
184                            exp.checked_sub(self.credentials_refresh_offset)
185                                .expect("this time should always be valid")
186                                < SystemTime::now()
187                        })
188                        // no expiry is never expire
189                        .unwrap_or(false)
190                })
191                .unwrap_or(true); // no cred is the same as expired;
192            if expired {
193                None
194            } else {
195                cache_value.clone()
196            }
197        };
198
199        if let Some(creds) = cached_creds {
200            Ok(Arc::new(Self::Credential {
201                key_id: creds.access_key_id().to_string(),
202                secret_key: creds.secret_access_key().to_string(),
203                token: creds.session_token().map(|s| s.to_string()),
204            }))
205        } else {
206            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
207                |e| Error::Internal {
208                    message: format!("Failed to get AWS credentials: {}", e),
209                    location: location!(),
210                },
211            )?);
212
213            self.cache
214                .write()
215                .await
216                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
217
218            Ok(Arc::new(Self::Credential {
219                key_id: refreshed_creds.access_key_id().to_string(),
220                secret_key: refreshed_creds.secret_access_key().to_string(),
221                token: refreshed_creds.session_token().map(|s| s.to_string()),
222            }))
223        }
224    }
225}
226
227/// Figure out the S3 region of the bucket.
228///
229/// This resolves in order of precedence:
230/// 1. The region provided in the storage options
231/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
232///
233/// It can return None if no region is provided and the endpoint is set.
234async fn resolve_s3_region(
235    url: &Url,
236    storage_options: &HashMap<AmazonS3ConfigKey, String>,
237) -> Result<Option<String>> {
238    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
239        Ok(Some(region.clone()))
240    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
241        // If no endpoint is set, we can assume this is AWS S3 and the region
242        // can be resolved from the bucket.
243        let bucket = url.host_str().ok_or_else(|| {
244            Error::invalid_input(
245                format!("Could not parse bucket from url: {}", url),
246                location!(),
247            )
248        })?;
249
250        let mut client_options = ClientOptions::default();
251        for (key, value) in storage_options {
252            if let AmazonS3ConfigKey::Client(client_key) = key {
253                client_options = client_options.with_config(*client_key, value.clone());
254            }
255        }
256
257        let bucket_region =
258            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
259        Ok(Some(bucket_region))
260    } else {
261        Ok(None)
262    }
263}
264
265/// Build AWS credentials
266///
267/// This resolves credentials from the following sources in order:
268/// 1. An explicit `credentials` provider
269/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
270///    `aws_secret_access_key`, `aws_session_token`)
271/// 3. The default credential provider chain from AWS SDK.
272///
273/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
274pub async fn build_aws_credential(
275    credentials_refresh_offset: Duration,
276    credentials: Option<AwsCredentialProvider>,
277    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
278    region: Option<String>,
279) -> Result<(AwsCredentialProvider, String)> {
280    // TODO: make this return no credential provider not using AWS
281    use aws_config::meta::region::RegionProviderChain;
282    const DEFAULT_REGION: &str = "us-west-2";
283
284    let region = if let Some(region) = region {
285        region
286    } else {
287        RegionProviderChain::default_provider()
288            .or_else(DEFAULT_REGION)
289            .region()
290            .await
291            .map(|r| r.as_ref().to_string())
292            .unwrap_or(DEFAULT_REGION.to_string())
293    };
294
295    if let Some(creds) = credentials {
296        Ok((creds, region))
297    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
298        Ok((Arc::new(creds), region))
299    } else {
300        let credentials_provider = DefaultCredentialsChain::builder().build().await;
301
302        Ok((
303            Arc::new(AwsCredentialAdapter::new(
304                Arc::new(credentials_provider),
305                credentials_refresh_offset,
306            )),
307            region,
308        ))
309    }
310}
311
312fn extract_static_s3_credentials(
313    options: &HashMap<AmazonS3ConfigKey, String>,
314) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
315    let key_id = options
316        .get(&AmazonS3ConfigKey::AccessKeyId)
317        .map(|s| s.to_string());
318    let secret_key = options
319        .get(&AmazonS3ConfigKey::SecretAccessKey)
320        .map(|s| s.to_string());
321    let token = options
322        .get(&AmazonS3ConfigKey::Token)
323        .map(|s| s.to_string());
324    match (key_id, secret_key, token) {
325        (Some(key_id), Some(secret_key), token) => {
326            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
327                key_id,
328                secret_key,
329                token,
330            }))
331        }
332        _ => None,
333    }
334}
335
336pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
337    fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
338}
339
340/// Parameters to create an [ObjectStore]
341///
342#[derive(Debug, Clone)]
343pub struct ObjectStoreParams {
344    pub block_size: Option<usize>,
345    pub object_store: Option<(Arc<DynObjectStore>, Url)>,
346    pub s3_credentials_refresh_offset: Duration,
347    pub aws_credentials: Option<AwsCredentialProvider>,
348    pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
349    pub storage_options: Option<HashMap<String, String>>,
350    /// Use constant size upload parts for multipart uploads. Only necessary
351    /// for Cloudflare R2, which doesn't support variable size parts. When this
352    /// is false, max upload size is 2.5TB. When this is true, the max size is
353    /// 50GB.
354    pub use_constant_size_upload_parts: bool,
355    pub list_is_lexically_ordered: Option<bool>,
356}
357
358impl Default for ObjectStoreParams {
359    fn default() -> Self {
360        Self {
361            object_store: None,
362            block_size: None,
363            s3_credentials_refresh_offset: Duration::from_secs(60),
364            aws_credentials: None,
365            object_store_wrapper: None,
366            storage_options: None,
367            use_constant_size_upload_parts: false,
368            list_is_lexically_ordered: None,
369        }
370    }
371}
372
373impl ObjectStoreParams {
374    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
375    pub fn with_aws_credentials(
376        aws_credentials: Option<AwsCredentialProvider>,
377        region: Option<String>,
378    ) -> Self {
379        Self {
380            aws_credentials,
381            storage_options: region
382                .map(|region| [("region".into(), region)].iter().cloned().collect()),
383            ..Default::default()
384        }
385    }
386}
387
388impl ObjectStore {
389    /// Parse from a string URI.
390    ///
391    /// Returns the ObjectStore instance and the absolute path to the object.
392    pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
393        let registry = Arc::new(ObjectStoreRegistry::default());
394
395        Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await
396    }
397
398    /// Parse from a string URI.
399    ///
400    /// Returns the ObjectStore instance and the absolute path to the object.
401    pub async fn from_uri_and_params(
402        registry: Arc<ObjectStoreRegistry>,
403        uri: &str,
404        params: &ObjectStoreParams,
405    ) -> Result<(Self, Path)> {
406        if let Some((store, path)) = params.object_store.as_ref() {
407            let mut inner = store.clone();
408            if let Some(wrapper) = params.object_store_wrapper.as_ref() {
409                inner = wrapper.wrap(inner);
410            }
411            let store = Self {
412                inner,
413                scheme: path.scheme().to_string(),
414                block_size: params.block_size.unwrap_or(64 * 1024),
415                use_constant_size_upload_parts: params.use_constant_size_upload_parts,
416                list_is_lexically_ordered: params.list_is_lexically_ordered.unwrap_or_default(),
417                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
418                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
419            };
420            let path = Path::from(path.path());
421            return Ok((store, path));
422        }
423        let (object_store, path) = match Url::parse(uri) {
424            Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
425                // On Windows, the drive is parsed as a scheme
426                Self::from_path(uri)
427            }
428            Ok(url) => {
429                let store = Self::new_from_url(registry, url.clone(), params.clone()).await?;
430                Ok((store, Path::from(url.path())))
431            }
432            Err(_) => Self::from_path(uri),
433        }?;
434
435        Ok((
436            Self {
437                inner: params
438                    .object_store_wrapper
439                    .as_ref()
440                    .map(|w| w.wrap(object_store.inner.clone()))
441                    .unwrap_or(object_store.inner),
442                ..object_store
443            },
444            path,
445        ))
446    }
447
448    pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> {
449        let expanded = tilde(str_path).to_string();
450
451        let mut expanded_path = path_abs::PathAbs::new(expanded)
452            .unwrap()
453            .as_path()
454            .to_path_buf();
455        // path_abs::PathAbs::new(".") returns an empty string.
456        if let Some(s) = expanded_path.as_path().to_str() {
457            if s.is_empty() {
458                expanded_path = std::env::current_dir()?;
459            }
460        }
461        Ok((
462            Self {
463                inner: Arc::new(LocalFileSystem::new()).traced(),
464                scheme: String::from(scheme),
465                block_size: 4 * 1024, // 4KB block size
466                use_constant_size_upload_parts: false,
467                list_is_lexically_ordered: false,
468                io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
469                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
470            },
471            Path::from_absolute_path(expanded_path.as_path())?,
472        ))
473    }
474
475    pub fn from_path(str_path: &str) -> Result<(Self, Path)> {
476        Self::from_path_with_scheme(str_path, "file")
477    }
478
479    async fn new_from_url(
480        registry: Arc<ObjectStoreRegistry>,
481        url: Url,
482        params: ObjectStoreParams,
483    ) -> Result<Self> {
484        configure_store(registry, url.as_str(), params).await
485    }
486
487    /// Local object store.
488    pub fn local() -> Self {
489        Self {
490            inner: Arc::new(LocalFileSystem::new()).traced(),
491            scheme: String::from("file"),
492            block_size: 4 * 1024, // 4KB block size
493            use_constant_size_upload_parts: false,
494            list_is_lexically_ordered: false,
495            io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
496            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
497        }
498    }
499
500    /// Create a in-memory object store directly for testing.
501    pub fn memory() -> Self {
502        Self {
503            inner: Arc::new(InMemory::new()).traced(),
504            scheme: String::from("memory"),
505            block_size: 4 * 1024,
506            use_constant_size_upload_parts: false,
507            list_is_lexically_ordered: true,
508            io_parallelism: get_num_compute_intensive_cpus(),
509            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
510        }
511    }
512
513    /// Returns true if the object store pointed to a local file system.
514    pub fn is_local(&self) -> bool {
515        self.scheme == "file"
516    }
517
518    pub fn is_cloud(&self) -> bool {
519        self.scheme != "file" && self.scheme != "memory"
520    }
521
522    pub fn block_size(&self) -> usize {
523        self.block_size
524    }
525
526    pub fn set_block_size(&mut self, new_size: usize) {
527        self.block_size = new_size;
528    }
529
530    pub fn set_io_parallelism(&mut self, io_parallelism: usize) {
531        self.io_parallelism = io_parallelism;
532    }
533
534    pub fn io_parallelism(&self) -> usize {
535        std::env::var("LANCE_IO_THREADS")
536            .map(|val| val.parse::<usize>().unwrap())
537            .unwrap_or(self.io_parallelism)
538    }
539
540    /// Open a file for path.
541    ///
542    /// Parameters
543    /// - ``path``: Absolute path to the file.
544    pub async fn open(&self, path: &Path) -> Result<Box<dyn Reader>> {
545        match self.scheme.as_str() {
546            "file" => LocalObjectReader::open(path, self.block_size, None).await,
547            _ => Ok(Box::new(CloudObjectReader::new(
548                self.inner.clone(),
549                path.clone(),
550                self.block_size,
551                None,
552                self.download_retry_count,
553            )?)),
554        }
555    }
556
557    /// Open a reader for a file with known size.
558    ///
559    /// This size may either have been retrieved from a list operation or
560    /// cached metadata. By passing in the known size, we can skip a HEAD / metadata
561    /// call.
562    pub async fn open_with_size(&self, path: &Path, known_size: usize) -> Result<Box<dyn Reader>> {
563        match self.scheme.as_str() {
564            "file" => LocalObjectReader::open(path, self.block_size, Some(known_size)).await,
565            _ => Ok(Box::new(CloudObjectReader::new(
566                self.inner.clone(),
567                path.clone(),
568                self.block_size,
569                Some(known_size),
570                self.download_retry_count,
571            )?)),
572        }
573    }
574
575    /// Create an [ObjectWriter] from local [std::path::Path]
576    pub async fn create_local_writer(path: &std::path::Path) -> Result<ObjectWriter> {
577        let object_store = Self::local();
578        let os_path = Path::from(path.to_str().unwrap());
579        object_store.create(&os_path).await
580    }
581
582    /// Open an [Reader] from local [std::path::Path]
583    pub async fn open_local(path: &std::path::Path) -> Result<Box<dyn Reader>> {
584        let object_store = Self::local();
585        let os_path = Path::from(path.to_str().unwrap());
586        object_store.open(&os_path).await
587    }
588
589    /// Create a new file.
590    pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
591        ObjectWriter::new(self, path).await
592    }
593
594    /// A helper function to create a file and write content to it.
595    pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> {
596        let mut writer = self.create(path).await?;
597        writer.write_all(content).await?;
598        writer.shutdown().await
599    }
600
601    pub async fn delete(&self, path: &Path) -> Result<()> {
602        self.inner.delete(path).await?;
603        Ok(())
604    }
605
606    pub async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
607        Ok(self.inner.copy(from, to).await?)
608    }
609
610    /// Read a directory (start from base directory) and returns all sub-paths in the directory.
611    pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
612        let path = dir_path.into();
613        let path = Path::parse(&path)?;
614        let output = self.inner.list_with_delimiter(Some(&path)).await?;
615        Ok(output
616            .common_prefixes
617            .iter()
618            .chain(output.objects.iter().map(|o| &o.location))
619            .map(|s| s.filename().unwrap().to_string())
620            .collect())
621    }
622
623    /// Read all files (start from base directory) recursively
624    ///
625    /// unmodified_since can be specified to only return files that have not been modified since the given time.
626    pub async fn read_dir_all(
627        &self,
628        dir_path: impl Into<&Path> + Send,
629        unmodified_since: Option<DateTime<Utc>>,
630    ) -> Result<BoxStream<Result<ObjectMeta>>> {
631        self.inner.read_dir_all(dir_path, unmodified_since).await
632    }
633
634    /// Remove a directory recursively.
635    pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
636        let path = dir_path.into();
637        let path = Path::parse(&path)?;
638
639        if self.is_local() {
640            // Local file system needs to delete directories as well.
641            return super::local::remove_dir_all(&path);
642        }
643        let sub_entries = self
644            .inner
645            .list(Some(&path))
646            .map(|m| m.map(|meta| meta.location))
647            .boxed();
648        self.inner
649            .delete_stream(sub_entries)
650            .try_collect::<Vec<_>>()
651            .await?;
652        Ok(())
653    }
654
655    pub fn remove_stream<'a>(
656        &'a self,
657        locations: BoxStream<'a, Result<Path>>,
658    ) -> BoxStream<'a, Result<Path>> {
659        self.inner
660            .delete_stream(locations.err_into::<ObjectStoreError>().boxed())
661            .err_into::<Error>()
662            .boxed()
663    }
664
665    /// Check a file exists.
666    pub async fn exists(&self, path: &Path) -> Result<bool> {
667        match self.inner.head(path).await {
668            Ok(_) => Ok(true),
669            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
670            Err(e) => Err(e.into()),
671        }
672    }
673
674    /// Get file size.
675    pub async fn size(&self, path: &Path) -> Result<usize> {
676        Ok(self.inner.head(path).await?.size)
677    }
678
679    /// Convenience function to open a reader and read all the bytes
680    pub async fn read_one_all(&self, path: &Path) -> Result<Bytes> {
681        let reader = self.open(path).await?;
682        Ok(reader.get_all().await?)
683    }
684
685    /// Convenience function open a reader and make a single request
686    ///
687    /// If you will be making multiple requests to the path it is more efficient to call [`Self::open`]
688    /// and then call [`Reader::get_range`] multiple times.
689    pub async fn read_one_range(&self, path: &Path, range: Range<usize>) -> Result<Bytes> {
690        let reader = self.open(path).await?;
691        Ok(reader.get_range(range).await?)
692    }
693}
694
695/// Options that can be set for multiple object stores
696#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
697pub enum LanceConfigKey {
698    /// Number of times to retry a download that fails
699    DownloadRetryCount,
700}
701
702impl FromStr for LanceConfigKey {
703    type Err = Error;
704
705    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
706        match s.to_ascii_lowercase().as_str() {
707            "download_retry_count" => Ok(Self::DownloadRetryCount),
708            _ => Err(Error::InvalidInput {
709                source: format!("Invalid LanceConfigKey: {}", s).into(),
710                location: location!(),
711            }),
712        }
713    }
714}
715
716#[derive(Clone, Debug, Default)]
717pub struct StorageOptions(pub HashMap<String, String>);
718
719impl StorageOptions {
720    /// Create a new instance of [`StorageOptions`]
721    pub fn new(options: HashMap<String, String>) -> Self {
722        let mut options = options;
723        if let Ok(value) = std::env::var("AZURE_STORAGE_ALLOW_HTTP") {
724            options.insert("allow_http".into(), value);
725        }
726        if let Ok(value) = std::env::var("AZURE_STORAGE_USE_HTTP") {
727            options.insert("allow_http".into(), value);
728        }
729        if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") {
730            options.insert("allow_http".into(), value);
731        }
732        Self(options)
733    }
734
735    /// Add values from the environment to storage options
736    pub fn with_env_azure(&mut self) {
737        for (os_key, os_value) in std::env::vars_os() {
738            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
739                if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) {
740                    if !self.0.contains_key(config_key.as_ref()) {
741                        self.0
742                            .insert(config_key.as_ref().to_string(), value.to_string());
743                    }
744                }
745            }
746        }
747    }
748
749    /// Add values from the environment to storage options
750    pub fn with_env_gcs(&mut self) {
751        for (os_key, os_value) in std::env::vars_os() {
752            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
753                if let Ok(config_key) = GoogleConfigKey::from_str(&key.to_ascii_lowercase()) {
754                    if !self.0.contains_key(config_key.as_ref()) {
755                        self.0
756                            .insert(config_key.as_ref().to_string(), value.to_string());
757                    }
758                }
759            }
760        }
761    }
762
763    /// Add values from the environment to storage options
764    pub fn with_env_s3(&mut self) {
765        for (os_key, os_value) in std::env::vars_os() {
766            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
767                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
768                    if !self.0.contains_key(config_key.as_ref()) {
769                        self.0
770                            .insert(config_key.as_ref().to_string(), value.to_string());
771                    }
772                }
773            }
774        }
775    }
776
777    /// Denotes if unsecure connections via http are allowed
778    pub fn allow_http(&self) -> bool {
779        self.0.iter().any(|(key, value)| {
780            key.to_ascii_lowercase().contains("allow_http") & str_is_truthy(value)
781        })
782    }
783
784    /// Number of times to retry a download that fails
785    pub fn download_retry_count(&self) -> usize {
786        self.0
787            .iter()
788            .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count"))
789            .map(|(_, value)| value.parse::<usize>().unwrap_or(3))
790            .unwrap_or(3)
791    }
792
793    /// Max retry times to set in RetryConfig for s3 client
794    pub fn client_max_retries(&self) -> usize {
795        self.0
796            .iter()
797            .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries"))
798            .and_then(|(_, value)| value.parse::<usize>().ok())
799            .unwrap_or(10)
800    }
801
802    /// Seconds of timeout to set in RetryConfig for s3 client
803    pub fn client_retry_timeout(&self) -> u64 {
804        self.0
805            .iter()
806            .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout"))
807            .and_then(|(_, value)| value.parse::<u64>().ok())
808            .unwrap_or(180)
809    }
810
811    /// Subset of options relevant for azure storage
812    pub fn as_azure_options(&self) -> HashMap<AzureConfigKey, String> {
813        self.0
814            .iter()
815            .filter_map(|(key, value)| {
816                let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
817                Some((az_key, value.clone()))
818            })
819            .collect()
820    }
821
822    /// Subset of options relevant for s3 storage
823    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
824        self.0
825            .iter()
826            .filter_map(|(key, value)| {
827                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
828                Some((s3_key, value.clone()))
829            })
830            .collect()
831    }
832
833    /// Subset of options relevant for gcs storage
834    pub fn as_gcs_options(&self) -> HashMap<GoogleConfigKey, String> {
835        self.0
836            .iter()
837            .filter_map(|(key, value)| {
838                let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
839                Some((gcs_key, value.clone()))
840            })
841            .collect()
842    }
843
844    pub fn get(&self, key: &str) -> Option<&String> {
845        self.0.get(key)
846    }
847}
848
849impl From<HashMap<String, String>> for StorageOptions {
850    fn from(value: HashMap<String, String>) -> Self {
851        Self::new(value)
852    }
853}
854
855async fn configure_store(
856    registry: Arc<ObjectStoreRegistry>,
857    url: &str,
858    options: ObjectStoreParams,
859) -> Result<ObjectStore> {
860    let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default());
861    let download_retry_count = storage_options.download_retry_count();
862    let mut url = ensure_table_uri(url)?;
863    // Block size: On local file systems, we use 4KB block size. On cloud
864    // object stores, we use 64KB block size. This is generally the largest
865    // block size where we don't see a latency penalty.
866    let file_block_size = options.block_size.unwrap_or(4 * 1024);
867    let cloud_block_size = options.block_size.unwrap_or(64 * 1024);
868    match url.scheme() {
869        "s3" | "s3+ddb" => {
870            storage_options.with_env_s3();
871
872            // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() {
873            //     return Err(Error::InvalidInput {
874            //         source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive"
875            //             .into(),
876            //         location: location!(),
877            //     });
878            // }
879
880            let max_retries = storage_options.client_max_retries();
881            let retry_timeout = storage_options.client_retry_timeout();
882            let retry_config = RetryConfig {
883                backoff: Default::default(),
884                max_retries,
885                retry_timeout: Duration::from_secs(retry_timeout),
886            };
887            let mut storage_options = storage_options.as_s3_options();
888            let region = resolve_s3_region(&url, &storage_options).await?;
889            let (aws_creds, region) = build_aws_credential(
890                options.s3_credentials_refresh_offset,
891                options.aws_credentials.clone(),
892                Some(&storage_options),
893                region,
894            )
895            .await?;
896
897            // This will be default in next version of object store.
898            // https://github.com/apache/arrow-rs/pull/7181
899            storage_options
900                .entry(AmazonS3ConfigKey::ConditionalPut)
901                .or_insert_with(|| "etag".to_string());
902
903            // Cloudflare does not support varying part sizes.
904            let use_constant_size_upload_parts = storage_options
905                .get(&AmazonS3ConfigKey::Endpoint)
906                .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
907                .unwrap_or(false);
908
909            // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
910            url.set_scheme("s3").map_err(|()| Error::Internal {
911                message: "could not set scheme".into(),
912                location: location!(),
913            })?;
914
915            url.set_query(None);
916
917            // we can't use parse_url_opts here because we need to manually set the credentials provider
918            let mut builder = AmazonS3Builder::new();
919            for (key, value) in storage_options {
920                builder = builder.with_config(key, value);
921            }
922            builder = builder
923                .with_url(url.as_ref())
924                .with_credentials(aws_creds)
925                .with_retry(retry_config)
926                .with_region(region);
927            let store = builder.build()?;
928
929            Ok(ObjectStore {
930                inner: Arc::new(store).traced(),
931                scheme: String::from(url.scheme()),
932                block_size: cloud_block_size,
933                use_constant_size_upload_parts,
934                list_is_lexically_ordered: true,
935                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
936                download_retry_count,
937            })
938        }
939        "gs" => {
940            storage_options.with_env_gcs();
941            let mut builder = GoogleCloudStorageBuilder::new().with_url(url.as_ref());
942            for (key, value) in storage_options.as_gcs_options() {
943                builder = builder.with_config(key, value);
944            }
945            let token_key = "google_storage_token";
946            if let Some(storage_token) = storage_options.get(token_key) {
947                let credential = GcpCredential {
948                    bearer: storage_token.to_string(),
949                };
950                let credential_provider = Arc::new(StaticCredentialProvider::new(credential)) as _;
951                builder = builder.with_credentials(credential_provider);
952            }
953            let store = builder.build()?;
954            let store = Arc::new(store).traced();
955
956            Ok(ObjectStore {
957                inner: store,
958                scheme: String::from("gs"),
959                block_size: cloud_block_size,
960                use_constant_size_upload_parts: false,
961                list_is_lexically_ordered: true,
962                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
963                download_retry_count,
964            })
965        }
966        "az" => {
967            storage_options.with_env_azure();
968            let (store, _) = parse_url_opts(&url, storage_options.as_azure_options())?;
969            let store = Arc::new(store).traced();
970
971            Ok(ObjectStore {
972                inner: store,
973                scheme: String::from("az"),
974                block_size: cloud_block_size,
975                use_constant_size_upload_parts: false,
976                list_is_lexically_ordered: true,
977                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
978                download_retry_count,
979            })
980        }
981        // we have a bypass logic to use `tokio::fs` directly to lower overhead
982        // however this makes testing harder as we can't use the same code path
983        // "file-object-store" forces local file system dataset to use the same
984        // code path as cloud object stores
985        "file" => {
986            let mut object_store = ObjectStore::from_path(url.path())?.0;
987            object_store.set_block_size(file_block_size);
988            Ok(object_store)
989        }
990        "file-object-store" => {
991            let mut object_store =
992                ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0;
993            object_store.set_block_size(file_block_size);
994            Ok(object_store)
995        }
996        "memory" => Ok(ObjectStore {
997            inner: Arc::new(InMemory::new()).traced(),
998            scheme: String::from("memory"),
999            block_size: file_block_size,
1000            use_constant_size_upload_parts: false,
1001            list_is_lexically_ordered: true,
1002            io_parallelism: get_num_compute_intensive_cpus(),
1003            download_retry_count,
1004        }),
1005        unknown_scheme => {
1006            if let Some(provider) = registry.providers.get(unknown_scheme) {
1007                provider.new_store(url, &options)
1008            } else {
1009                let err = lance_core::Error::from(object_store::Error::NotSupported {
1010                    source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url)
1011                        .into(),
1012                });
1013                Err(err)
1014            }
1015        }
1016    }
1017}
1018
1019impl ObjectStore {
1020    #[allow(clippy::too_many_arguments)]
1021    pub fn new(
1022        store: Arc<DynObjectStore>,
1023        location: Url,
1024        block_size: Option<usize>,
1025        wrapper: Option<Arc<dyn WrappingObjectStore>>,
1026        use_constant_size_upload_parts: bool,
1027        list_is_lexically_ordered: bool,
1028        io_parallelism: usize,
1029        download_retry_count: usize,
1030    ) -> Self {
1031        let scheme = location.scheme();
1032        let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme));
1033
1034        let store = match wrapper {
1035            Some(wrapper) => wrapper.wrap(store),
1036            None => store,
1037        };
1038
1039        Self {
1040            inner: store,
1041            scheme: scheme.into(),
1042            block_size,
1043            use_constant_size_upload_parts,
1044            list_is_lexically_ordered,
1045            io_parallelism,
1046            download_retry_count,
1047        }
1048    }
1049}
1050
1051fn infer_block_size(scheme: &str) -> usize {
1052    // Block size: On local file systems, we use 4KB block size. On cloud
1053    // object stores, we use 64KB block size. This is generally the largest
1054    // block size where we don't see a latency penalty.
1055    match scheme {
1056        "file" => 4 * 1024,
1057        _ => 64 * 1024,
1058    }
1059}
1060
1061/// Attempt to create a Url from given table location.
1062///
1063/// The location could be:
1064///  * A valid URL, which will be parsed and returned
1065///  * A path to a directory, which will be created and then converted to a URL.
1066///
1067/// If it is a local path, it will be created if it doesn't exist.
1068///
1069/// Extra slashes will be removed from the end path as well.
1070///
1071/// Will return an error if the location is not valid. For example,
1072pub fn ensure_table_uri(table_uri: impl AsRef<str>) -> Result<Url> {
1073    let table_uri = table_uri.as_ref();
1074
1075    enum UriType {
1076        LocalPath(PathBuf),
1077        Url(Url),
1078    }
1079    let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) {
1080        if url.scheme() == "file" {
1081            UriType::LocalPath(url.to_file_path().map_err(|err| {
1082                let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err);
1083                Error::InvalidTableLocation { message: msg }
1084            })?)
1085        // NOTE this check is required to support absolute windows paths which may properly parse as url
1086        } else {
1087            UriType::Url(url)
1088        }
1089    } else {
1090        UriType::LocalPath(PathBuf::from(table_uri))
1091    };
1092
1093    // If it is a local path, we need to create it if it does not exist.
1094    let mut url = match uri_type {
1095        UriType::LocalPath(path) => {
1096            let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound {
1097                path: table_uri.to_string(),
1098                source: Box::new(err),
1099                location: location!(),
1100            })?;
1101            Url::from_directory_path(path).map_err(|_| {
1102                let msg = format!(
1103                    "Could not construct a URL from canonicalized path: {}.\n\
1104                  Something must be very wrong with the table path.",
1105                    table_uri
1106                );
1107                Error::InvalidTableLocation { message: msg }
1108            })?
1109        }
1110        UriType::Url(url) => url,
1111    };
1112
1113    let trimmed_path = url.path().trim_end_matches('/').to_owned();
1114    url.set_path(&trimmed_path);
1115    Ok(url)
1116}
1117
1118lazy_static::lazy_static! {
1119  static ref KNOWN_SCHEMES: Vec<&'static str> =
1120      Vec::from([
1121        "s3",
1122        "s3+ddb",
1123        "gs",
1124        "az",
1125        "file",
1126        "file-object-store",
1127        "memory"
1128      ]);
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134    use parquet::data_type::AsBytes;
1135    use rstest::rstest;
1136    use std::env::set_current_dir;
1137    use std::fs::{create_dir_all, write};
1138    use std::path::Path as StdPath;
1139    use std::sync::atomic::{AtomicBool, Ordering};
1140
1141    /// Write test content to file.
1142    fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
1143        let expanded = tilde(path_str).to_string();
1144        let path = StdPath::new(&expanded);
1145        std::fs::create_dir_all(path.parent().unwrap())?;
1146        write(path, contents)
1147    }
1148
1149    async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
1150        let test_file_store = store.open(path).await.unwrap();
1151        let size = test_file_store.size().await.unwrap();
1152        let bytes = test_file_store.get_range(0..size).await.unwrap();
1153        let contents = String::from_utf8(bytes.to_vec()).unwrap();
1154        Ok(contents)
1155    }
1156
1157    #[tokio::test]
1158    async fn test_absolute_paths() {
1159        let tmp_dir = tempfile::tempdir().unwrap();
1160        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1161        write_to_file(
1162            &format!("{tmp_path}/bar/foo.lance/test_file"),
1163            "TEST_CONTENT",
1164        )
1165        .unwrap();
1166
1167        // test a few variations of the same path
1168        for uri in &[
1169            format!("{tmp_path}/bar/foo.lance"),
1170            format!("{tmp_path}/./bar/foo.lance"),
1171            format!("{tmp_path}/bar/foo.lance/../foo.lance"),
1172        ] {
1173            let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1174            let contents = read_from_store(store, &path.child("test_file"))
1175                .await
1176                .unwrap();
1177            assert_eq!(contents, "TEST_CONTENT");
1178        }
1179    }
1180
1181    #[tokio::test]
1182    async fn test_cloud_paths() {
1183        let uri = "s3://bucket/foo.lance";
1184        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1185        assert_eq!(store.scheme, "s3");
1186        assert_eq!(path.to_string(), "foo.lance");
1187
1188        let (store, path) = ObjectStore::from_uri("s3+ddb://bucket/foo.lance")
1189            .await
1190            .unwrap();
1191        assert_eq!(store.scheme, "s3");
1192        assert_eq!(path.to_string(), "foo.lance");
1193
1194        let (store, path) = ObjectStore::from_uri("gs://bucket/foo.lance")
1195            .await
1196            .unwrap();
1197        assert_eq!(store.scheme, "gs");
1198        assert_eq!(path.to_string(), "foo.lance");
1199    }
1200
1201    async fn test_block_size_used_test_helper(
1202        uri: &str,
1203        storage_options: Option<HashMap<String, String>>,
1204        default_expected_block_size: usize,
1205    ) {
1206        // Test the default
1207        let registry = Arc::new(ObjectStoreRegistry::default());
1208        let params = ObjectStoreParams {
1209            storage_options: storage_options.clone(),
1210            ..ObjectStoreParams::default()
1211        };
1212        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1213            .await
1214            .unwrap();
1215        assert_eq!(store.block_size, default_expected_block_size);
1216
1217        // Ensure param is used
1218        let registry = Arc::new(ObjectStoreRegistry::default());
1219        let params = ObjectStoreParams {
1220            block_size: Some(1024),
1221            storage_options: storage_options.clone(),
1222            ..ObjectStoreParams::default()
1223        };
1224        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1225            .await
1226            .unwrap();
1227        assert_eq!(store.block_size, 1024);
1228    }
1229
1230    #[rstest]
1231    #[case("s3://bucket/foo.lance", None)]
1232    #[case("gs://bucket/foo.lance", None)]
1233    #[case("az://account/bucket/foo.lance",
1234      Some(HashMap::from([
1235            (String::from("account_name"), String::from("account")),
1236            (String::from("container_name"), String::from("container"))
1237           ])))]
1238    #[tokio::test]
1239    async fn test_block_size_used_cloud(
1240        #[case] uri: &str,
1241        #[case] storage_options: Option<HashMap<String, String>>,
1242    ) {
1243        test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await;
1244    }
1245
1246    #[rstest]
1247    #[case("file")]
1248    #[case("file-object-store")]
1249    #[case("memory:///bucket/foo.lance")]
1250    #[tokio::test]
1251    async fn test_block_size_used_file(#[case] prefix: &str) {
1252        let tmp_dir = tempfile::tempdir().unwrap();
1253        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1254        let path = format!("{tmp_path}/bar/foo.lance/test_file");
1255        write_to_file(&path, "URL").unwrap();
1256        let uri = format!("{prefix}:///{path}");
1257        test_block_size_used_test_helper(&uri, None, 4 * 1024).await;
1258    }
1259
1260    #[tokio::test]
1261    async fn test_relative_paths() {
1262        let tmp_dir = tempfile::tempdir().unwrap();
1263        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1264        write_to_file(
1265            &format!("{tmp_path}/bar/foo.lance/test_file"),
1266            "RELATIVE_URL",
1267        )
1268        .unwrap();
1269
1270        set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
1271        let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
1272
1273        let contents = read_from_store(store, &path.child("test_file"))
1274            .await
1275            .unwrap();
1276        assert_eq!(contents, "RELATIVE_URL");
1277    }
1278
1279    #[tokio::test]
1280    async fn test_tilde_expansion() {
1281        let uri = "~/foo.lance";
1282        write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
1283        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1284        let contents = read_from_store(store, &path.child("test_file"))
1285            .await
1286            .unwrap();
1287        assert_eq!(contents, "TILDE");
1288    }
1289
1290    #[tokio::test]
1291    async fn test_read_directory() {
1292        let tmp_dir = tempfile::tempdir().unwrap();
1293        let path = tmp_dir.path();
1294        create_dir_all(path.join("foo").join("bar")).unwrap();
1295        create_dir_all(path.join("foo").join("zoo")).unwrap();
1296        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1297        write_to_file(
1298            path.join("foo").join("test_file").to_str().unwrap(),
1299            "read_dir",
1300        )
1301        .unwrap();
1302        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1303
1304        let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
1305        assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
1306    }
1307
1308    #[tokio::test]
1309    async fn test_delete_directory() {
1310        let tmp_dir = tempfile::tempdir().unwrap();
1311        let path = tmp_dir.path();
1312        create_dir_all(path.join("foo").join("bar")).unwrap();
1313        create_dir_all(path.join("foo").join("zoo")).unwrap();
1314        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1315        write_to_file(
1316            path.join("foo")
1317                .join("bar")
1318                .join("test_file")
1319                .to_str()
1320                .unwrap(),
1321            "delete",
1322        )
1323        .unwrap();
1324        write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
1325        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1326        store.remove_dir_all(base.child("foo")).await.unwrap();
1327
1328        assert!(!path.join("foo").exists());
1329    }
1330
1331    #[derive(Debug)]
1332    struct TestWrapper {
1333        called: AtomicBool,
1334
1335        return_value: Arc<dyn OSObjectStore>,
1336    }
1337
1338    impl WrappingObjectStore for TestWrapper {
1339        fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
1340            self.called.store(true, Ordering::Relaxed);
1341
1342            // return a mocked value so we can check if the final store is the one we expect
1343            self.return_value.clone()
1344        }
1345    }
1346
1347    impl TestWrapper {
1348        fn called(&self) -> bool {
1349            self.called.load(Ordering::Relaxed)
1350        }
1351    }
1352
1353    #[tokio::test]
1354    async fn test_wrapping_object_store_option_is_used() {
1355        // Make a store for the inner store first
1356        let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
1357        let registry = Arc::new(ObjectStoreRegistry::default());
1358
1359        assert_eq!(Arc::strong_count(&mock_inner_store), 1);
1360
1361        let wrapper = Arc::new(TestWrapper {
1362            called: AtomicBool::new(false),
1363            return_value: mock_inner_store.clone(),
1364        });
1365
1366        let params = ObjectStoreParams {
1367            object_store_wrapper: Some(wrapper.clone()),
1368            ..ObjectStoreParams::default()
1369        };
1370
1371        // not called yet
1372        assert!(!wrapper.called());
1373
1374        let _ = ObjectStore::from_uri_and_params(registry, "memory:///", &params)
1375            .await
1376            .unwrap();
1377
1378        // called after construction
1379        assert!(wrapper.called());
1380
1381        // hard to compare two trait pointers as the point to vtables
1382        // using the ref count as a proxy to make sure that the store is correctly kept
1383        assert_eq!(Arc::strong_count(&mock_inner_store), 2);
1384    }
1385
1386    #[derive(Debug, Default)]
1387    struct MockAwsCredentialsProvider {
1388        called: AtomicBool,
1389    }
1390
1391    #[async_trait]
1392    impl CredentialProvider for MockAwsCredentialsProvider {
1393        type Credential = ObjectStoreAwsCredential;
1394
1395        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
1396            self.called.store(true, Ordering::Relaxed);
1397            Ok(Arc::new(Self::Credential {
1398                key_id: "".to_string(),
1399                secret_key: "".to_string(),
1400                token: None,
1401            }))
1402        }
1403    }
1404
1405    #[tokio::test]
1406    async fn test_injected_aws_creds_option_is_used() {
1407        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
1408        let registry = Arc::new(ObjectStoreRegistry::default());
1409
1410        let params = ObjectStoreParams {
1411            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
1412            ..ObjectStoreParams::default()
1413        };
1414
1415        // Not called yet
1416        assert!(!mock_provider.called.load(Ordering::Relaxed));
1417
1418        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
1419            .await
1420            .unwrap();
1421
1422        // fails, but we don't care
1423        let _ = store
1424            .open(&Path::parse("/").unwrap())
1425            .await
1426            .unwrap()
1427            .get_range(0..1)
1428            .await;
1429
1430        // Not called yet
1431        assert!(mock_provider.called.load(Ordering::Relaxed));
1432    }
1433
1434    #[tokio::test]
1435    async fn test_local_paths() {
1436        let temp_dir = tempfile::tempdir().unwrap();
1437
1438        let file_path = temp_dir.path().join("test_file");
1439        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1440            .await
1441            .unwrap();
1442        writer.write_all(b"LOCAL").await.unwrap();
1443        writer.shutdown().await.unwrap();
1444
1445        let reader = ObjectStore::open_local(file_path.as_path()).await.unwrap();
1446        let buf = reader.get_range(0..5).await.unwrap();
1447        assert_eq!(buf.as_bytes(), b"LOCAL");
1448    }
1449
1450    #[tokio::test]
1451    async fn test_read_one() {
1452        let temp_dir = tempfile::tempdir().unwrap();
1453
1454        let file_path = temp_dir.path().join("test_file");
1455        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1456            .await
1457            .unwrap();
1458        writer.write_all(b"LOCAL").await.unwrap();
1459        writer.shutdown().await.unwrap();
1460
1461        let file_path_os = object_store::path::Path::parse(file_path.to_str().unwrap()).unwrap();
1462        let obj_store = ObjectStore::local();
1463        let buf = obj_store.read_one_all(&file_path_os).await.unwrap();
1464        assert_eq!(buf.as_bytes(), b"LOCAL");
1465
1466        let buf = obj_store.read_one_range(&file_path_os, 0..5).await.unwrap();
1467        assert_eq!(buf.as_bytes(), b"LOCAL");
1468    }
1469
1470    #[tokio::test]
1471    #[cfg(windows)]
1472    async fn test_windows_paths() {
1473        use std::path::Component;
1474        use std::path::Prefix;
1475        use std::path::Prefix::*;
1476
1477        fn get_path_prefix(path: &StdPath) -> Prefix {
1478            match path.components().next().unwrap() {
1479                Component::Prefix(prefix_component) => prefix_component.kind(),
1480                _ => panic!(),
1481            }
1482        }
1483
1484        fn get_drive_letter(prefix: Prefix) -> String {
1485            match prefix {
1486                Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
1487                _ => panic!(),
1488            }
1489        }
1490
1491        let tmp_dir = tempfile::tempdir().unwrap();
1492        let tmp_path = tmp_dir.path();
1493        let prefix = get_path_prefix(tmp_path);
1494        let drive_letter = get_drive_letter(prefix);
1495
1496        write_to_file(
1497            &(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
1498            "WINDOWS",
1499        )
1500        .unwrap();
1501
1502        for uri in &[
1503            format!("{drive_letter}:/test_folder/test.lance"),
1504            format!("{drive_letter}:\\test_folder\\test.lance"),
1505        ] {
1506            let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
1507            let contents = read_from_store(store, &base.child("test_file"))
1508                .await
1509                .unwrap();
1510            assert_eq!(contents, "WINDOWS");
1511        }
1512    }
1513}