use std::{
num::{NonZeroU16, NonZeroU64, NonZeroU8},
ops::ControlFlow,
sync::Arc,
time::{Duration, SystemTime},
};
use bytes::Bytes;
use futures_util::Stream;
use httpdate::parse_http_date;
use reqwest::{
header::{HeaderMap, HeaderValue, RETRY_AFTER},
Request,
};
use thiserror::Error as ThisError;
use tracing::{debug, info, instrument};
pub use reqwest::{header, Error as ReqwestError, Method, StatusCode};
pub use url::Url;
mod delay_request;
use delay_request::DelayRequest;
mod certificate;
pub use certificate::Certificate;
mod request_builder;
pub use request_builder::{Body, RequestBuilder, Response};
mod tls_version;
pub use tls_version::TLSVersion;
#[cfg(feature = "hickory-dns")]
mod resolver;
#[cfg(feature = "hickory-dns")]
use resolver::TrustDnsResolver;
#[cfg(feature = "json")]
pub use request_builder::JsonError;
const MAX_RETRY_DURATION: Duration = Duration::from_secs(120);
const MAX_RETRY_COUNT: u8 = 3;
const DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT: Duration = Duration::from_millis(200);
const RETRY_DURATION_FOR_TIMEOUT: Duration = Duration::from_millis(200);
#[allow(dead_code)]
const DEFAULT_MIN_TLS: TLSVersion = TLSVersion::TLS_1_2;
#[derive(Debug, ThisError)]
#[non_exhaustive]
pub enum Error {
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
Http(Box<HttpError>),
#[cfg(feature = "json")]
#[error("Failed to parse http response body as Json: {0}")]
Json(#[from] JsonError),
}
#[derive(Debug, ThisError)]
#[error("could not {method} {url}: {err}")]
pub struct HttpError {
method: reqwest::Method,
url: url::Url,
#[source]
err: reqwest::Error,
}
impl HttpError {
pub fn is_status(&self) -> bool {
self.err.is_status()
}
}
#[derive(Debug)]
struct Inner {
client: reqwest::Client,
service: DelayRequest,
}
#[derive(Clone, Debug)]
pub struct Client(Arc<Inner>);
#[cfg_attr(not(feature = "__tls"), allow(unused_variables, unused_mut))]
impl Client {
pub fn new(
user_agent: impl AsRef<str>,
min_tls: Option<TLSVersion>,
per_millis: NonZeroU16,
num_request: NonZeroU64,
certificates: impl IntoIterator<Item = Certificate>,
) -> Result<Self, Error> {
fn inner(
user_agent: &str,
min_tls: Option<TLSVersion>,
per_millis: NonZeroU16,
num_request: NonZeroU64,
certificates: &mut dyn Iterator<Item = Certificate>,
) -> Result<Client, Error> {
let mut builder = reqwest::ClientBuilder::new()
.user_agent(user_agent)
.https_only(true)
.tcp_nodelay(false);
#[cfg(feature = "hickory-dns")]
{
builder = builder.dns_resolver(Arc::new(TrustDnsResolver::default()));
}
#[cfg(feature = "__tls")]
{
let tls_ver = min_tls
.map(|tls| tls.max(DEFAULT_MIN_TLS))
.unwrap_or(DEFAULT_MIN_TLS);
builder = builder.min_tls_version(tls_ver.into());
for certificate in certificates {
builder = builder.add_root_certificate(certificate.0);
}
}
let client = builder.build()?;
Ok(Client(Arc::new(Inner {
client: client.clone(),
service: DelayRequest::new(
num_request,
Duration::from_millis(per_millis.get() as u64),
client,
),
})))
}
inner(
user_agent.as_ref(),
min_tls,
per_millis,
num_request,
&mut certificates.into_iter(),
)
}
pub fn get_inner(&self) -> &reqwest::Client {
&self.0.client
}
#[instrument(
skip(self, url),
fields(
url = format_args!("{url}"),
),
)]
async fn do_send_request(
&self,
request: Request,
url: &Url,
) -> Result<ControlFlow<reqwest::Response, Result<reqwest::Response, ReqwestError>>, ReqwestError>
{
static HEADER_VALUE_0: HeaderValue = HeaderValue::from_static("0");
let response = match self.0.service.call(request).await {
Err(err) if err.is_timeout() || err.is_connect() => {
let duration = RETRY_DURATION_FOR_TIMEOUT;
info!("Received timeout error from reqwest. Delay future request by {duration:#?}");
self.0.service.add_urls_to_delay(&[url], duration);
return Ok(ControlFlow::Continue(Err(err)));
}
res => res?,
};
let status = response.status();
let add_delay_and_continue = |response: reqwest::Response, duration| {
info!("Received status code {status}, will wait for {duration:#?} and retry");
self.0
.service
.add_urls_to_delay(&[url, response.url()], duration);
Ok(ControlFlow::Continue(Ok(response)))
};
let headers = response.headers();
if let Some(duration) = parse_header_retry_after(headers) {
add_delay_and_continue(response, duration.min(MAX_RETRY_DURATION))
} else if headers.get("x-ratelimit-remaining") == Some(&HEADER_VALUE_0) {
let duration = headers
.get("x-ratelimit-reset")
.and_then(|value| {
let secs = value.to_str().ok()?.parse().ok()?;
Some(Duration::from_secs(secs))
})
.unwrap_or(DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
.min(MAX_RETRY_DURATION);
add_delay_and_continue(response, duration)
} else {
match status {
StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS => {
add_delay_and_continue(response, DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
}
StatusCode::REQUEST_TIMEOUT | StatusCode::GATEWAY_TIMEOUT => {
add_delay_and_continue(response, RETRY_DURATION_FOR_TIMEOUT)
}
_ => Ok(ControlFlow::Break(response)),
}
}
}
async fn send_request_inner(
&self,
request: &Request,
) -> Result<reqwest::Response, ReqwestError> {
let mut count = 0;
let max_retry_count = NonZeroU8::new(MAX_RETRY_COUNT).unwrap();
loop {
count += 1;
match self
.do_send_request(request.try_clone().unwrap(), request.url())
.await?
{
ControlFlow::Break(response) => break Ok(response),
ControlFlow::Continue(res) if count >= max_retry_count.get() => {
break res;
}
_ => (),
}
}
}
async fn send_request(
&self,
request: Request,
error_for_status: bool,
) -> Result<reqwest::Response, Error> {
debug!("Downloading from: '{}'", request.url());
self.send_request_inner(&request)
.await
.and_then(|response| {
if error_for_status {
response.error_for_status()
} else {
Ok(response)
}
})
.map_err(|err| {
Error::Http(Box::new(HttpError {
method: request.method().clone(),
url: request.url().clone(),
err,
}))
})
}
async fn head_or_fallback_to_get(
&self,
url: Url,
error_for_status: bool,
) -> Result<reqwest::Response, Error> {
let res = self
.send_request(Request::new(Method::HEAD, url.clone()), error_for_status)
.await;
let retry_with_get = move || async move {
info!("HEAD on {url} is not allowed, fallback to GET");
self.send_request(Request::new(Method::GET, url), error_for_status)
.await
};
let is_retryable = |status| {
matches!(
status,
StatusCode::BAD_REQUEST | StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED | StatusCode::GONE )
};
match res {
Err(Error::Http(http_error))
if http_error.err.status().map(is_retryable).unwrap_or(false) =>
{
retry_with_get().await
}
Ok(response) if is_retryable(response.status()) => retry_with_get().await,
res => res,
}
}
pub async fn remote_gettable(&self, url: Url) -> Result<bool, Error> {
Ok(self.get(url).send(false).await?.status().is_success())
}
pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, Error> {
self.head_or_fallback_to_get(url, true)
.await
.map(|response| response.url().clone())
}
pub async fn get_stream(
&self,
url: Url,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
Ok(self.get(url).send(true).await?.bytes_stream())
}
pub fn request(&self, method: Method, url: Url) -> RequestBuilder {
RequestBuilder {
client: self.clone(),
inner: self.0.client.request(method, url),
}
}
pub fn get(&self, url: Url) -> RequestBuilder {
self.request(Method::GET, url)
}
pub fn post(&self, url: Url, body: impl Into<Body>) -> RequestBuilder {
self.request(Method::POST, url).body(body.into())
}
}
fn parse_header_retry_after(headers: &HeaderMap) -> Option<Duration> {
let header = headers
.get_all(RETRY_AFTER)
.into_iter()
.last()?
.to_str()
.ok()?;
match header.parse::<u64>() {
Ok(dur) => Some(Duration::from_secs(dur)),
Err(_) => {
let system_time = parse_http_date(header).ok()?;
let retry_after_unix_timestamp =
system_time.duration_since(SystemTime::UNIX_EPOCH).ok()?;
let curr_time_unix_timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime before UNIX EPOCH!");
Some(retry_after_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
}
}
}