tower_http/decompression/
body.rs

1#![allow(unused_imports)]
2
3use crate::compression_utils::CompressionLevel;
4use crate::{
5    compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody},
6    BoxError,
7};
8#[cfg(feature = "decompression-br")]
9use async_compression::tokio::bufread::BrotliDecoder;
10#[cfg(feature = "decompression-gzip")]
11use async_compression::tokio::bufread::GzipDecoder;
12#[cfg(feature = "decompression-deflate")]
13use async_compression::tokio::bufread::ZlibDecoder;
14#[cfg(feature = "decompression-zstd")]
15use async_compression::tokio::bufread::ZstdDecoder;
16use bytes::{Buf, Bytes};
17use http::HeaderMap;
18use http_body::{Body, SizeHint};
19use pin_project_lite::pin_project;
20use std::task::Context;
21use std::{
22    io,
23    marker::PhantomData,
24    pin::Pin,
25    task::{ready, Poll},
26};
27use tokio_util::io::StreamReader;
28
29pin_project! {
30    /// Response body of [`RequestDecompression`] and [`Decompression`].
31    ///
32    /// [`RequestDecompression`]: super::RequestDecompression
33    /// [`Decompression`]: super::Decompression
34    pub struct DecompressionBody<B>
35    where
36        B: Body
37    {
38        #[pin]
39        pub(crate) inner: BodyInner<B>,
40    }
41}
42
43impl<B> Default for DecompressionBody<B>
44where
45    B: Body + Default,
46{
47    fn default() -> Self {
48        Self {
49            inner: BodyInner::Identity {
50                inner: B::default(),
51            },
52        }
53    }
54}
55
56impl<B> DecompressionBody<B>
57where
58    B: Body,
59{
60    pub(crate) fn new(inner: BodyInner<B>) -> Self {
61        Self { inner }
62    }
63
64    /// Get a reference to the inner body
65    pub fn get_ref(&self) -> &B {
66        match &self.inner {
67            #[cfg(feature = "decompression-gzip")]
68            BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
69            #[cfg(feature = "decompression-deflate")]
70            BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
71            #[cfg(feature = "decompression-br")]
72            BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
73            #[cfg(feature = "decompression-zstd")]
74            BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
75            BodyInner::Identity { inner } => inner,
76
77            // FIXME: Remove once possible; see https://github.com/rust-lang/rust/issues/51085
78            #[cfg(not(feature = "decompression-gzip"))]
79            BodyInner::Gzip { inner } => match inner.0 {},
80            #[cfg(not(feature = "decompression-deflate"))]
81            BodyInner::Deflate { inner } => match inner.0 {},
82            #[cfg(not(feature = "decompression-br"))]
83            BodyInner::Brotli { inner } => match inner.0 {},
84            #[cfg(not(feature = "decompression-zstd"))]
85            BodyInner::Zstd { inner } => match inner.0 {},
86        }
87    }
88
89    /// Get a mutable reference to the inner body
90    pub fn get_mut(&mut self) -> &mut B {
91        match &mut self.inner {
92            #[cfg(feature = "decompression-gzip")]
93            BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
94            #[cfg(feature = "decompression-deflate")]
95            BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
96            #[cfg(feature = "decompression-br")]
97            BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
98            #[cfg(feature = "decompression-zstd")]
99            BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
100            BodyInner::Identity { inner } => inner,
101
102            #[cfg(not(feature = "decompression-gzip"))]
103            BodyInner::Gzip { inner } => match inner.0 {},
104            #[cfg(not(feature = "decompression-deflate"))]
105            BodyInner::Deflate { inner } => match inner.0 {},
106            #[cfg(not(feature = "decompression-br"))]
107            BodyInner::Brotli { inner } => match inner.0 {},
108            #[cfg(not(feature = "decompression-zstd"))]
109            BodyInner::Zstd { inner } => match inner.0 {},
110        }
111    }
112
113    /// Get a pinned mutable reference to the inner body
114    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
115        match self.project().inner.project() {
116            #[cfg(feature = "decompression-gzip")]
117            BodyInnerProj::Gzip { inner } => inner
118                .project()
119                .read
120                .get_pin_mut()
121                .get_pin_mut()
122                .get_pin_mut()
123                .get_pin_mut(),
124            #[cfg(feature = "decompression-deflate")]
125            BodyInnerProj::Deflate { inner } => inner
126                .project()
127                .read
128                .get_pin_mut()
129                .get_pin_mut()
130                .get_pin_mut()
131                .get_pin_mut(),
132            #[cfg(feature = "decompression-br")]
133            BodyInnerProj::Brotli { inner } => inner
134                .project()
135                .read
136                .get_pin_mut()
137                .get_pin_mut()
138                .get_pin_mut()
139                .get_pin_mut(),
140            #[cfg(feature = "decompression-zstd")]
141            BodyInnerProj::Zstd { inner } => inner
142                .project()
143                .read
144                .get_pin_mut()
145                .get_pin_mut()
146                .get_pin_mut()
147                .get_pin_mut(),
148            BodyInnerProj::Identity { inner } => inner,
149
150            #[cfg(not(feature = "decompression-gzip"))]
151            BodyInnerProj::Gzip { inner } => match inner.0 {},
152            #[cfg(not(feature = "decompression-deflate"))]
153            BodyInnerProj::Deflate { inner } => match inner.0 {},
154            #[cfg(not(feature = "decompression-br"))]
155            BodyInnerProj::Brotli { inner } => match inner.0 {},
156            #[cfg(not(feature = "decompression-zstd"))]
157            BodyInnerProj::Zstd { inner } => match inner.0 {},
158        }
159    }
160
161    /// Consume `self`, returning the inner body
162    pub fn into_inner(self) -> B {
163        match self.inner {
164            #[cfg(feature = "decompression-gzip")]
165            BodyInner::Gzip { inner } => inner
166                .read
167                .into_inner()
168                .into_inner()
169                .into_inner()
170                .into_inner(),
171            #[cfg(feature = "decompression-deflate")]
172            BodyInner::Deflate { inner } => inner
173                .read
174                .into_inner()
175                .into_inner()
176                .into_inner()
177                .into_inner(),
178            #[cfg(feature = "decompression-br")]
179            BodyInner::Brotli { inner } => inner
180                .read
181                .into_inner()
182                .into_inner()
183                .into_inner()
184                .into_inner(),
185            #[cfg(feature = "decompression-zstd")]
186            BodyInner::Zstd { inner } => inner
187                .read
188                .into_inner()
189                .into_inner()
190                .into_inner()
191                .into_inner(),
192            BodyInner::Identity { inner } => inner,
193
194            #[cfg(not(feature = "decompression-gzip"))]
195            BodyInner::Gzip { inner } => match inner.0 {},
196            #[cfg(not(feature = "decompression-deflate"))]
197            BodyInner::Deflate { inner } => match inner.0 {},
198            #[cfg(not(feature = "decompression-br"))]
199            BodyInner::Brotli { inner } => match inner.0 {},
200            #[cfg(not(feature = "decompression-zstd"))]
201            BodyInner::Zstd { inner } => match inner.0 {},
202        }
203    }
204}
205
206#[cfg(any(
207    not(feature = "decompression-gzip"),
208    not(feature = "decompression-deflate"),
209    not(feature = "decompression-br"),
210    not(feature = "decompression-zstd")
211))]
212pub(crate) enum Never {}
213
214#[cfg(feature = "decompression-gzip")]
215type GzipBody<B> = WrapBody<GzipDecoder<B>>;
216#[cfg(not(feature = "decompression-gzip"))]
217type GzipBody<B> = (Never, PhantomData<B>);
218
219#[cfg(feature = "decompression-deflate")]
220type DeflateBody<B> = WrapBody<ZlibDecoder<B>>;
221#[cfg(not(feature = "decompression-deflate"))]
222type DeflateBody<B> = (Never, PhantomData<B>);
223
224#[cfg(feature = "decompression-br")]
225type BrotliBody<B> = WrapBody<BrotliDecoder<B>>;
226#[cfg(not(feature = "decompression-br"))]
227type BrotliBody<B> = (Never, PhantomData<B>);
228
229#[cfg(feature = "decompression-zstd")]
230type ZstdBody<B> = WrapBody<ZstdDecoder<B>>;
231#[cfg(not(feature = "decompression-zstd"))]
232type ZstdBody<B> = (Never, PhantomData<B>);
233
234pin_project! {
235    #[project = BodyInnerProj]
236    pub(crate) enum BodyInner<B>
237    where
238        B: Body,
239    {
240        Gzip {
241            #[pin]
242            inner: GzipBody<B>,
243        },
244        Deflate {
245            #[pin]
246            inner: DeflateBody<B>,
247        },
248        Brotli {
249            #[pin]
250            inner: BrotliBody<B>,
251        },
252        Zstd {
253            #[pin]
254            inner: ZstdBody<B>,
255        },
256        Identity {
257            #[pin]
258            inner: B,
259        },
260    }
261}
262
263impl<B: Body> BodyInner<B> {
264    #[cfg(feature = "decompression-gzip")]
265    pub(crate) fn gzip(inner: WrapBody<GzipDecoder<B>>) -> Self {
266        Self::Gzip { inner }
267    }
268
269    #[cfg(feature = "decompression-deflate")]
270    pub(crate) fn deflate(inner: WrapBody<ZlibDecoder<B>>) -> Self {
271        Self::Deflate { inner }
272    }
273
274    #[cfg(feature = "decompression-br")]
275    pub(crate) fn brotli(inner: WrapBody<BrotliDecoder<B>>) -> Self {
276        Self::Brotli { inner }
277    }
278
279    #[cfg(feature = "decompression-zstd")]
280    pub(crate) fn zstd(inner: WrapBody<ZstdDecoder<B>>) -> Self {
281        Self::Zstd { inner }
282    }
283
284    pub(crate) fn identity(inner: B) -> Self {
285        Self::Identity { inner }
286    }
287}
288
289impl<B> Body for DecompressionBody<B>
290where
291    B: Body,
292    B::Error: Into<BoxError>,
293{
294    type Data = Bytes;
295    type Error = BoxError;
296
297    fn poll_frame(
298        self: Pin<&mut Self>,
299        cx: &mut Context<'_>,
300    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
301        match self.project().inner.project() {
302            #[cfg(feature = "decompression-gzip")]
303            BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
304            #[cfg(feature = "decompression-deflate")]
305            BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
306            #[cfg(feature = "decompression-br")]
307            BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
308            #[cfg(feature = "decompression-zstd")]
309            BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
310            BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
311                Some(Ok(frame)) => {
312                    let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
313                    Poll::Ready(Some(Ok(frame)))
314                }
315                Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
316                None => Poll::Ready(None),
317            },
318
319            #[cfg(not(feature = "decompression-gzip"))]
320            BodyInnerProj::Gzip { inner } => match inner.0 {},
321            #[cfg(not(feature = "decompression-deflate"))]
322            BodyInnerProj::Deflate { inner } => match inner.0 {},
323            #[cfg(not(feature = "decompression-br"))]
324            BodyInnerProj::Brotli { inner } => match inner.0 {},
325            #[cfg(not(feature = "decompression-zstd"))]
326            BodyInnerProj::Zstd { inner } => match inner.0 {},
327        }
328    }
329
330    fn size_hint(&self) -> SizeHint {
331        match self.inner {
332            BodyInner::Identity { ref inner } => inner.size_hint(),
333            _ => SizeHint::default(),
334        }
335    }
336}
337
338#[cfg(feature = "decompression-gzip")]
339impl<B> DecorateAsyncRead for GzipDecoder<B>
340where
341    B: Body,
342{
343    type Input = AsyncReadBody<B>;
344    type Output = GzipDecoder<Self::Input>;
345
346    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
347        let mut decoder = GzipDecoder::new(input);
348        decoder.multiple_members(true);
349        decoder
350    }
351
352    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
353        pinned.get_pin_mut()
354    }
355}
356
357#[cfg(feature = "decompression-deflate")]
358impl<B> DecorateAsyncRead for ZlibDecoder<B>
359where
360    B: Body,
361{
362    type Input = AsyncReadBody<B>;
363    type Output = ZlibDecoder<Self::Input>;
364
365    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
366        ZlibDecoder::new(input)
367    }
368
369    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
370        pinned.get_pin_mut()
371    }
372}
373
374#[cfg(feature = "decompression-br")]
375impl<B> DecorateAsyncRead for BrotliDecoder<B>
376where
377    B: Body,
378{
379    type Input = AsyncReadBody<B>;
380    type Output = BrotliDecoder<Self::Input>;
381
382    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
383        BrotliDecoder::new(input)
384    }
385
386    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
387        pinned.get_pin_mut()
388    }
389}
390
391#[cfg(feature = "decompression-zstd")]
392impl<B> DecorateAsyncRead for ZstdDecoder<B>
393where
394    B: Body,
395{
396    type Input = AsyncReadBody<B>;
397    type Output = ZstdDecoder<Self::Input>;
398
399    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
400        ZstdDecoder::new(input)
401    }
402
403    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
404        pinned.get_pin_mut()
405    }
406}