http_body_util/
limited.rs1use bytes::Buf;
2use http_body::{Body, Frame, SizeHint};
3use pin_project_lite::pin_project;
4use std::error::Error;
5use std::fmt;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 #[derive(Clone, Copy, Debug)]
15 pub struct Limited<B> {
16 remaining: usize,
17 #[pin]
18 inner: B,
19 }
20}
21
22impl<B> Limited<B> {
23 pub fn new(inner: B, limit: usize) -> Self {
25 Self {
26 remaining: limit,
27 inner,
28 }
29 }
30}
31
32impl<B> Body for Limited<B>
33where
34 B: Body,
35 B::Error: Into<Box<dyn Error + Send + Sync>>,
36{
37 type Data = B::Data;
38 type Error = Box<dyn Error + Send + Sync>;
39
40 fn poll_frame(
41 self: Pin<&mut Self>,
42 cx: &mut Context<'_>,
43 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
44 let this = self.project();
45 let res = match this.inner.poll_frame(cx) {
46 Poll::Pending => return Poll::Pending,
47 Poll::Ready(None) => None,
48 Poll::Ready(Some(Ok(frame))) => {
49 if let Some(data) = frame.data_ref() {
50 if data.remaining() > *this.remaining {
51 *this.remaining = 0;
52 Some(Err(LengthLimitError.into()))
53 } else {
54 *this.remaining -= data.remaining();
55 Some(Ok(frame))
56 }
57 } else {
58 Some(Ok(frame))
59 }
60 }
61 Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
62 };
63
64 Poll::Ready(res)
65 }
66
67 fn is_end_stream(&self) -> bool {
68 self.inner.is_end_stream()
69 }
70
71 fn size_hint(&self) -> SizeHint {
72 use std::convert::TryFrom;
73 match u64::try_from(self.remaining) {
74 Ok(n) => {
75 let mut hint = self.inner.size_hint();
76 if hint.lower() >= n {
77 hint.set_exact(n)
78 } else if let Some(max) = hint.upper() {
79 hint.set_upper(n.min(max))
80 } else {
81 hint.set_upper(n)
82 }
83 hint
84 }
85 Err(_) => self.inner.size_hint(),
86 }
87 }
88}
89
90#[derive(Debug)]
92#[non_exhaustive]
93pub struct LengthLimitError;
94
95impl fmt::Display for LengthLimitError {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 f.write_str("length limit exceeded")
98 }
99}
100
101impl Error for LengthLimitError {}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::{BodyExt, Full, StreamBody};
107 use bytes::Bytes;
108 use std::convert::Infallible;
109
110 #[tokio::test]
111 async fn read_for_body_under_limit_returns_data() {
112 const DATA: &[u8] = b"testing";
113 let inner = Full::new(Bytes::from(DATA));
114 let body = &mut Limited::new(inner, 8);
115
116 let mut hint = SizeHint::new();
117 hint.set_upper(7);
118 assert_eq!(body.size_hint().upper(), hint.upper());
119
120 let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
121 assert_eq!(data, DATA);
122 hint.set_upper(0);
123 assert_eq!(body.size_hint().upper(), hint.upper());
124
125 assert!(body.frame().await.is_none());
126 }
127
128 #[tokio::test]
129 async fn read_for_body_over_limit_returns_error() {
130 const DATA: &[u8] = b"testing a string that is too long";
131 let inner = Full::new(Bytes::from(DATA));
132 let body = &mut Limited::new(inner, 8);
133
134 let mut hint = SizeHint::new();
135 hint.set_upper(8);
136 assert_eq!(body.size_hint().upper(), hint.upper());
137
138 let error = body.frame().await.unwrap().unwrap_err();
139 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
140 }
141
142 fn body_from_iter<I>(into_iter: I) -> impl Body<Data = Bytes, Error = Infallible>
143 where
144 I: IntoIterator,
145 I::Item: Into<Bytes> + 'static,
146 I::IntoIter: Send + 'static,
147 {
148 let iter = into_iter
149 .into_iter()
150 .map(|it| Frame::data(it.into()))
151 .map(Ok::<_, Infallible>);
152
153 StreamBody::new(futures_util::stream::iter(iter))
154 }
155
156 #[tokio::test]
157 async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
158 ) {
159 const DATA: [&[u8]; 2] = [b"testing ", b"a string that is too long"];
160 let inner = body_from_iter(DATA);
161 let body = &mut Limited::new(inner, 8);
162
163 let mut hint = SizeHint::new();
164 hint.set_upper(8);
165 assert_eq!(body.size_hint().upper(), hint.upper());
166
167 let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
168 assert_eq!(data, DATA[0]);
169 hint.set_upper(0);
170 assert_eq!(body.size_hint().upper(), hint.upper());
171
172 let error = body.frame().await.unwrap().unwrap_err();
173 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
174 }
175
176 #[tokio::test]
177 async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
178 const DATA: [&[u8]; 2] = [b"testing a string", b" that is too long"];
179 let inner = body_from_iter(DATA);
180 let body = &mut Limited::new(inner, 8);
181
182 let mut hint = SizeHint::new();
183 hint.set_upper(8);
184 assert_eq!(body.size_hint().upper(), hint.upper());
185
186 let error = body.frame().await.unwrap().unwrap_err();
187 assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
188 }
189
190 #[tokio::test]
191 async fn read_for_chunked_body_under_limit_is_okay() {
192 const DATA: [&[u8]; 2] = [b"test", b"ing!"];
193 let inner = body_from_iter(DATA);
194 let body = &mut Limited::new(inner, 8);
195
196 let mut hint = SizeHint::new();
197 hint.set_upper(8);
198 assert_eq!(body.size_hint().upper(), hint.upper());
199
200 let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
201 assert_eq!(data, DATA[0]);
202 hint.set_upper(4);
203 assert_eq!(body.size_hint().upper(), hint.upper());
204
205 let data = body.frame().await.unwrap().unwrap().into_data().unwrap();
206 assert_eq!(data, DATA[1]);
207 hint.set_upper(0);
208 assert_eq!(body.size_hint().upper(), hint.upper());
209
210 assert!(body.frame().await.is_none());
211 }
212
213 struct SomeTrailers;
214
215 impl Body for SomeTrailers {
216 type Data = Bytes;
217 type Error = Infallible;
218
219 fn poll_frame(
220 self: Pin<&mut Self>,
221 _cx: &mut Context<'_>,
222 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
223 Poll::Ready(Some(Ok(Frame::trailers(http::HeaderMap::new()))))
224 }
225 }
226
227 #[tokio::test]
228 async fn read_for_trailers_propagates_inner_trailers() {
229 let body = &mut Limited::new(SomeTrailers, 8);
230 let frame = body.frame().await.unwrap().unwrap();
231 assert!(frame.is_trailers());
232 }
233
234 #[derive(Debug)]
235 struct ErrorBodyError;
236
237 impl fmt::Display for ErrorBodyError {
238 fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
239 Ok(())
240 }
241 }
242
243 impl Error for ErrorBodyError {}
244
245 struct ErrorBody;
246
247 impl Body for ErrorBody {
248 type Data = &'static [u8];
249 type Error = ErrorBodyError;
250
251 fn poll_frame(
252 self: Pin<&mut Self>,
253 _cx: &mut Context<'_>,
254 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
255 Poll::Ready(Some(Err(ErrorBodyError)))
256 }
257 }
258
259 #[tokio::test]
260 async fn read_for_body_returning_error_propagates_error() {
261 let body = &mut Limited::new(ErrorBody, 8);
262 let error = body.frame().await.unwrap().unwrap_err();
263 assert!(matches!(error.downcast_ref(), Some(ErrorBodyError)));
264 }
265}