rama_http/layer/trace/
mod.rs

1//! Middleware that adds high level [tracing] to a [`Service`].
2//!
3//! # Example
4//!
5//! Adding tracing to your service can be as simple as:
6//!
7//! ```rust
8//! use rama_http::{Body, Request, Response};
9//! use rama_core::service::service_fn;
10//! use rama_core::{Context, Layer, Service};
11//! use rama_http::layer::trace::TraceLayer;
12//! use std::convert::Infallible;
13//!
14//! async fn handle(request: Request) -> Result<Response, Infallible> {
15//!     Ok(Response::new(Body::from("foo")))
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! // Setup tracing
21//! tracing_subscriber::fmt::init();
22//!
23//! let mut service = TraceLayer::new_for_http().layer(service_fn(handle));
24//!
25//! let request = Request::new(Body::from("foo"));
26//!
27//! let response = service
28//!     .serve(Context::default(), request)
29//!     .await?;
30//! # Ok(())
31//! # }
32//! ```
33//!
34//! If you run this application with `RUST_LOG=rama=trace cargo run` you should see logs like:
35//!
36//! ```text
37//! Mar 05 20:50:28.523 DEBUG request{method=GET path="/foo"}: rama_http::layer::trace::on_request: started processing request
38//! Mar 05 20:50:28.524 DEBUG request{method=GET path="/foo"}: rama_http::layer::trace::on_response: finished processing request latency=1 ms status=200
39//! ```
40//!
41//! # Customization
42//!
43//! [`Trace`] comes with good defaults but also supports customizing many aspects of the output.
44//!
45//! The default behaviour supports some customization:
46//!
47//! ```rust
48//! use rama_http::{Body, Request, Response, HeaderMap, StatusCode};
49//! use rama_core::service::service_fn;
50//! use rama_core::{Context, Service, Layer};
51//! use tracing::Level;
52//! use rama_http::layer::trace::{
53//!     TraceLayer, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse,
54//! };
55//! use rama_utils::latency::LatencyUnit;
56//! use std::time::Duration;
57//! use std::convert::Infallible;
58//!
59//! # async fn handle(request: Request) -> Result<Response, Infallible> {
60//! #     Ok(Response::new(Body::from("foo")))
61//! # }
62//! # #[tokio::main]
63//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
64//! # tracing_subscriber::fmt::init();
65//! #
66//! let service = (
67//!     TraceLayer::new_for_http()
68//!         .make_span_with(
69//!             DefaultMakeSpan::new().include_headers(true)
70//!         )
71//!         .on_request(
72//!             DefaultOnRequest::new().level(Level::INFO)
73//!         )
74//!         .on_response(
75//!             DefaultOnResponse::new()
76//!                 .level(Level::INFO)
77//!                 .latency_unit(LatencyUnit::Micros)
78//!         ),
79//!         // on so on for `on_eos`, `on_body_chunk`, and `on_failure`
80//! ).layer(service_fn(handle));
81//! # let mut service = service;
82//! # let response = service
83//! #     .serve(Context::default(), Request::new(Body::from("foo")))
84//! #     .await?;
85//! # Ok(())
86//! # }
87//! ```
88//!
89//! However for maximum control you can provide callbacks:
90//!
91//! ```rust
92//! use rama_http::{Body, Request, Response, HeaderMap, StatusCode};
93//! use rama_core::service::service_fn;
94//! use rama_core::{Context, Service, Layer};
95//! use rama_http::layer::{classify::ServerErrorsFailureClass, trace::TraceLayer};
96//! use std::time::Duration;
97//! use tracing::Span;
98//! use std::convert::Infallible;
99//! use bytes::Bytes;
100//!
101//! # async fn handle(request: Request) -> Result<Response, Infallible> {
102//! #     Ok(Response::new(Body::from("foo")))
103//! # }
104//! # #[tokio::main]
105//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
106//! # tracing_subscriber::fmt::init();
107//! #
108//! let service = (
109//!     TraceLayer::new_for_http()
110//!         .make_span_with(|request: &Request| {
111//!             tracing::debug_span!("http-request")
112//!         })
113//!         .on_request(|request: &Request, _span: &Span| {
114//!             tracing::debug!("started {} {}", request.method(), request.uri().path())
115//!         })
116//!         .on_response(|response: &Response, latency: Duration, _span: &Span| {
117//!             tracing::debug!("response generated in {:?}", latency)
118//!         })
119//!         .on_body_chunk(|chunk: &Bytes, latency: Duration, _span: &Span| {
120//!             tracing::debug!("sending {} bytes", chunk.len())
121//!         })
122//!         .on_eos(|trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span| {
123//!             tracing::debug!("stream closed after {:?}", stream_duration)
124//!         })
125//!         .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
126//!             tracing::debug!("something went wrong")
127//!         })
128//! ).layer(service_fn(handle));
129//! # let mut service = service;
130//! # let response = service
131//! #     .serve(Context::default(), Request::new(Body::from("foo")))
132//! #     .await?;
133//! # Ok(())
134//! # }
135//! ```
136//!
137//! ## Disabling something
138//!
139//! Setting the behaviour to `()` will be disable that particular step:
140//!
141//! ```rust
142//! use rama_http::{Body, Request, Response, StatusCode};
143//! use rama_core::service::service_fn;
144//! use rama_core::{Context, Service, Layer};
145//! use rama_http::layer::{classify::ServerErrorsFailureClass, trace::TraceLayer};
146//! use std::time::Duration;
147//! use tracing::Span;
148//! # use std::convert::Infallible;
149//!
150//! # async fn handle(request: Request) -> Result<Response, Infallible> {
151//! #     Ok(Response::new(Body::from("foo")))
152//! # }
153//! # #[tokio::main]
154//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
155//! # tracing_subscriber::fmt::init();
156//! #
157//! let service = (
158//!     // This configuration will only emit events on failures
159//!     TraceLayer::new_for_http()
160//!         .on_request(())
161//!         .on_response(())
162//!         .on_body_chunk(())
163//!         .on_eos(())
164//!         .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
165//!             tracing::debug!("something went wrong")
166//!         })
167//! ).layer(service_fn(handle));
168//! # let mut service = service;
169//! # let response = service
170//! #     .serve(Context::default(), Request::new(Body::from("foo")))
171//! #     .await?;
172//! # Ok(())
173//! # }
174//! ```
175//!
176//! # When the callbacks are called
177//!
178//! ### `on_request`
179//!
180//! The `on_request` callback is called when the request arrives at the
181//! middleware in [`Service::serve`] just prior to passing the request to the
182//! inner service.
183//!
184//! ### `on_response`
185//!
186//! The `on_response` callback is called when the inner service's response
187//! future completes with `Ok(response)` regardless if the response is
188//! classified as a success or a failure.
189//!
190//! For example if you're using [`ServerErrorsAsFailures`] as your classifier
191//! and the inner service responds with `500 Internal Server Error` then the
192//! `on_response` callback is still called. `on_failure` would _also_ be called
193//! in this case since the response was classified as a failure.
194//!
195//! ### `on_body_chunk`
196//!
197//! The `on_body_chunk` callback is called when the response body produces a new
198//! chunk, that is when [`http_body::Body::poll_frame`] returns `Poll::Ready(Some(Ok(chunk)))`.
199//!
200//! `on_body_chunk` is called even if the chunk is empty.
201//!
202//! ### `on_eos`
203//!
204//! The `on_eos` callback is called when a streaming response body ends, that is
205//! when `http_body::Body::poll_frame` returns `Poll::Ready(None)`.
206//!
207//! `on_eos` is called even if the trailers produced are `None`.
208//!
209//! ### `on_failure`
210//!
211//! The `on_failure` callback is called when:
212//!
213//! - The inner [`Service`]'s response future resolves to an error.
214//! - A response is classified as a failure.
215//! - [`http_body::Body::poll_frame`] returns an error.
216//! - An end-of-stream is classified as a failure.
217//!
218//! # Recording fields on the span
219//!
220//! All callbacks receive a reference to the [tracing] [`Span`], corresponding to this request,
221//! produced by the closure passed to [`TraceLayer::make_span_with`]. It can be used to [record
222//! field values][record] that weren't known when the span was created.
223//!
224//! ```rust
225//! use rama_http::{Body, Request, Response, HeaderMap, StatusCode};
226//! use rama_core::service::service_fn;
227//! use rama_core::Layer;
228//! use rama_http::layer::trace::TraceLayer;
229//! use tracing::Span;
230//! use std::time::Duration;
231//! use std::convert::Infallible;
232//!
233//! # async fn handle(request: Request) -> Result<Response, Infallible> {
234//! #     Ok(Response::new(Body::from("foo")))
235//! # }
236//! # #[tokio::main]
237//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
238//! # tracing_subscriber::fmt::init();
239//! #
240//! let service = (
241//!     TraceLayer::new_for_http()
242//!         .make_span_with(|request: &Request| {
243//!             tracing::debug_span!(
244//!                 "http-request",
245//!                 status_code = tracing::field::Empty,
246//!             )
247//!         })
248//!         .on_response(|response: &Response, _latency: Duration, span: &Span| {
249//!             span.record("status_code", &tracing::field::display(response.status()));
250//!
251//!             tracing::debug!("response generated")
252//!         }),
253//! ).layer(service_fn(handle));
254//! # Ok(())
255//! # }
256//! ```
257//!
258//! # Providing classifiers
259//!
260//! Tracing requires determining if a response is a success or failure. [`MakeClassifier`] is used
261//! to create a classifier for the incoming request. See the docs for [`MakeClassifier`] and
262//! [`ClassifyResponse`] for more details on classification.
263//!
264//! A [`MakeClassifier`] can be provided when creating a [`TraceLayer`]:
265//!
266//! ```rust
267//! use rama_http::{Body, Request, Response};
268//! use rama_core::service::service_fn;
269//! use rama_core::Layer;
270//! use rama_http::layer::{
271//!     trace::TraceLayer,
272//!     classify::{
273//!         MakeClassifier, ClassifyResponse, ClassifiedResponse, NeverClassifyEos,
274//!         SharedClassifier,
275//!     },
276//! };
277//! use std::convert::Infallible;
278//!
279//! # async fn handle(request: Request) -> Result<Response, Infallible> {
280//! #     Ok(Response::new(Body::from("foo")))
281//! # }
282//! # #[tokio::main]
283//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
284//! # tracing_subscriber::fmt::init();
285//! #
286//! // Our `MakeClassifier` that always crates `MyClassifier` classifiers.
287//! #[derive(Copy, Clone)]
288//! struct MyMakeClassify;
289//!
290//! impl MakeClassifier for MyMakeClassify {
291//!     type Classifier = MyClassifier;
292//!     type FailureClass = &'static str;
293//!     type ClassifyEos = NeverClassifyEos<&'static str>;
294//!
295//!     fn make_classifier<B>(&self, req: &Request<B>) -> Self::Classifier {
296//!         MyClassifier
297//!     }
298//! }
299//!
300//! // A classifier that classifies failures as `"something went wrong..."`.
301//! #[derive(Copy, Clone)]
302//! struct MyClassifier;
303//!
304//! impl ClassifyResponse for MyClassifier {
305//!     type FailureClass = &'static str;
306//!     type ClassifyEos = NeverClassifyEos<&'static str>;
307//!
308//!     fn classify_response<B>(
309//!         self,
310//!         res: &Response<B>
311//!     ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
312//!         // Classify based on the status code.
313//!         if res.status().is_server_error() {
314//!             ClassifiedResponse::Ready(Err("something went wrong..."))
315//!         } else {
316//!             ClassifiedResponse::Ready(Ok(()))
317//!         }
318//!     }
319//!
320//!     fn classify_error<E>(self, error: &E) -> Self::FailureClass
321//!     where
322//!         E: std::fmt::Display,
323//!     {
324//!         "something went wrong..."
325//!     }
326//! }
327//!
328//! let service = (
329//!     // Create a trace layer that uses our classifier.
330//!     TraceLayer::new(MyMakeClassify),
331//! ).layer(service_fn(handle));
332//!
333//! // Since `MyClassifier` is `Clone` we can also use `SharedClassifier`
334//! // to avoid having to define a separate `MakeClassifier`.
335//! let service = TraceLayer::new(SharedClassifier::new(MyClassifier)).layer(service_fn(handle));
336//! # Ok(())
337//! # }
338//! ```
339//!
340//! [`TraceLayer`] comes with convenience methods for using common classifiers:
341//!
342//! - [`TraceLayer::new_for_http`] classifies based on the status code. It doesn't consider
343//!   streaming responses.
344//! - [`TraceLayer::new_for_grpc`] classifies based on the gRPC protocol and supports streaming
345//!   responses.
346//!
347//! [tracing]: https://crates.io/crates/tracing
348//! [`Service`]: rama_core::Service
349//! [`Service::serve`]: rama_core::Service::serve
350//! [`MakeClassifier`]: crate::layer::classify::MakeClassifier
351//! [`ClassifyResponse`]: crate::layer::classify::ClassifyResponse
352//! [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record
353//! [`TraceLayer::make_span_with`]: crate::layer::trace::TraceLayer::make_span_with
354//! [`Span`]: tracing::Span
355//! [`ServerErrorsAsFailures`]: crate::layer::classify::ServerErrorsAsFailures
356
357use std::{fmt, time::Duration};
358
359use tracing::Level;
360
361#[doc(inline)]
362pub use self::{
363    body::ResponseBody,
364    layer::TraceLayer,
365    make_span::{DefaultMakeSpan, MakeSpan},
366    on_body_chunk::{DefaultOnBodyChunk, OnBodyChunk},
367    on_eos::{DefaultOnEos, OnEos},
368    on_failure::{DefaultOnFailure, OnFailure},
369    on_request::{DefaultOnRequest, OnRequest},
370    on_response::{DefaultOnResponse, OnResponse},
371    service::Trace,
372};
373
374use crate::layer::classify::{GrpcErrorsAsFailures, ServerErrorsAsFailures, SharedClassifier};
375use rama_utils::latency::LatencyUnit;
376
377/// MakeClassifier for HTTP requests.
378pub type HttpMakeClassifier = SharedClassifier<ServerErrorsAsFailures>;
379
380/// MakeClassifier for gRPC requests.
381pub type GrpcMakeClassifier = SharedClassifier<GrpcErrorsAsFailures>;
382
383macro_rules! event_dynamic_lvl {
384    ( $(target: $target:expr,)? $(parent: $parent:expr,)? $lvl:expr, $($tt:tt)* ) => {
385        match $lvl {
386            tracing::Level::ERROR => {
387                tracing::event!(
388                    $(target: $target,)?
389                    $(parent: $parent,)?
390                    tracing::Level::ERROR,
391                    $($tt)*
392                );
393            }
394            tracing::Level::WARN => {
395                tracing::event!(
396                    $(target: $target,)?
397                    $(parent: $parent,)?
398                    tracing::Level::WARN,
399                    $($tt)*
400                );
401            }
402            tracing::Level::INFO => {
403                tracing::event!(
404                    $(target: $target,)?
405                    $(parent: $parent,)?
406                    tracing::Level::INFO,
407                    $($tt)*
408                );
409            }
410            tracing::Level::DEBUG => {
411                tracing::event!(
412                    $(target: $target,)?
413                    $(parent: $parent,)?
414                    tracing::Level::DEBUG,
415                    $($tt)*
416                );
417            }
418            tracing::Level::TRACE => {
419                tracing::event!(
420                    $(target: $target,)?
421                    $(parent: $parent,)?
422                    tracing::Level::TRACE,
423                    $($tt)*
424                );
425            }
426        }
427    };
428}
429
430mod body;
431mod layer;
432mod make_span;
433mod on_body_chunk;
434mod on_eos;
435mod on_failure;
436mod on_request;
437mod on_response;
438mod service;
439
440const DEFAULT_MESSAGE_LEVEL: Level = Level::DEBUG;
441const DEFAULT_ERROR_LEVEL: Level = Level::ERROR;
442
443struct Latency {
444    unit: LatencyUnit,
445    duration: Duration,
446}
447
448impl fmt::Display for Latency {
449    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450        match self.unit {
451            LatencyUnit::Seconds => write!(f, "{} s", self.duration.as_secs_f64()),
452            LatencyUnit::Millis => write!(f, "{} ms", self.duration.as_millis()),
453            LatencyUnit::Micros => write!(f, "{} μs", self.duration.as_micros()),
454            LatencyUnit::Nanos => write!(f, "{} ns", self.duration.as_nanos()),
455        }
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    use crate::dep::http_body_util::BodyExt as _;
464    use crate::layer::classify::ServerErrorsFailureClass;
465    use crate::{Body, HeaderMap, Request, Response};
466    use bytes::Bytes;
467    use rama_core::error::BoxError;
468    use rama_core::service::service_fn;
469    use rama_core::{Context, Layer, Service};
470    use std::sync::OnceLock;
471    use std::{
472        sync::atomic::{AtomicU32, Ordering},
473        time::Duration,
474    };
475    use tracing::Span;
476
477    macro_rules! lazy_atomic_u32 {
478        ($($name:ident),+) => {
479            $(
480                #[allow(non_snake_case)]
481                fn $name() -> &'static AtomicU32 {
482                    static $name: OnceLock<AtomicU32> = OnceLock::new();
483                    $name.get_or_init(|| AtomicU32::new(0))
484                }
485            )+
486        };
487    }
488
489    #[tokio::test]
490    async fn unary_request() {
491        lazy_atomic_u32!(
492            ON_REQUEST_COUNT,
493            ON_RESPONSE_COUNT,
494            ON_BODY_CHUNK_COUNT,
495            ON_EOS,
496            ON_FAILURE
497        );
498
499        let trace_layer = TraceLayer::new_for_http()
500            .make_span_with(|_req: &Request| {
501                tracing::info_span!("test-span", foo = tracing::field::Empty)
502            })
503            .on_request(|_req: &Request, span: &Span| {
504                span.record("foo", 42);
505                ON_REQUEST_COUNT().fetch_add(1, Ordering::AcqRel);
506            })
507            .on_response(|_res: &Response, _latency: Duration, _span: &Span| {
508                ON_RESPONSE_COUNT().fetch_add(1, Ordering::AcqRel);
509            })
510            .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
511                ON_BODY_CHUNK_COUNT().fetch_add(1, Ordering::AcqRel);
512            })
513            .on_eos(
514                |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
515                    ON_EOS().fetch_add(1, Ordering::AcqRel);
516                },
517            )
518            .on_failure(
519                |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
520                    ON_FAILURE().fetch_add(1, Ordering::AcqRel);
521                },
522            );
523
524        let svc = trace_layer.layer(service_fn(echo));
525
526        let res = svc
527            .serve(Context::default(), Request::new(Body::from("foobar")))
528            .await
529            .unwrap();
530
531        assert_eq!(1, ON_REQUEST_COUNT().load(Ordering::Acquire), "request");
532        assert_eq!(1, ON_RESPONSE_COUNT().load(Ordering::Acquire), "request");
533        assert_eq!(
534            0,
535            ON_BODY_CHUNK_COUNT().load(Ordering::Acquire),
536            "body chunk"
537        );
538        assert_eq!(0, ON_EOS().load(Ordering::Acquire), "eos");
539        assert_eq!(0, ON_FAILURE().load(Ordering::Acquire), "failure");
540
541        res.into_body().collect().await.unwrap().to_bytes();
542        assert_eq!(
543            1,
544            ON_BODY_CHUNK_COUNT().load(Ordering::Acquire),
545            "body chunk"
546        );
547        assert_eq!(0, ON_EOS().load(Ordering::Acquire), "eos");
548        assert_eq!(0, ON_FAILURE().load(Ordering::Acquire), "failure");
549    }
550
551    #[tokio::test]
552    async fn streaming_response() {
553        lazy_atomic_u32!(
554            ON_REQUEST_COUNT,
555            ON_RESPONSE_COUNT,
556            ON_BODY_CHUNK_COUNT,
557            ON_EOS,
558            ON_FAILURE
559        );
560
561        let trace_layer = TraceLayer::new_for_http()
562            .on_request(|_req: &Request, _span: &Span| {
563                ON_REQUEST_COUNT().fetch_add(1, Ordering::AcqRel);
564            })
565            .on_response(|_res: &Response, _latency: Duration, _span: &Span| {
566                ON_RESPONSE_COUNT().fetch_add(1, Ordering::AcqRel);
567            })
568            .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
569                ON_BODY_CHUNK_COUNT().fetch_add(1, Ordering::AcqRel);
570            })
571            .on_eos(
572                |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
573                    ON_EOS().fetch_add(1, Ordering::AcqRel);
574                },
575            )
576            .on_failure(
577                |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
578                    ON_FAILURE().fetch_add(1, Ordering::AcqRel);
579                },
580            );
581
582        let svc = trace_layer.layer(service_fn(streaming_body));
583
584        let res = svc
585            .serve(Context::default(), Request::new(Body::empty()))
586            .await
587            .unwrap();
588
589        assert_eq!(1, ON_REQUEST_COUNT().load(Ordering::Acquire), "request");
590        assert_eq!(1, ON_RESPONSE_COUNT().load(Ordering::Acquire), "request");
591        assert_eq!(
592            0,
593            ON_BODY_CHUNK_COUNT().load(Ordering::Acquire),
594            "body chunk"
595        );
596        assert_eq!(0, ON_EOS().load(Ordering::Acquire), "eos");
597        assert_eq!(0, ON_FAILURE().load(Ordering::Acquire), "failure");
598
599        res.into_body().collect().await.unwrap().to_bytes();
600        assert_eq!(
601            3,
602            ON_BODY_CHUNK_COUNT().load(Ordering::Acquire),
603            "body chunk"
604        );
605        assert_eq!(0, ON_EOS().load(Ordering::Acquire), "eos");
606        assert_eq!(0, ON_FAILURE().load(Ordering::Acquire), "failure");
607    }
608
609    async fn echo(req: Request) -> Result<Response, BoxError> {
610        Ok(Response::new(req.into_body()))
611    }
612
613    async fn streaming_body(_req: Request) -> Result<Response, BoxError> {
614        use futures_lite::stream::iter;
615
616        let stream = iter(vec![
617            Ok::<_, BoxError>(Bytes::from("one")),
618            Ok::<_, BoxError>(Bytes::from("two")),
619            Ok::<_, BoxError>(Bytes::from("three")),
620        ]);
621
622        let body = Body::from_stream(stream);
623
624        Ok(Response::new(body))
625    }
626}