tower_http/
set_status.rs

1//! Middleware to override status codes.
2//!
3//! # Example
4//!
5//! ```
6//! use tower_http::set_status::SetStatusLayer;
7//! use http::{Request, Response, StatusCode};
8//! use bytes::Bytes;
9//! use http_body_util::Full;
10//! use std::{iter::once, convert::Infallible};
11//! use tower::{ServiceBuilder, Service, ServiceExt};
12//!
13//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//!     // ...
15//!     # Ok(Response::new(Full::default()))
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let mut service = ServiceBuilder::new()
21//!     // change the status to `404 Not Found` regardless what the inner service returns
22//!     .layer(SetStatusLayer::new(StatusCode::NOT_FOUND))
23//!     .service_fn(handle);
24//!
25//! // Call the service.
26//! let request = Request::builder().body(Full::default())?;
27//!
28//! let response = service.ready().await?.call(request).await?;
29//!
30//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
31//! #
32//! # Ok(())
33//! # }
34//! ```
35
36use http::{Request, Response, StatusCode};
37use pin_project_lite::pin_project;
38use std::{
39    future::Future,
40    pin::Pin,
41    task::{ready, Context, Poll},
42};
43use tower_layer::Layer;
44use tower_service::Service;
45
46/// Layer that applies [`SetStatus`] which overrides the status codes.
47#[derive(Debug, Clone, Copy)]
48pub struct SetStatusLayer {
49    status: StatusCode,
50}
51
52impl SetStatusLayer {
53    /// Create a new [`SetStatusLayer`].
54    ///
55    /// The response status code will be `status` regardless of what the inner service returns.
56    pub fn new(status: StatusCode) -> Self {
57        SetStatusLayer { status }
58    }
59}
60
61impl<S> Layer<S> for SetStatusLayer {
62    type Service = SetStatus<S>;
63
64    fn layer(&self, inner: S) -> Self::Service {
65        SetStatus::new(inner, self.status)
66    }
67}
68
69/// Middleware to override status codes.
70///
71/// See the [module docs](self) for more details.
72#[derive(Debug, Clone, Copy)]
73pub struct SetStatus<S> {
74    inner: S,
75    status: StatusCode,
76}
77
78impl<S> SetStatus<S> {
79    /// Create a new [`SetStatus`].
80    ///
81    /// The response status code will be `status` regardless of what the inner service returns.
82    pub fn new(inner: S, status: StatusCode) -> Self {
83        Self { status, inner }
84    }
85
86    define_inner_service_accessors!();
87
88    /// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware.
89    ///
90    /// [`Layer`]: tower_layer::Layer
91    pub fn layer(status: StatusCode) -> SetStatusLayer {
92        SetStatusLayer::new(status)
93    }
94}
95
96impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S>
97where
98    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
99{
100    type Response = S::Response;
101    type Error = S::Error;
102    type Future = ResponseFuture<S::Future>;
103
104    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        self.inner.poll_ready(cx)
106    }
107
108    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
109        ResponseFuture {
110            inner: self.inner.call(req),
111            status: Some(self.status),
112        }
113    }
114}
115
116pin_project! {
117    /// Response future for [`SetStatus`].
118    pub struct ResponseFuture<F> {
119        #[pin]
120        inner: F,
121        status: Option<StatusCode>,
122    }
123}
124
125impl<F, B, E> Future for ResponseFuture<F>
126where
127    F: Future<Output = Result<Response<B>, E>>,
128{
129    type Output = F::Output;
130
131    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132        let this = self.project();
133        let mut response = ready!(this.inner.poll(cx)?);
134        *response.status_mut() = this.status.take().expect("future polled after completion");
135        Poll::Ready(Ok(response))
136    }
137}