mod stream;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use base64::Engine;
use futures_util::io::{BufReader, BufWriter};
use jsonrpsee_core::client::{MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
use jsonrpsee_core::{async_trait, Cow};
use soketto::connection::Error::Utf8;
use soketto::data::ByteSlice125;
use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse};
use soketto::{connection, Data, Incoming};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
pub use soketto::handshake::client::Header;
pub use stream::EitherStream;
pub use tokio::io::{AsyncRead, AsyncWrite};
pub use url::Url;
const LOG_TARGET: &str = "jsonrpsee-client";
#[cfg(feature = "tls")]
pub type CustomCertStore = rustls::ClientConfig;
#[allow(clippy::large_enum_variant)]
#[cfg(feature = "tls")]
#[derive(Debug, Clone)]
pub enum CertificateStore {
Native,
Custom(CustomCertStore),
}
#[derive(Debug)]
pub struct Sender<T> {
inner: connection::Sender<BufReader<BufWriter<T>>>,
max_request_size: u32,
}
#[derive(Debug)]
pub struct Receiver<T> {
inner: connection::Receiver<BufReader<BufWriter<T>>>,
}
#[derive(Debug)]
pub struct WsTransportClientBuilder {
#[cfg(feature = "tls")]
pub certificate_store: CertificateStore,
pub connection_timeout: Duration,
pub headers: http::HeaderMap,
pub max_request_size: u32,
pub max_response_size: u32,
pub max_redirections: usize,
pub tcp_no_delay: bool,
}
impl Default for WsTransportClientBuilder {
fn default() -> Self {
Self {
#[cfg(feature = "tls")]
certificate_store: CertificateStore::Native,
max_request_size: TEN_MB_SIZE_BYTES,
max_response_size: TEN_MB_SIZE_BYTES,
connection_timeout: Duration::from_secs(10),
headers: http::HeaderMap::new(),
max_redirections: 5,
tcp_no_delay: true,
}
}
}
impl WsTransportClientBuilder {
#[cfg(feature = "tls")]
pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
self.certificate_store = CertificateStore::Custom(cfg);
self
}
pub fn max_request_size(mut self, size: u32) -> Self {
self.max_request_size = size;
self
}
pub fn max_response_size(mut self, size: u32) -> Self {
self.max_response_size = size;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}
pub fn max_redirections(mut self, redirect: usize) -> Self {
self.max_redirections = redirect;
self
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Mode {
Plain,
Tls,
}
#[derive(Debug, Error)]
pub enum WsHandshakeError {
#[error("Failed to load system certs: {0}")]
CertificateStore(io::Error),
#[error("Invalid URL: {0}")]
Url(Cow<'static, str>),
#[error("Error when opening the TCP socket: {0}")]
Io(io::Error),
#[error("{0}")]
Transport(#[source] soketto::handshake::Error),
#[error("Connection rejected with status code: {status_code}")]
Rejected {
status_code: u16,
},
#[error("Connection redirected with status code: {status_code} and location: {location}")]
Redirected {
status_code: u16,
location: String,
},
#[error("Connection timeout exceeded: {0:?}")]
Timeout(Duration),
#[error("Failed to resolve IP addresses for this hostname: {0}")]
ResolutionFailed(io::Error),
#[error("No IP address found for this hostname: {0}")]
NoAddressFound(String),
}
#[derive(Debug, Error)]
pub enum WsError {
#[error("{0}")]
Connection(#[source] soketto::connection::Error),
#[error("The message was too large")]
MessageTooLarge,
}
#[async_trait]
impl<T> TransportSenderT for Sender<T>
where
T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;
async fn send(&mut self, body: String) -> Result<(), Self::Error> {
if body.len() > self.max_request_size as usize {
return Err(WsError::MessageTooLarge);
}
self.inner.send_text(body).await?;
self.inner.flush().await?;
Ok(())
}
async fn send_ping(&mut self) -> Result<(), Self::Error> {
tracing::debug!(target: LOG_TARGET, "Send ping");
let slice: &[u8] = &[];
let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
self.inner.send_ping(byte_slice).await?;
self.inner.flush().await?;
Ok(())
}
async fn close(&mut self) -> Result<(), WsError> {
self.inner.close().await.map_err(Into::into)
}
}
#[async_trait]
impl<T> TransportReceiverT for Receiver<T>
where
T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;
async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> {
loop {
let mut message = Vec::new();
let recv = self.inner.receive(&mut message).await?;
match recv {
Incoming::Data(Data::Text(_)) => {
let s = String::from_utf8(message).map_err(|err| WsError::Connection(Utf8(err.utf8_error())))?;
break Ok(ReceivedMessage::Text(s));
}
Incoming::Data(Data::Binary(_)) => break Ok(ReceivedMessage::Bytes(message)),
Incoming::Pong(_) => break Ok(ReceivedMessage::Pong),
_ => continue,
}
}
}
}
impl WsTransportClientBuilder {
pub async fn build(
self,
uri: Url,
) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
self.try_connect_over_tcp(uri).await
}
pub async fn build_with_stream<T>(
self,
uri: Url,
data_stream: T,
) -> Result<(Sender<Compat<T>>, Receiver<Compat<T>>), WsHandshakeError>
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let target: Target = uri.try_into()?;
self.try_connect(&target, data_stream.compat()).await
}
async fn try_connect_over_tcp(
&self,
uri: Url,
) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
let mut target: Target = uri.clone().try_into()?;
let mut err = None;
#[cfg(feature = "tls")]
let mut connector = match target._mode {
Mode::Tls => Some(build_tls_config(&self.certificate_store)?),
Mode::Plain => None,
};
let mut target_sockaddrs = uri.socket_addrs(|| None).map_err(WsHandshakeError::ResolutionFailed)?;
for _ in 0..self.max_redirections {
tracing::debug!(target: LOG_TARGET, "Connecting to target: {:?}", target);
let sockaddrs = std::mem::take(&mut target_sockaddrs);
for sockaddr in &sockaddrs {
#[cfg(feature = "tls")]
let tcp_stream = match connect(*sockaddr, self.connection_timeout, &target.host, connector.as_ref(), self.tcp_no_delay)
.await
{
Ok(stream) => stream,
Err(e) => {
tracing::debug!(target: LOG_TARGET, "Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};
#[cfg(not(feature = "tls"))]
let tcp_stream = match connect(*sockaddr, self.connection_timeout).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!(target: LOG_TARGET, "Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};
match self.try_connect(&target, tcp_stream.compat()).await {
Ok(result) => return Ok(result),
Err(WsHandshakeError::Redirected { status_code, location }) => {
tracing::debug!(target: LOG_TARGET, "Redirection: status_code: {}, location: {}", status_code, location);
match Url::parse(&location) {
Ok(uri) => {
target_sockaddrs = uri.socket_addrs(|| None).map_err(|e| {
tracing::debug!(target: LOG_TARGET, "Redirection failed: {:?}", e);
e
})?;
target = uri.try_into().map_err(|e| {
tracing::debug!(target: LOG_TARGET, "Redirection failed: {:?}", e);
e
})?;
#[cfg(feature = "tls")]
match target._mode {
Mode::Tls if connector.is_none() => {
connector = Some(build_tls_config(&self.certificate_store)?);
}
Mode::Tls => (),
Mode::Plain => {
connector = None;
}
};
}
Err(url::ParseError::RelativeUrlWithoutBase) => {
if location.starts_with('/') {
target.path_and_query = location;
} else {
match target.path_and_query.rfind('/') {
Some(offset) => target.path_and_query.replace_range(offset + 1.., &location),
None => {
let e = format!("path_and_query: {location}; this is a bug it must contain `/` please open issue");
err = Some(Err(WsHandshakeError::Url(e.into())));
continue;
}
};
}
target_sockaddrs = sockaddrs;
break;
}
Err(e) => {
err = Some(Err(WsHandshakeError::Url(e.to_string().into())));
}
};
}
Err(e) => {
err = Some(Err(e));
}
};
}
}
err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host)))
}
async fn try_connect<T>(
&self,
target: &Target,
data_stream: T,
) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
where
T: futures_util::AsyncRead + futures_util::AsyncWrite + Unpin,
{
let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(data_stream)),
&target.host_header,
&target.path_and_query,
);
let headers: Vec<_> = match &target.basic_auth {
Some(basic_auth) if !self.headers.contains_key(http::header::AUTHORIZATION) => {
let it1 =
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() });
let it2 = std::iter::once(Header {
name: http::header::AUTHORIZATION.as_str(),
value: basic_auth.as_bytes(),
});
it1.chain(it2).collect()
}
_ => {
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect()
}
};
client.set_headers(&headers);
match client.handshake().await {
Ok(ServerResponse::Accepted { .. }) => {
tracing::debug!(target: LOG_TARGET, "Connection established to target: {:?}", target);
let mut builder = client.into_builder();
builder.set_max_message_size(self.max_response_size as usize);
let (sender, receiver) = builder.finish();
Ok((Sender { inner: sender, max_request_size: self.max_request_size }, Receiver { inner: receiver }))
}
Ok(ServerResponse::Rejected { status_code }) => {
tracing::debug!(target: LOG_TARGET, "Connection rejected: {:?}", status_code);
Err(WsHandshakeError::Rejected { status_code })
}
Ok(ServerResponse::Redirect { status_code, location }) => {
tracing::debug!(target: LOG_TARGET, "Redirection: status_code: {}, location: {}", status_code, location);
Err(WsHandshakeError::Redirected { status_code, location })
}
Err(e) => Err(e.into()),
}
}
}
#[cfg(feature = "tls")]
async fn connect(
sockaddr: SocketAddr,
timeout_dur: Duration,
host: &str,
tls_connector: Option<&tokio_rustls::TlsConnector>,
tcp_no_delay: bool,
) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
tokio::select! {
socket = socket => {
let socket = socket?;
if let Err(err) = socket.set_nodelay(tcp_no_delay) {
tracing::warn!(target: LOG_TARGET, "set nodelay failed: {:?}", err);
}
match tls_connector {
None => Ok(EitherStream::Plain(socket)),
Some(connector) => {
let server_name: rustls_pki_types::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {host} {e:?}").into()))?;
let tls_stream = connector.connect(server_name.to_owned(), socket).await?;
Ok(EitherStream::Tls(tls_stream))
}
}
}
_ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
}
}
#[cfg(not(feature = "tls"))]
async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
tokio::select! {
socket = socket => {
let socket = socket?;
if let Err(err) = socket.set_nodelay(true) {
tracing::warn!(target: LOG_TARGET, "set nodelay failed: {:?}", err);
}
Ok(EitherStream::Plain(socket))
}
_ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
}
}
impl From<io::Error> for WsHandshakeError {
fn from(err: io::Error) -> WsHandshakeError {
WsHandshakeError::Io(err)
}
}
impl From<soketto::handshake::Error> for WsHandshakeError {
fn from(err: soketto::handshake::Error) -> WsHandshakeError {
WsHandshakeError::Transport(err)
}
}
impl From<soketto::connection::Error> for WsError {
fn from(err: soketto::connection::Error) -> Self {
WsError::Connection(err)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Target {
host: String,
host_header: String,
_mode: Mode,
path_and_query: String,
basic_auth: Option<HeaderValue>,
}
impl TryFrom<url::Url> for Target {
type Error = WsHandshakeError;
fn try_from(url: Url) -> Result<Self, Self::Error> {
let _mode = match url.scheme() {
"ws" => Mode::Plain,
#[cfg(feature = "tls")]
"wss" => Mode::Tls,
invalid_scheme => {
#[cfg(feature = "tls")]
let err = format!("`{invalid_scheme}` not supported, expects 'ws' or 'wss'");
#[cfg(not(feature = "tls"))]
let err = format!("`{invalid_scheme}` not supported, expects 'ws' ('wss' requires the tls feature)");
return Err(WsHandshakeError::Url(err.into()));
}
};
let host = url.host_str().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("Invalid host".into()))?;
let mut path_and_query = url.path().to_owned();
if let Some(query) = url.query() {
path_and_query.push('?');
path_and_query.push_str(query);
}
let basic_auth = if let Some(pwd) = url.password() {
let digest = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", url.username(), pwd));
let val = HeaderValue::from_str(&format!("Basic {digest}"))
.map_err(|_| WsHandshakeError::Url("Header value `authorization basic user:pwd` invalid".into()))?;
Some(val)
} else {
None
};
let host_header = if let Some(port) = url.port() { format!("{host}:{port}") } else { host.to_string() };
Ok(Self { host, host_header, _mode, path_and_query: path_and_query.to_string(), basic_auth })
}
}
#[cfg(feature = "tls")]
fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsConnector, WsHandshakeError> {
let config = match cert_store {
CertificateStore::Native => rustls_platform_verifier::tls_config(),
CertificateStore::Custom(cfg) => cfg.clone(),
};
Ok(std::sync::Arc::new(config).into())
}
#[cfg(test)]
mod tests {
use http::HeaderValue;
use super::{Mode, Target, Url, WsHandshakeError};
fn assert_ws_target(
target: Target,
host: &str,
host_header: &str,
mode: Mode,
path_and_query: &str,
basic_auth: Option<HeaderValue>,
) {
assert_eq!(&target.host, host);
assert_eq!(&target.host_header, host_header);
assert_eq!(target._mode, mode);
assert_eq!(&target.path_and_query, path_and_query);
assert_eq!(target.basic_auth, basic_auth);
}
fn parse_target(uri: &str) -> Result<Target, WsHandshakeError> {
Url::parse(uri).map_err(|e| WsHandshakeError::Url(e.to_string().into()))?.try_into()
}
#[test]
fn ws_works_with_port() {
let target = parse_target("ws://127.0.0.1:9933").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/", None);
}
#[cfg(feature = "tls")]
#[test]
fn wss_works_with_port() {
let target = parse_target("wss://kusama-rpc.polkadot.io:9999").unwrap();
assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:9999", Mode::Tls, "/", None);
}
#[cfg(not(feature = "tls"))]
#[test]
fn wss_fails_with_tls_feature() {
let err = parse_target("wss://kusama-rpc.polkadot.io").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
#[test]
fn faulty_url_scheme() {
let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
#[test]
fn faulty_port() {
let err = parse_target("ws://127.0.0.1:-43").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
let err = parse_target("ws://127.0.0.1:99999").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
#[test]
fn url_with_path_works() {
let target = parse_target("ws://127.0.0.1/my-special-path").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my-special-path", None);
}
#[test]
fn url_with_query_works() {
let target = parse_target("ws://127.0.0.1/my?name1=value1&name2=value2").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my?name1=value1&name2=value2", None);
}
#[test]
fn url_with_fragment_is_ignored() {
let target = parse_target("ws://127.0.0.1:/my.htm#ignore").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my.htm", None);
}
#[cfg(feature = "tls")]
#[test]
fn wss_default_port_is_omitted() {
let target = parse_target("wss://127.0.0.1:443").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Tls, "/", None);
}
#[test]
fn ws_default_port_is_omitted() {
let target = parse_target("ws://127.0.0.1:80").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", None);
}
#[test]
fn ws_with_username_and_password() {
use base64::Engine;
let target = parse_target("ws://user:pwd@127.0.0.1").unwrap();
let digest = base64::engine::general_purpose::STANDARD.encode("user:pwd");
let basic_auth = HeaderValue::from_str(&format!("Basic {digest}")).unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", Some(basic_auth));
}
}