use std::error::Error as StdError;
use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use crate::future::{ConnectionGuard, FutureDriver, ServerHandle, StopHandle};
use crate::logger::{Logger, TransportProtocol};
use crate::transport::{http, ws};
use futures_util::future::{BoxFuture, FutureExt};
use futures_util::io::{BufReader, BufWriter};
use hyper::body::HttpBody;
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::server::helpers::MethodResponse;
use jsonrpsee_core::server::host_filtering::AllowHosts;
use jsonrpsee_core::server::resource_limiting::Resources;
use jsonrpsee_core::server::rpc_module::Methods;
use jsonrpsee_core::traits::IdProvider;
use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES};
use soketto::handshake::http::is_upgrade_request;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::{watch, OwnedSemaphorePermit};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tower::layer::util::Identity;
use tower::{Layer, Service};
use tracing::{instrument, Instrument};
const MAX_CONNECTIONS: u32 = 100;
pub struct Server<B = Identity, L = ()> {
listener: TcpListener,
cfg: Settings,
resources: Resources,
logger: L,
id_provider: Arc<dyn IdProvider>,
service_builder: tower::ServiceBuilder<B>,
}
impl<L> std::fmt::Debug for Server<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Server")
.field("listener", &self.listener)
.field("cfg", &self.cfg)
.field("id_provider", &self.id_provider)
.field("resources", &self.resources)
.finish()
}
}
impl<B, L> Server<B, L> {
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.listener.local_addr().map_err(Into::into)
}
}
impl<B, U, L> Server<B, L>
where
L: Logger,
B: Layer<TowerService<L>> + Send + 'static,
<B as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<hyper::Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<B as Layer<TowerService<L>>>::Service as Service<hyper::Request<hyper::Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
let methods = methods.into().initialize_resources(&self.resources)?;
let (stop_tx, stop_rx) = watch::channel(());
let stop_handle = StopHandle::new(stop_rx);
match self.cfg.tokio_runtime.take() {
Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)),
None => tokio::spawn(self.start_inner(methods, stop_handle)),
};
Ok(ServerHandle::new(stop_tx))
}
async fn start_inner(self, methods: Methods, stop_handle: StopHandle) {
let max_request_body_size = self.cfg.max_request_body_size;
let max_response_body_size = self.cfg.max_response_body_size;
let max_log_length = self.cfg.max_log_length;
let allow_hosts = self.cfg.allow_hosts;
let resources = self.resources;
let logger = self.logger;
let batch_requests_supported = self.cfg.batch_requests_supported;
let id_provider = self.id_provider;
let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection;
let mut id: u32 = 0;
let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize);
let mut connections = FutureDriver::default();
let mut incoming = Monitored::new(Incoming(self.listener), &stop_handle);
loop {
match connections.select_with(&mut incoming).await {
Ok((socket, remote_addr)) => {
let data = ProcessConnection {
remote_addr,
methods: methods.clone(),
allow_hosts: allow_hosts.clone(),
resources: resources.clone(),
max_request_body_size,
max_response_body_size,
max_log_length,
batch_requests_supported,
id_provider: id_provider.clone(),
ping_interval: self.cfg.ping_interval,
stop_handle: stop_handle.clone(),
max_subscriptions_per_connection,
conn_id: id,
logger: logger.clone(),
max_connections: self.cfg.max_connections,
enable_http: self.cfg.enable_http,
enable_ws: self.cfg.enable_ws,
};
process_connection(&self.service_builder, &connection_guard, data, socket, &mut connections);
id = id.wrapping_add(1);
}
Err(MonitoredError::Selector(err)) => {
tracing::error!("Error while awaiting a new connection: {:?}", err);
}
Err(MonitoredError::Shutdown) => break,
}
}
connections.await;
}
}
#[derive(Debug, Clone)]
struct Settings {
max_request_body_size: u32,
max_response_body_size: u32,
max_connections: u32,
max_subscriptions_per_connection: u32,
max_log_length: u32,
allow_hosts: AllowHosts,
batch_requests_supported: bool,
tokio_runtime: Option<tokio::runtime::Handle>,
ping_interval: Duration,
enable_http: bool,
enable_ws: bool,
}
impl Default for Settings {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_response_body_size: TEN_MB_SIZE_BYTES,
max_log_length: 4096,
max_subscriptions_per_connection: 1024,
max_connections: MAX_CONNECTIONS,
batch_requests_supported: true,
allow_hosts: AllowHosts::Any,
tokio_runtime: None,
ping_interval: Duration::from_secs(60),
enable_http: true,
enable_ws: true,
}
}
}
#[derive(Debug)]
pub struct Builder<B = Identity, L = ()> {
settings: Settings,
resources: Resources,
logger: L,
id_provider: Arc<dyn IdProvider>,
service_builder: tower::ServiceBuilder<B>,
}
impl Default for Builder {
fn default() -> Self {
Builder {
settings: Settings::default(),
resources: Resources::default(),
logger: (),
id_provider: Arc::new(RandomIntegerIdProvider),
service_builder: tower::ServiceBuilder::new(),
}
}
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
}
impl<B, L> Builder<B, L> {
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.settings.max_request_body_size = size;
self
}
pub fn max_response_body_size(mut self, size: u32) -> Self {
self.settings.max_response_body_size = size;
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.settings.max_connections = max;
self
}
pub fn batch_requests_supported(mut self, supported: bool) -> Self {
self.settings.batch_requests_supported = supported;
self
}
pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
self.settings.max_subscriptions_per_connection = max;
self
}
pub fn register_resource(mut self, label: &'static str, capacity: u16, default: u16) -> Result<Self, Error> {
self.resources.register(label, capacity, default)?;
Ok(self)
}
pub fn set_logger<T: Logger>(self, logger: T) -> Builder<B, T> {
Builder {
settings: self.settings,
resources: self.resources,
logger,
id_provider: self.id_provider,
service_builder: self.service_builder,
}
}
pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
self.settings.tokio_runtime = Some(rt);
self
}
pub fn ping_interval(mut self, interval: Duration) -> Self {
self.settings.ping_interval = interval;
self
}
pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
self.id_provider = Arc::new(id_provider);
self
}
pub fn set_host_filtering(mut self, allow: AllowHosts) -> Self {
self.settings.allow_hosts = allow;
self
}
pub fn set_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> Builder<T, L> {
Builder {
settings: self.settings,
resources: self.resources,
logger: self.logger,
id_provider: self.id_provider,
service_builder,
}
}
pub fn http_only(mut self) -> Self {
self.settings.enable_http = true;
self.settings.enable_ws = false;
self
}
pub fn ws_only(mut self) -> Self {
self.settings.enable_http = false;
self.settings.enable_ws = true;
self
}
pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<B, L>, Error> {
let listener = TcpListener::bind(addrs).await?;
Ok(Server {
listener,
cfg: self.settings,
resources: self.resources,
logger: self.logger,
id_provider: self.id_provider,
service_builder: self.service_builder,
})
}
pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<B, L>, Error> {
let listener = TcpListener::from_std(listener.into())?;
Ok(Server {
listener,
cfg: self.settings,
resources: self.resources,
logger: self.logger,
id_provider: self.id_provider,
service_builder: self.service_builder,
})
}
}
pub(crate) enum MethodResult {
JustLogger(MethodResponse),
SendAndLogger(MethodResponse),
}
impl MethodResult {
pub(crate) fn as_inner(&self) -> &MethodResponse {
match &self {
Self::JustLogger(r) => r,
Self::SendAndLogger(r) => r,
}
}
pub(crate) fn into_inner(self) -> MethodResponse {
match self {
Self::JustLogger(r) => r,
Self::SendAndLogger(r) => r,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ServiceData<L: Logger> {
pub(crate) remote_addr: SocketAddr,
pub(crate) methods: Methods,
pub(crate) allow_hosts: AllowHosts,
pub(crate) resources: Resources,
pub(crate) max_request_body_size: u32,
pub(crate) max_response_body_size: u32,
pub(crate) max_log_length: u32,
pub(crate) batch_requests_supported: bool,
pub(crate) id_provider: Arc<dyn IdProvider>,
pub(crate) ping_interval: Duration,
pub(crate) stop_handle: StopHandle,
pub(crate) max_subscriptions_per_connection: u32,
pub(crate) conn_id: u32,
pub(crate) logger: L,
pub(crate) conn: Arc<OwnedSemaphorePermit>,
pub(crate) enable_http: bool,
pub(crate) enable_ws: bool,
}
#[derive(Debug, Clone)]
pub struct TowerService<L: Logger> {
inner: ServiceData<L>,
}
impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerService<L> {
type Response = hyper::Response<hyper::Body>;
type Error = Box<dyn StdError + Send + Sync + 'static>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: hyper::Request<hyper::Body>) -> Self::Future {
tracing::trace!("{:?}", request);
let host = match http_helpers::read_header_value(request.headers(), hyper::header::HOST) {
Some(host) => host,
None if request.version() == hyper::Version::HTTP_2 => match request.uri().host() {
Some(host) => host,
None => return async move { Ok(http::response::malformed()) }.boxed(),
},
None => return async move { Ok(http::response::malformed()) }.boxed(),
};
if let Err(e) = self.inner.allow_hosts.verify(host) {
tracing::warn!("Denied request: {}", e);
return async { Ok(http::response::host_not_allowed()) }.boxed();
}
let is_upgrade_request = is_upgrade_request(&request);
if self.inner.enable_ws && is_upgrade_request {
let mut server = soketto::handshake::http::Server::new();
let response = match server.receive_request(&request) {
Ok(response) => {
self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::WebSocket);
let data = self.inner.clone();
tokio::spawn(
async move {
let upgraded = match hyper::upgrade::on(request).await {
Ok(u) => u,
Err(e) => {
tracing::warn!("Could not upgrade connection: {}", e);
return;
}
};
let stream = BufReader::new(BufWriter::new(upgraded.compat()));
let mut ws_builder = server.into_builder(stream);
ws_builder.set_max_message_size(data.max_request_body_size as usize);
let (sender, receiver) = ws_builder.finish();
let _ = ws::background_task::<L>(sender, receiver, data).await;
}
.in_current_span(),
);
response.map(|()| hyper::Body::empty())
}
Err(e) => {
tracing::error!("Could not upgrade connection: {}", e);
hyper::Response::new(hyper::Body::from(format!("Could not upgrade connection: {}", e)))
}
};
async { Ok(response) }.boxed()
} else if self.inner.enable_http && !is_upgrade_request {
let data = http::HandleRequest {
methods: self.inner.methods.clone(),
resources: self.inner.resources.clone(),
max_request_body_size: self.inner.max_request_body_size,
max_response_body_size: self.inner.max_response_body_size,
max_log_length: self.inner.max_log_length,
batch_requests_supported: self.inner.batch_requests_supported,
logger: self.inner.logger.clone(),
conn: self.inner.conn.clone(),
remote_addr: self.inner.remote_addr,
};
self.inner.logger.on_connect(self.inner.remote_addr, &request, TransportProtocol::Http);
Box::pin(http::handle_request(request, data).map(Ok))
} else {
Box::pin(async { http::response::denied() }.map(Ok))
}
}
}
struct Monitored<'a, F> {
future: F,
stop_monitor: &'a StopHandle,
}
impl<'a, F> Monitored<'a, F> {
fn new(future: F, stop_monitor: &'a StopHandle) -> Self {
Monitored { future, stop_monitor }
}
}
enum MonitoredError<E> {
Shutdown,
Selector(E),
}
struct Incoming(TcpListener);
impl<'a> Future for Monitored<'a, Incoming> {
type Output = Result<(TcpStream, SocketAddr), MonitoredError<std::io::Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);
if this.stop_monitor.shutdown_requested() {
return Poll::Ready(Err(MonitoredError::Shutdown));
}
this.future.0.poll_accept(cx).map_err(MonitoredError::Selector)
}
}
impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>>
where
F: Future<Output = Result<T, E>>,
{
type Output = Result<T, MonitoredError<E>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);
if this.stop_monitor.shutdown_requested() {
return Poll::Ready(Err(MonitoredError::Shutdown));
}
this.future.poll_unpin(cx).map_err(MonitoredError::Selector)
}
}
struct ProcessConnection<L> {
remote_addr: SocketAddr,
methods: Methods,
allow_hosts: AllowHosts,
resources: Resources,
max_request_body_size: u32,
max_response_body_size: u32,
max_log_length: u32,
batch_requests_supported: bool,
id_provider: Arc<dyn IdProvider>,
ping_interval: Duration,
stop_handle: StopHandle,
max_subscriptions_per_connection: u32,
max_connections: u32,
conn_id: u32,
logger: L,
enable_http: bool,
enable_ws: bool,
}
#[instrument(name = "connection", skip_all, fields(remote_addr = %cfg.remote_addr, conn_id = %cfg.conn_id), level = "INFO")]
fn process_connection<'a, L: Logger, B, U>(
service_builder: &tower::ServiceBuilder<B>,
connection_guard: &ConnectionGuard,
cfg: ProcessConnection<L>,
socket: TcpStream,
connections: &mut FutureDriver<BoxFuture<'a, ()>>,
) where
B: Layer<TowerService<L>> + Send + 'static,
<B as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<hyper::Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<B as Layer<TowerService<L>>>::Service as Service<hyper::Request<hyper::Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("Could not set NODELAY on socket: {:?}", e);
return;
}
let conn = match connection_guard.try_acquire() {
Some(conn) => conn,
None => {
tracing::warn!("Too many connections. Please try again later.");
connections.add(http::reject_connection(socket).in_current_span().boxed());
return;
}
};
let max_conns = cfg.max_connections as usize;
let curr_conns = max_conns - connection_guard.available_connections();
tracing::info!("Accepting new connection {}/{}", curr_conns, max_conns);
let tower_service = TowerService {
inner: ServiceData {
remote_addr: cfg.remote_addr,
methods: cfg.methods,
allow_hosts: cfg.allow_hosts,
resources: cfg.resources,
max_request_body_size: cfg.max_request_body_size,
max_response_body_size: cfg.max_response_body_size,
max_log_length: cfg.max_log_length,
batch_requests_supported: cfg.batch_requests_supported,
id_provider: cfg.id_provider,
ping_interval: cfg.ping_interval,
stop_handle: cfg.stop_handle.clone(),
max_subscriptions_per_connection: cfg.max_subscriptions_per_connection,
conn_id: cfg.conn_id,
logger: cfg.logger,
conn: Arc::new(conn),
enable_http: cfg.enable_http,
enable_ws: cfg.enable_ws,
},
};
let service = service_builder.service(tower_service);
connections.add(Box::pin(try_accept_connection(socket, service, cfg.stop_handle).in_current_span()));
}
async fn try_accept_connection<S, Bd>(socket: TcpStream, service: S, mut stop_handle: StopHandle)
where
S: Service<hyper::Request<hyper::Body>, Response = hyper::Response<Bd>> + Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send,
Bd: HttpBody + Send + 'static,
<Bd as HttpBody>::Error: Send + Sync + StdError,
<Bd as HttpBody>::Data: Send,
{
let conn = hyper::server::conn::Http::new().serve_connection(socket, service).with_upgrades();
tokio::pin!(conn);
tokio::select! {
res = &mut conn => {
if let Err(e) = res {
tracing::warn!("HTTP serve connection failed {:?}", e);
}
}
_ = stop_handle.shutdown() => {
conn.graceful_shutdown();
}
}
}