rama_http/layer/
propagate_headers.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
//! Propagate a header from the request to the response.
//!
//! # Example
//!
//! ```rust
//! use std::convert::Infallible;
//! use rama_core::error::BoxError;
//! use rama_core::service::service_fn;
//! use rama_core::{Context, Service, Layer};
//! use rama_http::{Body, Request, Response, header::HeaderName};
//! use rama_http::layer::propagate_headers::PropagateHeaderLayer;
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! async fn handle(req: Request) -> Result<Response, Infallible> {
//!     // ...
//!     # Ok(Response::new(Body::default()))
//! }
//!
//! let mut svc = (
//!     // This will copy `x-request-id` headers from requests onto responses.
//!     PropagateHeaderLayer::new(HeaderName::from_static("x-request-id")),
//! ).layer(service_fn(handle));
//!
//! // Call the service.
//! let request = Request::builder()
//!     .header("x-request-id", "1337")
//!     .body(Body::default())?;
//!
//! let response = svc.serve(Context::default(), request).await?;
//!
//! assert_eq!(response.headers()["x-request-id"], "1337");
//! #
//! # Ok(())
//! # }
//! ```

use crate::{header::HeaderName, Request, Response};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;

/// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses.
///
/// If the header is present on the request it'll be applied to the response as well. This could
/// for example be used to propagate headers such as `X-Request-Id`.
///
/// See the [module docs](crate::layer::propagate_headers) for more details.
#[derive(Clone, Debug)]
pub struct PropagateHeaderLayer {
    header: HeaderName,
}

impl PropagateHeaderLayer {
    /// Create a new [`PropagateHeaderLayer`].
    pub const fn new(header: HeaderName) -> Self {
        Self { header }
    }
}

impl<S> Layer<S> for PropagateHeaderLayer {
    type Service = PropagateHeader<S>;

    fn layer(&self, inner: S) -> Self::Service {
        PropagateHeader {
            inner,
            header: self.header.clone(),
        }
    }
}

/// Middleware that propagates headers from requests to responses.
///
/// If the header is present on the request it'll be applied to the response as well. This could
/// for example be used to propagate headers such as `X-Request-Id`.
///
/// See the [module docs](crate::layer::propagate_headers) for more details.
#[derive(Clone, Debug)]
pub struct PropagateHeader<S> {
    inner: S,
    header: HeaderName,
}

impl<S> PropagateHeader<S> {
    /// Create a new [`PropagateHeader`] that propagates the given header.
    pub const fn new(inner: S, header: HeaderName) -> Self {
        Self { inner, header }
    }

    define_inner_service_accessors!();
}

impl<ReqBody, ResBody, S, State> Service<State, Request<ReqBody>> for PropagateHeader<S>
where
    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
    State: Clone + Send + Sync + 'static,
    ReqBody: Send + 'static,
    ResBody: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    async fn serve(
        &self,
        ctx: Context<State>,
        req: Request<ReqBody>,
    ) -> Result<Self::Response, Self::Error> {
        let value = req.headers().get(&self.header).cloned();

        let mut res = self.inner.serve(ctx, req).await?;

        if let Some(value) = value {
            res.headers_mut().insert(self.header.clone(), value);
        }

        Ok(res)
    }
}