use super::ServeDir;
use crate::dep::{mime::Mime, mime_guess};
use crate::{HeaderValue, Request, Response};
use rama_core::{Context, Service};
use std::path::Path;
#[derive(Clone, Debug)]
pub struct ServeFile(ServeDir);
impl ServeFile {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
let guess = mime_guess::from_path(path.as_ref());
let mime = guess
.first_raw()
.map(HeaderValue::from_static)
.unwrap_or_else(|| {
HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap()
});
ServeFile(ServeDir::new_single_file(path, mime))
}
pub fn new_with_mime<P: AsRef<Path>>(path: P, mime: &Mime) -> Self {
let mime = HeaderValue::from_str(mime.as_ref()).expect("mime isn't a valid header value");
ServeFile(ServeDir::new_single_file(path, mime))
}
pub fn precompressed_gzip(self) -> Self {
Self(self.0.precompressed_gzip())
}
pub fn precompressed_br(self) -> Self {
Self(self.0.precompressed_br())
}
pub fn precompressed_deflate(self) -> Self {
Self(self.0.precompressed_deflate())
}
pub fn precompressed_zstd(self) -> Self {
Self(self.0.precompressed_zstd())
}
pub fn with_buf_chunk_size(self, chunk_size: usize) -> Self {
Self(self.0.with_buf_chunk_size(chunk_size))
}
#[inline]
pub async fn try_call<State, ReqBody>(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Response, std::io::Error>
where
State: Clone + Send + Sync + 'static,
ReqBody: Send + 'static,
{
self.0.try_call(ctx, req).await
}
}
impl<State, ReqBody> Service<State, Request<ReqBody>> for ServeFile
where
ReqBody: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Error = <ServeDir as Service<State, Request<ReqBody>>>::Error;
type Response = <ServeDir as Service<State, Request<ReqBody>>>::Response;
#[inline]
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
self.0.serve(ctx, req).await
}
}
#[cfg(test)]
#[cfg(feature = "compression")]
mod compression_tests {
use super::*;
use crate::Body;
#[tokio::test]
#[cfg(feature = "compression")]
async fn precompressed_zstd() {
use async_compression::tokio::bufread::ZstdDecoder;
use http_body_util::BodyExt;
use tokio::io::AsyncReadExt;
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd();
let request = Request::builder()
.header("Accept-Encoding", "zstd,br")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "zstd");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = ZstdDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).await.unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}
}
#[cfg(test)]
mod tests {
use crate::dep::http_body_util::BodyExt;
use crate::dep::mime::Mime;
use crate::header;
use crate::service::fs::ServeFile;
use crate::Body;
use crate::Method;
use crate::{Request, StatusCode};
use brotli::BrotliDecompress;
use flate2::bufread::DeflateDecoder;
use flate2::bufread::GzDecoder;
use rama_core::{Context, Service};
use std::io::Read;
use std::str::FromStr;
#[tokio::test]
async fn basic() {
let svc = ServeFile::new("../README.md");
let res = svc
.serve(Context::default(), Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.headers()["content-type"], "text/markdown");
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("[![rama banner]"));
}
#[tokio::test]
async fn basic_with_mime() {
let svc = ServeFile::new_with_mime("../README.md", &Mime::from_str("image/jpg").unwrap());
let res = svc
.serve(Context::default(), Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.headers()["content-type"], "image/jpg");
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("[![rama banner]"));
}
#[tokio::test]
async fn head_request() {
let svc = ServeFile::new("../test-files/precompressed.txt");
let mut request = Request::new(Body::empty());
*request.method_mut() = Method::HEAD;
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
#[cfg(target_os = "windows")]
assert_eq!(res.headers()["content-length"], "24");
#[cfg(not(target_os = "windows"))]
assert_eq!(res.headers()["content-length"], "23");
assert!(res.into_body().frame().await.is_none());
}
#[tokio::test]
async fn precompresed_head_request() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip();
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.method(Method::HEAD)
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "gzip");
assert_eq!(res.headers()["content-length"], "59");
assert!(res.into_body().frame().await.is_none());
}
#[tokio::test]
async fn precompressed_gzip() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip();
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "gzip");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = GzDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}
#[tokio::test]
async fn unsupported_precompression_algorithm_fallbacks_to_uncompressed() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip();
let request = Request::builder()
.header("Accept-Encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert!(res.headers().get("content-encoding").is_none());
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("\"This is a test file!\""));
}
#[tokio::test]
async fn missing_precompressed_variant_fallbacks_to_uncompressed() {
let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip();
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert!(res.headers().get("content-encoding").is_none());
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("Test file!"));
}
#[tokio::test]
async fn missing_precompressed_variant_fallbacks_to_uncompressed_head_request() {
let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip();
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.method(Method::HEAD)
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
#[cfg(target_os = "windows")]
assert_eq!(res.headers()["content-length"], "12");
#[cfg(not(target_os = "windows"))]
assert_eq!(res.headers()["content-length"], "11");
assert!(res.headers().get("content-encoding").is_none());
assert!(res.into_body().frame().await.is_none());
}
#[tokio::test]
async fn only_precompressed_variant_existing() {
let svc = ServeFile::new("../test-files/only_gzipped.txt").precompressed_gzip();
let request = Request::builder().body(Body::empty()).unwrap();
let res = svc
.clone()
.serve(Context::default(), request)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "gzip");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = GzDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(decompressed.starts_with("\"This is a test file\""));
}
#[tokio::test]
async fn precompressed_br() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_br();
let request = Request::builder()
.header("Accept-Encoding", "gzip,br")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "br");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decompressed = Vec::new();
BrotliDecompress(&mut &body[..], &mut decompressed).unwrap();
let decompressed = String::from_utf8(decompressed.to_vec()).unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}
#[tokio::test]
async fn precompressed_deflate() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_deflate();
let request = Request::builder()
.header("Accept-Encoding", "deflate,br")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "deflate");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = DeflateDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}
#[tokio::test]
async fn multi_precompressed() {
let svc = ServeFile::new("../test-files/precompressed.txt")
.precompressed_gzip()
.precompressed_br();
let request = Request::builder()
.header("Accept-Encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc
.clone()
.serve(Context::default(), request)
.await
.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "gzip");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = GzDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
let request = Request::builder()
.header("Accept-Encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc
.clone()
.serve(Context::default(), request)
.await
.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "br");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decompressed = Vec::new();
BrotliDecompress(&mut &body[..], &mut decompressed).unwrap();
let decompressed = String::from_utf8(decompressed.to_vec()).unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}
#[tokio::test]
async fn with_custom_chunk_size() {
let svc = ServeFile::new("../README.md").with_buf_chunk_size(1024 * 32);
let res = svc
.serve(Context::default(), Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.headers()["content-type"], "text/markdown");
let body = res.into_body().collect().await.unwrap().to_bytes();
let body = String::from_utf8(body.to_vec()).unwrap();
assert!(body.starts_with("[![rama banner]"));
}
#[tokio::test]
async fn fallbacks_to_different_precompressed_variant_if_not_found() {
let svc = ServeFile::new("../test-files/precompressed_br.txt")
.precompressed_gzip()
.precompressed_deflate()
.precompressed_br();
let request = Request::builder()
.header("Accept-Encoding", "gzip,deflate,br")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "br");
let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decompressed = Vec::new();
BrotliDecompress(&mut &body[..], &mut decompressed).unwrap();
let decompressed = String::from_utf8(decompressed.to_vec()).unwrap();
assert!(decompressed.starts_with("Test file"));
}
#[tokio::test]
async fn fallbacks_to_different_precompressed_variant_if_not_found_head_request() {
let svc = ServeFile::new("../test-files/precompressed_br.txt")
.precompressed_gzip()
.precompressed_deflate()
.precompressed_br();
let request = Request::builder()
.header("Accept-Encoding", "gzip,deflate,br")
.method(Method::HEAD)
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-length"], "15");
assert_eq!(res.headers()["content-encoding"], "br");
assert!(res.into_body().frame().await.is_none());
}
#[tokio::test]
async fn returns_404_if_file_doesnt_exist() {
let svc = ServeFile::new("../this-doesnt-exist.md");
let res = svc
.serve(Context::default(), Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert!(res.headers().get(header::CONTENT_TYPE).is_none());
}
#[tokio::test]
async fn returns_404_if_file_doesnt_exist_when_precompression_is_used() {
let svc = ServeFile::new("../this-doesnt-exist.md").precompressed_deflate();
let request = Request::builder()
.header("Accept-Encoding", "deflate")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), request).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert!(res.headers().get(header::CONTENT_TYPE).is_none());
}
#[tokio::test]
async fn last_modified() {
let svc = ServeFile::new("../README.md");
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let last_modified = res
.headers()
.get(header::LAST_MODIFIED)
.expect("Missing last modified header!");
let svc = ServeFile::new("../README.md");
let req = Request::builder()
.header(header::IF_MODIFIED_SINCE, last_modified)
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_MODIFIED);
assert!(res.into_body().frame().await.is_none());
let svc = ServeFile::new("../README.md");
let req = Request::builder()
.header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let readme_bytes = include_bytes!("../../../../README.md");
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body.as_ref(), readme_bytes);
let svc = ServeFile::new("../README.md");
let req = Request::builder()
.header(header::IF_UNMODIFIED_SINCE, last_modified)
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body.as_ref(), readme_bytes);
let svc = ServeFile::new("../README.md");
let req = Request::builder()
.header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT")
.body(Body::empty())
.unwrap();
let res = svc.serve(Context::default(), req).await.unwrap();
assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED);
assert!(res.into_body().frame().await.is_none());
}
}