tower_http/
propagate_header.rs

1//! Propagate a header from the request to the response.
2//!
3//! # Example
4//!
5//! ```rust
6//! use http::{Request, Response, header::HeaderName};
7//! use std::convert::Infallible;
8//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
9//! use tower_http::propagate_header::PropagateHeaderLayer;
10//! use bytes::Bytes;
11//! use http_body_util::Full;
12//!
13//! # #[tokio::main]
14//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
16//!     // ...
17//!     # Ok(Response::new(Full::default()))
18//! }
19//!
20//! let mut svc = ServiceBuilder::new()
21//!     // This will copy `x-request-id` headers from requests onto responses.
22//!     .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id")))
23//!     .service_fn(handle);
24//!
25//! // Call the service.
26//! let request = Request::builder()
27//!     .header("x-request-id", "1337")
28//!     .body(Full::default())?;
29//!
30//! let response = svc.ready().await?.call(request).await?;
31//!
32//! assert_eq!(response.headers()["x-request-id"], "1337");
33//! #
34//! # Ok(())
35//! # }
36//! ```
37
38use http::{header::HeaderName, HeaderValue, Request, Response};
39use pin_project_lite::pin_project;
40use std::future::Future;
41use std::{
42    pin::Pin,
43    task::{ready, Context, Poll},
44};
45use tower_layer::Layer;
46use tower_service::Service;
47
48/// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses.
49///
50/// If the header is present on the request it'll be applied to the response as well. This could
51/// for example be used to propagate headers such as `X-Request-Id`.
52///
53/// See the [module docs](crate::propagate_header) for more details.
54#[derive(Clone, Debug)]
55pub struct PropagateHeaderLayer {
56    header: HeaderName,
57}
58
59impl PropagateHeaderLayer {
60    /// Create a new [`PropagateHeaderLayer`].
61    pub fn new(header: HeaderName) -> Self {
62        Self { header }
63    }
64}
65
66impl<S> Layer<S> for PropagateHeaderLayer {
67    type Service = PropagateHeader<S>;
68
69    fn layer(&self, inner: S) -> Self::Service {
70        PropagateHeader {
71            inner,
72            header: self.header.clone(),
73        }
74    }
75}
76
77/// Middleware that propagates headers from requests to responses.
78///
79/// If the header is present on the request it'll be applied to the response as well. This could
80/// for example be used to propagate headers such as `X-Request-Id`.
81///
82/// See the [module docs](crate::propagate_header) for more details.
83#[derive(Clone, Debug)]
84pub struct PropagateHeader<S> {
85    inner: S,
86    header: HeaderName,
87}
88
89impl<S> PropagateHeader<S> {
90    /// Create a new [`PropagateHeader`] that propagates the given header.
91    pub fn new(inner: S, header: HeaderName) -> Self {
92        Self { inner, header }
93    }
94
95    define_inner_service_accessors!();
96
97    /// Returns a new [`Layer`] that wraps services with a `PropagateHeader` middleware.
98    ///
99    /// [`Layer`]: tower_layer::Layer
100    pub fn layer(header: HeaderName) -> PropagateHeaderLayer {
101        PropagateHeaderLayer::new(header)
102    }
103}
104
105impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for PropagateHeader<S>
106where
107    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
108{
109    type Response = S::Response;
110    type Error = S::Error;
111    type Future = ResponseFuture<S::Future>;
112
113    #[inline]
114    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        self.inner.poll_ready(cx)
116    }
117
118    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
119        let value = req.headers().get(&self.header).cloned();
120
121        ResponseFuture {
122            future: self.inner.call(req),
123            header_and_value: Some(self.header.clone()).zip(value),
124        }
125    }
126}
127
128pin_project! {
129    /// Response future for [`PropagateHeader`].
130    #[derive(Debug)]
131    pub struct ResponseFuture<F> {
132        #[pin]
133        future: F,
134        header_and_value: Option<(HeaderName, HeaderValue)>,
135    }
136}
137
138impl<F, ResBody, E> Future for ResponseFuture<F>
139where
140    F: Future<Output = Result<Response<ResBody>, E>>,
141{
142    type Output = F::Output;
143
144    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145        let this = self.project();
146        let mut res = ready!(this.future.poll(cx)?);
147
148        if let Some((header, value)) = this.header_and_value.take() {
149            res.headers_mut().insert(header, value);
150        }
151
152        Poll::Ready(Ok(res))
153    }
154}