use std::{
collections::HashMap,
fmt::Write as _,
io::{ErrorKind, Write as _},
path::PathBuf,
sync::{Arc, RwLock},
};
use anyhow::{Context, Error};
use bytes::Bytes;
use http::{HeaderMap, Method};
use tempfile::NamedTempFile;
use url::Url;
use wasmer_package::{
package::WasmerPackageError,
utils::{from_bytes, from_disk},
};
use webc::DetectError;
use webc::{Container, ContainerError};
use crate::{
bin_factory::BinaryPackage,
http::{HttpClient, HttpRequest, USER_AGENT},
runtime::{
package_loader::PackageLoader,
resolver::{DistributionInfo, PackageSummary, Resolution, WebcHash},
},
};
#[derive(Debug)]
pub struct BuiltinPackageLoader {
client: Arc<dyn HttpClient + Send + Sync>,
in_memory: InMemoryCache,
cache: Option<FileSystemCache>,
tokens: HashMap<String, String>,
hash_validation: HashIntegrityValidationMode,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HashIntegrityValidationMode {
NoValidate,
WarnOnHashMismatch,
FailOnHashMismatch,
}
impl BuiltinPackageLoader {
pub fn new() -> Self {
BuiltinPackageLoader {
in_memory: InMemoryCache::default(),
client: Arc::new(crate::http::default_http_client().unwrap()),
cache: None,
hash_validation: HashIntegrityValidationMode::NoValidate,
tokens: HashMap::new(),
}
}
pub fn with_hash_validation_mode(mut self, mode: HashIntegrityValidationMode) -> Self {
self.hash_validation = mode;
self
}
pub fn with_cache_dir(self, cache_dir: impl Into<PathBuf>) -> Self {
BuiltinPackageLoader {
cache: Some(FileSystemCache {
cache_dir: cache_dir.into(),
}),
..self
}
}
pub fn validate_cache(
&self,
mode: CacheValidationMode,
) -> Result<Vec<ImageHashMismatchError>, anyhow::Error> {
let cache = self
.cache
.as_ref()
.context("can not validate cache - no cache configured")?;
let items = cache.validate_hashes()?;
let mut errors = Vec::new();
for (path, error) in items {
match mode {
CacheValidationMode::WarnOnMismatch => {
tracing::warn!(?error, "hash mismatch in cached image file");
}
CacheValidationMode::PruneOnMismatch => {
tracing::warn!(?error, "deleting cached image file due to hash mismatch");
match std::fs::remove_file(&path) {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
Err(fs_err) => {
tracing::error!(
path=%error.source,
?fs_err,
"could not delete cached image file with hash mismatch"
);
}
}
}
}
errors.push(error);
}
Ok(errors)
}
pub fn with_http_client(self, client: impl HttpClient + Send + Sync + 'static) -> Self {
self.with_shared_http_client(Arc::new(client))
}
pub fn with_shared_http_client(self, client: Arc<dyn HttpClient + Send + Sync>) -> Self {
BuiltinPackageLoader { client, ..self }
}
pub fn with_tokens<I, K, V>(mut self, tokens: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
for (hostname, token) in tokens {
self = self.with_token(hostname, token);
}
self
}
pub fn with_token(mut self, hostname: impl Into<String>, token: impl Into<String>) -> Self {
self.tokens.insert(hostname.into(), token.into());
self
}
pub fn insert_cached(&self, hash: WebcHash, container: &Container) {
self.in_memory.save(container, hash);
}
#[tracing::instrument(level = "debug", skip_all, fields(pkg.hash=%hash))]
async fn get_cached(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
if let Some(cached) = self.in_memory.lookup(hash) {
return Ok(Some(cached));
}
if let Some(cache) = self.cache.as_ref() {
if let Some(cached) = cache.lookup(hash).await? {
tracing::debug!("Copying from the filesystem cache to the in-memory cache");
self.in_memory.save(&cached, *hash);
return Ok(Some(cached));
}
}
Ok(None)
}
async fn validate_hash(
image: &bytes::Bytes,
mode: HashIntegrityValidationMode,
info: &DistributionInfo,
) -> Result<(), anyhow::Error> {
let info = info.clone();
let image = image.clone();
crate::spawn_blocking(move || Self::validate_hash_sync(&image, mode, &info))
.await
.context("tokio runtime failed")?
}
fn validate_hash_sync(
image: &[u8],
mode: HashIntegrityValidationMode,
info: &DistributionInfo,
) -> Result<(), anyhow::Error> {
match mode {
HashIntegrityValidationMode::NoValidate => {
Ok(())
}
HashIntegrityValidationMode::WarnOnHashMismatch => {
let actual_hash = WebcHash::sha256(image);
if actual_hash != info.webc_sha256 {
tracing::warn!(%info.webc_sha256, %actual_hash, "image hash mismatch - actual image hash does not match the expected hash!");
}
Ok(())
}
HashIntegrityValidationMode::FailOnHashMismatch => {
let actual_hash = WebcHash::sha256(image);
if actual_hash != info.webc_sha256 {
Err(ImageHashMismatchError {
source: info.webc.to_string(),
actual_hash,
expected_hash: info.webc_sha256,
}
.into())
} else {
Ok(())
}
}
}
}
#[tracing::instrument(level = "debug", skip_all, fields(%dist.webc, %dist.webc_sha256))]
async fn download(&self, dist: &DistributionInfo) -> Result<Bytes, Error> {
if dist.webc.scheme() == "file" {
match crate::runtime::resolver::utils::file_path_from_url(&dist.webc) {
Ok(path) => {
let bytes = crate::spawn_blocking({
let path = path.clone();
move || std::fs::read(path)
})
.await?
.with_context(|| format!("Unable to read \"{}\"", path.display()))?;
let bytes = bytes::Bytes::from(bytes);
Self::validate_hash(&bytes, self.hash_validation, dist).await?;
return Ok(bytes);
}
Err(e) => {
tracing::debug!(
url=%dist.webc,
error=&*e,
"Unable to convert the file:// URL to a path",
);
}
}
}
let request = HttpRequest {
headers: self.headers(&dist.webc),
url: dist.webc.clone(),
method: Method::GET,
body: None,
options: Default::default(),
};
tracing::debug!(%request.url, %request.method, "webc_package_download_start");
tracing::trace!(?request.headers);
let response = self.client.request(request).await?;
tracing::trace!(
%response.status,
%response.redirected,
?response.headers,
response.len=response.body.as_ref().map(|body| body.len()),
"Received a response",
);
let url = &dist.webc;
if !response.is_ok() {
return Err(
crate::runtime::resolver::utils::http_error(&response).context(format!(
"package download failed: GET request to \"{}\" failed with status {}",
url, response.status
)),
);
}
let body = response.body.context("package download failed")?;
tracing::debug!(%url, "package_download_succeeded");
let body = bytes::Bytes::from(body);
Self::validate_hash(&body, self.hash_validation, dist).await?;
Ok(body)
}
fn headers(&self, url: &Url) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("Accept", "application/webc".parse().unwrap());
headers.insert("User-Agent", USER_AGENT.parse().unwrap());
if url.has_authority() {
if let Some(token) = self.tokens.get(url.authority()) {
let header = format!("Bearer {token}");
match header.parse() {
Ok(header) => {
headers.insert(http::header::AUTHORIZATION, header);
}
Err(e) => {
tracing::warn!(
error = &e as &dyn std::error::Error,
"An error occurred while parsing the authorization header",
);
}
}
}
}
headers
}
}
impl Default for BuiltinPackageLoader {
fn default() -> Self {
BuiltinPackageLoader::new()
}
}
#[async_trait::async_trait]
impl PackageLoader for BuiltinPackageLoader {
#[tracing::instrument(
level="debug",
skip_all,
fields(
pkg=%summary.pkg.id,
),
)]
async fn load(&self, summary: &PackageSummary) -> Result<Container, Error> {
if let Some(container) = self.get_cached(&summary.dist.webc_sha256).await? {
tracing::debug!("Cache hit!");
return Ok(container);
}
let bytes = self
.download(&summary.dist)
.await
.with_context(|| format!("Unable to download \"{}\"", summary.dist.webc))?;
if let Some(cache) = &self.cache {
match cache
.save_and_load_as_mmapped(bytes.clone(), &summary.dist)
.await
{
Ok(container) => {
tracing::debug!("Cached to disk");
self.in_memory.save(&container, summary.dist.webc_sha256);
return Ok(container);
}
Err(e) => {
tracing::warn!(
error=&*e,
pkg=%summary.pkg.id,
pkg.hash=%summary.dist.webc_sha256,
pkg.url=%summary.dist.webc,
"Unable to save the downloaded package to disk",
);
}
}
}
let container = crate::spawn_blocking(move || from_bytes(bytes)).await??;
self.in_memory.save(&container, summary.dist.webc_sha256);
Ok(container)
}
async fn load_package_tree(
&self,
root: &Container,
resolution: &Resolution,
root_is_local_dir: bool,
) -> Result<BinaryPackage, Error> {
super::load_package_tree(root, self, resolution, root_is_local_dir).await
}
}
#[derive(Clone, Debug)]
pub struct ImageHashMismatchError {
source: String,
expected_hash: WebcHash,
actual_hash: WebcHash,
}
impl std::fmt::Display for ImageHashMismatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"image hash mismatch! expected hash '{}', but the computed hash is '{}' (source '{}')",
self.expected_hash, self.actual_hash, self.source,
)
}
}
impl std::error::Error for ImageHashMismatchError {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CacheValidationMode {
WarnOnMismatch,
PruneOnMismatch,
}
#[derive(Debug)]
struct FileSystemCache {
cache_dir: PathBuf,
}
impl FileSystemCache {
const FILE_SUFFIX: &'static str = ".bin";
fn validate_hashes(&self) -> Result<Vec<(PathBuf, ImageHashMismatchError)>, anyhow::Error> {
let mut items = Vec::<(PathBuf, ImageHashMismatchError)>::new();
let iter = match std::fs::read_dir(&self.cache_dir) {
Ok(v) => v,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
return Ok(Vec::new());
}
Err(err) => {
return Err(err).with_context(|| {
format!(
"Could not read image cache dir: '{}'",
self.cache_dir.display()
)
});
}
};
for res in iter {
let entry = res?;
if !entry.file_type()?.is_file() {
continue;
}
let hash_opt = entry
.file_name()
.to_str()
.and_then(|x| {
let (raw_hash, _) = x.split_once(Self::FILE_SUFFIX)?;
Some(raw_hash)
})
.and_then(|x| WebcHash::parse_hex(x).ok());
let Some(expected_hash) = hash_opt else {
continue;
};
let path = entry.path();
let actual_hash = WebcHash::for_file(&path)?;
if actual_hash != expected_hash {
let err = ImageHashMismatchError {
source: path.to_string_lossy().to_string(),
actual_hash,
expected_hash,
};
items.push((path, err));
}
}
Ok(items)
}
async fn lookup(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
let path = self.path(hash);
let container = crate::spawn_blocking({
let path = path.clone();
move || from_disk(path)
})
.await?;
match container {
Ok(c) => Ok(Some(c)),
Err(WasmerPackageError::ContainerError(ContainerError::Open { error, .. }))
| Err(WasmerPackageError::ContainerError(ContainerError::Read { error, .. }))
| Err(WasmerPackageError::ContainerError(ContainerError::Detect(DetectError::Io(
error,
)))) if error.kind() == ErrorKind::NotFound => Ok(None),
Err(e) => {
let msg = format!("Unable to read \"{}\"", path.display());
Err(Error::new(e).context(msg))
}
}
}
async fn save(&self, webc: Bytes, dist: &DistributionInfo) -> Result<(), Error> {
let path = self.path(&dist.webc_sha256);
let dist = dist.clone();
crate::spawn_blocking(move || {
let parent = path.parent().expect("Always within cache_dir");
std::fs::create_dir_all(parent)
.with_context(|| format!("Unable to create \"{}\"", parent.display()))?;
let mut temp = NamedTempFile::new_in(parent)?;
temp.write_all(&webc)?;
temp.flush()?;
temp.as_file_mut().sync_all()?;
temp.persist(&path)?;
tracing::debug!(
pkg.hash=%dist.webc_sha256,
pkg.url=%dist.webc,
path=%path.display(),
num_bytes=webc.len(),
"Saved to disk",
);
Result::<_, Error>::Ok(())
})
.await??;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all)]
async fn save_and_load_as_mmapped(
&self,
webc: Bytes,
dist: &DistributionInfo,
) -> Result<Container, Error> {
self.save(webc, dist).await?;
match self.lookup(&dist.webc_sha256).await? {
Some(container) => Ok(container),
None => {
Err(Error::msg("Unable to load the downloaded memory from disk"))
}
}
}
fn path(&self, hash: &WebcHash) -> PathBuf {
let hash = hash.as_bytes();
let mut filename = String::with_capacity(hash.len() * 2);
for b in hash {
write!(filename, "{b:02x}").unwrap();
}
filename.push_str(Self::FILE_SUFFIX);
self.cache_dir.join(filename)
}
}
#[derive(Debug, Default)]
struct InMemoryCache(RwLock<HashMap<WebcHash, Container>>);
impl InMemoryCache {
fn lookup(&self, hash: &WebcHash) -> Option<Container> {
self.0.read().unwrap().get(hash).cloned()
}
fn save(&self, container: &Container, hash: WebcHash) {
let mut cache = self.0.write().unwrap();
cache.entry(hash).or_insert_with(|| container.clone());
}
}
#[cfg(test)]
mod tests {
use std::{collections::VecDeque, sync::Mutex};
use futures::future::BoxFuture;
use http::{HeaderMap, StatusCode};
use tempfile::TempDir;
use wasmer_config::package::PackageId;
use crate::{
http::{HttpRequest, HttpResponse},
runtime::resolver::PackageInfo,
};
use super::*;
const PYTHON: &[u8] = include_bytes!("../../../../c-api/examples/assets/python-0.1.0.wasmer");
#[derive(Debug)]
pub(crate) struct DummyClient {
requests: Mutex<Vec<HttpRequest>>,
responses: Mutex<VecDeque<HttpResponse>>,
}
impl DummyClient {
pub fn with_responses(responses: impl IntoIterator<Item = HttpResponse>) -> Self {
DummyClient {
requests: Mutex::new(Vec::new()),
responses: Mutex::new(responses.into_iter().collect()),
}
}
}
impl HttpClient for DummyClient {
fn request(
&self,
request: HttpRequest,
) -> BoxFuture<'_, Result<HttpResponse, anyhow::Error>> {
let response = self.responses.lock().unwrap().pop_front().unwrap();
self.requests.lock().unwrap().push(request);
Box::pin(async { Ok(response) })
}
}
async fn cache_misses_will_trigger_a_download_internal() {
let temp = TempDir::new().unwrap();
let client = Arc::new(DummyClient::with_responses([HttpResponse {
body: Some(PYTHON.to_vec()),
redirected: false,
status: StatusCode::OK,
headers: HeaderMap::new(),
}]));
let loader = BuiltinPackageLoader::new()
.with_cache_dir(temp.path())
.with_shared_http_client(client.clone());
let summary = PackageSummary {
pkg: PackageInfo {
id: PackageId::new_named("python/python", "0.1.0".parse().unwrap()),
dependencies: Vec::new(),
commands: Vec::new(),
entrypoint: Some("asdf".to_string()),
filesystem: Vec::new(),
},
dist: DistributionInfo {
webc: "https://wasmer.io/python/python".parse().unwrap(),
webc_sha256: [0xaa; 32].into(),
},
};
let container = loader.load(&summary).await.unwrap();
let requests = client.requests.lock().unwrap();
let request = &requests[0];
assert_eq!(request.url, summary.dist.webc);
assert_eq!(request.method, "GET");
assert_eq!(request.headers.len(), 2);
assert_eq!(request.headers["Accept"], "application/webc");
assert_eq!(request.headers["User-Agent"], USER_AGENT);
let manifest = container.manifest();
assert_eq!(manifest.entrypoint.as_deref(), Some("python"));
let path = loader
.cache
.as_ref()
.unwrap()
.path(&summary.dist.webc_sha256);
assert!(path.exists());
assert_eq!(std::fs::read(&path).unwrap(), PYTHON);
let in_memory = loader.in_memory.0.read().unwrap();
assert!(in_memory.contains_key(&summary.dist.webc_sha256));
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test(flavor = "multi_thread")]
async fn cache_misses_will_trigger_a_download() {
cache_misses_will_trigger_a_download_internal().await
}
#[cfg(target_arch = "wasm32")]
#[tokio::test()]
async fn cache_misses_will_trigger_a_download() {
cache_misses_will_trigger_a_download_internal().await
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_builtin_package_downloader_cache_validation() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();
let contents = "fail";
let correct_hash = WebcHash::sha256(contents);
let used_hash =
WebcHash::parse_hex("0000a28ea38a000f3a3328cb7fabe330638d3258affe1a869e3f92986222d997")
.unwrap();
let filename = format!("{}{}", used_hash, FileSystemCache::FILE_SUFFIX);
let file_path = path.join(filename);
std::fs::write(&file_path, contents).unwrap();
let dl = BuiltinPackageLoader::new().with_cache_dir(path);
let errors = dl
.validate_cache(CacheValidationMode::PruneOnMismatch)
.unwrap();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].actual_hash, correct_hash);
assert_eq!(errors[0].expected_hash, used_hash);
assert_eq!(file_path.exists(), false);
}
}