tower_http/compression/
future.rs1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12 future::Future,
13 pin::Pin,
14 task::{ready, Context, Poll},
15};
16
17pin_project! {
18 #[derive(Debug)]
22 pub struct ResponseFuture<F, P> {
23 #[pin]
24 pub(crate) inner: F,
25 pub(crate) encoding: Encoding,
26 pub(crate) predicate: P,
27 pub(crate) quality: CompressionLevel,
28 }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33 F: Future<Output = Result<Response<B>, E>>,
34 B: Body,
35 P: Predicate,
36{
37 type Output = Result<Response<CompressionBody<B>>, E>;
38
39 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42 let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
44 && !res.headers().contains_key(header::CONTENT_RANGE)
46 && self.predicate.should_compress(&res);
47
48 let (mut parts, body) = res.into_parts();
49
50 if should_compress {
51 parts
52 .headers
53 .append(header::VARY, header::ACCEPT_ENCODING.into());
54 }
55
56 let body = match (should_compress, self.encoding) {
57 (false, _) | (_, Encoding::Identity) => {
59 return Poll::Ready(Ok(Response::from_parts(
60 parts,
61 CompressionBody::new(BodyInner::identity(body)),
62 )))
63 }
64
65 #[cfg(feature = "compression-gzip")]
66 (_, Encoding::Gzip) => {
67 CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
68 }
69 #[cfg(feature = "compression-deflate")]
70 (_, Encoding::Deflate) => {
71 CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
72 }
73 #[cfg(feature = "compression-br")]
74 (_, Encoding::Brotli) => {
75 CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
76 }
77 #[cfg(feature = "compression-zstd")]
78 (_, Encoding::Zstd) => {
79 CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
80 }
81 #[cfg(feature = "fs")]
82 #[allow(unreachable_patterns)]
83 (true, _) => {
84 return Poll::Ready(Ok(Response::from_parts(
99 parts,
100 CompressionBody::new(BodyInner::identity(body)),
101 )));
102 }
103 };
104
105 parts.headers.remove(header::ACCEPT_RANGES);
106 parts.headers.remove(header::CONTENT_LENGTH);
107
108 parts
109 .headers
110 .insert(header::CONTENT_ENCODING, self.encoding.into_header_value());
111
112 let res = Response::from_parts(parts, body);
113 Poll::Ready(Ok(res))
114 }
115}