rama_http/layer/required_header/
request.rs

1//! Set required headers on the request, if they are missing.
2//!
3//! For now this only sets `Host` header on http/1.1,
4//! as well as always a User-Agent for all versions.
5
6use crate::{
7    HeaderValue, Request, Response,
8    header::{self, HOST, RAMA_ID_HEADER_VALUE, USER_AGENT},
9    headers::HeaderMapExt,
10};
11use rama_core::{
12    Context, Layer, Service,
13    error::{BoxError, ErrorContext},
14};
15use rama_net::http::RequestContext;
16use rama_utils::macros::define_inner_service_accessors;
17use std::fmt;
18
19/// Layer that applies [`AddRequiredRequestHeaders`] which adds a request header.
20///
21/// See [`AddRequiredRequestHeaders`] for more details.
22#[derive(Debug, Clone, Default)]
23pub struct AddRequiredRequestHeadersLayer {
24    overwrite: bool,
25    user_agent_header_value: Option<HeaderValue>,
26}
27
28impl AddRequiredRequestHeadersLayer {
29    /// Create a new [`AddRequiredRequestHeadersLayer`].
30    pub const fn new() -> Self {
31        Self {
32            overwrite: false,
33            user_agent_header_value: None,
34        }
35    }
36
37    /// Set whether to overwrite the existing headers.
38    /// If set to `true`, the headers will be overwritten.
39    ///
40    /// Default is `false`.
41    pub const fn overwrite(mut self, overwrite: bool) -> Self {
42        self.overwrite = overwrite;
43        self
44    }
45
46    /// Set whether to overwrite the existing headers.
47    /// If set to `true`, the headers will be overwritten.
48    ///
49    /// Default is `false`.
50    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
51        self.overwrite = overwrite;
52        self
53    }
54
55    /// Set a custom [`USER_AGENT`] header value.
56    ///
57    /// By default a versioned `rama` value is used.
58    pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
59        self.user_agent_header_value = Some(value);
60        self
61    }
62
63    /// Maybe set a custom [`USER_AGENT`] header value.
64    ///
65    /// By default a versioned `rama` value is used.
66    pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
67        self.user_agent_header_value = value;
68        self
69    }
70
71    /// Set a custom [`USER_AGENT`] header value.
72    ///
73    /// By default a versioned `rama` value is used.
74    pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
75        self.user_agent_header_value = Some(value);
76        self
77    }
78}
79
80impl<S> Layer<S> for AddRequiredRequestHeadersLayer {
81    type Service = AddRequiredRequestHeaders<S>;
82
83    fn layer(&self, inner: S) -> Self::Service {
84        AddRequiredRequestHeaders {
85            inner,
86            overwrite: self.overwrite,
87            user_agent_header_value: self.user_agent_header_value.clone(),
88        }
89    }
90
91    fn into_layer(self, inner: S) -> Self::Service {
92        AddRequiredRequestHeaders {
93            inner,
94            overwrite: self.overwrite,
95            user_agent_header_value: self.user_agent_header_value,
96        }
97    }
98}
99
100/// Middleware that sets a header on the request.
101#[derive(Clone)]
102pub struct AddRequiredRequestHeaders<S> {
103    inner: S,
104    overwrite: bool,
105    user_agent_header_value: Option<HeaderValue>,
106}
107
108impl<S> AddRequiredRequestHeaders<S> {
109    /// Create a new [`AddRequiredRequestHeaders`].
110    pub const fn new(inner: S) -> Self {
111        Self {
112            inner,
113            overwrite: false,
114            user_agent_header_value: None,
115        }
116    }
117
118    /// Set whether to overwrite the existing headers.
119    /// If set to `true`, the headers will be overwritten.
120    ///
121    /// Default is `false`.
122    pub const fn overwrite(mut self, overwrite: bool) -> Self {
123        self.overwrite = overwrite;
124        self
125    }
126
127    /// Set whether to overwrite the existing headers.
128    /// If set to `true`, the headers will be overwritten.
129    ///
130    /// Default is `false`.
131    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
132        self.overwrite = overwrite;
133        self
134    }
135
136    /// Set a custom [`USER_AGENT`] header value.
137    ///
138    /// By default a versioned `rama` value is used.
139    pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
140        self.user_agent_header_value = Some(value);
141        self
142    }
143
144    /// Maybe set a custom [`USER_AGENT`] header value.
145    ///
146    /// By default a versioned `rama` value is used.
147    pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
148        self.user_agent_header_value = value;
149        self
150    }
151
152    /// Set a custom [`USER_AGENT`] header value.
153    ///
154    /// By default a versioned `rama` value is used.
155    pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
156        self.user_agent_header_value = Some(value);
157        self
158    }
159
160    define_inner_service_accessors!();
161}
162
163impl<S> fmt::Debug for AddRequiredRequestHeaders<S>
164where
165    S: fmt::Debug,
166{
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        f.debug_struct("AddRequiredRequestHeaders")
169            .field("inner", &self.inner)
170            .field("user_agent_header_value", &self.user_agent_header_value)
171            .finish()
172    }
173}
174
175impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredRequestHeaders<S>
176where
177    ReqBody: Send + 'static,
178    ResBody: Send + 'static,
179    State: Clone + Send + Sync + 'static,
180    S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
181{
182    type Response = S::Response;
183    type Error = BoxError;
184
185    async fn serve(
186        &self,
187        mut ctx: Context<State>,
188        mut req: Request<ReqBody>,
189    ) -> Result<Self::Response, Self::Error> {
190        if self.overwrite || !req.headers().contains_key(HOST) {
191            let request_ctx: &mut RequestContext = ctx
192                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
193                .context(
194                    "AddRequiredRequestHeaders: get/compute RequestContext to set authority",
195                )?;
196            let host = crate::dep::http::uri::Authority::from_maybe_shared(
197                request_ctx.authority.to_string(),
198            )
199            .map(crate::headers::Host::from)
200            .context("AddRequiredRequestHeaders: set authority")?;
201            req.headers_mut().typed_insert(host);
202        }
203
204        if self.overwrite {
205            req.headers_mut().insert(
206                USER_AGENT,
207                self.user_agent_header_value
208                    .as_ref()
209                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
210                    .clone(),
211            );
212        } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) {
213            header.insert(
214                self.user_agent_header_value
215                    .as_ref()
216                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
217                    .clone(),
218            );
219        }
220
221        self.inner.serve(ctx, req).await.map_err(Into::into)
222    }
223}
224
225#[cfg(test)]
226mod test {
227    use super::*;
228    use crate::{Body, Request};
229    use rama_core::service::service_fn;
230    use rama_core::{Context, Layer, Service};
231    use std::convert::Infallible;
232
233    #[tokio::test]
234    async fn add_required_request_headers() {
235        let svc = AddRequiredRequestHeadersLayer::default().into_layer(service_fn(
236            async |_ctx: Context<()>, req: Request| {
237                assert!(req.headers().contains_key(HOST));
238                assert!(req.headers().contains_key(USER_AGENT));
239                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
240            },
241        ));
242
243        let req = Request::builder()
244            .uri("http://www.example.com/")
245            .body(Body::empty())
246            .unwrap();
247        let resp = svc.serve(Context::default(), req).await.unwrap();
248
249        assert!(!resp.headers().contains_key(HOST));
250        assert!(!resp.headers().contains_key(USER_AGENT));
251    }
252
253    #[tokio::test]
254    async fn add_required_request_headers_custom_ua() {
255        let svc = AddRequiredRequestHeadersLayer::default()
256            .user_agent_header_value(HeaderValue::from_static("foo"))
257            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
258                assert!(req.headers().contains_key(HOST));
259                assert_eq!(
260                    req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
261                    Some("foo")
262                );
263                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
264            }));
265
266        let req = Request::builder()
267            .uri("http://www.example.com/")
268            .body(Body::empty())
269            .unwrap();
270        let resp = svc.serve(Context::default(), req).await.unwrap();
271
272        assert!(!resp.headers().contains_key(HOST));
273        assert!(!resp.headers().contains_key(USER_AGENT));
274    }
275
276    #[tokio::test]
277    async fn add_required_request_headers_overwrite() {
278        let svc = AddRequiredRequestHeadersLayer::new()
279            .overwrite(true)
280            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
281                assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1:80");
282                assert_eq!(
283                    req.headers().get(USER_AGENT).unwrap(),
284                    RAMA_ID_HEADER_VALUE.to_str().unwrap()
285                );
286                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
287            }));
288
289        let req = Request::builder()
290            .uri("http://127.0.0.1/")
291            .header(HOST, "example.com")
292            .header(USER_AGENT, "test")
293            .body(Body::empty())
294            .unwrap();
295
296        let resp = svc.serve(Context::default(), req).await.unwrap();
297
298        assert!(!resp.headers().contains_key(HOST));
299        assert!(!resp.headers().contains_key(USER_AGENT));
300    }
301
302    #[tokio::test]
303    async fn add_required_request_headers_overwrite_custom_ua() {
304        let svc = AddRequiredRequestHeadersLayer::new()
305            .overwrite(true)
306            .user_agent_header_value(HeaderValue::from_static("foo"))
307            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
308                assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1:80");
309                assert_eq!(
310                    req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
311                    Some("foo")
312                );
313                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
314            }));
315
316        let req = Request::builder()
317            .uri("http://127.0.0.1/")
318            .header(HOST, "example.com")
319            .header(USER_AGENT, "test")
320            .body(Body::empty())
321            .unwrap();
322
323        let resp = svc.serve(Context::default(), req).await.unwrap();
324
325        assert!(!resp.headers().contains_key(HOST));
326        assert!(!resp.headers().contains_key(USER_AGENT));
327    }
328}