1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
//! Middleware to override status codes.
//!
//! # Example
//!
//! ```
//! use tower_http::set_status::SetStatusLayer;
//! use http::{Request, Response, StatusCode};
//! use bytes::Bytes;
//! use http_body_util::Full;
//! use std::{iter::once, convert::Infallible};
//! use tower::{ServiceBuilder, Service, ServiceExt};
//!
//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
//! // ...
//! # Ok(Response::new(Full::default()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut service = ServiceBuilder::new()
//! // change the status to `404 Not Found` regardless what the inner service returns
//! .layer(SetStatusLayer::new(StatusCode::NOT_FOUND))
//! .service_fn(handle);
//!
//! // Call the service.
//! let request = Request::builder().body(Full::default())?;
//!
//! let response = service.ready().await?.call(request).await?;
//!
//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
//! #
//! # Ok(())
//! # }
//! ```
use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
/// Layer that applies [`SetStatus`] which overrides the status codes.
#[derive(Debug, Clone, Copy)]
pub struct SetStatusLayer {
status: StatusCode,
}
impl SetStatusLayer {
/// Create a new [`SetStatusLayer`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(status: StatusCode) -> Self {
SetStatusLayer { status }
}
}
impl<S> Layer<S> for SetStatusLayer {
type Service = SetStatus<S>;
fn layer(&self, inner: S) -> Self::Service {
SetStatus::new(inner, self.status)
}
}
/// Middleware to override status codes.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Clone, Copy)]
pub struct SetStatus<S> {
inner: S,
status: StatusCode,
}
impl<S> SetStatus<S> {
/// Create a new [`SetStatus`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(inner: S, status: StatusCode) -> Self {
Self { status, inner }
}
define_inner_service_accessors!();
/// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware.
///
/// [`Layer`]: tower_layer::Layer
pub fn layer(status: StatusCode) -> SetStatusLayer {
SetStatusLayer::new(status)
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
status: Some(self.status),
}
}
}
pin_project! {
/// Response future for [`SetStatus`].
pub struct ResponseFuture<F> {
#[pin]
inner: F,
status: Option<StatusCode>,
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut response = ready!(this.inner.poll(cx)?);
*response.status_mut() = this.status.take().expect("future polled after completion");
Poll::Ready(Ok(response))
}
}