use base64::Engine;
use hyper::body::Bytes;
use hyper::http::{HeaderMap, HeaderValue};
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use jsonrpsee_core::tracing::client::{rx_log_from_bytes, tx_log_from_str};
use jsonrpsee_core::BoxError;
use jsonrpsee_core::{
http_helpers::{self, HttpError},
TEN_MB_SIZE_BYTES,
};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
use tower::layer::util::Identity;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{HttpBody, HttpRequest, HttpResponse};
#[cfg(feature = "tls")]
use crate::{CertificateStore, CustomCertStore};
const CONTENT_TYPE_JSON: &str = "application/json";
#[derive(Debug)]
pub enum HttpBackend<B = HttpBody> {
#[cfg(feature = "tls")]
Https(Client<hyper_rustls::HttpsConnector<HttpConnector>, B>),
Http(Client<HttpConnector, B>),
}
impl<B> Clone for HttpBackend<B> {
fn clone(&self) -> Self {
match self {
Self::Http(inner) => Self::Http(inner.clone()),
#[cfg(feature = "tls")]
Self::Https(inner) => Self::Https(inner.clone()),
}
}
}
impl<B> tower::Service<HttpRequest<B>> for HttpBackend<B>
where
B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Response = HttpResponse<hyper::body::Incoming>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Http(inner) => inner.poll_ready(ctx),
#[cfg(feature = "tls")]
Self::Https(inner) => inner.poll_ready(ctx),
}
.map_err(|e| Error::Http(HttpError::Stream(e.into())))
}
fn call(&mut self, req: HttpRequest<B>) -> Self::Future {
let resp = match self {
Self::Http(inner) => inner.call(req),
#[cfg(feature = "tls")]
Self::Https(inner) => inner.call(req),
};
Box::pin(async move { resp.await.map_err(|e| Error::Http(HttpError::Stream(e.into()))) })
}
}
#[derive(Debug)]
pub struct HttpTransportClientBuilder<L> {
#[cfg(feature = "tls")]
pub(crate) certificate_store: CertificateStore,
pub(crate) max_request_size: u32,
pub(crate) max_response_size: u32,
pub(crate) max_log_length: u32,
pub(crate) headers: HeaderMap,
pub(crate) service_builder: tower::ServiceBuilder<L>,
pub(crate) tcp_no_delay: bool,
}
impl Default for HttpTransportClientBuilder<Identity> {
fn default() -> Self {
Self::new()
}
}
impl HttpTransportClientBuilder<Identity> {
pub fn new() -> Self {
Self {
#[cfg(feature = "tls")]
certificate_store: CertificateStore::Native,
max_request_size: TEN_MB_SIZE_BYTES,
max_response_size: TEN_MB_SIZE_BYTES,
max_log_length: 1024,
headers: HeaderMap::new(),
service_builder: tower::ServiceBuilder::new(),
tcp_no_delay: true,
}
}
}
impl<L> HttpTransportClientBuilder<L> {
#[cfg(feature = "tls")]
pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
self.certificate_store = CertificateStore::Custom(cfg);
self
}
pub fn max_request_size(mut self, size: u32) -> Self {
self.max_request_size = size;
self
}
pub fn max_response_size(mut self, size: u32) -> Self {
self.max_response_size = size;
self
}
pub fn set_headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}
pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
self.tcp_no_delay = no_delay;
self
}
pub fn set_max_logging_length(mut self, max: u32) -> Self {
self.max_log_length = max;
self
}
pub fn set_service<T>(self, service: tower::ServiceBuilder<T>) -> HttpTransportClientBuilder<T> {
HttpTransportClientBuilder {
#[cfg(feature = "tls")]
certificate_store: self.certificate_store,
headers: self.headers,
max_log_length: self.max_log_length,
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
service_builder: service,
tcp_no_delay: self.tcp_no_delay,
}
}
pub fn build<S, B>(self, target: impl AsRef<str>) -> Result<HttpTransportClient<S>, Error>
where
L: Layer<HttpBackend, Service = S>,
S: Service<HttpRequest, Response = HttpResponse<B>, Error = Error> + Clone,
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
let Self {
#[cfg(feature = "tls")]
certificate_store,
max_request_size,
max_response_size,
max_log_length,
headers,
service_builder,
tcp_no_delay,
} = self;
let mut url = Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {e}")))?;
if url.host_str().is_none() {
return Err(Error::Url("Invalid host".into()));
}
url.set_fragment(None);
let client = match url.scheme() {
"http" => {
let mut connector = HttpConnector::new();
connector.set_nodelay(tcp_no_delay);
HttpBackend::Http(Client::builder(TokioExecutor::new()).build(connector))
}
#[cfg(feature = "tls")]
"https" => {
let mut http_conn = HttpConnector::new();
http_conn.set_nodelay(tcp_no_delay);
http_conn.enforce_http(false);
let https_conn = match certificate_store {
CertificateStore::Native => hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(rustls_platform_verifier::tls_config())
.https_or_http()
.enable_all_versions()
.wrap_connector(http_conn),
CertificateStore::Custom(tls_config) => hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_all_versions()
.wrap_connector(http_conn),
};
HttpBackend::Https(Client::builder(TokioExecutor::new()).build(https_conn))
}
_ => {
#[cfg(feature = "tls")]
let err = "URL scheme not supported, expects 'http' or 'https'";
#[cfg(not(feature = "tls"))]
let err = "URL scheme not supported, expects 'http'";
return Err(Error::Url(err.into()));
}
};
let mut cached_headers = HeaderMap::with_capacity(2 + headers.len());
cached_headers.insert(hyper::header::CONTENT_TYPE, HeaderValue::from_static(CONTENT_TYPE_JSON));
cached_headers.insert(hyper::header::ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON));
for (key, value) in headers.into_iter() {
if let Some(key) = key {
cached_headers.insert(key, value);
}
}
if let Some(pwd) = url.password() {
if !cached_headers.contains_key(hyper::header::AUTHORIZATION) {
let digest = base64::engine::general_purpose::STANDARD.encode(format!("{}:{pwd}", url.username()));
cached_headers.insert(
hyper::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {digest}"))
.map_err(|_| Error::Url("Header value `authorization basic user:pwd` invalid".into()))?,
);
}
}
Ok(HttpTransportClient {
target: url.as_str().to_owned(),
client: service_builder.service(client),
max_request_size,
max_response_size,
max_log_length,
headers: cached_headers,
})
}
}
#[derive(Debug, Clone)]
pub struct HttpTransportClient<S> {
target: String,
client: S,
max_request_size: u32,
max_response_size: u32,
max_log_length: u32,
headers: HeaderMap,
}
impl<B, S> HttpTransportClient<S>
where
S: Service<HttpRequest, Response = HttpResponse<B>, Error = Error> + Clone,
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
async fn inner_send(&self, body: String) -> Result<HttpResponse<B>, Error> {
if body.len() > self.max_request_size as usize {
return Err(Error::RequestTooLarge);
}
let mut req = HttpRequest::post(&self.target);
if let Some(headers) = req.headers_mut() {
*headers = self.headers.clone();
}
let req = req.body(body.into()).expect("URI and request headers are valid; qed");
let response = self.client.clone().ready().await?.call(req).await?;
if response.status().is_success() {
Ok(response)
} else {
Err(Error::Rejected { status_code: response.status().into() })
}
}
pub(crate) async fn send_and_read_body(&self, body: String) -> Result<Vec<u8>, Error> {
tx_log_from_str(&body, self.max_log_length);
let response = self.inner_send(body).await?;
let (parts, body) = response.into_parts();
let (body, _is_single) = http_helpers::read_body(&parts.headers, body, self.max_response_size).await?;
rx_log_from_bytes(&body, self.max_log_length);
Ok(body)
}
pub(crate) async fn send(&self, body: String) -> Result<(), Error> {
let _ = self.inner_send(body).await?;
Ok(())
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("Invalid Url: {0}")]
Url(String),
#[error("{0}")]
Http(#[from] HttpError),
#[error("Request rejected `{status_code}`")]
Rejected {
status_code: u16,
},
#[error("The request body was too large")]
RequestTooLarge,
#[error("Invalid certificate store")]
InvalidCertficateStore,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClientBuilder::new().build("ws://localhost:9933").unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClientBuilder::new().build("https://localhost").unwrap();
assert_eq!(&client.target, "https://localhost/");
}
#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClientBuilder::new().build("https://localhost").unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn faulty_port() {
let err = HttpTransportClientBuilder::new().build("http://localhost:-43").unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClientBuilder::new().build("http://localhost:-99999").unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn url_with_path_works() {
let client = HttpTransportClientBuilder::new().build("http://localhost/my-special-path").unwrap();
assert_eq!(&client.target, "http://localhost/my-special-path");
}
#[test]
fn url_with_query_works() {
let client = HttpTransportClientBuilder::new().build("http://127.0.0.1/my?name1=value1&name2=value2").unwrap();
assert_eq!(&client.target, "http://127.0.0.1/my?name1=value1&name2=value2");
}
#[test]
fn url_with_fragment_is_ignored() {
let client = HttpTransportClientBuilder::new().build("http://127.0.0.1/my.htm#ignore").unwrap();
assert_eq!(&client.target, "http://127.0.0.1/my.htm");
}
#[test]
fn url_default_port_is_omitted() {
let client = HttpTransportClientBuilder::new().build("http://127.0.0.1:80").unwrap();
assert_eq!(&client.target, "http://127.0.0.1/");
}
#[cfg(feature = "tls")]
#[test]
fn https_custom_port_works() {
let client = HttpTransportClientBuilder::new().build("https://localhost:9999").unwrap();
assert_eq!(&client.target, "https://localhost:9999/");
}
#[test]
fn http_custom_port_works() {
let client = HttpTransportClientBuilder::new().build("http://localhost:9999").unwrap();
assert_eq!(&client.target, "http://localhost:9999/");
}
#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
let fifty_bytes_limit = 50;
let client = HttpTransportClientBuilder::new()
.max_request_size(eighty_bytes_limit)
.max_response_size(fifty_bytes_limit)
.build("http://localhost:9933")
.unwrap();
assert_eq!(client.max_request_size, eighty_bytes_limit);
assert_eq!(client.max_response_size, fifty_bytes_limit);
let body = "a".repeat(81);
assert_eq!(body.len(), 81);
let response = client.send(body).await.unwrap_err();
assert!(matches!(response, Error::RequestTooLarge));
}
}