1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
use crate::BoxError;
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
time::Duration,
};
use tokio::time::{sleep, Sleep};
pin_project! {
/// Middleware that applies a timeout to request and response bodies.
///
/// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration.
///
/// Bodies must produce data at most within the specified timeout.
/// If the body does not produce a requested data frame within the timeout period, it will return an error.
///
/// # Differences from [`Timeout`][crate::timeout::Timeout]
///
/// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body.
/// That timeout is not reset when bytes are handled, whether the request is active or not.
/// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout.
///
/// This middleware will return a [`TimeoutError`].
///
/// # Example
///
/// ```
/// use http::{Request, Response};
/// use bytes::Bytes;
/// use http_body_util::Full;
/// use std::time::Duration;
/// use tower::ServiceBuilder;
/// use tower_http::timeout::RequestBodyTimeoutLayer;
///
/// async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
/// // ...
/// # todo!()
/// }
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let svc = ServiceBuilder::new()
/// // Timeout bodies after 30 seconds of inactivity
/// .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30)))
/// .service_fn(handle);
/// # Ok(())
/// # }
/// ```
pub struct TimeoutBody<B> {
timeout: Duration,
#[pin]
sleep: Option<Sleep>,
#[pin]
body: B,
}
}
impl<B> TimeoutBody<B> {
/// Creates a new [`TimeoutBody`].
pub fn new(timeout: Duration, body: B) -> Self {
TimeoutBody {
timeout,
sleep: None,
body,
}
}
}
impl<B> Body for TimeoutBody<B>
where
B: Body,
B::Error: Into<BoxError>,
{
type Data = B::Data;
type Error = Box<dyn std::error::Error + Send + Sync>;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();
// Start the `Sleep` if not active.
let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
some
} else {
this.sleep.set(Some(sleep(*this.timeout)));
this.sleep.as_mut().as_pin_mut().unwrap()
};
// Error if the timeout has expired.
if let Poll::Ready(()) = sleep_pinned.poll(cx) {
return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
}
// Check for body data.
let frame = ready!(this.body.poll_frame(cx));
// A frame is ready. Reset the `Sleep`...
this.sleep.set(None);
Poll::Ready(frame.transpose().map_err(Into::into).transpose())
}
}
/// Error for [`TimeoutBody`].
#[derive(Debug)]
pub struct TimeoutError(());
impl std::error::Error for TimeoutError {}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "data was not received within the designated timeout")
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http_body::Frame;
use http_body_util::BodyExt;
use pin_project_lite::pin_project;
use std::{error::Error, fmt::Display};
#[derive(Debug)]
struct MockError;
impl Error for MockError {}
impl Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mock error")
}
}
pin_project! {
struct MockBody {
#[pin]
sleep: Sleep
}
}
impl Body for MockBody {
type Data = Bytes;
type Error = MockError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
let this = self.project();
this.sleep
.poll(cx)
.map(|_| Some(Ok(Frame::data(vec![].into()))))
}
}
#[tokio::test]
async fn test_body_available_within_timeout() {
let mock_sleep = Duration::from_secs(1);
let timeout_sleep = Duration::from_secs(2);
let mock_body = MockBody {
sleep: sleep(mock_sleep),
};
let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
assert!(timeout_body
.boxed()
.frame()
.await
.expect("no frame")
.is_ok());
}
#[tokio::test]
async fn test_body_unavailable_within_timeout_error() {
let mock_sleep = Duration::from_secs(2);
let timeout_sleep = Duration::from_secs(1);
let mock_body = MockBody {
sleep: sleep(mock_sleep),
};
let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
assert!(timeout_body.boxed().frame().await.unwrap().is_err());
}
}