tower_http/compression/
body.rs

1#![allow(unused_imports)]
2
3use crate::compression::CompressionLevel;
4use crate::{
5    compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody},
6    BoxError,
7};
8#[cfg(feature = "compression-br")]
9use async_compression::tokio::bufread::BrotliEncoder;
10#[cfg(feature = "compression-gzip")]
11use async_compression::tokio::bufread::GzipEncoder;
12#[cfg(feature = "compression-deflate")]
13use async_compression::tokio::bufread::ZlibEncoder;
14#[cfg(feature = "compression-zstd")]
15use async_compression::tokio::bufread::ZstdEncoder;
16
17use bytes::{Buf, Bytes};
18use http::HeaderMap;
19use http_body::Body;
20use pin_project_lite::pin_project;
21use std::{
22    io,
23    marker::PhantomData,
24    pin::Pin,
25    task::{ready, Context, Poll},
26};
27use tokio_util::io::StreamReader;
28
29use super::pin_project_cfg::pin_project_cfg;
30
31pin_project! {
32    /// Response body of [`Compression`].
33    ///
34    /// [`Compression`]: super::Compression
35    pub struct CompressionBody<B>
36    where
37        B: Body,
38    {
39        #[pin]
40        pub(crate) inner: BodyInner<B>,
41    }
42}
43
44impl<B> Default for CompressionBody<B>
45where
46    B: Body + Default,
47{
48    fn default() -> Self {
49        Self {
50            inner: BodyInner::Identity {
51                inner: B::default(),
52            },
53        }
54    }
55}
56
57impl<B> CompressionBody<B>
58where
59    B: Body,
60{
61    pub(crate) fn new(inner: BodyInner<B>) -> Self {
62        Self { inner }
63    }
64
65    /// Get a reference to the inner body
66    pub fn get_ref(&self) -> &B {
67        match &self.inner {
68            #[cfg(feature = "compression-gzip")]
69            BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
70            #[cfg(feature = "compression-deflate")]
71            BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
72            #[cfg(feature = "compression-br")]
73            BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
74            #[cfg(feature = "compression-zstd")]
75            BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
76            BodyInner::Identity { inner } => inner,
77        }
78    }
79
80    /// Get a mutable reference to the inner body
81    pub fn get_mut(&mut self) -> &mut B {
82        match &mut self.inner {
83            #[cfg(feature = "compression-gzip")]
84            BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
85            #[cfg(feature = "compression-deflate")]
86            BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
87            #[cfg(feature = "compression-br")]
88            BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
89            #[cfg(feature = "compression-zstd")]
90            BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
91            BodyInner::Identity { inner } => inner,
92        }
93    }
94
95    /// Get a pinned mutable reference to the inner body
96    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
97        match self.project().inner.project() {
98            #[cfg(feature = "compression-gzip")]
99            BodyInnerProj::Gzip { inner } => inner
100                .project()
101                .read
102                .get_pin_mut()
103                .get_pin_mut()
104                .get_pin_mut()
105                .get_pin_mut(),
106            #[cfg(feature = "compression-deflate")]
107            BodyInnerProj::Deflate { inner } => inner
108                .project()
109                .read
110                .get_pin_mut()
111                .get_pin_mut()
112                .get_pin_mut()
113                .get_pin_mut(),
114            #[cfg(feature = "compression-br")]
115            BodyInnerProj::Brotli { inner } => inner
116                .project()
117                .read
118                .get_pin_mut()
119                .get_pin_mut()
120                .get_pin_mut()
121                .get_pin_mut(),
122            #[cfg(feature = "compression-zstd")]
123            BodyInnerProj::Zstd { inner } => inner
124                .project()
125                .read
126                .get_pin_mut()
127                .get_pin_mut()
128                .get_pin_mut()
129                .get_pin_mut(),
130            BodyInnerProj::Identity { inner } => inner,
131        }
132    }
133
134    /// Consume `self`, returning the inner body
135    pub fn into_inner(self) -> B {
136        match self.inner {
137            #[cfg(feature = "compression-gzip")]
138            BodyInner::Gzip { inner } => inner
139                .read
140                .into_inner()
141                .into_inner()
142                .into_inner()
143                .into_inner(),
144            #[cfg(feature = "compression-deflate")]
145            BodyInner::Deflate { inner } => inner
146                .read
147                .into_inner()
148                .into_inner()
149                .into_inner()
150                .into_inner(),
151            #[cfg(feature = "compression-br")]
152            BodyInner::Brotli { inner } => inner
153                .read
154                .into_inner()
155                .into_inner()
156                .into_inner()
157                .into_inner(),
158            #[cfg(feature = "compression-zstd")]
159            BodyInner::Zstd { inner } => inner
160                .read
161                .into_inner()
162                .into_inner()
163                .into_inner()
164                .into_inner(),
165            BodyInner::Identity { inner } => inner,
166        }
167    }
168}
169
170#[cfg(feature = "compression-gzip")]
171type GzipBody<B> = WrapBody<GzipEncoder<B>>;
172
173#[cfg(feature = "compression-deflate")]
174type DeflateBody<B> = WrapBody<ZlibEncoder<B>>;
175
176#[cfg(feature = "compression-br")]
177type BrotliBody<B> = WrapBody<BrotliEncoder<B>>;
178
179#[cfg(feature = "compression-zstd")]
180type ZstdBody<B> = WrapBody<ZstdEncoder<B>>;
181
182pin_project_cfg! {
183    #[project = BodyInnerProj]
184    pub(crate) enum BodyInner<B>
185    where
186        B: Body,
187    {
188        #[cfg(feature = "compression-gzip")]
189        Gzip {
190            #[pin]
191            inner: GzipBody<B>,
192        },
193        #[cfg(feature = "compression-deflate")]
194        Deflate {
195            #[pin]
196            inner: DeflateBody<B>,
197        },
198        #[cfg(feature = "compression-br")]
199        Brotli {
200            #[pin]
201            inner: BrotliBody<B>,
202        },
203        #[cfg(feature = "compression-zstd")]
204        Zstd {
205            #[pin]
206            inner: ZstdBody<B>,
207        },
208        Identity {
209            #[pin]
210            inner: B,
211        },
212    }
213}
214
215impl<B: Body> BodyInner<B> {
216    #[cfg(feature = "compression-gzip")]
217    pub(crate) fn gzip(inner: WrapBody<GzipEncoder<B>>) -> Self {
218        Self::Gzip { inner }
219    }
220
221    #[cfg(feature = "compression-deflate")]
222    pub(crate) fn deflate(inner: WrapBody<ZlibEncoder<B>>) -> Self {
223        Self::Deflate { inner }
224    }
225
226    #[cfg(feature = "compression-br")]
227    pub(crate) fn brotli(inner: WrapBody<BrotliEncoder<B>>) -> Self {
228        Self::Brotli { inner }
229    }
230
231    #[cfg(feature = "compression-zstd")]
232    pub(crate) fn zstd(inner: WrapBody<ZstdEncoder<B>>) -> Self {
233        Self::Zstd { inner }
234    }
235
236    pub(crate) fn identity(inner: B) -> Self {
237        Self::Identity { inner }
238    }
239}
240
241impl<B> Body for CompressionBody<B>
242where
243    B: Body,
244    B::Error: Into<BoxError>,
245{
246    type Data = Bytes;
247    type Error = BoxError;
248
249    fn poll_frame(
250        self: Pin<&mut Self>,
251        cx: &mut Context<'_>,
252    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
253        match self.project().inner.project() {
254            #[cfg(feature = "compression-gzip")]
255            BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
256            #[cfg(feature = "compression-deflate")]
257            BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
258            #[cfg(feature = "compression-br")]
259            BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
260            #[cfg(feature = "compression-zstd")]
261            BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
262            BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
263                Some(Ok(frame)) => {
264                    let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
265                    Poll::Ready(Some(Ok(frame)))
266                }
267                Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
268                None => Poll::Ready(None),
269            },
270        }
271    }
272
273    fn size_hint(&self) -> http_body::SizeHint {
274        if let BodyInner::Identity { inner } = &self.inner {
275            inner.size_hint()
276        } else {
277            http_body::SizeHint::new()
278        }
279    }
280}
281
282#[cfg(feature = "compression-gzip")]
283impl<B> DecorateAsyncRead for GzipEncoder<B>
284where
285    B: Body,
286{
287    type Input = AsyncReadBody<B>;
288    type Output = GzipEncoder<Self::Input>;
289
290    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
291        GzipEncoder::with_quality(input, quality.into_async_compression())
292    }
293
294    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
295        pinned.get_pin_mut()
296    }
297}
298
299#[cfg(feature = "compression-deflate")]
300impl<B> DecorateAsyncRead for ZlibEncoder<B>
301where
302    B: Body,
303{
304    type Input = AsyncReadBody<B>;
305    type Output = ZlibEncoder<Self::Input>;
306
307    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
308        ZlibEncoder::with_quality(input, quality.into_async_compression())
309    }
310
311    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
312        pinned.get_pin_mut()
313    }
314}
315
316#[cfg(feature = "compression-br")]
317impl<B> DecorateAsyncRead for BrotliEncoder<B>
318where
319    B: Body,
320{
321    type Input = AsyncReadBody<B>;
322    type Output = BrotliEncoder<Self::Input>;
323
324    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
325        // The brotli crate used under the hood here has a default compression level of 11,
326        // which is the max for brotli. This causes extremely slow compression times, so we
327        // manually set a default of 4 here.
328        //
329        // This is the same default used by NGINX for on-the-fly brotli compression.
330        let level = match quality {
331            CompressionLevel::Default => async_compression::Level::Precise(4),
332            other => other.into_async_compression(),
333        };
334        BrotliEncoder::with_quality(input, level)
335    }
336
337    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
338        pinned.get_pin_mut()
339    }
340}
341
342#[cfg(feature = "compression-zstd")]
343impl<B> DecorateAsyncRead for ZstdEncoder<B>
344where
345    B: Body,
346{
347    type Input = AsyncReadBody<B>;
348    type Output = ZstdEncoder<Self::Input>;
349
350    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
351        // See https://issues.chromium.org/issues/41493659:
352        //  "For memory usage reasons, Chromium limits the window size to 8MB"
353        // See https://datatracker.ietf.org/doc/html/rfc8878#name-window-descriptor
354        //  "For improved interoperability, it's recommended for decoders to support values
355        //  of Window_Size up to 8 MB and for encoders not to generate frames requiring a
356        //  Window_Size larger than 8 MB."
357        // Level 17 in zstd (as of v1.5.6) is the first level with a window size of 8 MB (2^23):
358        // https://github.com/facebook/zstd/blob/v1.5.6/lib/compress/clevels.h#L25-L51
359        // Set the parameter for all levels >= 17. This will either have no effect (but reduce
360        // the risk of future changes in zstd) or limit the window log to 8MB.
361        let needs_window_limit = match quality {
362            CompressionLevel::Best => true, // level 20
363            CompressionLevel::Precise(level) => level >= 17,
364            _ => false,
365        };
366        // The parameter is not set for levels below 17 as it will increase the window size
367        // for those levels.
368        if needs_window_limit {
369            let params = [async_compression::zstd::CParameter::window_log(23)];
370            ZstdEncoder::with_quality_and_params(input, quality.into_async_compression(), &params)
371        } else {
372            ZstdEncoder::with_quality(input, quality.into_async_compression())
373        }
374    }
375
376    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
377        pinned.get_pin_mut()
378    }
379}