tower_http/trace/
future.rs1use super::{
2 DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnResponse, OnBodyChunk, OnEos,
3 OnFailure, OnResponse, ResponseBody,
4};
5use crate::classify::{ClassifiedResponse, ClassifyResponse};
6use http::Response;
7use http_body::Body;
8use pin_project_lite::pin_project;
9use std::{
10 future::Future,
11 pin::Pin,
12 task::{ready, Context, Poll},
13 time::Instant,
14};
15use tracing::Span;
16
17pin_project! {
18 pub struct ResponseFuture<F, C, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
22 #[pin]
23 pub(crate) inner: F,
24 pub(crate) span: Span,
25 pub(crate) classifier: Option<C>,
26 pub(crate) on_response: Option<OnResponse>,
27 pub(crate) on_body_chunk: Option<OnBodyChunk>,
28 pub(crate) on_eos: Option<OnEos>,
29 pub(crate) on_failure: Option<OnFailure>,
30 pub(crate) start: Instant,
31 }
32}
33
34impl<Fut, ResBody, E, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT> Future
35 for ResponseFuture<Fut, C, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT>
36where
37 Fut: Future<Output = Result<Response<ResBody>, E>>,
38 ResBody: Body,
39 ResBody::Error: std::fmt::Display + 'static,
40 E: std::fmt::Display + 'static,
41 C: ClassifyResponse,
42 OnResponseT: OnResponse<ResBody>,
43 OnFailureT: OnFailure<C::FailureClass>,
44 OnBodyChunkT: OnBodyChunk<ResBody::Data>,
45 OnEosT: OnEos,
46{
47 type Output = Result<
48 Response<ResponseBody<ResBody, C::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>,
49 E,
50 >;
51
52 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
53 let this = self.project();
54 let _guard = this.span.enter();
55 let result = ready!(this.inner.poll(cx));
56 let latency = this.start.elapsed();
57
58 let classifier = this.classifier.take().unwrap();
59 let on_eos = this.on_eos.take();
60 let on_body_chunk = this.on_body_chunk.take().unwrap();
61 let mut on_failure = this.on_failure.take().unwrap();
62
63 match result {
64 Ok(res) => {
65 let classification = classifier.classify_response(&res);
66 let start = *this.start;
67
68 this.on_response
69 .take()
70 .unwrap()
71 .on_response(&res, latency, this.span);
72
73 match classification {
74 ClassifiedResponse::Ready(classification) => {
75 if let Err(failure_class) = classification {
76 on_failure.on_failure(failure_class, latency, this.span);
77 }
78
79 let span = this.span.clone();
80 let res = res.map(|body| ResponseBody {
81 inner: body,
82 classify_eos: None,
83 on_eos: None,
84 on_body_chunk,
85 on_failure: Some(on_failure),
86 start,
87 span,
88 });
89
90 Poll::Ready(Ok(res))
91 }
92 ClassifiedResponse::RequiresEos(classify_eos) => {
93 let span = this.span.clone();
94 let res = res.map(|body| ResponseBody {
95 inner: body,
96 classify_eos: Some(classify_eos),
97 on_eos: on_eos.zip(Some(Instant::now())),
98 on_body_chunk,
99 on_failure: Some(on_failure),
100 start,
101 span,
102 });
103
104 Poll::Ready(Ok(res))
105 }
106 }
107 }
108 Err(err) => {
109 let failure_class = classifier.classify_error(&err);
110 on_failure.on_failure(failure_class, latency, this.span);
111
112 Poll::Ready(Err(err))
113 }
114 }
115 }
116}