rama_http/layer/
timeout.rs

1//! Middleware that applies a timeout to requests.
2//!
3//! If the request does not complete within the specified timeout it will be aborted and a `408
4//! Request Timeout` response will be sent.
5//!
6//! # Differences from `rama_core::service::layer::Timeout`
7//!
8//! The generic [`Timeout`] middleware uses an error to signal timeout, i.e.
9//! it changes the error type to [`BoxError`](rama_core::error::BoxError). For HTTP services that is rarely
10//! what you want as returning errors will terminate the connection without sending a response.
11//!
12//! This middleware won't change the error type and instead return a `408 Request Timeout`
13//! response. That means if your service's error type is [`Infallible`] it will still be
14//! [`Infallible`] after applying this middleware.
15//!
16//! # Example
17//!
18//! ```
19//! use std::{convert::Infallible, time::Duration};
20//!
21//! use rama_core::Layer;
22//! use rama_core::service::service_fn;
23//! use rama_http::{Body, Request, Response};
24//! use rama_http::layer::timeout::TimeoutLayer;
25//! use rama_core::error::BoxError;
26//!
27//! async fn handle(_: Request) -> Result<Response, Infallible> {
28//!     // ...
29//!     # Ok(Response::new(Body::empty()))
30//! }
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), BoxError> {
34//! let svc = (
35//!     // Timeout requests after 30 seconds
36//!     TimeoutLayer::new(Duration::from_secs(30)),
37//! ).layer(service_fn(handle));
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! [`Infallible`]: std::convert::Infallible
43
44use crate::{Request, Response, StatusCode};
45use rama_core::{Context, Layer, Service};
46use rama_utils::macros::define_inner_service_accessors;
47use std::fmt;
48use std::time::Duration;
49
50/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests.
51///
52/// See the [module docs](super) for an example.
53#[derive(Debug, Clone)]
54pub struct TimeoutLayer {
55    timeout: Duration,
56}
57
58impl TimeoutLayer {
59    /// Creates a new [`TimeoutLayer`].
60    pub const fn new(timeout: Duration) -> Self {
61        TimeoutLayer { timeout }
62    }
63}
64
65impl<S> Layer<S> for TimeoutLayer {
66    type Service = Timeout<S>;
67
68    fn layer(&self, inner: S) -> Self::Service {
69        Timeout::new(inner, self.timeout)
70    }
71}
72
73/// Middleware which apply a timeout to requests.
74///
75/// If the request does not complete within the specified timeout it will be aborted and a `408
76/// Request Timeout` response will be sent.
77///
78/// See the [module docs](super) for an example.
79pub struct Timeout<S> {
80    inner: S,
81    timeout: Duration,
82}
83
84impl<S> Timeout<S> {
85    /// Creates a new [`Timeout`].
86    pub const fn new(inner: S, timeout: Duration) -> Self {
87        Self { inner, timeout }
88    }
89
90    define_inner_service_accessors!();
91}
92
93impl<S: fmt::Debug> fmt::Debug for Timeout<S> {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("Timeout")
96            .field("inner", &self.inner)
97            .field("timeout", &self.timeout)
98            .finish()
99    }
100}
101
102impl<S: Clone> Clone for Timeout<S> {
103    fn clone(&self) -> Self {
104        Timeout {
105            inner: self.inner.clone(),
106            timeout: self.timeout,
107        }
108    }
109}
110
111impl<S: Copy> Copy for Timeout<S> {}
112
113impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for Timeout<S>
114where
115    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
116    ReqBody: Send + 'static,
117    ResBody: Default + Send + 'static,
118    State: Clone + Send + Sync + 'static,
119{
120    type Response = S::Response;
121    type Error = S::Error;
122
123    async fn serve(
124        &self,
125        ctx: Context<State>,
126        req: Request<ReqBody>,
127    ) -> Result<Self::Response, Self::Error> {
128        tokio::select! {
129            res = self.inner.serve(ctx, req) => res,
130            _ = tokio::time::sleep(self.timeout) => {
131                let mut res = Response::new(ResBody::default());
132                *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
133                Ok(res)
134            }
135        }
136    }
137}