tower_http/compression/
layer.rsuse super::{Compression, Predicate};
use crate::compression::predicate::DefaultPredicate;
use crate::compression::CompressionLevel;
use crate::compression_utils::AcceptEncoding;
use tower_layer::Layer;
#[derive(Clone, Debug, Default)]
pub struct CompressionLayer<P = DefaultPredicate> {
accept: AcceptEncoding,
predicate: P,
quality: CompressionLevel,
}
impl<S, P> Layer<S> for CompressionLayer<P>
where
P: Predicate,
{
type Service = Compression<S, P>;
fn layer(&self, inner: S) -> Self::Service {
Compression {
inner,
accept: self.accept,
predicate: self.predicate.clone(),
quality: self.quality,
}
}
}
impl CompressionLayer {
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "compression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
#[cfg(feature = "compression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
#[cfg(feature = "compression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}
pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
where
C: Predicate,
{
CompressionLayer {
accept: self.accept,
predicate,
quality: self.quality,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::Body;
use http::{header::ACCEPT_ENCODING, Request, Response};
use http_body_util::BodyExt;
use std::convert::Infallible;
use tokio::fs::File;
use tokio_util::io::ReaderStream;
use tower::{Service, ServiceBuilder, ServiceExt};
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let file = File::open("Cargo.toml").await.expect("file missing");
let stream = ReaderStream::new(file);
let body = Body::from_stream(stream);
Ok(Response::new(body))
}
#[tokio::test]
async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> {
let deflate_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_br()
.no_gzip();
let mut service = ServiceBuilder::new()
.layer(deflate_only_layer)
.service_fn(handle);
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "deflate");
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let deflate_bytes_len = bytes.len();
let br_only_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_gzip()
.no_deflate();
let mut service = ServiceBuilder::new()
.layer(br_only_layer)
.service_fn(handle);
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "br");
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let br_byte_length = bytes.len();
assert!(br_byte_length < deflate_bytes_len * 9 / 10);
Ok(())
}
#[tokio::test]
async fn zstd_is_web_safe() -> Result<(), crate::BoxError> {
async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
}
let zstd_layer = CompressionLayer::new()
.quality(CompressionLevel::Best)
.no_br()
.no_deflate()
.no_gzip();
let mut service = ServiceBuilder::new().layer(zstd_layer).service_fn(zeroes);
let request = Request::builder()
.header(ACCEPT_ENCODING, "zstd")
.body(Body::empty())?;
let response = service.ready().await?.call(request).await?;
assert_eq!(response.headers()["content-encoding"], "zstd");
let body = response.into_body();
let bytes = body.collect().await?.to_bytes();
let mut dec = zstd::Decoder::new(&*bytes)?;
dec.window_log_max(23)?; std::io::copy(&mut dec, &mut std::io::sink())?;
Ok(())
}
}