tower_http/request_id.rs
1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use tower::{Service, ServiceExt, ServiceBuilder};
8//! use tower_http::request_id::{
9//! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
10//! };
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
18//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//! counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//! fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
29//! let request_id = self.counter
30//! .fetch_add(1, Ordering::SeqCst)
31//! .to_string()
32//! .parse()
33//! .unwrap();
34//!
35//! Some(RequestId::new(request_id))
36//! }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//! // set `x-request-id` header on all requests
43//! .layer(SetRequestIdLayer::new(
44//! x_request_id.clone(),
45//! MyMakeRequestId::default(),
46//! ))
47//! // propagate `x-request-id` headers from request to response
48//! .layer(PropagateRequestIdLayer::new(x_request_id))
49//! .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.ready().await?.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use tower::{Service, ServiceExt, ServiceBuilder};
66//! # use tower_http::request_id::{
67//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
68//! # };
69//! # use bytes::Bytes;
70//! # use http_body_util::Full;
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! # #[tokio::main]
73//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
75//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
76//! # });
77//! # #[derive(Clone, Default)]
78//! # struct MyMakeRequestId {
79//! # counter: Arc<AtomicU64>,
80//! # }
81//! # impl MakeRequestId for MyMakeRequestId {
82//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
83//! # let request_id = self.counter
84//! # .fetch_add(1, Ordering::SeqCst)
85//! # .to_string()
86//! # .parse()
87//! # .unwrap();
88//! # Some(RequestId::new(request_id))
89//! # }
90//! # }
91//!
92//! let mut svc = ServiceBuilder::new()
93//! .set_x_request_id(MyMakeRequestId::default())
94//! .propagate_x_request_id()
95//! .service(handler);
96//!
97//! let request = Request::new(Full::default());
98//! let response = svc.ready().await?.call(request).await?;
99//!
100//! assert_eq!(response.headers()["x-request-id"], "0");
101//! #
102//! # Ok(())
103//! # }
104//! ```
105//!
106//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
107//!
108//! # Using `Trace`
109//!
110//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers
111//! in this order:
112//!
113//! ```
114//! use tower_http::{
115//! ServiceBuilderExt,
116//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
117//! };
118//! # use http::{Request, Response, header::HeaderName};
119//! # use tower::{Service, ServiceExt, ServiceBuilder};
120//! # use tower_http::request_id::{
121//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
122//! # };
123//! # use http_body_util::Full;
124//! # use bytes::Bytes;
125//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
126//! # #[tokio::main]
127//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
128//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
129//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
130//! # });
131//! # #[derive(Clone, Default)]
132//! # struct MyMakeRequestId {
133//! # counter: Arc<AtomicU64>,
134//! # }
135//! # impl MakeRequestId for MyMakeRequestId {
136//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
137//! # let request_id = self.counter
138//! # .fetch_add(1, Ordering::SeqCst)
139//! # .to_string()
140//! # .parse()
141//! # .unwrap();
142//! # Some(RequestId::new(request_id))
143//! # }
144//! # }
145//!
146//! let svc = ServiceBuilder::new()
147//! // make sure to set request ids before the request reaches `TraceLayer`
148//! .set_x_request_id(MyMakeRequestId::default())
149//! // log requests and responses
150//! .layer(
151//! TraceLayer::new_for_http()
152//! .make_span_with(DefaultMakeSpan::new().include_headers(true))
153//! .on_response(DefaultOnResponse::new().include_headers(true))
154//! )
155//! // propagate the header to the response before the response reaches `TraceLayer`
156//! .propagate_x_request_id()
157//! .service(handler);
158//! #
159//! # Ok(())
160//! # }
161//! ```
162//!
163//! # Doesn't override existing headers
164//!
165//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
166//! requests or responses. Among other things, this allows other middleware to conditionally set
167//! request ids and use the middleware in this module as a fallback.
168//!
169//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
170//! [`Uuid`]: https://crates.io/crates/uuid
171//! [`Trace`]: crate::trace::Trace
172
173use http::{
174 header::{HeaderName, HeaderValue},
175 Request, Response,
176};
177use pin_project_lite::pin_project;
178use std::task::{ready, Context, Poll};
179use std::{future::Future, pin::Pin};
180use tower_layer::Layer;
181use tower_service::Service;
182use uuid::Uuid;
183
184pub(crate) const X_REQUEST_ID: &str = "x-request-id";
185
186/// Trait for producing [`RequestId`]s.
187///
188/// Used by [`SetRequestId`].
189pub trait MakeRequestId {
190 /// Try and produce a [`RequestId`] from the request.
191 fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>;
192}
193
194/// An identifier for a request.
195#[derive(Debug, Clone)]
196pub struct RequestId(HeaderValue);
197
198impl RequestId {
199 /// Create a new `RequestId` from a [`HeaderValue`].
200 pub fn new(header_value: HeaderValue) -> Self {
201 Self(header_value)
202 }
203
204 /// Gets a reference to the underlying [`HeaderValue`].
205 pub fn header_value(&self) -> &HeaderValue {
206 &self.0
207 }
208
209 /// Consumes `self`, returning the underlying [`HeaderValue`].
210 pub fn into_header_value(self) -> HeaderValue {
211 self.0
212 }
213}
214
215impl From<HeaderValue> for RequestId {
216 fn from(value: HeaderValue) -> Self {
217 Self::new(value)
218 }
219}
220
221/// Set request id headers and extensions on requests.
222///
223/// This layer applies the [`SetRequestId`] middleware.
224///
225/// See the [module docs](self) and [`SetRequestId`] for more details.
226#[derive(Debug, Clone)]
227pub struct SetRequestIdLayer<M> {
228 header_name: HeaderName,
229 make_request_id: M,
230}
231
232impl<M> SetRequestIdLayer<M> {
233 /// Create a new `SetRequestIdLayer`.
234 pub fn new(header_name: HeaderName, make_request_id: M) -> Self
235 where
236 M: MakeRequestId,
237 {
238 SetRequestIdLayer {
239 header_name,
240 make_request_id,
241 }
242 }
243
244 /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
245 pub fn x_request_id(make_request_id: M) -> Self
246 where
247 M: MakeRequestId,
248 {
249 SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id)
250 }
251}
252
253impl<S, M> Layer<S> for SetRequestIdLayer<M>
254where
255 M: Clone + MakeRequestId,
256{
257 type Service = SetRequestId<S, M>;
258
259 fn layer(&self, inner: S) -> Self::Service {
260 SetRequestId::new(
261 inner,
262 self.header_name.clone(),
263 self.make_request_id.clone(),
264 )
265 }
266}
267
268/// Set request id headers and extensions on requests.
269///
270/// See the [module docs](self) for an example.
271///
272/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
273/// header with the same name, then the header will be inserted.
274///
275/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
276/// services can access it.
277#[derive(Debug, Clone)]
278pub struct SetRequestId<S, M> {
279 inner: S,
280 header_name: HeaderName,
281 make_request_id: M,
282}
283
284impl<S, M> SetRequestId<S, M> {
285 /// Create a new `SetRequestId`.
286 pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
287 where
288 M: MakeRequestId,
289 {
290 Self {
291 inner,
292 header_name,
293 make_request_id,
294 }
295 }
296
297 /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
298 pub fn x_request_id(inner: S, make_request_id: M) -> Self
299 where
300 M: MakeRequestId,
301 {
302 Self::new(
303 inner,
304 HeaderName::from_static(X_REQUEST_ID),
305 make_request_id,
306 )
307 }
308
309 define_inner_service_accessors!();
310
311 /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
312 pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
313 where
314 M: MakeRequestId,
315 {
316 SetRequestIdLayer::new(header_name, make_request_id)
317 }
318}
319
320impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
321where
322 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
323 M: MakeRequestId,
324{
325 type Response = S::Response;
326 type Error = S::Error;
327 type Future = S::Future;
328
329 #[inline]
330 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331 self.inner.poll_ready(cx)
332 }
333
334 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
335 if let Some(request_id) = req.headers().get(&self.header_name) {
336 if req.extensions().get::<RequestId>().is_none() {
337 let request_id = request_id.clone();
338 req.extensions_mut().insert(RequestId::new(request_id));
339 }
340 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
341 req.extensions_mut().insert(request_id.clone());
342 req.headers_mut()
343 .insert(self.header_name.clone(), request_id.0);
344 }
345
346 self.inner.call(req)
347 }
348}
349
350/// Propagate request ids from requests to responses.
351///
352/// This layer applies the [`PropagateRequestId`] middleware.
353///
354/// See the [module docs](self) and [`PropagateRequestId`] for more details.
355#[derive(Debug, Clone)]
356pub struct PropagateRequestIdLayer {
357 header_name: HeaderName,
358}
359
360impl PropagateRequestIdLayer {
361 /// Create a new `PropagateRequestIdLayer`.
362 pub fn new(header_name: HeaderName) -> Self {
363 PropagateRequestIdLayer { header_name }
364 }
365
366 /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
367 pub fn x_request_id() -> Self {
368 Self::new(HeaderName::from_static(X_REQUEST_ID))
369 }
370}
371
372impl<S> Layer<S> for PropagateRequestIdLayer {
373 type Service = PropagateRequestId<S>;
374
375 fn layer(&self, inner: S) -> Self::Service {
376 PropagateRequestId::new(inner, self.header_name.clone())
377 }
378}
379
380/// Propagate request ids from requests to responses.
381///
382/// See the [module docs](self) for an example.
383///
384/// If the request contains a matching header that header will be applied to responses. If a
385/// [`RequestId`] extension is also present it will be propagated as well.
386#[derive(Debug, Clone)]
387pub struct PropagateRequestId<S> {
388 inner: S,
389 header_name: HeaderName,
390}
391
392impl<S> PropagateRequestId<S> {
393 /// Create a new `PropagateRequestId`.
394 pub fn new(inner: S, header_name: HeaderName) -> Self {
395 Self { inner, header_name }
396 }
397
398 /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
399 pub fn x_request_id(inner: S) -> Self {
400 Self::new(inner, HeaderName::from_static(X_REQUEST_ID))
401 }
402
403 define_inner_service_accessors!();
404
405 /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
406 pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
407 PropagateRequestIdLayer::new(header_name)
408 }
409}
410
411impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
412where
413 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
414{
415 type Response = S::Response;
416 type Error = S::Error;
417 type Future = PropagateRequestIdResponseFuture<S::Future>;
418
419 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
420 self.inner.poll_ready(cx)
421 }
422
423 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
424 let request_id = req
425 .headers()
426 .get(&self.header_name)
427 .cloned()
428 .map(RequestId::new);
429
430 PropagateRequestIdResponseFuture {
431 inner: self.inner.call(req),
432 header_name: self.header_name.clone(),
433 request_id,
434 }
435 }
436}
437
438pin_project! {
439 /// Response future for [`PropagateRequestId`].
440 pub struct PropagateRequestIdResponseFuture<F> {
441 #[pin]
442 inner: F,
443 header_name: HeaderName,
444 request_id: Option<RequestId>,
445 }
446}
447
448impl<F, B, E> Future for PropagateRequestIdResponseFuture<F>
449where
450 F: Future<Output = Result<Response<B>, E>>,
451{
452 type Output = Result<Response<B>, E>;
453
454 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
455 let this = self.project();
456 let mut response = ready!(this.inner.poll(cx))?;
457
458 if let Some(current_id) = response.headers().get(&*this.header_name) {
459 if response.extensions().get::<RequestId>().is_none() {
460 let current_id = current_id.clone();
461 response.extensions_mut().insert(RequestId::new(current_id));
462 }
463 } else if let Some(request_id) = this.request_id.take() {
464 response
465 .headers_mut()
466 .insert(this.header_name.clone(), request_id.0.clone());
467 response.extensions_mut().insert(request_id);
468 }
469
470 Poll::Ready(Ok(response))
471 }
472}
473
474/// A [`MakeRequestId`] that generates `UUID`s.
475#[derive(Clone, Copy, Default)]
476pub struct MakeRequestUuid;
477
478impl MakeRequestId for MakeRequestUuid {
479 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
480 let request_id = Uuid::new_v4().to_string().parse().unwrap();
481 Some(RequestId::new(request_id))
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use crate::test_helpers::Body;
488 use crate::ServiceBuilderExt as _;
489 use http::Response;
490 use std::{
491 convert::Infallible,
492 sync::{
493 atomic::{AtomicU64, Ordering},
494 Arc,
495 },
496 };
497 use tower::{ServiceBuilder, ServiceExt};
498
499 #[allow(unused_imports)]
500 use super::*;
501
502 #[tokio::test]
503 async fn basic() {
504 let svc = ServiceBuilder::new()
505 .set_x_request_id(Counter::default())
506 .propagate_x_request_id()
507 .service_fn(handler);
508
509 // header on response
510 let req = Request::builder().body(Body::empty()).unwrap();
511 let res = svc.clone().oneshot(req).await.unwrap();
512 assert_eq!(res.headers()["x-request-id"], "0");
513
514 let req = Request::builder().body(Body::empty()).unwrap();
515 let res = svc.clone().oneshot(req).await.unwrap();
516 assert_eq!(res.headers()["x-request-id"], "1");
517
518 // doesn't override if header is already there
519 let req = Request::builder()
520 .header("x-request-id", "foo")
521 .body(Body::empty())
522 .unwrap();
523 let res = svc.clone().oneshot(req).await.unwrap();
524 assert_eq!(res.headers()["x-request-id"], "foo");
525
526 // extension propagated
527 let req = Request::builder().body(Body::empty()).unwrap();
528 let res = svc.clone().oneshot(req).await.unwrap();
529 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
530 }
531
532 #[tokio::test]
533 async fn other_middleware_setting_request_id() {
534 let svc = ServiceBuilder::new()
535 .override_request_header(
536 HeaderName::from_static("x-request-id"),
537 HeaderValue::from_str("foo").unwrap(),
538 )
539 .set_x_request_id(Counter::default())
540 .map_request(|request: Request<_>| {
541 // `set_x_request_id` should set the extension if its missing
542 assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
543 request
544 })
545 .propagate_x_request_id()
546 .service_fn(handler);
547
548 let req = Request::builder()
549 .header(
550 "x-request-id",
551 "this-will-be-overriden-by-override_request_header-middleware",
552 )
553 .body(Body::empty())
554 .unwrap();
555 let res = svc.clone().oneshot(req).await.unwrap();
556 assert_eq!(res.headers()["x-request-id"], "foo");
557 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
558 }
559
560 #[tokio::test]
561 async fn other_middleware_setting_request_id_on_response() {
562 let svc = ServiceBuilder::new()
563 .set_x_request_id(Counter::default())
564 .propagate_x_request_id()
565 .override_response_header(
566 HeaderName::from_static("x-request-id"),
567 HeaderValue::from_str("foo").unwrap(),
568 )
569 .service_fn(handler);
570
571 let req = Request::builder()
572 .header("x-request-id", "foo")
573 .body(Body::empty())
574 .unwrap();
575 let res = svc.clone().oneshot(req).await.unwrap();
576 assert_eq!(res.headers()["x-request-id"], "foo");
577 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
578 }
579
580 #[derive(Clone, Default)]
581 struct Counter(Arc<AtomicU64>);
582
583 impl MakeRequestId for Counter {
584 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
585 let id =
586 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
587 Some(RequestId::new(id))
588 }
589 }
590
591 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
592 Ok(Response::new(Body::empty()))
593 }
594
595 #[tokio::test]
596 async fn uuid() {
597 let svc = ServiceBuilder::new()
598 .set_x_request_id(MakeRequestUuid)
599 .propagate_x_request_id()
600 .service_fn(handler);
601
602 // header on response
603 let req = Request::builder().body(Body::empty()).unwrap();
604 let mut res = svc.clone().oneshot(req).await.unwrap();
605 let id = res.headers_mut().remove("x-request-id").unwrap();
606 id.to_str().unwrap().parse::<Uuid>().unwrap();
607 }
608}