yup_oauth2/
client.rs

1//! Module containing the HTTP client used for sending requests
2use std::time::Duration;
3
4use http::Uri;
5use hyper_util::client::legacy::{connect::Connect, Error as LegacyHyperError};
6#[cfg(all(feature = "aws-lc-rs", feature = "hyper-rustls", not(feature = "ring")))]
7use rustls::crypto::aws_lc_rs::default_provider as default_crypto_provider;
8#[cfg(all(feature = "ring", feature = "hyper-rustls"))]
9use rustls::crypto::ring::default_provider as default_crypto_provider;
10#[cfg(all(
11    feature = "hyper-rustls",
12    not(any(feature = "ring", feature = "aws-lc-rs"))
13))]
14compile_error!(
15    "The `hyper-rustls` feature requires either the `ring` or `aws-lc-rs` feature to be enabled"
16);
17use thiserror::Error as ThisError;
18
19use crate::Error;
20
21type HyperResponse = http::Response<hyper::body::Incoming>;
22pub(crate) type LegacyClient<C> = hyper_util::client::legacy::Client<C, String>;
23
24#[derive(Debug, ThisError)]
25/// Errors that can happen when a request is sent
26pub enum SendError {
27    /// Request could not complete before timeout elapsed
28    #[error("Request timed out")]
29    Timeout,
30    /// Wrapper for hyper errors
31    #[error("Hyper error: {0}")]
32    Hyper(#[source] LegacyHyperError),
33}
34
35/// A trait implemented for any hyper_util::client::legacy::Client as well as the DefaultHyperClient.
36pub trait HyperClientBuilder {
37    /// The hyper connector that the resulting hyper client will use.
38    type Connector: Connect + Clone + Send + Sync + 'static;
39
40    /// Sets duration after which a request times out
41    fn with_timeout(self, timeout: Duration) -> Self;
42
43    /// Create a hyper::Client
44    fn build_hyper_client(self) -> Result<HttpClient<Self::Connector>, Error>;
45}
46
47/// Client that can be configured that a request will timeout after a specified
48/// duration.
49#[derive(Clone)]
50pub struct HttpClient<C>
51where
52    C: Connect + Clone + Send + Sync + 'static,
53{
54    client: LegacyClient<C>,
55    timeout: Option<Duration>,
56}
57
58impl<C> HttpClient<C>
59where
60    C: Connect + Clone + Send + Sync + 'static,
61{
62    pub(crate) fn new(hyper_client: LegacyClient<C>, timeout: Option<Duration>) -> Self {
63        Self {
64            client: hyper_client,
65            timeout,
66        }
67    }
68
69    pub(crate) fn set_timeout(&mut self, timeout: Duration) {
70        self.timeout = Some(timeout);
71    }
72
73    /// Execute a get request with the underlying hyper client
74    #[doc(hidden)]
75    pub async fn get(&self, uri: Uri) -> Result<HyperResponse, hyper_util::client::legacy::Error> {
76        self.client.get(uri).await
77    }
78}
79
80impl<C> HyperClientBuilder for HttpClient<C>
81where
82    C: Connect + Clone + Send + Sync + 'static,
83{
84    type Connector = C;
85
86    fn with_timeout(mut self, timeout: Duration) -> Self {
87        self.set_timeout(timeout);
88        self
89    }
90
91    fn build_hyper_client(self) -> Result<HttpClient<Self::Connector>, Error> {
92        Ok(self)
93    }
94}
95
96impl<C> SendRequest for HttpClient<C>
97where
98    C: Connect + Clone + Send + Sync + 'static,
99{
100    async fn request(&self, payload: http::Request<String>) -> Result<HyperResponse, SendError> {
101        let future = self.client.request(payload);
102        match self.timeout {
103            Some(duration) => tokio::time::timeout(duration, future)
104                .await
105                .map_err(|_| SendError::Timeout)?,
106            None => future.await,
107        }
108        .map_err(SendError::Hyper)
109    }
110}
111
112pub(crate) trait SendRequest {
113    async fn request(&self, payload: http::Request<String>) -> Result<HyperResponse, SendError>;
114}
115
116/// The builder value used when the default hyper client should be used.
117#[cfg(any(feature = "hyper-rustls", feature = "hyper-tls"))]
118#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-rustls", feature = "hyper-tls"))))]
119#[derive(Default)]
120pub struct DefaultHyperClientBuilder {
121    timeout: Option<Duration>,
122}
123
124#[cfg(any(feature = "hyper-rustls", feature = "hyper-tls"))]
125impl DefaultHyperClientBuilder {
126    /// Set the duration after which a request times out
127    pub fn with_timeout(mut self, timeout: Duration) -> Self {
128        self.timeout = Some(timeout);
129        self
130    }
131}
132
133#[cfg(any(feature = "hyper-rustls", feature = "hyper-tls"))]
134#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-rustls", feature = "hyper-tls"))))]
135impl HyperClientBuilder for DefaultHyperClientBuilder {
136    #[cfg(feature = "hyper-rustls")]
137    type Connector =
138        hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>;
139    #[cfg(all(not(feature = "hyper-rustls"), feature = "hyper-tls"))]
140    type Connector = hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>;
141
142    fn with_timeout(mut self, timeout: Duration) -> Self {
143        self.timeout = Some(timeout);
144        self
145    }
146
147    fn build_hyper_client(self) -> Result<HttpClient<Self::Connector>, Error> {
148        #[cfg(feature = "hyper-rustls")]
149        let connector = hyper_rustls::HttpsConnectorBuilder::new()
150            .with_provider_and_native_roots(default_crypto_provider())?
151            .https_or_http()
152            .enable_http1()
153            .enable_http2()
154            .build();
155        #[cfg(all(not(feature = "hyper-rustls"), feature = "hyper-tls"))]
156        let connector = hyper_tls::HttpsConnector::new();
157
158        Ok(HttpClient::new(
159            hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
160                .pool_max_idle_per_host(0)
161                .build::<_, String>(connector),
162            self.timeout,
163        ))
164    }
165}
166
167/// Intended for using an existing hyper client with `yup-oauth2`. Instantiate
168/// with [`CustomHyperClientBuilder::from`]
169pub struct CustomHyperClientBuilder<C>
170where
171    C: Connect + Clone + Send + Sync + 'static,
172{
173    client: HttpClient<C>,
174    timeout: Option<Duration>,
175}
176
177impl<C> From<LegacyClient<C>> for CustomHyperClientBuilder<C>
178where
179    C: Connect + Clone + Send + Sync + 'static,
180{
181    fn from(client: LegacyClient<C>) -> Self {
182        Self {
183            client: HttpClient::new(client, None),
184            timeout: None,
185        }
186    }
187}
188
189impl<C> HyperClientBuilder for CustomHyperClientBuilder<C>
190where
191    C: Connect + Clone + Send + Sync + 'static,
192{
193    type Connector = C;
194
195    fn with_timeout(mut self, timeout: Duration) -> Self {
196        self.timeout = Some(timeout);
197        self
198    }
199
200    fn build_hyper_client(self) -> Result<HttpClient<Self::Connector>, Error> {
201        Ok(self.client)
202    }
203}