pub mod predicate;
mod body;
mod future;
mod layer;
mod pin_project_cfg;
mod service;
#[doc(inline)]
pub use self::{
body::CompressionBody,
future::ResponseFuture,
layer::CompressionLayer,
predicate::{DefaultPredicate, Predicate},
service::Compression,
};
pub use crate::compression_utils::CompressionLevel;
#[cfg(test)]
mod tests {
use crate::compression::predicate::SizeAbove;
use super::*;
use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use bytes::BytesMut;
use flate2::read::GzDecoder;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
use http_body::Body as _;
use hyper::{Body, Error, Request, Response, Server};
use std::sync::{Arc, RwLock};
use std::{io::Read, net::SocketAddr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::io::StreamReader;
use tower::{make::Shared, service_fn, Service, ServiceExt};
#[derive(Clone)]
struct Always;
impl Predicate for Always {
fn should_compress<B>(&self, _: &http::Response<B>) -> bool
where
B: http_body::Body,
{
true
}
}
#[tokio::test]
async fn gzip_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn zstd_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);
let req = Request::builder()
.header("accept-encoding", "zstd")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
let decompressed = String::from_utf8(decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[allow(dead_code)]
async fn is_compatible_with_hyper() {
let svc = service_fn(handle);
let svc = Compression::new(svc);
let make_service = Shared::new(svc);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let server = Server::bind(&addr).serve(make_service);
server.await.unwrap();
}
#[tokio::test]
async fn no_recompress() {
const DATA: &str = "Hello, World! I'm already compressed with br!";
let svc = service_fn(|_| async {
let buf = {
let mut buf = Vec::new();
let mut enc = BrotliEncoder::new(&mut buf);
enc.write_all(DATA.as_bytes()).await?;
enc.flush().await?;
buf
};
let resp = Response::builder()
.header("content-encoding", "br")
.body(Body::from(buf))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let mut svc = Compression::new(svc);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(
res.headers()
.get("content-encoding")
.and_then(|h| h.to_str().ok())
.unwrap_or_default(),
"br",
);
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let data = {
let mut output_buf = Vec::new();
let mut decoder = BrotliDecoder::new(&mut output_buf);
decoder
.write_all(&data)
.await
.expect("couldn't brotli-decode");
decoder.flush().await.expect("couldn't flush");
output_buf
};
assert_eq!(data, DATA.as_bytes());
}
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from("Hello, World!")))
}
#[tokio::test]
async fn will_not_compress_if_filtered_out() {
use predicate::Predicate;
const DATA: &str = "Hello world uncompressed";
let svc_fn = service_fn(|_| async {
let resp = Response::builder()
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
#[derive(Default, Clone)]
struct EveryOtherResponse(Arc<RwLock<u64>>);
impl Predicate for EveryOtherResponse {
fn should_compress<B>(&self, _: &http::Response<B>) -> bool
where
B: http_body::Body,
{
let mut guard = self.0.write().unwrap();
let should_compress = *guard % 2 != 0;
*guard += 1;
dbg!(should_compress)
}
}
let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
assert_eq!(DATA, &still_uncompressed);
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
assert!(String::from_utf8(data.to_vec()).is_err());
}
#[tokio::test]
async fn doesnt_compress_images() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/png".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.oneshot(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(res.headers().get(CONTENT_ENCODING).is_none());
}
#[tokio::test]
async fn does_compress_svg() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.oneshot(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
}
#[tokio::test]
async fn compress_with_quality() {
const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
let level = CompressionLevel::Best;
let svc = service_fn(|_| async {
let resp = Response::builder()
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let mut svc = Compression::new(svc).quality(level);
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let compressed_with_level = {
use async_compression::tokio::bufread::BrotliEncoder;
let stream = Box::pin(futures::stream::once(async move {
Ok::<_, std::io::Error>(DATA.as_bytes())
}));
let reader = StreamReader::new(stream);
let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
let mut buf = Vec::new();
enc.read_to_end(&mut buf).await.unwrap();
buf
};
assert_eq!(
compressed_data.as_slice(),
compressed_with_level.as_slice(),
"Compression level is not respected"
);
}
}