1use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use http_body::{Body, Frame};
4use pin_project_lite::pin_project;
5use std::{
6 fmt,
7 pin::Pin,
8 task::{ready, Context, Poll},
9 time::Instant,
10};
11use tracing::Span;
12
13pin_project! {
14 pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
18 #[pin]
19 pub(crate) inner: B,
20 pub(crate) classify_eos: Option<C>,
21 pub(crate) on_eos: Option<(OnEos, Instant)>,
22 pub(crate) on_body_chunk: OnBodyChunk,
23 pub(crate) on_failure: Option<OnFailure>,
24 pub(crate) start: Instant,
25 pub(crate) span: Span,
26 }
27}
28
29impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
30 for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
31where
32 B: Body,
33 B::Error: fmt::Display + 'static,
34 C: ClassifyEos,
35 OnEosT: OnEos,
36 OnBodyChunkT: OnBodyChunk<B::Data>,
37 OnFailureT: OnFailure<C::FailureClass>,
38{
39 type Data = B::Data;
40 type Error = B::Error;
41
42 fn poll_frame(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46 let this = self.project();
47 let _guard = this.span.enter();
48 let result = ready!(this.inner.poll_frame(cx));
49
50 let latency = this.start.elapsed();
51 *this.start = Instant::now();
52
53 match result {
54 Some(Ok(frame)) => {
55 let frame = match frame.into_data() {
56 Ok(chunk) => {
57 this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
58 Frame::data(chunk)
59 }
60 Err(frame) => frame,
61 };
62
63 let frame = match frame.into_trailers() {
64 Ok(trailers) => {
65 if let Some((on_eos, stream_start)) = this.on_eos.take() {
66 on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
67 }
68 Frame::trailers(trailers)
69 }
70 Err(frame) => frame,
71 };
72
73 Poll::Ready(Some(Ok(frame)))
74 }
75 Some(Err(err)) => {
76 if let Some((classify_eos, mut on_failure)) =
77 this.classify_eos.take().zip(this.on_failure.take())
78 {
79 let failure_class = classify_eos.classify_error(&err);
80 on_failure.on_failure(failure_class, latency, this.span);
81 }
82
83 Poll::Ready(Some(Err(err)))
84 }
85 None => {
86 if let Some((on_eos, stream_start)) = this.on_eos.take() {
87 on_eos.on_eos(None, stream_start.elapsed(), this.span);
88 }
89
90 Poll::Ready(None)
91 }
92 }
93 }
94
95 fn is_end_stream(&self) -> bool {
96 self.inner.is_end_stream()
97 }
98
99 fn size_hint(&self) -> http_body::SizeHint {
100 self.inner.size_hint()
101 }
102}