use super::Progress as SyncProgress;
use super::{RepoInfo, HF_ENDPOINT};
use crate::{Cache, Repo, RepoType};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indicatif::ProgressBar;
use rand::Rng;
use reqwest::{
header::{
HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
CONTENT_RANGE, LOCATION, RANGE, USER_AGENT,
},
redirect::Policy,
Client, Error as ReqwestError, RequestBuilder,
};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tokio::sync::{AcquireError, Semaphore, TryAcquireError};
use tokio::task::JoinError;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const NAME: &str = env!("CARGO_PKG_NAME");
const EXTENSION: &str = "sync.part";
pub trait Progress {
fn init(&mut self, size: usize, filename: &str)
-> impl std::future::Future<Output = ()> + Send;
fn update(&mut self, size: usize) -> impl std::future::Future<Output = ()> + Send;
fn finish(&mut self) -> impl std::future::Future<Output = ()> + Send;
}
impl Progress for ProgressBar {
async fn init(&mut self, size: usize, filename: &str) {
<ProgressBar as SyncProgress>::init(self, size, filename);
}
async fn finish(&mut self) {
<ProgressBar as SyncProgress>::finish(self);
}
async fn update(&mut self, size: usize) {
<ProgressBar as SyncProgress>::update(self, size);
}
}
impl Progress for () {
async fn init(&mut self, _size: usize, _filename: &str) {}
async fn finish(&mut self) {}
async fn update(&mut self, _size: usize) {}
}
struct Handle {
_file: tokio::fs::File,
path: PathBuf,
}
impl Drop for Handle {
fn drop(&mut self) {
std::fs::remove_file(&self.path).expect("Removing lockfile")
}
}
async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");
let mut lock_handle = None;
for i in 0..30 {
match tokio::fs::File::create_new(path.clone()).await {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
Err(_err) => {
if i == 0 {
log::warn!("Waiting for lock {path:?}");
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Header {0} is missing")]
MissingHeader(HeaderName),
#[error("Header {0} is invalid")]
InvalidHeader(HeaderName),
#[error("Invalid header value {0}")]
InvalidHeaderValue(#[from] InvalidHeaderValue),
#[error("header value is not a string")]
ToStr(#[from] ToStrError),
#[error("request error: {0}")]
RequestError(#[from] ReqwestError),
#[error("Cannot parse int")]
ParseIntError(#[from] ParseIntError),
#[error("I/O error {0}")]
IoError(#[from] std::io::Error),
#[error("Too many retries: {0}")]
TooManyRetries(Box<ApiError>),
#[error("Try acquire: {0}")]
TryAcquireError(#[from] TryAcquireError),
#[error("Acquire: {0}")]
AcquireError(#[from] AcquireError),
#[error("Join: {0}")]
Join(#[from] JoinError),
#[error("Lock acquisition failed: {0}")]
LockAcquisition(PathBuf),
}
#[derive(Debug)]
pub struct ApiBuilder {
endpoint: String,
cache: Cache,
token: Option<String>,
max_files: usize,
chunk_size: Option<usize>,
parallel_failures: usize,
max_retries: usize,
progress: bool,
}
impl Default for ApiBuilder {
fn default() -> Self {
Self::new()
}
}
impl ApiBuilder {
pub fn new() -> Self {
let cache = Cache::default();
Self::from_cache(cache)
}
pub fn from_env() -> Self {
let cache = Cache::from_env();
let mut builder = Self::from_cache(cache);
if let Ok(endpoint) = std::env::var(HF_ENDPOINT) {
builder = builder.with_endpoint(endpoint);
}
builder
}
pub fn high(self) -> Self {
self.with_max_files(num_cpus::get())
.with_chunk_size(Some(10_000_000))
}
pub fn from_cache(cache: Cache) -> Self {
let token = cache.token();
let progress = true;
Self {
endpoint: "https://huggingface.co".to_string(),
cache,
token,
max_files: 1,
chunk_size: Some(10_000_000),
parallel_failures: 0,
max_retries: 0,
progress,
}
}
pub fn with_progress(mut self, progress: bool) -> Self {
self.progress = progress;
self
}
pub fn with_endpoint(mut self, endpoint: String) -> Self {
self.endpoint = endpoint;
self
}
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache = Cache::new(cache_dir);
self
}
pub fn with_token(mut self, token: Option<String>) -> Self {
self.token = token;
self
}
pub fn with_max_files(mut self, max_files: usize) -> Self {
self.max_files = max_files;
self
}
pub fn with_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.chunk_size = chunk_size;
self
}
fn build_headers(&self) -> Result<HeaderMap, ApiError> {
let mut headers = HeaderMap::new();
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
if let Some(token) = &self.token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}"))?,
);
}
Ok(headers)
}
pub fn build(self) -> Result<Api, ApiError> {
let headers = self.build_headers()?;
let client = Client::builder().default_headers(headers.clone()).build()?;
let relative_redirect_policy = Policy::custom(|attempt| {
if attempt.previous().len() > 10 {
return attempt.error("too many redirects");
}
if let Some(last) = attempt.previous().last() {
if last.make_relative(attempt.url()).is_none() {
return attempt.stop();
}
}
attempt.follow()
});
let relative_redirect_client = Client::builder()
.redirect(relative_redirect_policy)
.default_headers(headers)
.build()?;
Ok(Api {
endpoint: self.endpoint,
cache: self.cache,
client,
relative_redirect_client,
max_files: self.max_files,
chunk_size: self.chunk_size,
parallel_failures: self.parallel_failures,
max_retries: self.max_retries,
progress: self.progress,
})
}
}
#[derive(Debug)]
struct Metadata {
commit_hash: String,
etag: String,
size: usize,
}
#[derive(Clone, Debug)]
pub struct Api {
endpoint: String,
cache: Cache,
client: Client,
relative_redirect_client: Client,
max_files: usize,
chunk_size: Option<usize>,
parallel_failures: usize,
max_retries: usize,
progress: bool,
}
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
let path = src;
let base = dst;
assert_eq!(
path.is_absolute(),
base.is_absolute(),
"This function is made to look at absolute paths only"
);
let mut ita = path.components();
let mut itb = base.components();
loop {
match (ita.next(), itb.next()) {
(Some(a), Some(b)) if a == b => (),
(some_a, _) => {
let mut new_path = PathBuf::new();
for _ in itb {
new_path.push(Component::ParentDir);
}
if let Some(a) = some_a {
new_path.push(a);
for comp in ita {
new_path.push(comp);
}
}
return new_path;
}
}
}
}
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
if dst.exists() {
return Ok(());
}
let rel_src = make_relative(src, dst);
#[cfg(target_os = "windows")]
{
if std::os::windows::fs::symlink_file(rel_src, dst).is_err() {
std::fs::rename(src, dst)?;
}
}
#[cfg(target_family = "unix")]
std::os::unix::fs::symlink(rel_src, dst)?;
Ok(())
}
fn jitter() -> usize {
rand::thread_rng().gen_range(0..=500)
}
fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
(base_wait_time + n.pow(2) + jitter()).min(max)
}
impl Api {
pub fn new() -> Result<Self, ApiError> {
ApiBuilder::new().build()
}
pub fn client(&self) -> &Client {
&self.client
}
async fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
let response = self
.relative_redirect_client
.get(url)
.header(RANGE, "bytes=0-0")
.send()
.await?;
let response = response.error_for_status()?;
let headers = response.headers();
let header_commit = HeaderName::from_static("x-repo-commit");
let header_linked_etag = HeaderName::from_static("x-linked-etag");
let header_etag = HeaderName::from_static("etag");
let etag = match headers.get(&header_linked_etag) {
Some(etag) => etag,
None => headers
.get(&header_etag)
.ok_or(ApiError::MissingHeader(header_etag))?,
};
let etag = etag.to_str()?.to_string().replace('"', "");
let commit_hash = headers
.get(&header_commit)
.ok_or(ApiError::MissingHeader(header_commit))?
.to_str()?
.to_string();
let response = if response.status().is_redirection() {
self.client
.get(headers.get(LOCATION).unwrap().to_str()?.to_string())
.header(RANGE, "bytes=0-0")
.send()
.await?
} else {
response
};
let headers = response.headers();
let content_range = headers
.get(CONTENT_RANGE)
.ok_or(ApiError::MissingHeader(CONTENT_RANGE))?
.to_str()?;
let size = content_range
.split('/')
.last()
.ok_or(ApiError::InvalidHeader(CONTENT_RANGE))?
.parse()?;
Ok(Metadata {
commit_hash,
etag,
size,
})
}
pub fn repo(&self, repo: Repo) -> ApiRepo {
ApiRepo::new(self.clone(), repo)
}
pub fn model(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Model))
}
pub fn dataset(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Dataset))
}
pub fn space(&self, model_id: String) -> ApiRepo {
self.repo(Repo::new(model_id, RepoType::Space))
}
}
#[derive(Debug)]
pub struct ApiRepo {
api: Api,
repo: Repo,
}
impl ApiRepo {
fn new(api: Api, repo: Repo) -> Self {
Self { api, repo }
}
}
impl ApiRepo {
pub fn url(&self, filename: &str) -> String {
let endpoint = &self.api.endpoint;
let revision = &self.repo.url_revision();
let repo_id = self.repo.url();
format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}")
}
async fn download_tempfile<'a, P: Progress + Clone + Send + Sync + 'static>(
&self,
url: &str,
length: usize,
filename: PathBuf,
mut progressbar: P,
) -> Result<PathBuf, ApiError> {
let semaphore = Arc::new(Semaphore::new(self.api.max_files));
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures));
const N_BYTES: usize = size_of::<u64>();
let start = match tokio::fs::OpenOptions::new()
.read(true)
.open(&filename)
.await
{
Ok(mut f) => {
let len = f.metadata().await?.len();
if len == (length + N_BYTES) as u64 {
f.seek(SeekFrom::Start(length as u64)).await?;
let mut buf = [0u8; N_BYTES];
let n = f.read(buf.as_mut_slice()).await?;
if n == N_BYTES {
let committed = u64::from_le_bytes(buf);
committed as usize
} else {
0
}
} else {
0
}
}
Err(_err) => {
tokio::fs::File::create(&filename)
.await?
.set_len((length + N_BYTES) as u64)
.await?;
0
}
};
progressbar.update(start).await;
let chunk_size = self.api.chunk_size.unwrap_or(length);
let n_chunks = length / chunk_size;
let mut handles = Vec::with_capacity(n_chunks);
for start in (start..length).step_by(chunk_size) {
let url = url.to_string();
let filename = filename.clone();
let client = self.api.client.clone();
let stop = std::cmp::min(start + chunk_size - 1, length);
let permit = semaphore.clone();
let parallel_failures = self.api.parallel_failures;
let max_retries = self.api.max_retries;
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
let progress = progressbar.clone();
handles.push(tokio::spawn(async move {
let permit = permit.acquire_owned().await?;
let mut chunk =
Self::download_chunk(&client, &url, &filename, start, stop, progress.clone())
.await;
let mut i = 0;
if parallel_failures > 0 {
while let Err(dlerr) = chunk {
let parallel_failure_permit =
parallel_failures_semaphore.clone().try_acquire_owned()?;
let wait_time = exponential_backoff(300, i, 10_000);
tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64))
.await;
chunk = Self::download_chunk(
&client,
&url,
&filename,
start,
stop,
progress.clone(),
)
.await;
i += 1;
if i > max_retries {
return Err(ApiError::TooManyRetries(dlerr.into()));
}
drop(parallel_failure_permit);
}
}
drop(permit);
chunk
}));
}
let mut futures: FuturesUnordered<_> = handles.into_iter().collect();
let mut temporaries = BinaryHeap::new();
let mut committed: u64 = start as u64;
while let Some(chunk) = futures.next().await {
let chunk = chunk?;
let (start, stop) = chunk?;
temporaries.push(Reverse((start, stop)));
let mut modified = false;
while let Some(Reverse((min, max))) = temporaries.pop() {
if min as u64 == committed {
committed = max as u64 + 1;
modified = true;
} else {
temporaries.push(Reverse((min, max)));
break;
}
}
if modified {
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(&filename)
.await?;
file.seek(SeekFrom::Start(length as u64)).await?;
file.write_all(&committed.to_le_bytes()).await?;
}
}
tokio::fs::OpenOptions::new()
.write(true)
.open(&filename)
.await?
.set_len(length as u64)
.await?;
progressbar.finish().await;
Ok(filename)
}
async fn download_chunk<P>(
client: &reqwest::Client,
url: &str,
filename: &PathBuf,
start: usize,
stop: usize,
mut progress: P,
) -> Result<(usize, usize), ApiError>
where
P: Progress,
{
let range = format!("bytes={start}-{stop}");
let response = client
.get(url)
.header(RANGE, range)
.send()
.await?
.error_for_status()?;
let mut byte_stream = response.bytes_stream();
let mut buf: Vec<u8> = Vec::with_capacity(stop - start);
while let Some(next) = byte_stream.next().await {
let next = next?;
buf.extend(&next);
progress.update(next.len()).await;
}
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(filename)
.await?;
file.seek(SeekFrom::Start(start as u64)).await?;
file.write_all(&buf).await?;
Ok((start, stop))
}
pub async fn get(&self, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) {
Ok(path)
} else {
self.download(filename).await
}
}
pub async fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
if self.api.progress {
self.download_with_progress(filename, ProgressBar::new(0))
.await
} else {
self.download_with_progress(filename, ()).await
}
}
pub async fn download_with_progress<P: Progress + Clone + Send + Sync + 'static>(
&self,
filename: &str,
mut progress: P,
) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url).await?;
let cache = self.api.cache.repo(self.repo.clone());
let blob_path = cache.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;
let lock = lock_file(blob_path.clone()).await;
progress.init(metadata.size, filename).await;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename = self
.download_tempfile(&url, metadata.size, tmp_path, progress)
.await?;
tokio::fs::rename(&tmp_filename, &blob_path).await?;
drop(lock);
let mut pointer_path = cache.pointer_path(&metadata.commit_hash);
pointer_path.push(filename);
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
symlink_or_rename(&blob_path, &pointer_path)?;
cache.create_ref(&metadata.commit_hash)?;
Ok(pointer_path)
}
pub async fn info(&self) -> Result<RepoInfo, ApiError> {
Ok(self.info_request().send().await?.json().await?)
}
pub fn info_request(&self) -> RequestBuilder {
let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url());
self.api.client.get(url)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::Siblings;
use crate::assert_no_diff;
use hex_literal::hex;
use rand::distributions::Alphanumeric;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, Write};
struct TempDir {
path: PathBuf,
}
impl TempDir {
pub fn new() -> Self {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
let mut path = std::env::temp_dir();
path.push(s);
std::fs::create_dir(&path).unwrap();
Self { path }
}
}
impl Drop for TempDir {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}
#[tokio::test]
async fn simple() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let model_id = "julien-c/dummy-unknown".to_string();
let repo = Repo::new(model_id.clone(), RepoType::Model);
let downloaded_path = api.model(model_id).download("config.json").await.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let cache_path = api.cache.repo(repo.clone()).get("config.json").unwrap();
assert_eq!(cache_path, downloaded_path);
}
#[tokio::test]
async fn locking() {
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
let tmp = Arc::new(Mutex::new(TempDir::new()));
let mut handles = JoinSet::new();
for _ in 0..5 {
let tmp2 = tmp.clone();
handles.spawn(async move {
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp2.lock().await.path.clone())
.build()
.unwrap();
let millis: u64 = rand::random::<u8>().into();
tokio::time::sleep(Duration::from_millis(millis)).await;
let model_id = "julien-c/dummy-unknown".to_string();
api.model(model_id.clone())
.download("config.json")
.await
.unwrap()
});
}
while let Some(handle) = handles.join_next().await {
let downloaded_path = handle.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
}
}
#[tokio::test]
async fn resume() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let total_size = size + size_of::<u64>() as u64;
file.set_len(total_size).unwrap();
file.seek(SeekFrom::Start(size)).unwrap();
file.write_all(&new_size.to_le_bytes()).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, total_size);
let val = Sha256::digest(content);
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = 0.5;
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let total_size = size + size_of::<u64>() as u64;
file.set_len(total_size).unwrap();
file.seek(SeekFrom::Start(size)).unwrap();
file.write_all(&new_size.to_le_bytes()).unwrap();
file.seek(SeekFrom::Start(new_size - 1)).unwrap();
file.write_all(&[0]).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, total_size);
let val = Sha256::digest(content);
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}
#[tokio::test]
async fn revision() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let model_id = "BAAI/bge-base-en".to_string();
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, "refs/pr/2".to_string());
let downloaded_path = api
.repo(repo.clone())
.download("tokenizer.json")
.await
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66")
);
let cache_path = api.cache.repo(repo).get("tokenizer.json").unwrap();
assert_eq!(cache_path, downloaded_path);
}
#[tokio::test]
async fn dataset() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let downloaded_path = api
.repo(repo)
.download("wikitext-103-v1/test/0000.parquet")
.await
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("ABDFC9F83B1103B502924072460D4C92F277C9B49C313CEF3E48CFCF7428E125")
);
}
#[tokio::test]
async fn models() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"BAAI/bGe-reRanker-Base".to_string(),
RepoType::Model,
"refs/pr/5".to_string(),
);
let downloaded_path = api.repo(repo).download("tokenizer.json").await.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E")
);
}
#[tokio::test]
async fn info() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let model_info = api.repo(repo).info().await.unwrap();
assert_eq!(
model_info,
RepoInfo {
siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/test/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/train/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/train/0001.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/validation/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/test/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/train/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/train/0001.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/validation/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/test/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/train/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/validation/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/test/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/train/0000.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/validation/0000.parquet".to_string()
}
],
sha: "3f68cd45302c7b4b532d933e71d9e6e54b1c7d5e".to_string()
}
);
}
#[tokio::test]
async fn info_request() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_token(None)
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"mcpotato/42-eicar-street".to_string(),
RepoType::Model,
"8b3861f6931c4026b0cd22b38dbc09e7668983ac".to_string(),
);
let blobs_info: Value = api
.repo(repo)
.info_request()
.query(&[("blobs", "true")])
.send()
.await
.unwrap()
.json()
.await
.unwrap();
assert_no_diff!(
blobs_info,
json!({
"_id": "621ffdc136468d709f17ddb4",
"author": "mcpotato",
"createdAt": "2022-03-02T23:29:05.000Z",
"disabled": false,
"downloads": 0,
"gated": false,
"id": "mcpotato/42-eicar-street",
"lastModified": "2022-11-30T19:54:16.000Z",
"likes": 1,
"modelId": "mcpotato/42-eicar-street",
"private": false,
"sha": "8b3861f6931c4026b0cd22b38dbc09e7668983ac",
"siblings": [
{
"blobId": "6d34772f5ca361021038b404fb913ec8dc0b1a5a",
"rfilename": ".gitattributes",
"size": 1175
},
{
"blobId": "be98037f7c542112c15a1d2fc7e2a2427e42cb50",
"rfilename": "build_pickles.py",
"size": 304
},
{
"blobId": "8acd02161fff53f9df9597e377e22b04bc34feff",
"rfilename": "danger.dat",
"size": 66
},
{
"blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e",
"rfilename": "eicar_test_file",
"size": 70
},
{
"blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e",
"rfilename": "eicar_test_file_bis",
"size":70
},
{
"blobId": "cd1c6d8bde5006076655711a49feae66f07d707e",
"lfs": {
"pointerSize": 127,
"sha256": "f9343d7d7ec5c3d8bcced056c438fc9f1d3819e9ca3d42418a40857050e10e20",
"size": 22
},
"rfilename": "pytorch_model.bin",
"size": 22
},
{
"blobId": "8ab39654695136173fee29cba0193f679dfbd652",
"rfilename": "supposedly_safe.pkl",
"size": 31
}
],
"spaces": [],
"tags": ["pytorch", "region:us"],
})
);
}
}