rama_http/layer/required_header/
request.rs1use crate::{
7 HeaderValue, Request, Response,
8 header::{self, HOST, RAMA_ID_HEADER_VALUE, USER_AGENT},
9 headers::HeaderMapExt,
10};
11use rama_core::{
12 Context, Layer, Service,
13 error::{BoxError, ErrorContext},
14};
15use rama_net::http::RequestContext;
16use rama_utils::macros::define_inner_service_accessors;
17use std::fmt;
18
19#[derive(Debug, Clone, Default)]
23pub struct AddRequiredRequestHeadersLayer {
24 overwrite: bool,
25 user_agent_header_value: Option<HeaderValue>,
26}
27
28impl AddRequiredRequestHeadersLayer {
29 pub const fn new() -> Self {
31 Self {
32 overwrite: false,
33 user_agent_header_value: None,
34 }
35 }
36
37 pub const fn overwrite(mut self, overwrite: bool) -> Self {
42 self.overwrite = overwrite;
43 self
44 }
45
46 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
51 self.overwrite = overwrite;
52 self
53 }
54
55 pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
59 self.user_agent_header_value = Some(value);
60 self
61 }
62
63 pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
67 self.user_agent_header_value = value;
68 self
69 }
70
71 pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
75 self.user_agent_header_value = Some(value);
76 self
77 }
78}
79
80impl<S> Layer<S> for AddRequiredRequestHeadersLayer {
81 type Service = AddRequiredRequestHeaders<S>;
82
83 fn layer(&self, inner: S) -> Self::Service {
84 AddRequiredRequestHeaders {
85 inner,
86 overwrite: self.overwrite,
87 user_agent_header_value: self.user_agent_header_value.clone(),
88 }
89 }
90
91 fn into_layer(self, inner: S) -> Self::Service {
92 AddRequiredRequestHeaders {
93 inner,
94 overwrite: self.overwrite,
95 user_agent_header_value: self.user_agent_header_value,
96 }
97 }
98}
99
100#[derive(Clone)]
102pub struct AddRequiredRequestHeaders<S> {
103 inner: S,
104 overwrite: bool,
105 user_agent_header_value: Option<HeaderValue>,
106}
107
108impl<S> AddRequiredRequestHeaders<S> {
109 pub const fn new(inner: S) -> Self {
111 Self {
112 inner,
113 overwrite: false,
114 user_agent_header_value: None,
115 }
116 }
117
118 pub const fn overwrite(mut self, overwrite: bool) -> Self {
123 self.overwrite = overwrite;
124 self
125 }
126
127 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
132 self.overwrite = overwrite;
133 self
134 }
135
136 pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
140 self.user_agent_header_value = Some(value);
141 self
142 }
143
144 pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
148 self.user_agent_header_value = value;
149 self
150 }
151
152 pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
156 self.user_agent_header_value = Some(value);
157 self
158 }
159
160 define_inner_service_accessors!();
161}
162
163impl<S> fmt::Debug for AddRequiredRequestHeaders<S>
164where
165 S: fmt::Debug,
166{
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 f.debug_struct("AddRequiredRequestHeaders")
169 .field("inner", &self.inner)
170 .field("user_agent_header_value", &self.user_agent_header_value)
171 .finish()
172 }
173}
174
175impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredRequestHeaders<S>
176where
177 ReqBody: Send + 'static,
178 ResBody: Send + 'static,
179 State: Clone + Send + Sync + 'static,
180 S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
181{
182 type Response = S::Response;
183 type Error = BoxError;
184
185 async fn serve(
186 &self,
187 mut ctx: Context<State>,
188 mut req: Request<ReqBody>,
189 ) -> Result<Self::Response, Self::Error> {
190 if self.overwrite || !req.headers().contains_key(HOST) {
191 let request_ctx: &mut RequestContext = ctx
192 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
193 .context(
194 "AddRequiredRequestHeaders: get/compute RequestContext to set authority",
195 )?;
196 let host = crate::dep::http::uri::Authority::from_maybe_shared(
197 request_ctx.authority.to_string(),
198 )
199 .map(crate::headers::Host::from)
200 .context("AddRequiredRequestHeaders: set authority")?;
201 req.headers_mut().typed_insert(host);
202 }
203
204 if self.overwrite {
205 req.headers_mut().insert(
206 USER_AGENT,
207 self.user_agent_header_value
208 .as_ref()
209 .unwrap_or(&RAMA_ID_HEADER_VALUE)
210 .clone(),
211 );
212 } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) {
213 header.insert(
214 self.user_agent_header_value
215 .as_ref()
216 .unwrap_or(&RAMA_ID_HEADER_VALUE)
217 .clone(),
218 );
219 }
220
221 self.inner.serve(ctx, req).await.map_err(Into::into)
222 }
223}
224
225#[cfg(test)]
226mod test {
227 use super::*;
228 use crate::{Body, Request};
229 use rama_core::service::service_fn;
230 use rama_core::{Context, Layer, Service};
231 use std::convert::Infallible;
232
233 #[tokio::test]
234 async fn add_required_request_headers() {
235 let svc = AddRequiredRequestHeadersLayer::default().into_layer(service_fn(
236 async |_ctx: Context<()>, req: Request| {
237 assert!(req.headers().contains_key(HOST));
238 assert!(req.headers().contains_key(USER_AGENT));
239 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
240 },
241 ));
242
243 let req = Request::builder()
244 .uri("http://www.example.com/")
245 .body(Body::empty())
246 .unwrap();
247 let resp = svc.serve(Context::default(), req).await.unwrap();
248
249 assert!(!resp.headers().contains_key(HOST));
250 assert!(!resp.headers().contains_key(USER_AGENT));
251 }
252
253 #[tokio::test]
254 async fn add_required_request_headers_custom_ua() {
255 let svc = AddRequiredRequestHeadersLayer::default()
256 .user_agent_header_value(HeaderValue::from_static("foo"))
257 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
258 assert!(req.headers().contains_key(HOST));
259 assert_eq!(
260 req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
261 Some("foo")
262 );
263 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
264 }));
265
266 let req = Request::builder()
267 .uri("http://www.example.com/")
268 .body(Body::empty())
269 .unwrap();
270 let resp = svc.serve(Context::default(), req).await.unwrap();
271
272 assert!(!resp.headers().contains_key(HOST));
273 assert!(!resp.headers().contains_key(USER_AGENT));
274 }
275
276 #[tokio::test]
277 async fn add_required_request_headers_overwrite() {
278 let svc = AddRequiredRequestHeadersLayer::new()
279 .overwrite(true)
280 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
281 assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1:80");
282 assert_eq!(
283 req.headers().get(USER_AGENT).unwrap(),
284 RAMA_ID_HEADER_VALUE.to_str().unwrap()
285 );
286 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
287 }));
288
289 let req = Request::builder()
290 .uri("http://127.0.0.1/")
291 .header(HOST, "example.com")
292 .header(USER_AGENT, "test")
293 .body(Body::empty())
294 .unwrap();
295
296 let resp = svc.serve(Context::default(), req).await.unwrap();
297
298 assert!(!resp.headers().contains_key(HOST));
299 assert!(!resp.headers().contains_key(USER_AGENT));
300 }
301
302 #[tokio::test]
303 async fn add_required_request_headers_overwrite_custom_ua() {
304 let svc = AddRequiredRequestHeadersLayer::new()
305 .overwrite(true)
306 .user_agent_header_value(HeaderValue::from_static("foo"))
307 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
308 assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1:80");
309 assert_eq!(
310 req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
311 Some("foo")
312 );
313 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
314 }));
315
316 let req = Request::builder()
317 .uri("http://127.0.0.1/")
318 .header(HOST, "example.com")
319 .header(USER_AGENT, "test")
320 .body(Body::empty())
321 .unwrap();
322
323 let resp = svc.serve(Context::default(), req).await.unwrap();
324
325 assert!(!resp.headers().contains_key(HOST));
326 assert!(!resp.headers().contains_key(USER_AGENT));
327 }
328}