http_body_util/
limited.rs

1use bytes::Buf;
2use http_body::{Body, Frame, SizeHint};
3use pin_project_lite::pin_project;
4use std::error::Error;
5use std::fmt;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10    /// A length limited body.
11    ///
12    /// This body will return an error if more than the configured number
13    /// of bytes are returned on polling the wrapped body.
14    #[derive(Clone, Copy, Debug)]
15    pub struct Limited<B> {
16        remaining: usize,
17        #[pin]
18        inner: B,
19    }
20}
21
22impl<B> Limited<B> {
23    /// Create a new `Limited`.
24    pub fn new(inner: B, limit: usize) -> Self {
25        Self {
26            remaining: limit,
27            inner,
28        }
29    }
30}
31
32impl<B> Body for Limited<B>
33where
34    B: Body,
35    B::Error: Into<Box<dyn Error + Send + Sync>>,
36{
37    type Data = B::Data;
38    type Error = Box<dyn Error + Send + Sync>;
39
40    fn poll_frame(
41        self: Pin<&mut Self>,
42        cx: &mut Context<'_>,
43    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
44        let this = self.project();
45        let res = match this.inner.poll_frame(cx) {
46            Poll::Pending => return Poll::Pending,
47            Poll::Ready(None) => None,
48            Poll::Ready(Some(Ok(frame))) => {
49                if let Some(data) = frame.data_ref() {
50                    if data.remaining() > *this.remaining {
51                        *this.remaining = 0;
52                        Some(Err(LengthLimitError.into()))
53                    } else {
54                        *this.remaining -= data.remaining();
55                        Some(Ok(frame))
56                    }
57                } else {
58                    Some(Ok(frame))
59                }
60            }
61            Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
62        };
63
64        Poll::Ready(res)
65    }
66
67    fn is_end_stream(&self) -> bool {
68        self.inner.is_end_stream()
69    }
70
71    fn size_hint(&self) -> SizeHint {
72        use std::convert::TryFrom;
73        match u64::try_from(self.remaining) {
74            Ok(n) => {
75                let mut hint = self.inner.size_hint();
76                if hint.lower() >= n {
77                    hint.set_exact(n)
78                } else if let Some(max) = hint.upper() {
79                    hint.set_upper(n.min(max))
80                } else {
81                    hint.set_upper(n)
82                }
83                hint
84            }
85            Err(_) => self.inner.size_hint(),
86        }
87    }
88}
89
90/// An error returned when body length exceeds the configured limit.
91#[derive(Debug)]
92#[non_exhaustive]
93pub struct LengthLimitError;
94
95impl fmt::Display for LengthLimitError {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        f.write_str("length limit exceeded")
98    }
99}
100
101impl Error for LengthLimitError {}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::{BodyExt, Full, StreamBody};
107    use bytes::Bytes;
108    use std::convert::Infallible;
109
110    #[tokio::test]
111    async fn read_for_body_under_limit_returns_data() {
112        const DATA: &[u8] = b"testing";
113        let inner = Full::new(Bytes::from(DATA));
114        let body = &mut Limited::new(inner, 8);
115
116        let mut hint = SizeHint::new();
117        hint.set_upper(7);
118        assert_eq!(body.size_hint().upper(), hint.upper());
119
120        let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
121        assert_eq!(data, DATA);
122        hint.set_upper(0);
123        assert_eq!(body.size_hint().upper(), hint.upper());
124
125        assert!(body.frame().await.is_none());
126    }
127
128    #[tokio::test]
129    async fn read_for_body_over_limit_returns_error() {
130        const DATA: &[u8] = b"testing a string that is too long";
131        let inner = Full::new(Bytes::from(DATA));
132        let body = &mut Limited::new(inner, 8);
133
134        let mut hint = SizeHint::new();
135        hint.set_upper(8);
136        assert_eq!(body.size_hint().upper(), hint.upper());
137
138        let error = body.frame().await.unwrap().unwrap_err();
139        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
140    }
141
142    fn body_from_iter<I>(into_iter: I) -> impl Body<Data = Bytes, Error = Infallible>
143    where
144        I: IntoIterator,
145        I::Item: Into<Bytes> + 'static,
146        I::IntoIter: Send + 'static,
147    {
148        let iter = into_iter
149            .into_iter()
150            .map(|it| Frame::data(it.into()))
151            .map(Ok::<_, Infallible>);
152
153        StreamBody::new(futures_util::stream::iter(iter))
154    }
155
156    #[tokio::test]
157    async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
158    ) {
159        const DATA: [&[u8]; 2] = [b"testing ", b"a string that is too long"];
160        let inner = body_from_iter(DATA);
161        let body = &mut Limited::new(inner, 8);
162
163        let mut hint = SizeHint::new();
164        hint.set_upper(8);
165        assert_eq!(body.size_hint().upper(), hint.upper());
166
167        let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
168        assert_eq!(data, DATA[0]);
169        hint.set_upper(0);
170        assert_eq!(body.size_hint().upper(), hint.upper());
171
172        let error = body.frame().await.unwrap().unwrap_err();
173        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
174    }
175
176    #[tokio::test]
177    async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
178        const DATA: [&[u8]; 2] = [b"testing a string", b" that is too long"];
179        let inner = body_from_iter(DATA);
180        let body = &mut Limited::new(inner, 8);
181
182        let mut hint = SizeHint::new();
183        hint.set_upper(8);
184        assert_eq!(body.size_hint().upper(), hint.upper());
185
186        let error = body.frame().await.unwrap().unwrap_err();
187        assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
188    }
189
190    #[tokio::test]
191    async fn read_for_chunked_body_under_limit_is_okay() {
192        const DATA: [&[u8]; 2] = [b"test", b"ing!"];
193        let inner = body_from_iter(DATA);
194        let body = &mut Limited::new(inner, 8);
195
196        let mut hint = SizeHint::new();
197        hint.set_upper(8);
198        assert_eq!(body.size_hint().upper(), hint.upper());
199
200        let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
201        assert_eq!(data, DATA[0]);
202        hint.set_upper(4);
203        assert_eq!(body.size_hint().upper(), hint.upper());
204
205        let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
206        assert_eq!(data, DATA[1]);
207        hint.set_upper(0);
208        assert_eq!(body.size_hint().upper(), hint.upper());
209
210        assert!(body.frame().await.is_none());
211    }
212
213    struct SomeTrailers;
214
215    impl Body for SomeTrailers {
216        type Data = Bytes;
217        type Error = Infallible;
218
219        fn poll_frame(
220            self: Pin<&mut Self>,
221            _cx: &mut Context<'_>,
222        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
223            Poll::Ready(Some(Ok(Frame::trailers(http::HeaderMap::new()))))
224        }
225    }
226
227    #[tokio::test]
228    async fn read_for_trailers_propagates_inner_trailers() {
229        let body = &mut Limited::new(SomeTrailers, 8);
230        let frame = body.frame().await.unwrap().unwrap();
231        assert!(frame.is_trailers());
232    }
233
234    #[derive(Debug)]
235    struct ErrorBodyError;
236
237    impl fmt::Display for ErrorBodyError {
238        fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
239            Ok(())
240        }
241    }
242
243    impl Error for ErrorBodyError {}
244
245    struct ErrorBody;
246
247    impl Body for ErrorBody {
248        type Data = &'static [u8];
249        type Error = ErrorBodyError;
250
251        fn poll_frame(
252            self: Pin<&mut Self>,
253            _cx: &mut Context<'_>,
254        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
255            Poll::Ready(Some(Err(ErrorBodyError)))
256        }
257    }
258
259    #[tokio::test]
260    async fn read_for_body_returning_error_propagates_error() {
261        let body = &mut Limited::new(ErrorBody, 8);
262        let error = body.frame().await.unwrap().unwrap_err();
263        assert!(matches!(error.downcast_ref(), Some(ErrorBodyError)));
264    }
265}