1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
use crate::BoxError;
use futures_core::{ready, Future};
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::{sleep, Sleep};
pin_project! {
/// Middleware that applies a timeout to request and response bodies.
///
/// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration.
///
/// Bodies must produce data at most within the specified timeout.
/// If the body does not produce a requested data frame within the timeout period, it will return an error.
///
/// # Differences from [`Timeout`][crate::timeout::Timeout]
///
/// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body.
/// That timeout is not reset when bytes are handled, whether the request is active or not.
/// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout.
///
/// This middleware will return a [`TimeoutError`].
///
/// # Example
///
/// ```
/// use http::{Request, Response};
/// use hyper::Body;
/// use std::time::Duration;
/// use tower::ServiceBuilder;
/// use tower_http::timeout::RequestBodyTimeoutLayer;
///
/// async fn handle(_: Request<Body>) -> Result<Response<Body>, std::convert::Infallible> {
/// // ...
/// # todo!()
/// }
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let svc = ServiceBuilder::new()
/// // Timeout bodies after 30 seconds of inactivity
/// .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30)))
/// .service_fn(handle);
/// # Ok(())
/// # }
/// ```
pub struct TimeoutBody<B> {
timeout: Duration,
// In http-body 1.0, `poll_*` will be merged into `poll_frame`.
// Merge the two `sleep_data` and `sleep_trailers` into one `sleep`.
// See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958
#[pin]
sleep_data: Option<Sleep>,
#[pin]
sleep_trailers: Option<Sleep>,
#[pin]
body: B,
}
}
impl<B> TimeoutBody<B> {
/// Creates a new [`TimeoutBody`].
pub fn new(timeout: Duration, body: B) -> Self {
TimeoutBody {
timeout,
sleep_data: None,
sleep_trailers: None,
body,
}
}
}
impl<B> Body for TimeoutBody<B>
where
B: Body,
B::Error: Into<BoxError>,
{
type Data = B::Data;
type Error = Box<dyn std::error::Error + Send + Sync>;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let mut this = self.project();
// Start the `Sleep` if not active.
let sleep_pinned = if let Some(some) = this.sleep_data.as_mut().as_pin_mut() {
some
} else {
this.sleep_data.set(Some(sleep(*this.timeout)));
this.sleep_data.as_mut().as_pin_mut().unwrap()
};
// Error if the timeout has expired.
if let Poll::Ready(()) = sleep_pinned.poll(cx) {
return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
}
// Check for body data.
let data = ready!(this.body.poll_data(cx));
// Some data is ready. Reset the `Sleep`...
this.sleep_data.set(None);
Poll::Ready(data.transpose().map_err(Into::into).transpose())
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
let mut this = self.project();
// In http-body 1.0, `poll_*` will be merged into `poll_frame`.
// Merge the two `sleep_data` and `sleep_trailers` into one `sleep`.
// See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958
let sleep_pinned = if let Some(some) = this.sleep_trailers.as_mut().as_pin_mut() {
some
} else {
this.sleep_trailers.set(Some(sleep(*this.timeout)));
this.sleep_trailers.as_mut().as_pin_mut().unwrap()
};
// Error if the timeout has expired.
if let Poll::Ready(()) = sleep_pinned.poll(cx) {
return Poll::Ready(Err(Box::new(TimeoutError(()))));
}
this.body.poll_trailers(cx).map_err(Into::into)
}
}
/// Error for [`TimeoutBody`].
#[derive(Debug)]
pub struct TimeoutError(());
impl std::error::Error for TimeoutError {}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "data was not received within the designated timeout")
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use pin_project_lite::pin_project;
use std::{error::Error, fmt::Display};
#[derive(Debug)]
struct MockError;
impl Error for MockError {}
impl Display for MockError {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
todo!()
}
}
pin_project! {
struct MockBody {
#[pin]
sleep: Sleep
}
}
impl Body for MockBody {
type Data = Bytes;
type Error = MockError;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.project();
this.sleep.poll(cx).map(|_| Some(Ok(vec![].into())))
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
todo!()
}
}
#[tokio::test]
async fn test_body_available_within_timeout() {
let mock_sleep = Duration::from_secs(1);
let timeout_sleep = Duration::from_secs(2);
let mock_body = MockBody {
sleep: sleep(mock_sleep),
};
let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
assert!(timeout_body.boxed().data().await.unwrap().is_ok());
}
#[tokio::test]
async fn test_body_unavailable_within_timeout_error() {
let mock_sleep = Duration::from_secs(2);
let timeout_sleep = Duration::from_secs(1);
let mock_body = MockBody {
sleep: sleep(mock_sleep),
};
let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
assert!(timeout_body.boxed().data().await.unwrap().is_err());
}
}