tower_http/trace/
body.rs

1use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use http_body::{Body, Frame};
4use pin_project_lite::pin_project;
5use std::{
6    fmt,
7    pin::Pin,
8    task::{ready, Context, Poll},
9    time::Instant,
10};
11use tracing::Span;
12
13pin_project! {
14    /// Response body for [`Trace`].
15    ///
16    /// [`Trace`]: super::Trace
17    pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
18        #[pin]
19        pub(crate) inner: B,
20        pub(crate) classify_eos: Option<C>,
21        pub(crate) on_eos: Option<(OnEos, Instant)>,
22        pub(crate) on_body_chunk: OnBodyChunk,
23        pub(crate) on_failure: Option<OnFailure>,
24        pub(crate) start: Instant,
25        pub(crate) span: Span,
26    }
27}
28
29impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
30    for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
31where
32    B: Body,
33    B::Error: fmt::Display + 'static,
34    C: ClassifyEos,
35    OnEosT: OnEos,
36    OnBodyChunkT: OnBodyChunk<B::Data>,
37    OnFailureT: OnFailure<C::FailureClass>,
38{
39    type Data = B::Data;
40    type Error = B::Error;
41
42    fn poll_frame(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46        let this = self.project();
47        let _guard = this.span.enter();
48        let result = ready!(this.inner.poll_frame(cx));
49
50        let latency = this.start.elapsed();
51        *this.start = Instant::now();
52
53        match result {
54            Some(Ok(frame)) => {
55                let frame = match frame.into_data() {
56                    Ok(chunk) => {
57                        this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
58                        Frame::data(chunk)
59                    }
60                    Err(frame) => frame,
61                };
62
63                let frame = match frame.into_trailers() {
64                    Ok(trailers) => {
65                        if let Some((on_eos, stream_start)) = this.on_eos.take() {
66                            on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
67                        }
68                        Frame::trailers(trailers)
69                    }
70                    Err(frame) => frame,
71                };
72
73                Poll::Ready(Some(Ok(frame)))
74            }
75            Some(Err(err)) => {
76                if let Some((classify_eos, mut on_failure)) =
77                    this.classify_eos.take().zip(this.on_failure.take())
78                {
79                    let failure_class = classify_eos.classify_error(&err);
80                    on_failure.on_failure(failure_class, latency, this.span);
81                }
82
83                Poll::Ready(Some(Err(err)))
84            }
85            None => {
86                if let Some((on_eos, stream_start)) = this.on_eos.take() {
87                    on_eos.on_eos(None, stream_start.elapsed(), this.span);
88                }
89
90                Poll::Ready(None)
91            }
92        }
93    }
94
95    fn is_end_stream(&self) -> bool {
96        self.inner.is_end_stream()
97    }
98
99    fn size_hint(&self) -> http_body::SizeHint {
100        self.inner.size_hint()
101    }
102}