1use std::fmt;
61
62use crate::{
63 header::{HeaderName, HeaderValue},
64 Request, Response,
65};
66use rama_core::{Context, Layer, Service};
67use rama_utils::macros::define_inner_service_accessors;
68use uuid::Uuid;
69
70pub(crate) const X_REQUEST_ID: &str = "x-request-id";
71
72pub trait MakeRequestId: Send + Sync + 'static {
76 fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
78}
79
80#[derive(Debug, Clone)]
82pub struct RequestId(HeaderValue);
83
84impl RequestId {
85 pub const fn new(header_value: HeaderValue) -> Self {
87 Self(header_value)
88 }
89
90 pub fn header_value(&self) -> &HeaderValue {
92 &self.0
93 }
94
95 pub fn into_header_value(self) -> HeaderValue {
97 self.0
98 }
99}
100
101impl From<HeaderValue> for RequestId {
102 fn from(value: HeaderValue) -> Self {
103 Self::new(value)
104 }
105}
106
107pub struct SetRequestIdLayer<M> {
113 header_name: HeaderName,
114 make_request_id: M,
115}
116
117impl<M: fmt::Debug> fmt::Debug for SetRequestIdLayer<M> {
118 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119 f.debug_struct("SetRequestIdLayer")
120 .field("header_name", &self.header_name)
121 .field("make_request_id", &self.make_request_id)
122 .finish()
123 }
124}
125
126impl<M: Clone> Clone for SetRequestIdLayer<M> {
127 fn clone(&self) -> Self {
128 Self {
129 header_name: self.header_name.clone(),
130 make_request_id: self.make_request_id.clone(),
131 }
132 }
133}
134
135impl<M> SetRequestIdLayer<M> {
136 pub const fn new(header_name: HeaderName, make_request_id: M) -> Self
138 where
139 M: MakeRequestId,
140 {
141 SetRequestIdLayer {
142 header_name,
143 make_request_id,
144 }
145 }
146
147 pub const fn x_request_id(make_request_id: M) -> Self
149 where
150 M: MakeRequestId,
151 {
152 SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
153 }
154}
155
156impl<S, M> Layer<S> for SetRequestIdLayer<M>
157where
158 M: Clone + MakeRequestId,
159{
160 type Service = SetRequestId<S, M>;
161
162 fn layer(&self, inner: S) -> Self::Service {
163 SetRequestId::new(
164 inner,
165 self.header_name.clone(),
166 self.make_request_id.clone(),
167 )
168 }
169}
170
171pub struct SetRequestId<S, M> {
181 inner: S,
182 header_name: HeaderName,
183 make_request_id: M,
184}
185
186impl<S: fmt::Debug, M: fmt::Debug> fmt::Debug for SetRequestId<S, M> {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.debug_struct("SetRequestId")
189 .field("inner", &self.inner)
190 .field("header_name", &self.header_name)
191 .field("make_request_id", &self.make_request_id)
192 .finish()
193 }
194}
195
196impl<S: Clone, M: Clone> Clone for SetRequestId<S, M> {
197 fn clone(&self) -> Self {
198 SetRequestId {
199 inner: self.inner.clone(),
200 header_name: self.header_name.clone(),
201 make_request_id: self.make_request_id.clone(),
202 }
203 }
204}
205
206impl<S, M> SetRequestId<S, M> {
207 pub const fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
209 where
210 M: MakeRequestId,
211 {
212 Self {
213 inner,
214 header_name,
215 make_request_id,
216 }
217 }
218
219 pub const fn x_request_id(inner: S, make_request_id: M) -> Self
221 where
222 M: MakeRequestId,
223 {
224 Self::new(
225 inner,
226 HeaderName::from_static(X_REQUEST_ID),
227 make_request_id,
228 )
229 }
230
231 define_inner_service_accessors!();
232}
233
234impl<State, S, M, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetRequestId<S, M>
235where
236 State: Clone + Send + Sync + 'static,
237 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
238 M: MakeRequestId,
239 ReqBody: Send + 'static,
240 ResBody: Send + 'static,
241{
242 type Response = S::Response;
243 type Error = S::Error;
244
245 async fn serve(
246 &self,
247 ctx: Context<State>,
248 mut req: Request<ReqBody>,
249 ) -> Result<Self::Response, Self::Error> {
250 if let Some(request_id) = req.headers().get(&self.header_name) {
251 if req.extensions().get::<RequestId>().is_none() {
252 let request_id = request_id.clone();
253 req.extensions_mut().insert(RequestId::new(request_id));
254 }
255 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
256 req.extensions_mut().insert(request_id.clone());
257 req.headers_mut()
258 .insert(self.header_name.clone(), request_id.0);
259 }
260
261 self.inner.serve(ctx, req).await
262 }
263}
264
265#[derive(Debug, Clone)]
271pub struct PropagateRequestIdLayer {
272 header_name: HeaderName,
273}
274
275impl PropagateRequestIdLayer {
276 pub const fn new(header_name: HeaderName) -> Self {
278 PropagateRequestIdLayer { header_name }
279 }
280
281 pub const fn x_request_id() -> Self {
283 Self::new(HeaderName::from_static(X_REQUEST_ID))
284 }
285}
286
287impl<S> Layer<S> for PropagateRequestIdLayer {
288 type Service = PropagateRequestId<S>;
289
290 fn layer(&self, inner: S) -> Self::Service {
291 PropagateRequestId::new(inner, self.header_name.clone())
292 }
293}
294
295pub struct PropagateRequestId<S> {
302 inner: S,
303 header_name: HeaderName,
304}
305
306impl<S> PropagateRequestId<S> {
307 pub const fn new(inner: S, header_name: HeaderName) -> Self {
309 Self { inner, header_name }
310 }
311
312 pub const fn x_request_id(inner: S) -> Self {
314 Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
315 }
316
317 define_inner_service_accessors!();
318}
319
320impl<S: fmt::Debug> fmt::Debug for PropagateRequestId<S> {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 f.debug_struct("PropagateRequestId")
323 .field("inner", &self.inner)
324 .field("header_name", &self.header_name)
325 .finish()
326 }
327}
328
329impl<S: Clone> Clone for PropagateRequestId<S> {
330 fn clone(&self) -> Self {
331 PropagateRequestId {
332 inner: self.inner.clone(),
333 header_name: self.header_name.clone(),
334 }
335 }
336}
337
338impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for PropagateRequestId<S>
339where
340 State: Clone + Send + Sync + 'static,
341 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
342 ReqBody: Send + 'static,
343 ResBody: Send + 'static,
344{
345 type Response = S::Response;
346 type Error = S::Error;
347
348 async fn serve(
349 &self,
350 ctx: Context<State>,
351 req: Request<ReqBody>,
352 ) -> Result<Self::Response, Self::Error> {
353 let request_id = req
354 .headers()
355 .get(&self.header_name)
356 .cloned()
357 .map(RequestId::new);
358
359 let mut response = self.inner.serve(ctx, req).await?;
360
361 if let Some(current_id) = response.headers().get(&self.header_name) {
362 if response.extensions().get::<RequestId>().is_none() {
363 let current_id = current_id.clone();
364 response.extensions_mut().insert(RequestId::new(current_id));
365 }
366 } else if let Some(request_id) = request_id {
367 response
368 .headers_mut()
369 .insert(self.header_name.clone(), request_id.0.clone());
370 response.extensions_mut().insert(request_id);
371 }
372
373 Ok(response)
374 }
375}
376
377#[derive(Debug, Clone, Copy, Default)]
379pub struct MakeRequestUuid;
380
381impl MakeRequestId for MakeRequestUuid {
382 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
383 let request_id = Uuid::new_v4().to_string().parse().unwrap();
384 Some(RequestId::new(request_id))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::layer::set_header;
391 use crate::{Body, Response};
392 use rama_core::service::service_fn;
393 use rama_core::Layer;
394 use std::{
395 convert::Infallible,
396 sync::{
397 atomic::{AtomicU64, Ordering},
398 Arc,
399 },
400 };
401
402 #[allow(unused_imports)]
403 use super::*;
404
405 #[tokio::test]
406 async fn basic() {
407 let svc = (
408 SetRequestIdLayer::x_request_id(Counter::default()),
409 PropagateRequestIdLayer::x_request_id(),
410 )
411 .layer(service_fn(handler));
412
413 let req = Request::builder().body(Body::empty()).unwrap();
415 let res = svc.serve(Context::default(), req).await.unwrap();
416 assert_eq!(res.headers()["x-request-id"], "0");
417
418 let req = Request::builder().body(Body::empty()).unwrap();
419 let res = svc.serve(Context::default(), req).await.unwrap();
420 assert_eq!(res.headers()["x-request-id"], "1");
421
422 let req = Request::builder()
424 .header("x-request-id", "foo")
425 .body(Body::empty())
426 .unwrap();
427 let res = svc.serve(Context::default(), req).await.unwrap();
428 assert_eq!(res.headers()["x-request-id"], "foo");
429
430 let req = Request::builder().body(Body::empty()).unwrap();
432 let res = svc.serve(Context::default(), req).await.unwrap();
433 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
434 }
435
436 #[tokio::test]
437 async fn other_middleware_setting_request_id_on_response() {
438 let svc = (
439 SetRequestIdLayer::x_request_id(Counter::default()),
440 PropagateRequestIdLayer::x_request_id(),
441 set_header::SetResponseHeaderLayer::overriding(
442 HeaderName::from_static("x-request-id"),
443 HeaderValue::from_str("foo").unwrap(),
444 ),
445 )
446 .layer(service_fn(handler));
447
448 let req = Request::builder()
449 .header("x-request-id", "foo")
450 .body(Body::empty())
451 .unwrap();
452 let res = svc.serve(Context::default(), req).await.unwrap();
453 assert_eq!(res.headers()["x-request-id"], "foo");
454 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
455 }
456
457 #[derive(Clone, Default)]
458 struct Counter(Arc<AtomicU64>);
459
460 impl MakeRequestId for Counter {
461 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
462 let id =
463 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::AcqRel).to_string()).unwrap();
464 Some(RequestId::new(id))
465 }
466 }
467
468 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
469 Ok(Response::new(Body::empty()))
470 }
471
472 #[tokio::test]
473 async fn uuid() {
474 let svc = (
475 SetRequestIdLayer::x_request_id(MakeRequestUuid),
476 PropagateRequestIdLayer::x_request_id(),
477 )
478 .layer(service_fn(handler));
479
480 let req = Request::builder().body(Body::empty()).unwrap();
482 let mut res = svc.serve(Context::default(), req).await.unwrap();
483 let id = res.headers_mut().remove("x-request-id").unwrap();
484 id.to_str().unwrap().parse::<Uuid>().unwrap();
485 }
486}