opentelemetry_http/
lib.rs

1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10/// Helper for injecting headers into HTTP Requests. This is used for OpenTelemetry context
11/// propagation over HTTP.
12/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
13/// for example usage.
14pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17    /// Set a key and value in the HeaderMap.  Does nothing if the key or value are not valid inputs.
18    fn set(&mut self, key: &str, value: String) {
19        if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20            if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21                self.0.insert(name, val);
22            }
23        }
24    }
25}
26
27/// Helper for extracting headers from HTTP Requests. This is used for OpenTelemetry context
28/// propagation over HTTP.
29/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
30/// for example usage.
31pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
32
33impl Extractor for HeaderExtractor<'_> {
34    /// Get a value for a key from the HeaderMap.  If the value is not valid ASCII, returns None.
35    fn get(&self, key: &str) -> Option<&str> {
36        self.0.get(key).and_then(|value| value.to_str().ok())
37    }
38
39    /// Collect all the keys from the HeaderMap.
40    fn keys(&self) -> Vec<&str> {
41        self.0
42            .keys()
43            .map(|value| value.as_str())
44            .collect::<Vec<_>>()
45    }
46}
47
48pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
49
50/// A minimal interface necessary for sending requests over HTTP.
51/// Used primarily for exporting telemetry over HTTP. Also used for fetching
52/// sampling strategies for JaegerRemoteSampler
53///
54/// Users sometime choose HTTP clients that relay on a certain async runtime. This trait allows
55/// users to bring their choice of HTTP client.
56#[async_trait]
57pub trait HttpClient: Debug + Send + Sync {
58    /// Send the specified HTTP request with `Vec<u8>` payload
59    ///
60    /// Returns the HTTP response including the status code and body.
61    ///
62    /// Returns an error if it can't connect to the server or the request could not be completed,
63    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
64    #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
65    async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
66        self.send_bytes(request.map(Into::into)).await
67    }
68
69    /// Send the specified HTTP request with `Bytes` payload.
70    ///
71    /// Returns the HTTP response including the status code and body.
72    ///
73    /// Returns an error if it can't connect to the server or the request could not be completed,
74    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
75    async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
76}
77
78#[cfg(feature = "reqwest")]
79mod reqwest {
80    use opentelemetry::otel_debug;
81
82    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
83
84    #[async_trait]
85    impl HttpClient for reqwest::Client {
86        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
87            otel_debug!(name: "ReqwestClient.Send");
88            let request = request.try_into()?;
89            let mut response = self.execute(request).await?.error_for_status()?;
90            let headers = std::mem::take(response.headers_mut());
91            let mut http_response = Response::builder()
92                .status(response.status())
93                .body(response.bytes().await?)?;
94            *http_response.headers_mut() = headers;
95
96            Ok(http_response)
97        }
98    }
99
100    #[cfg(not(target_arch = "wasm32"))]
101    #[async_trait]
102    impl HttpClient for reqwest::blocking::Client {
103        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
104            otel_debug!(name: "ReqwestBlockingClient.Send");
105            let request = request.try_into()?;
106            let mut response = self.execute(request)?.error_for_status()?;
107            let headers = std::mem::take(response.headers_mut());
108            let mut http_response = Response::builder()
109                .status(response.status())
110                .body(response.bytes()?)?;
111            *http_response.headers_mut() = headers;
112
113            Ok(http_response)
114        }
115    }
116}
117
118#[cfg(feature = "hyper")]
119pub mod hyper {
120    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
121    use crate::ResponseExt;
122    use http::HeaderValue;
123    use http_body_util::{BodyExt, Full};
124    use hyper::body::{Body as HttpBody, Frame};
125    use hyper_util::client::legacy::{
126        connect::{Connect, HttpConnector},
127        Client,
128    };
129    use opentelemetry::otel_debug;
130    use std::fmt::Debug;
131    use std::pin::Pin;
132    use std::task::{self, Poll};
133    use std::time::Duration;
134    use tokio::time;
135
136    #[derive(Debug, Clone)]
137    pub struct HyperClient<C = HttpConnector>
138    where
139        C: Connect + Clone + Send + Sync + 'static,
140    {
141        inner: Client<C, Body>,
142        timeout: Duration,
143        authorization: Option<HeaderValue>,
144    }
145
146    impl<C> HyperClient<C>
147    where
148        C: Connect + Clone + Send + Sync + 'static,
149    {
150        pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
151            // TODO - support custom executor
152            let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
153            Self {
154                inner,
155                timeout,
156                authorization,
157            }
158        }
159    }
160
161    impl HyperClient<HttpConnector> {
162        /// Creates a new `HyperClient` with a default `HttpConnector`.
163        pub fn with_default_connector(
164            timeout: Duration,
165            authorization: Option<HeaderValue>,
166        ) -> Self {
167            Self::new(HttpConnector::new(), timeout, authorization)
168        }
169    }
170
171    #[async_trait]
172    impl HttpClient for HyperClient {
173        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
174            otel_debug!(name: "HyperClient.Send");
175            let (parts, body) = request.into_parts();
176            let mut request = Request::from_parts(parts, Body(Full::from(body)));
177            if let Some(ref authorization) = self.authorization {
178                request
179                    .headers_mut()
180                    .insert(http::header::AUTHORIZATION, authorization.clone());
181            }
182            let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
183            let headers = std::mem::take(response.headers_mut());
184
185            let mut http_response = Response::builder()
186                .status(response.status())
187                .body(response.into_body().collect().await?.to_bytes())?;
188            *http_response.headers_mut() = headers;
189
190            Ok(http_response.error_for_status()?)
191        }
192    }
193
194    pub struct Body(Full<Bytes>);
195
196    impl HttpBody for Body {
197        type Data = Bytes;
198        type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
199
200        #[inline]
201        fn poll_frame(
202            self: Pin<&mut Self>,
203            cx: &mut task::Context<'_>,
204        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
205            let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
206            inner_body.poll_frame(cx).map_err(Into::into)
207        }
208
209        #[inline]
210        fn is_end_stream(&self) -> bool {
211            self.0.is_end_stream()
212        }
213
214        #[inline]
215        fn size_hint(&self) -> hyper::body::SizeHint {
216            self.0.size_hint()
217        }
218    }
219}
220
221/// Methods to make working with responses from the [`HttpClient`] trait easier.
222pub trait ResponseExt: Sized {
223    /// Turn a response into an error if the HTTP status does not indicate success (200 - 299).
224    fn error_for_status(self) -> Result<Self, HttpError>;
225}
226
227impl<T> ResponseExt for Response<T> {
228    fn error_for_status(self) -> Result<Self, HttpError> {
229        if self.status().is_success() {
230            Ok(self)
231        } else {
232            Err(format!("request failed with status {}", self.status()).into())
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn http_headers_get() {
243        let mut carrier = http::HeaderMap::new();
244        HeaderInjector(&mut carrier).set("headerName", "value".to_string());
245
246        assert_eq!(
247            HeaderExtractor(&carrier).get("HEADERNAME"),
248            Some("value"),
249            "case insensitive extraction"
250        )
251    }
252
253    #[test]
254    fn http_headers_keys() {
255        let mut carrier = http::HeaderMap::new();
256        HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
257        HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
258
259        let extractor = HeaderExtractor(&carrier);
260        let got = extractor.keys();
261        assert_eq!(got.len(), 2);
262        assert!(got.contains(&"headername1"));
263        assert!(got.contains(&"headername2"));
264    }
265}