tower_http/trace/
future.rs

1use super::{
2    DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnResponse, OnBodyChunk, OnEos,
3    OnFailure, OnResponse, ResponseBody,
4};
5use crate::classify::{ClassifiedResponse, ClassifyResponse};
6use http::Response;
7use http_body::Body;
8use pin_project_lite::pin_project;
9use std::{
10    future::Future,
11    pin::Pin,
12    task::{ready, Context, Poll},
13    time::Instant,
14};
15use tracing::Span;
16
17pin_project! {
18    /// Response future for [`Trace`].
19    ///
20    /// [`Trace`]: super::Trace
21    pub struct ResponseFuture<F, C, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
22        #[pin]
23        pub(crate) inner: F,
24        pub(crate) span: Span,
25        pub(crate) classifier: Option<C>,
26        pub(crate) on_response: Option<OnResponse>,
27        pub(crate) on_body_chunk: Option<OnBodyChunk>,
28        pub(crate) on_eos: Option<OnEos>,
29        pub(crate) on_failure: Option<OnFailure>,
30        pub(crate) start: Instant,
31    }
32}
33
34impl<Fut, ResBody, E, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT> Future
35    for ResponseFuture<Fut, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT>
36where
37    Fut: Future<Output = Result<Response<ResBody>, E>>,
38    ResBody: Body,
39    ResBody::Error: std::fmt::Display + 'static,
40    E: std::fmt::Display + 'static,
41    C: ClassifyResponse,
42    OnResponseT: OnResponse<ResBody>,
43    OnFailureT: OnFailure<C::FailureClass>,
44    OnBodyChunkT: OnBodyChunk<ResBody::Data>,
45    OnEosT: OnEos,
46{
47    type Output = Result<
48        Response<ResponseBody<ResBody, C::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>,
49        E,
50    >;
51
52    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
53        let this = self.project();
54        let _guard = this.span.enter();
55        let result = ready!(this.inner.poll(cx));
56        let latency = this.start.elapsed();
57
58        let classifier = this.classifier.take().unwrap();
59        let on_eos = this.on_eos.take();
60        let on_body_chunk = this.on_body_chunk.take().unwrap();
61        let mut on_failure = this.on_failure.take().unwrap();
62
63        match result {
64            Ok(res) => {
65                let classification = classifier.classify_response(&res);
66                let start = *this.start;
67
68                this.on_response
69                    .take()
70                    .unwrap()
71                    .on_response(&res, latency, this.span);
72
73                match classification {
74                    ClassifiedResponse::Ready(classification) => {
75                        if let Err(failure_class) = classification {
76                            on_failure.on_failure(failure_class, latency, this.span);
77                        }
78
79                        let span = this.span.clone();
80                        let res = res.map(|body| ResponseBody {
81                            inner: body,
82                            classify_eos: None,
83                            on_eos: None,
84                            on_body_chunk,
85                            on_failure: Some(on_failure),
86                            start,
87                            span,
88                        });
89
90                        Poll::Ready(Ok(res))
91                    }
92                    ClassifiedResponse::RequiresEos(classify_eos) => {
93                        let span = this.span.clone();
94                        let res = res.map(|body| ResponseBody {
95                            inner: body,
96                            classify_eos: Some(classify_eos),
97                            on_eos: on_eos.zip(Some(Instant::now())),
98                            on_body_chunk,
99                            on_failure: Some(on_failure),
100                            start,
101                            span,
102                        });
103
104                        Poll::Ready(Ok(res))
105                    }
106                }
107            }
108            Err(err) => {
109                let failure_class = classifier.classify_error(&err);
110                on_failure.on_failure(failure_class, latency, this.span);
111
112                Poll::Ready(Err(err))
113            }
114        }
115    }
116}