tower_http/
request_id.rs

1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use tower::{Service, ServiceExt, ServiceBuilder};
8//! use tower_http::request_id::{
9//!     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
10//! };
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
18//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//!     counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//!     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
29//!         let request_id = self.counter
30//!             .fetch_add(1, Ordering::SeqCst)
31//!             .to_string()
32//!             .parse()
33//!             .unwrap();
34//!
35//!         Some(RequestId::new(request_id))
36//!     }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//!     // set `x-request-id` header on all requests
43//!     .layer(SetRequestIdLayer::new(
44//!         x_request_id.clone(),
45//!         MyMakeRequestId::default(),
46//!     ))
47//!     // propagate `x-request-id` headers from request to response
48//!     .layer(PropagateRequestIdLayer::new(x_request_id))
49//!     .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.ready().await?.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use tower::{Service, ServiceExt, ServiceBuilder};
66//! # use tower_http::request_id::{
67//! #     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
68//! # };
69//! # use bytes::Bytes;
70//! # use http_body_util::Full;
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! # #[tokio::main]
73//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
75//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
76//! # });
77//! # #[derive(Clone, Default)]
78//! # struct MyMakeRequestId {
79//! #     counter: Arc<AtomicU64>,
80//! # }
81//! # impl MakeRequestId for MyMakeRequestId {
82//! #     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
83//! #         let request_id = self.counter
84//! #             .fetch_add(1, Ordering::SeqCst)
85//! #             .to_string()
86//! #             .parse()
87//! #             .unwrap();
88//! #         Some(RequestId::new(request_id))
89//! #     }
90//! # }
91//!
92//! let mut svc = ServiceBuilder::new()
93//!     .set_x_request_id(MyMakeRequestId::default())
94//!     .propagate_x_request_id()
95//!     .service(handler);
96//!
97//! let request = Request::new(Full::default());
98//! let response = svc.ready().await?.call(request).await?;
99//!
100//! assert_eq!(response.headers()["x-request-id"], "0");
101//! #
102//! # Ok(())
103//! # }
104//! ```
105//!
106//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
107//!
108//! # Using `Trace`
109//!
110//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers
111//! in this order:
112//!
113//! ```
114//! use tower_http::{
115//!     ServiceBuilderExt,
116//!     trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
117//! };
118//! # use http::{Request, Response, header::HeaderName};
119//! # use tower::{Service, ServiceExt, ServiceBuilder};
120//! # use tower_http::request_id::{
121//! #     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
122//! # };
123//! # use http_body_util::Full;
124//! # use bytes::Bytes;
125//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
126//! # #[tokio::main]
127//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
128//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
129//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
130//! # });
131//! # #[derive(Clone, Default)]
132//! # struct MyMakeRequestId {
133//! #     counter: Arc<AtomicU64>,
134//! # }
135//! # impl MakeRequestId for MyMakeRequestId {
136//! #     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
137//! #         let request_id = self.counter
138//! #             .fetch_add(1, Ordering::SeqCst)
139//! #             .to_string()
140//! #             .parse()
141//! #             .unwrap();
142//! #         Some(RequestId::new(request_id))
143//! #     }
144//! # }
145//!
146//! let svc = ServiceBuilder::new()
147//!     // make sure to set request ids before the request reaches `TraceLayer`
148//!     .set_x_request_id(MyMakeRequestId::default())
149//!     // log requests and responses
150//!     .layer(
151//!         TraceLayer::new_for_http()
152//!             .make_span_with(DefaultMakeSpan::new().include_headers(true))
153//!             .on_response(DefaultOnResponse::new().include_headers(true))
154//!     )
155//!     // propagate the header to the response before the response reaches `TraceLayer`
156//!     .propagate_x_request_id()
157//!     .service(handler);
158//! #
159//! # Ok(())
160//! # }
161//! ```
162//!
163//! # Doesn't override existing headers
164//!
165//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
166//! requests or responses. Among other things, this allows other middleware to conditionally set
167//! request ids and use the middleware in this module as a fallback.
168//!
169//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
170//! [`Uuid`]: https://crates.io/crates/uuid
171//! [`Trace`]: crate::trace::Trace
172
173use http::{
174    header::{HeaderName, HeaderValue},
175    Request, Response,
176};
177use pin_project_lite::pin_project;
178use std::task::{ready, Context, Poll};
179use std::{future::Future, pin::Pin};
180use tower_layer::Layer;
181use tower_service::Service;
182use uuid::Uuid;
183
184pub(crate) const X_REQUEST_ID: &str = "x-request-id";
185
186/// Trait for producing [`RequestId`]s.
187///
188/// Used by [`SetRequestId`].
189pub trait MakeRequestId {
190    /// Try and produce a [`RequestId`] from the request.
191    fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>;
192}
193
194/// An identifier for a request.
195#[derive(Debug, Clone)]
196pub struct RequestId(HeaderValue);
197
198impl RequestId {
199    /// Create a new `RequestId` from a [`HeaderValue`].
200    pub fn new(header_value: HeaderValue) -> Self {
201        Self(header_value)
202    }
203
204    /// Gets a reference to the underlying [`HeaderValue`].
205    pub fn header_value(&self) -> &HeaderValue {
206        &self.0
207    }
208
209    /// Consumes `self`, returning the underlying [`HeaderValue`].
210    pub fn into_header_value(self) -> HeaderValue {
211        self.0
212    }
213}
214
215impl From<HeaderValue> for RequestId {
216    fn from(value: HeaderValue) -> Self {
217        Self::new(value)
218    }
219}
220
221/// Set request id headers and extensions on requests.
222///
223/// This layer applies the [`SetRequestId`] middleware.
224///
225/// See the [module docs](self) and [`SetRequestId`] for more details.
226#[derive(Debug, Clone)]
227pub struct SetRequestIdLayer<M> {
228    header_name: HeaderName,
229    make_request_id: M,
230}
231
232impl<M> SetRequestIdLayer<M> {
233    /// Create a new `SetRequestIdLayer`.
234    pub fn new(header_name: HeaderName, make_request_id: M) -> Self
235    where
236        M: MakeRequestId,
237    {
238        SetRequestIdLayer {
239            header_name,
240            make_request_id,
241        }
242    }
243
244    /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
245    pub fn x_request_id(make_request_id: M) -> Self
246    where
247        M: MakeRequestId,
248    {
249        SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
250    }
251}
252
253impl<S, M> Layer<S> for SetRequestIdLayer<M>
254where
255    M: Clone + MakeRequestId,
256{
257    type Service = SetRequestId<S, M>;
258
259    fn layer(&self, inner: S) -> Self::Service {
260        SetRequestId::new(
261            inner,
262            self.header_name.clone(),
263            self.make_request_id.clone(),
264        )
265    }
266}
267
268/// Set request id headers and extensions on requests.
269///
270/// See the [module docs](self) for an example.
271///
272/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
273/// header with the same name, then the header will be inserted.
274///
275/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
276/// services can access it.
277#[derive(Debug, Clone)]
278pub struct SetRequestId<S, M> {
279    inner: S,
280    header_name: HeaderName,
281    make_request_id: M,
282}
283
284impl<S, M> SetRequestId<S, M> {
285    /// Create a new `SetRequestId`.
286    pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
287    where
288        M: MakeRequestId,
289    {
290        Self {
291            inner,
292            header_name,
293            make_request_id,
294        }
295    }
296
297    /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
298    pub fn x_request_id(inner: S, make_request_id: M) -> Self
299    where
300        M: MakeRequestId,
301    {
302        Self::new(
303            inner,
304            HeaderName::from_static(X_REQUEST_ID),
305            make_request_id,
306        )
307    }
308
309    define_inner_service_accessors!();
310
311    /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
312    pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
313    where
314        M: MakeRequestId,
315    {
316        SetRequestIdLayer::new(header_name, make_request_id)
317    }
318}
319
320impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
321where
322    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
323    M: MakeRequestId,
324{
325    type Response = S::Response;
326    type Error = S::Error;
327    type Future = S::Future;
328
329    #[inline]
330    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331        self.inner.poll_ready(cx)
332    }
333
334    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
335        if let Some(request_id) = req.headers().get(&self.header_name) {
336            if req.extensions().get::<RequestId>().is_none() {
337                let request_id = request_id.clone();
338                req.extensions_mut().insert(RequestId::new(request_id));
339            }
340        } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
341            req.extensions_mut().insert(request_id.clone());
342            req.headers_mut()
343                .insert(self.header_name.clone(), request_id.0);
344        }
345
346        self.inner.call(req)
347    }
348}
349
350/// Propagate request ids from requests to responses.
351///
352/// This layer applies the [`PropagateRequestId`] middleware.
353///
354/// See the [module docs](self) and [`PropagateRequestId`] for more details.
355#[derive(Debug, Clone)]
356pub struct PropagateRequestIdLayer {
357    header_name: HeaderName,
358}
359
360impl PropagateRequestIdLayer {
361    /// Create a new `PropagateRequestIdLayer`.
362    pub fn new(header_name: HeaderName) -> Self {
363        PropagateRequestIdLayer { header_name }
364    }
365
366    /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
367    pub fn x_request_id() -> Self {
368        Self::new(HeaderName::from_static(X_REQUEST_ID))
369    }
370}
371
372impl<S> Layer<S> for PropagateRequestIdLayer {
373    type Service = PropagateRequestId<S>;
374
375    fn layer(&self, inner: S) -> Self::Service {
376        PropagateRequestId::new(inner, self.header_name.clone())
377    }
378}
379
380/// Propagate request ids from requests to responses.
381///
382/// See the [module docs](self) for an example.
383///
384/// If the request contains a matching header that header will be applied to responses. If a
385/// [`RequestId`] extension is also present it will be propagated as well.
386#[derive(Debug, Clone)]
387pub struct PropagateRequestId<S> {
388    inner: S,
389    header_name: HeaderName,
390}
391
392impl<S> PropagateRequestId<S> {
393    /// Create a new `PropagateRequestId`.
394    pub fn new(inner: S, header_name: HeaderName) -> Self {
395        Self { inner, header_name }
396    }
397
398    /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
399    pub fn x_request_id(inner: S) -> Self {
400        Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
401    }
402
403    define_inner_service_accessors!();
404
405    /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
406    pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
407        PropagateRequestIdLayer::new(header_name)
408    }
409}
410
411impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
412where
413    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
414{
415    type Response = S::Response;
416    type Error = S::Error;
417    type Future = PropagateRequestIdResponseFuture<S::Future>;
418
419    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
420        self.inner.poll_ready(cx)
421    }
422
423    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
424        let request_id = req
425            .headers()
426            .get(&self.header_name)
427            .cloned()
428            .map(RequestId::new);
429
430        PropagateRequestIdResponseFuture {
431            inner: self.inner.call(req),
432            header_name: self.header_name.clone(),
433            request_id,
434        }
435    }
436}
437
438pin_project! {
439    /// Response future for [`PropagateRequestId`].
440    pub struct PropagateRequestIdResponseFuture<F> {
441        #[pin]
442        inner: F,
443        header_name: HeaderName,
444        request_id: Option<RequestId>,
445    }
446}
447
448impl<F, B, E> Future for PropagateRequestIdResponseFuture<F>
449where
450    F: Future<Output = Result<Response<B>, E>>,
451{
452    type Output = Result<Response<B>, E>;
453
454    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
455        let this = self.project();
456        let mut response = ready!(this.inner.poll(cx))?;
457
458        if let Some(current_id) = response.headers().get(&*this.header_name) {
459            if response.extensions().get::<RequestId>().is_none() {
460                let current_id = current_id.clone();
461                response.extensions_mut().insert(RequestId::new(current_id));
462            }
463        } else if let Some(request_id) = this.request_id.take() {
464            response
465                .headers_mut()
466                .insert(this.header_name.clone(), request_id.0.clone());
467            response.extensions_mut().insert(request_id);
468        }
469
470        Poll::Ready(Ok(response))
471    }
472}
473
474/// A [`MakeRequestId`] that generates `UUID`s.
475#[derive(Clone, Copy, Default)]
476pub struct MakeRequestUuid;
477
478impl MakeRequestId for MakeRequestUuid {
479    fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
480        let request_id = Uuid::new_v4().to_string().parse().unwrap();
481        Some(RequestId::new(request_id))
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use crate::test_helpers::Body;
488    use crate::ServiceBuilderExt as _;
489    use http::Response;
490    use std::{
491        convert::Infallible,
492        sync::{
493            atomic::{AtomicU64, Ordering},
494            Arc,
495        },
496    };
497    use tower::{ServiceBuilder, ServiceExt};
498
499    #[allow(unused_imports)]
500    use super::*;
501
502    #[tokio::test]
503    async fn basic() {
504        let svc = ServiceBuilder::new()
505            .set_x_request_id(Counter::default())
506            .propagate_x_request_id()
507            .service_fn(handler);
508
509        // header on response
510        let req = Request::builder().body(Body::empty()).unwrap();
511        let res = svc.clone().oneshot(req).await.unwrap();
512        assert_eq!(res.headers()["x-request-id"], "0");
513
514        let req = Request::builder().body(Body::empty()).unwrap();
515        let res = svc.clone().oneshot(req).await.unwrap();
516        assert_eq!(res.headers()["x-request-id"], "1");
517
518        // doesn't override if header is already there
519        let req = Request::builder()
520            .header("x-request-id", "foo")
521            .body(Body::empty())
522            .unwrap();
523        let res = svc.clone().oneshot(req).await.unwrap();
524        assert_eq!(res.headers()["x-request-id"], "foo");
525
526        // extension propagated
527        let req = Request::builder().body(Body::empty()).unwrap();
528        let res = svc.clone().oneshot(req).await.unwrap();
529        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
530    }
531
532    #[tokio::test]
533    async fn other_middleware_setting_request_id() {
534        let svc = ServiceBuilder::new()
535            .override_request_header(
536                HeaderName::from_static("x-request-id"),
537                HeaderValue::from_str("foo").unwrap(),
538            )
539            .set_x_request_id(Counter::default())
540            .map_request(|request: Request<_>| {
541                // `set_x_request_id` should set the extension if its missing
542                assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
543                request
544            })
545            .propagate_x_request_id()
546            .service_fn(handler);
547
548        let req = Request::builder()
549            .header(
550                "x-request-id",
551                "this-will-be-overriden-by-override_request_header-middleware",
552            )
553            .body(Body::empty())
554            .unwrap();
555        let res = svc.clone().oneshot(req).await.unwrap();
556        assert_eq!(res.headers()["x-request-id"], "foo");
557        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
558    }
559
560    #[tokio::test]
561    async fn other_middleware_setting_request_id_on_response() {
562        let svc = ServiceBuilder::new()
563            .set_x_request_id(Counter::default())
564            .propagate_x_request_id()
565            .override_response_header(
566                HeaderName::from_static("x-request-id"),
567                HeaderValue::from_str("foo").unwrap(),
568            )
569            .service_fn(handler);
570
571        let req = Request::builder()
572            .header("x-request-id", "foo")
573            .body(Body::empty())
574            .unwrap();
575        let res = svc.clone().oneshot(req).await.unwrap();
576        assert_eq!(res.headers()["x-request-id"], "foo");
577        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
578    }
579
580    #[derive(Clone, Default)]
581    struct Counter(Arc<AtomicU64>);
582
583    impl MakeRequestId for Counter {
584        fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
585            let id =
586                HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
587            Some(RequestId::new(id))
588        }
589    }
590
591    async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
592        Ok(Response::new(Body::empty()))
593    }
594
595    #[tokio::test]
596    async fn uuid() {
597        let svc = ServiceBuilder::new()
598            .set_x_request_id(MakeRequestUuid)
599            .propagate_x_request_id()
600            .service_fn(handler);
601
602        // header on response
603        let req = Request::builder().body(Body::empty()).unwrap();
604        let mut res = svc.clone().oneshot(req).await.unwrap();
605        let id = res.headers_mut().remove("x-request-id").unwrap();
606        id.to_str().unwrap().parse::<Uuid>().unwrap();
607    }
608}