rama_http/layer/
request_id.rs

1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::request_id::{
7//!     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
8//! };
9//! use rama_http::{Body, Request, Response, header::HeaderName};
10//! use rama_core::service::service_fn;
11//! use rama_core::{Context, Service, Layer};
12//! use rama_core::error::BoxError;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), BoxError> {
17//! # let handler = service_fn(|request: Request| async move {
18//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//!     counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//!     fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
29//!         let request_id = self.counter
30//!             .fetch_add(1, Ordering::AcqRel)
31//!             .to_string()
32//!             .parse()
33//!             .unwrap();
34//!
35//!         Some(RequestId::new(request_id))
36//!     }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = (
42//!     // set `x-request-id` header on all requests
43//!     SetRequestIdLayer::new(
44//!         x_request_id.clone(),
45//!         MyMakeRequestId::default(),
46//!     ),
47//!     // propagate `x-request-id` headers from request to response
48//!     PropagateRequestIdLayer::new(x_request_id),
49//! ).layer(handler);
50//!
51//! let request = Request::new(Body::empty());
52//! let response = svc.serve(Context::default(), request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59
60use std::fmt;
61
62use crate::{
63    header::{HeaderName, HeaderValue},
64    Request, Response,
65};
66use rama_core::{Context, Layer, Service};
67use rama_utils::macros::define_inner_service_accessors;
68use uuid::Uuid;
69
70pub(crate) const X_REQUEST_ID: &str = "x-request-id";
71
72/// Trait for producing [`RequestId`]s.
73///
74/// Used by [`SetRequestId`].
75pub trait MakeRequestId: Send + Sync + 'static {
76    /// Try and produce a [`RequestId`] from the request.
77    fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
78}
79
80/// An identifier for a request.
81#[derive(Debug, Clone)]
82pub struct RequestId(HeaderValue);
83
84impl RequestId {
85    /// Create a new `RequestId` from a [`HeaderValue`].
86    pub const fn new(header_value: HeaderValue) -> Self {
87        Self(header_value)
88    }
89
90    /// Gets a reference to the underlying [`HeaderValue`].
91    pub fn header_value(&self) -> &HeaderValue {
92        &self.0
93    }
94
95    /// Consumes `self`, returning the underlying [`HeaderValue`].
96    pub fn into_header_value(self) -> HeaderValue {
97        self.0
98    }
99}
100
101impl From<HeaderValue> for RequestId {
102    fn from(value: HeaderValue) -> Self {
103        Self::new(value)
104    }
105}
106
107/// Set request id headers and extensions on requests.
108///
109/// This layer applies the [`SetRequestId`] middleware.
110///
111/// See the [module docs](self) and [`SetRequestId`] for more details.
112pub struct SetRequestIdLayer<M> {
113    header_name: HeaderName,
114    make_request_id: M,
115}
116
117impl<M: fmt::Debug> fmt::Debug for SetRequestIdLayer<M> {
118    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119        f.debug_struct("SetRequestIdLayer")
120            .field("header_name", &self.header_name)
121            .field("make_request_id", &self.make_request_id)
122            .finish()
123    }
124}
125
126impl<M: Clone> Clone for SetRequestIdLayer<M> {
127    fn clone(&self) -> Self {
128        Self {
129            header_name: self.header_name.clone(),
130            make_request_id: self.make_request_id.clone(),
131        }
132    }
133}
134
135impl<M> SetRequestIdLayer<M> {
136    /// Create a new `SetRequestIdLayer`.
137    pub const fn new(header_name: HeaderName, make_request_id: M) -> Self
138    where
139        M: MakeRequestId,
140    {
141        SetRequestIdLayer {
142            header_name,
143            make_request_id,
144        }
145    }
146
147    /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
148    pub const fn x_request_id(make_request_id: M) -> Self
149    where
150        M: MakeRequestId,
151    {
152        SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
153    }
154}
155
156impl<S, M> Layer<S> for SetRequestIdLayer<M>
157where
158    M: Clone + MakeRequestId,
159{
160    type Service = SetRequestId<S, M>;
161
162    fn layer(&self, inner: S) -> Self::Service {
163        SetRequestId::new(
164            inner,
165            self.header_name.clone(),
166            self.make_request_id.clone(),
167        )
168    }
169}
170
171/// Set request id headers and extensions on requests.
172///
173/// See the [module docs](self) for an example.
174///
175/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
176/// header with the same name, then the header will be inserted.
177///
178/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
179/// services can access it.
180pub struct SetRequestId<S, M> {
181    inner: S,
182    header_name: HeaderName,
183    make_request_id: M,
184}
185
186impl<S: fmt::Debug, M: fmt::Debug> fmt::Debug for SetRequestId<S, M> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        f.debug_struct("SetRequestId")
189            .field("inner", &self.inner)
190            .field("header_name", &self.header_name)
191            .field("make_request_id", &self.make_request_id)
192            .finish()
193    }
194}
195
196impl<S: Clone, M: Clone> Clone for SetRequestId<S, M> {
197    fn clone(&self) -> Self {
198        SetRequestId {
199            inner: self.inner.clone(),
200            header_name: self.header_name.clone(),
201            make_request_id: self.make_request_id.clone(),
202        }
203    }
204}
205
206impl<S, M> SetRequestId<S, M> {
207    /// Create a new `SetRequestId`.
208    pub const fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
209    where
210        M: MakeRequestId,
211    {
212        Self {
213            inner,
214            header_name,
215            make_request_id,
216        }
217    }
218
219    /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
220    pub const fn x_request_id(inner: S, make_request_id: M) -> Self
221    where
222        M: MakeRequestId,
223    {
224        Self::new(
225            inner,
226            HeaderName::from_static(X_REQUEST_ID),
227            make_request_id,
228        )
229    }
230
231    define_inner_service_accessors!();
232}
233
234impl<State, S, M, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetRequestId<S, M>
235where
236    State: Clone + Send + Sync + 'static,
237    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
238    M: MakeRequestId,
239    ReqBody: Send + 'static,
240    ResBody: Send + 'static,
241{
242    type Response = S::Response;
243    type Error = S::Error;
244
245    async fn serve(
246        &self,
247        ctx: Context<State>,
248        mut req: Request<ReqBody>,
249    ) -> Result<Self::Response, Self::Error> {
250        if let Some(request_id) = req.headers().get(&self.header_name) {
251            if req.extensions().get::<RequestId>().is_none() {
252                let request_id = request_id.clone();
253                req.extensions_mut().insert(RequestId::new(request_id));
254            }
255        } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
256            req.extensions_mut().insert(request_id.clone());
257            req.headers_mut()
258                .insert(self.header_name.clone(), request_id.0);
259        }
260
261        self.inner.serve(ctx, req).await
262    }
263}
264
265/// Propagate request ids from requests to responses.
266///
267/// This layer applies the [`PropagateRequestId`] middleware.
268///
269/// See the [module docs](self) and [`PropagateRequestId`] for more details.
270#[derive(Debug, Clone)]
271pub struct PropagateRequestIdLayer {
272    header_name: HeaderName,
273}
274
275impl PropagateRequestIdLayer {
276    /// Create a new `PropagateRequestIdLayer`.
277    pub const fn new(header_name: HeaderName) -> Self {
278        PropagateRequestIdLayer { header_name }
279    }
280
281    /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
282    pub const fn x_request_id() -> Self {
283        Self::new(HeaderName::from_static(X_REQUEST_ID))
284    }
285}
286
287impl<S> Layer<S> for PropagateRequestIdLayer {
288    type Service = PropagateRequestId<S>;
289
290    fn layer(&self, inner: S) -> Self::Service {
291        PropagateRequestId::new(inner, self.header_name.clone())
292    }
293}
294
295/// Propagate request ids from requests to responses.
296///
297/// See the [module docs](self) for an example.
298///
299/// If the request contains a matching header that header will be applied to responses. If a
300/// [`RequestId`] extension is also present it will be propagated as well.
301pub struct PropagateRequestId<S> {
302    inner: S,
303    header_name: HeaderName,
304}
305
306impl<S> PropagateRequestId<S> {
307    /// Create a new `PropagateRequestId`.
308    pub const fn new(inner: S, header_name: HeaderName) -> Self {
309        Self { inner, header_name }
310    }
311
312    /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
313    pub const fn x_request_id(inner: S) -> Self {
314        Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
315    }
316
317    define_inner_service_accessors!();
318}
319
320impl<S: fmt::Debug> fmt::Debug for PropagateRequestId<S> {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        f.debug_struct("PropagateRequestId")
323            .field("inner", &self.inner)
324            .field("header_name", &self.header_name)
325            .finish()
326    }
327}
328
329impl<S: Clone> Clone for PropagateRequestId<S> {
330    fn clone(&self) -> Self {
331        PropagateRequestId {
332            inner: self.inner.clone(),
333            header_name: self.header_name.clone(),
334        }
335    }
336}
337
338impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for PropagateRequestId<S>
339where
340    State: Clone + Send + Sync + 'static,
341    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
342    ReqBody: Send + 'static,
343    ResBody: Send + 'static,
344{
345    type Response = S::Response;
346    type Error = S::Error;
347
348    async fn serve(
349        &self,
350        ctx: Context<State>,
351        req: Request<ReqBody>,
352    ) -> Result<Self::Response, Self::Error> {
353        let request_id = req
354            .headers()
355            .get(&self.header_name)
356            .cloned()
357            .map(RequestId::new);
358
359        let mut response = self.inner.serve(ctx, req).await?;
360
361        if let Some(current_id) = response.headers().get(&self.header_name) {
362            if response.extensions().get::<RequestId>().is_none() {
363                let current_id = current_id.clone();
364                response.extensions_mut().insert(RequestId::new(current_id));
365            }
366        } else if let Some(request_id) = request_id {
367            response
368                .headers_mut()
369                .insert(self.header_name.clone(), request_id.0.clone());
370            response.extensions_mut().insert(request_id);
371        }
372
373        Ok(response)
374    }
375}
376
377/// A [`MakeRequestId`] that generates `UUID`s.
378#[derive(Debug, Clone, Copy, Default)]
379pub struct MakeRequestUuid;
380
381impl MakeRequestId for MakeRequestUuid {
382    fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
383        let request_id = Uuid::new_v4().to_string().parse().unwrap();
384        Some(RequestId::new(request_id))
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use crate::layer::set_header;
391    use crate::{Body, Response};
392    use rama_core::service::service_fn;
393    use rama_core::Layer;
394    use std::{
395        convert::Infallible,
396        sync::{
397            atomic::{AtomicU64, Ordering},
398            Arc,
399        },
400    };
401
402    #[allow(unused_imports)]
403    use super::*;
404
405    #[tokio::test]
406    async fn basic() {
407        let svc = (
408            SetRequestIdLayer::x_request_id(Counter::default()),
409            PropagateRequestIdLayer::x_request_id(),
410        )
411            .layer(service_fn(handler));
412
413        // header on response
414        let req = Request::builder().body(Body::empty()).unwrap();
415        let res = svc.serve(Context::default(), req).await.unwrap();
416        assert_eq!(res.headers()["x-request-id"], "0");
417
418        let req = Request::builder().body(Body::empty()).unwrap();
419        let res = svc.serve(Context::default(), req).await.unwrap();
420        assert_eq!(res.headers()["x-request-id"], "1");
421
422        // doesn't override if header is already there
423        let req = Request::builder()
424            .header("x-request-id", "foo")
425            .body(Body::empty())
426            .unwrap();
427        let res = svc.serve(Context::default(), req).await.unwrap();
428        assert_eq!(res.headers()["x-request-id"], "foo");
429
430        // extension propagated
431        let req = Request::builder().body(Body::empty()).unwrap();
432        let res = svc.serve(Context::default(), req).await.unwrap();
433        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
434    }
435
436    #[tokio::test]
437    async fn other_middleware_setting_request_id_on_response() {
438        let svc = (
439            SetRequestIdLayer::x_request_id(Counter::default()),
440            PropagateRequestIdLayer::x_request_id(),
441            set_header::SetResponseHeaderLayer::overriding(
442                HeaderName::from_static("x-request-id"),
443                HeaderValue::from_str("foo").unwrap(),
444            ),
445        )
446            .layer(service_fn(handler));
447
448        let req = Request::builder()
449            .header("x-request-id", "foo")
450            .body(Body::empty())
451            .unwrap();
452        let res = svc.serve(Context::default(), req).await.unwrap();
453        assert_eq!(res.headers()["x-request-id"], "foo");
454        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
455    }
456
457    #[derive(Clone, Default)]
458    struct Counter(Arc<AtomicU64>);
459
460    impl MakeRequestId for Counter {
461        fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
462            let id =
463                HeaderValue::from_str(&self.0.fetch_add(1, Ordering::AcqRel).to_string()).unwrap();
464            Some(RequestId::new(id))
465        }
466    }
467
468    async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
469        Ok(Response::new(Body::empty()))
470    }
471
472    #[tokio::test]
473    async fn uuid() {
474        let svc = (
475            SetRequestIdLayer::x_request_id(MakeRequestUuid),
476            PropagateRequestIdLayer::x_request_id(),
477        )
478            .layer(service_fn(handler));
479
480        // header on response
481        let req = Request::builder().body(Body::empty()).unwrap();
482        let mut res = svc.serve(Context::default(), req).await.unwrap();
483        let id = res.headers_mut().remove("x-request-id").unwrap();
484        id.to_str().unwrap().parse::<Uuid>().unwrap();
485    }
486}