tower_http/compression/
future.rs

1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    pin::Pin,
14    task::{ready, Context, Poll},
15};
16
17pin_project! {
18    /// Response future of [`Compression`].
19    ///
20    /// [`Compression`]: super::Compression
21    #[derive(Debug)]
22    pub struct ResponseFuture<F, P> {
23        #[pin]
24        pub(crate) inner: F,
25        pub(crate) encoding: Encoding,
26        pub(crate) predicate: P,
27        pub(crate) quality: CompressionLevel,
28    }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33    F: Future<Output = Result<Response<B>, E>>,
34    B: Body,
35    P: Predicate,
36{
37    type Output = Result<Response<CompressionBody<B>>, E>;
38
39    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42        // never recompress responses that are already compressed
43        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
44            // never compress responses that are ranges
45            && !res.headers().contains_key(header::CONTENT_RANGE)
46            && self.predicate.should_compress(&res);
47
48        let (mut parts, body) = res.into_parts();
49
50        if should_compress {
51            parts
52                .headers
53                .append(header::VARY, header::ACCEPT_ENCODING.into());
54        }
55
56        let body = match (should_compress, self.encoding) {
57            // if compression is _not_ supported or the client doesn't accept it
58            (false, _) | (_, Encoding::Identity) => {
59                return Poll::Ready(Ok(Response::from_parts(
60                    parts,
61                    CompressionBody::new(BodyInner::identity(body)),
62                )))
63            }
64
65            #[cfg(feature = "compression-gzip")]
66            (_, Encoding::Gzip) => {
67                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
68            }
69            #[cfg(feature = "compression-deflate")]
70            (_, Encoding::Deflate) => {
71                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
72            }
73            #[cfg(feature = "compression-br")]
74            (_, Encoding::Brotli) => {
75                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
76            }
77            #[cfg(feature = "compression-zstd")]
78            (_, Encoding::Zstd) => {
79                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
80            }
81            #[cfg(feature = "fs")]
82            #[allow(unreachable_patterns)]
83            (true, _) => {
84                // This should never happen because the `AcceptEncoding` struct which is used to determine
85                // `self.encoding` will only enable the different compression algorithms if the
86                // corresponding crate feature has been enabled. This means
87                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
88                // features enabled.
89                //
90                // The match arm is still required though because the `fs` feature uses the
91                // Encoding struct independently and requires no compression logic to be enabled.
92                // This means a combination of an individual compression feature and `fs` will fail
93                // to compile without this branch even though it will never be reached.
94                //
95                // To safeguard against refactors that changes this relationship or other bugs the
96                // server will return an uncompressed response instead of panicking since that could
97                // become a ddos attack vector.
98                return Poll::Ready(Ok(Response::from_parts(
99                    parts,
100                    CompressionBody::new(BodyInner::identity(body)),
101                )));
102            }
103        };
104
105        parts.headers.remove(header::ACCEPT_RANGES);
106        parts.headers.remove(header::CONTENT_LENGTH);
107
108        parts
109            .headers
110            .insert(header::CONTENT_ENCODING, self.encoding.into_header_value());
111
112        let res = Response::from_parts(parts, body);
113        Poll::Ready(Ok(res))
114    }
115}