rama_http/layer/decompression/request/
service.rsuse std::fmt;
use crate::dep::http_body::Body;
use crate::dep::http_body_util::{combinators::UnsyncBoxBody, BodyExt, Empty};
use crate::layer::{
decompression::body::BodyInner,
decompression::DecompressionBody,
util::compression::{AcceptEncoding, CompressionLevel, WrapBody},
util::content_encoding::SupportedEncodings,
};
use crate::{header, HeaderValue, Request, Response, StatusCode};
use bytes::Buf;
use rama_core::error::BoxError;
use rama_core::{Context, Service};
use rama_utils::macros::define_inner_service_accessors;
pub struct RequestDecompression<S> {
pub(super) inner: S,
pub(super) accept: AcceptEncoding,
pub(super) pass_through_unaccepted: bool,
}
impl<S: fmt::Debug> fmt::Debug for RequestDecompression<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RequestDecompression")
.field("inner", &self.inner)
.field("accept", &self.accept)
.field("pass_through_unaccepted", &self.pass_through_unaccepted)
.finish()
}
}
impl<S: Clone> Clone for RequestDecompression<S> {
fn clone(&self) -> Self {
RequestDecompression {
inner: self.inner.clone(),
accept: self.accept,
pass_through_unaccepted: self.pass_through_unaccepted,
}
}
}
impl<S, State, ReqBody, ResBody, D> Service<State, Request<ReqBody>> for RequestDecompression<S>
where
S: Service<
State,
Request<DecompressionBody<ReqBody>>,
Response = Response<ResBody>,
Error: Into<BoxError>,
>,
State: Clone + Send + Sync + 'static,
ReqBody: Body + Send + 'static,
ResBody: Body<Data = D, Error: Into<BoxError>> + Send + 'static,
D: Buf + 'static,
{
type Response = Response<UnsyncBoxBody<D, BoxError>>;
type Error = BoxError;
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
let (mut parts, body) = req.into_parts();
let body =
if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
match entry.get().as_bytes() {
b"gzip" if self.accept.gzip() => {
entry.remove();
parts.headers.remove(header::CONTENT_LENGTH);
BodyInner::gzip(WrapBody::new(body, CompressionLevel::default()))
}
b"deflate" if self.accept.deflate() => {
entry.remove();
parts.headers.remove(header::CONTENT_LENGTH);
BodyInner::deflate(WrapBody::new(body, CompressionLevel::default()))
}
b"br" if self.accept.br() => {
entry.remove();
parts.headers.remove(header::CONTENT_LENGTH);
BodyInner::brotli(WrapBody::new(body, CompressionLevel::default()))
}
b"zstd" if self.accept.zstd() => {
entry.remove();
parts.headers.remove(header::CONTENT_LENGTH);
BodyInner::zstd(WrapBody::new(body, CompressionLevel::default()))
}
b"identity" => BodyInner::identity(body),
_ if self.pass_through_unaccepted => BodyInner::identity(body),
_ => return unsupported_encoding(self.accept).await,
}
} else {
BodyInner::identity(body)
};
let body = DecompressionBody::new(body);
let req = Request::from_parts(parts, body);
self.inner
.serve(ctx, req)
.await
.map(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
.map_err(Into::into)
}
}
async fn unsupported_encoding<D>(
accept: AcceptEncoding,
) -> Result<Response<UnsyncBoxBody<D, BoxError>>, BoxError>
where
D: Buf + 'static,
{
let res = Response::builder()
.header(
header::ACCEPT_ENCODING,
accept
.to_header_value()
.unwrap_or(HeaderValue::from_static("identity")),
)
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Empty::new().map_err(Into::into).boxed_unsync())
.unwrap();
Ok(res)
}
impl<S> RequestDecompression<S> {
pub fn new(service: S) -> Self {
Self {
inner: service,
accept: AcceptEncoding::default(),
pass_through_unaccepted: false,
}
}
define_inner_service_accessors!();
pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self {
self.pass_through_unaccepted = enabled;
self
}
pub fn set_pass_through_unaccepted(&mut self, enabled: bool) -> &mut Self {
self.pass_through_unaccepted = enabled;
self
}
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
self.accept.set_gzip(enable);
self
}
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
self.accept.set_deflate(enable);
self
}
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
pub fn set_br(&mut self, enable: bool) -> &mut Self {
self.accept.set_br(enable);
self
}
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
self.accept.set_zstd(enable);
self
}
}