use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use futures::TryStreamExt;
use rustls_pki_types::CertificateDer;
use std::io::{Error as IOError, ErrorKind};
use tokio_util::io::StreamReader;
use tracing::error;
pub struct HTTP;
impl HTTP {
pub fn new() -> HTTP {
Self {}
}
fn client(
&self,
client_certs: Option<Vec<CertificateDer<'static>>>,
) -> Result<reqwest::Client> {
match client_certs {
Some(client_certs) => {
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_parsable_certificates(&client_certs);
let client_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let client = reqwest::Client::builder()
.use_preconfigured_tls(client_config)
.build()
.or_err(ErrorType::HTTPError)?;
Ok(client)
}
None => {
let client = reqwest::Client::builder()
.use_native_tls()
.build()
.or_err(ErrorType::HTTPError)?;
Ok(client)
}
}
}
}
#[tonic::async_trait]
impl crate::Backend for HTTP {
async fn head(&self, request: crate::HeadRequest) -> Result<crate::HeadResponse> {
let header = request.http_header.ok_or(Error::InvalidParameter)?;
let response = self
.client(request.client_certs)?
.get(&request.url)
.headers(header)
.timeout(request.timeout)
.send()
.await
.or_err(ErrorType::HTTPError)
.map_err(|err| {
error!("head request failed: {}", err);
err
})?;
let header = response.headers().clone();
let status_code = response.status();
Ok(crate::HeadResponse {
http_header: Some(header),
http_status_code: Some(status_code),
})
}
async fn get(&self, request: crate::GetRequest) -> Result<crate::GetResponse<crate::Body>> {
let header = request.http_header.ok_or(Error::InvalidParameter)?;
let response = self
.client(request.client_certs)?
.get(&request.url)
.headers(header)
.timeout(request.timeout)
.send()
.await
.or_err(ErrorType::HTTPError)
.map_err(|err| {
error!("get request failed: {}", err);
err
})?;
let header = response.headers().clone();
let status_code = response.status();
let reader = Box::new(StreamReader::new(
response
.bytes_stream()
.map_err(|err| IOError::new(ErrorKind::Other, err)),
));
Ok(crate::GetResponse {
http_header: Some(header),
http_status_code: Some(status_code),
reader,
})
}
}
impl Default for HTTP {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use crate::{Backend, GetRequest, HeadRequest};
use httpmock::{Method, MockServer};
use reqwest::{header::HeaderMap, StatusCode};
use super::*;
#[tokio::test]
async fn should_get_head_response() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(Method::GET).path("/head");
then.status(200)
.header("content-type", "text/html; charset=UTF-8")
.body("");
});
let http_backend = HTTP::new();
let resp = http_backend
.head(HeadRequest {
url: server.url("/head"),
http_header: Some(HeaderMap::new()),
timeout: std::time::Duration::from_secs(5),
client_certs: None,
})
.await
.unwrap();
assert_eq!(resp.http_status_code, Some(StatusCode::OK))
}
#[tokio::test]
async fn should_return_error_response_when_head_notexists() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(Method::GET).path("/head");
then.status(200)
.header("content-type", "text/html; charset=UTF-8")
.body("");
});
let http_backend = HTTP::new();
let resp = http_backend
.head(HeadRequest {
url: server.url("/head"),
http_header: None,
timeout: std::time::Duration::from_secs(5),
client_certs: None,
})
.await;
assert!(resp.is_err());
}
#[tokio::test]
async fn should_get_response() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(Method::GET).path("/get");
then.status(200)
.header("content-type", "text/html; charset=UTF-8")
.body("OK");
});
let http_backend = HTTP::new();
let mut resp = http_backend
.get(GetRequest {
url: server.url("/get"),
http_header: Some(HeaderMap::new()),
timeout: std::time::Duration::from_secs(5),
client_certs: None,
})
.await
.unwrap();
assert_eq!(resp.http_status_code, Some(StatusCode::OK));
assert_eq!(resp.text().await.unwrap(), "OK");
}
}