tokio_buf/util/
limit.rs

1use BufStream;
2
3use bytes::Buf;
4use futures::Poll;
5
6/// Limits the stream to a maximum amount of data.
7#[derive(Debug)]
8pub struct Limit<T> {
9    stream: T,
10    remaining: u64,
11}
12
13/// Errors returned from `Limit`.
14#[derive(Debug)]
15pub struct LimitError<T> {
16    /// When `None`, limit was reached
17    inner: Option<T>,
18}
19
20impl<T> Limit<T> {
21    pub(crate) fn new(stream: T, amount: u64) -> Limit<T> {
22        Limit {
23            stream,
24            remaining: amount,
25        }
26    }
27}
28
29impl<T> BufStream for Limit<T>
30where
31    T: BufStream,
32{
33    type Item = T::Item;
34    type Error = LimitError<T::Error>;
35
36    fn poll_buf(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
37        use futures::Async::Ready;
38
39        if self.stream.size_hint().lower() > self.remaining {
40            return Err(LimitError { inner: None });
41        }
42
43        let res = self
44            .stream
45            .poll_buf()
46            .map_err(|err| LimitError { inner: Some(err) });
47
48        match res {
49            Ok(Ready(Some(ref buf))) => {
50                if buf.remaining() as u64 > self.remaining {
51                    self.remaining = 0;
52                    return Err(LimitError { inner: None });
53                }
54
55                self.remaining -= buf.remaining() as u64;
56            }
57            _ => {}
58        }
59
60        res
61    }
62}
63
64// ===== impl LimitError =====
65
66impl<T> LimitError<T> {
67    /// Returns `true` if the error was caused by polling the stream.
68    pub fn is_stream_err(&self) -> bool {
69        self.inner.is_some()
70    }
71
72    /// Returns `true` if the stream reached its limit.
73    pub fn is_limit_err(&self) -> bool {
74        self.inner.is_none()
75    }
76}