1use std::error::Error as StdError;
28use std::future::Future;
29use std::net::{SocketAddr, TcpListener as StdTcpListener};
30use std::pin::Pin;
31use std::sync::atomic::AtomicU32;
32use std::sync::Arc;
33use std::task::Poll;
34use std::time::Duration;
35
36use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle};
37use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
38use crate::transport::ws::BackgroundTaskParams;
39use crate::transport::{http, ws};
40use crate::utils::deserialize;
41use crate::{Extensions, HttpBody, HttpRequest, HttpResponse, LOG_TARGET};
42
43use futures_util::future::{self, Either, FutureExt};
44use futures_util::io::{BufReader, BufWriter};
45
46use hyper::body::Bytes;
47use hyper_util::rt::{TokioExecutor, TokioIo};
48use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
49use jsonrpsee_core::server::helpers::prepare_error;
50use jsonrpsee_core::server::{
51 BatchResponseBuilder, BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods,
52};
53use jsonrpsee_core::traits::IdProvider;
54use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
55
56use jsonrpsee_types::error::{
57 reject_too_big_batch_request, ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG,
58};
59use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification};
60use soketto::handshake::http::is_upgrade_request;
61use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
62use tokio::sync::{mpsc, watch, OwnedSemaphorePermit};
63use tokio_util::compat::TokioAsyncReadCompatExt;
64use tower::layer::util::Identity;
65use tower::{Layer, Service};
66use tracing::{instrument, Instrument};
67
68type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>;
69
70const MAX_CONNECTIONS: u32 = 100;
72
73pub struct Server<HttpMiddleware = Identity, RpcMiddleware = Identity> {
75 listener: TcpListener,
76 server_cfg: ServerConfig,
77 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
78 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
79}
80
81impl Server<Identity, Identity> {
82 pub fn builder() -> Builder<Identity, Identity> {
84 Builder::new()
85 }
86}
87
88impl<RpcMiddleware, HttpMiddleware> std::fmt::Debug for Server<RpcMiddleware, HttpMiddleware> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("Server").field("listener", &self.listener).field("server_cfg", &self.server_cfg).finish()
91 }
92}
93
94impl<RpcMiddleware, HttpMiddleware> Server<RpcMiddleware, HttpMiddleware> {
95 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
97 self.listener.local_addr()
98 }
99}
100
101impl<HttpMiddleware, RpcMiddleware, Body> Server<HttpMiddleware, RpcMiddleware>
102where
103 RpcMiddleware: tower::Layer<RpcService> + Clone + Send + 'static,
104 for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
105 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
106 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
107 Send + Clone + Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError>,
108 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future: Send,
109 Body: http_body::Body<Data = Bytes> + Send + 'static,
110 <Body as http_body::Body>::Error: Into<BoxError>,
111 <Body as http_body::Body>::Data: Send,
112{
113 pub fn start(mut self, methods: impl Into<Methods>) -> ServerHandle {
117 let methods = methods.into();
118 let (stop_tx, stop_rx) = watch::channel(());
119
120 let stop_handle = StopHandle::new(stop_rx);
121
122 match self.server_cfg.tokio_runtime.take() {
123 Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)),
124 None => tokio::spawn(self.start_inner(methods, stop_handle)),
125 };
126
127 ServerHandle::new(stop_tx)
128 }
129
130 async fn start_inner(self, methods: Methods, stop_handle: StopHandle) {
131 let mut id: u32 = 0;
132 let connection_guard = ConnectionGuard::new(self.server_cfg.max_connections as usize);
133 let listener = self.listener;
134
135 let stopped = stop_handle.clone().shutdown();
136 tokio::pin!(stopped);
137
138 let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
139
140 loop {
141 match try_accept_conn(&listener, stopped).await {
142 AcceptConnection::Established { socket, remote_addr, stop } => {
143 process_connection(ProcessConnection {
144 http_middleware: &self.http_middleware,
145 rpc_middleware: self.rpc_middleware.clone(),
146 remote_addr,
147 methods: methods.clone(),
148 stop_handle: stop_handle.clone(),
149 conn_id: id,
150 server_cfg: self.server_cfg.clone(),
151 conn_guard: &connection_guard,
152 socket,
153 drop_on_completion: drop_on_completion.clone(),
154 });
155 id = id.wrapping_add(1);
156 stopped = stop;
157 }
158 AcceptConnection::Err((e, stop)) => {
159 tracing::debug!(target: LOG_TARGET, "Error while awaiting a new connection: {:?}", e);
160 stopped = stop;
161 }
162 AcceptConnection::Shutdown => break,
163 }
164 }
165
166 drop(drop_on_completion);
168
169 while process_connection_awaiter.recv().await.is_some() {
171 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct ServerConfig {
180 pub(crate) max_request_body_size: u32,
182 pub(crate) max_response_body_size: u32,
184 pub(crate) max_connections: u32,
186 pub(crate) max_subscriptions_per_connection: u32,
188 pub(crate) batch_requests_config: BatchRequestConfig,
190 pub(crate) tokio_runtime: Option<tokio::runtime::Handle>,
192 pub(crate) enable_http: bool,
194 pub(crate) enable_ws: bool,
196 pub(crate) message_buffer_capacity: u32,
198 pub(crate) ping_config: Option<PingConfig>,
200 pub(crate) id_provider: Arc<dyn IdProvider>,
202 pub(crate) tcp_no_delay: bool,
204}
205
206#[derive(Debug, Clone)]
207pub struct ServerConfigBuilder {
208 max_request_body_size: u32,
210 max_response_body_size: u32,
212 max_connections: u32,
214 max_subscriptions_per_connection: u32,
216 batch_requests_config: BatchRequestConfig,
218 enable_http: bool,
220 enable_ws: bool,
222 message_buffer_capacity: u32,
224 ping_config: Option<PingConfig>,
226 id_provider: Arc<dyn IdProvider>,
228}
229
230#[derive(Debug, Clone)]
232pub struct TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
233 pub(crate) server_cfg: ServerConfig,
235 pub(crate) rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
237 pub(crate) http_middleware: tower::ServiceBuilder<HttpMiddleware>,
239 pub(crate) conn_id: Arc<AtomicU32>,
241 pub(crate) conn_guard: ConnectionGuard,
243}
244
245#[derive(Debug, Copy, Clone)]
247pub enum BatchRequestConfig {
248 Disabled,
250 Limit(u32),
252 Unlimited,
254}
255
256#[derive(Debug, Clone)]
259pub struct ConnectionState {
260 pub(crate) stop_handle: StopHandle,
262 pub(crate) conn_id: u32,
264 pub(crate) _conn_permit: Arc<OwnedSemaphorePermit>,
266}
267
268impl ConnectionState {
269 pub fn new(stop_handle: StopHandle, conn_id: u32, conn_permit: OwnedSemaphorePermit) -> ConnectionState {
271 Self { stop_handle, conn_id, _conn_permit: Arc::new(conn_permit) }
272 }
273}
274
275#[derive(Debug, Copy, Clone)]
288pub struct PingConfig {
289 pub(crate) ping_interval: Duration,
291 pub(crate) inactive_limit: Duration,
293 pub(crate) max_failures: usize,
295}
296
297impl Default for PingConfig {
298 fn default() -> Self {
299 Self { ping_interval: Duration::from_secs(30), max_failures: 1, inactive_limit: Duration::from_secs(40) }
300 }
301}
302
303impl PingConfig {
304 pub fn new() -> Self {
306 Self::default()
307 }
308
309 pub fn ping_interval(mut self, ping_interval: Duration) -> Self {
311 self.ping_interval = ping_interval;
312 self
313 }
314
315 pub fn inactive_limit(mut self, inactivity_limit: Duration) -> Self {
321 self.inactive_limit = inactivity_limit;
322 self
323 }
324
325 pub fn max_failures(mut self, max: usize) -> Self {
332 assert!(max > 0);
333 self.max_failures = max;
334 self
335 }
336}
337
338impl Default for ServerConfig {
339 fn default() -> Self {
340 Self {
341 max_request_body_size: TEN_MB_SIZE_BYTES,
342 max_response_body_size: TEN_MB_SIZE_BYTES,
343 max_connections: MAX_CONNECTIONS,
344 max_subscriptions_per_connection: 1024,
345 batch_requests_config: BatchRequestConfig::Unlimited,
346 tokio_runtime: None,
347 enable_http: true,
348 enable_ws: true,
349 message_buffer_capacity: 1024,
350 ping_config: None,
351 id_provider: Arc::new(RandomIntegerIdProvider),
352 tcp_no_delay: true,
353 }
354 }
355}
356
357impl ServerConfig {
358 pub fn builder() -> ServerConfigBuilder {
360 ServerConfigBuilder::default()
361 }
362}
363
364impl Default for ServerConfigBuilder {
365 fn default() -> Self {
366 let this = ServerConfig::default();
367
368 ServerConfigBuilder {
369 max_request_body_size: this.max_request_body_size,
370 max_response_body_size: this.max_response_body_size,
371 max_connections: this.max_connections,
372 max_subscriptions_per_connection: this.max_subscriptions_per_connection,
373 batch_requests_config: this.batch_requests_config,
374 enable_http: this.enable_http,
375 enable_ws: this.enable_ws,
376 message_buffer_capacity: this.message_buffer_capacity,
377 ping_config: this.ping_config,
378 id_provider: this.id_provider,
379 }
380 }
381}
382
383impl ServerConfigBuilder {
384 pub fn new() -> Self {
386 Self::default()
387 }
388
389 pub fn max_request_body_size(mut self, size: u32) -> Self {
391 self.max_request_body_size = size;
392 self
393 }
394
395 pub fn max_response_body_size(mut self, size: u32) -> Self {
397 self.max_response_body_size = size;
398 self
399 }
400
401 pub fn max_connections(mut self, max: u32) -> Self {
403 self.max_connections = max;
404 self
405 }
406
407 pub fn set_batch_request_config(mut self, cfg: BatchRequestConfig) -> Self {
409 self.batch_requests_config = cfg;
410 self
411 }
412
413 pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
415 self.max_subscriptions_per_connection = max;
416 self
417 }
418
419 pub fn http_only(mut self) -> Self {
421 self.enable_http = true;
422 self.enable_ws = false;
423 self
424 }
425
426 pub fn ws_only(mut self) -> Self {
428 self.enable_http = false;
429 self.enable_ws = true;
430 self
431 }
432
433 pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
435 self.message_buffer_capacity = c;
436 self
437 }
438
439 pub fn enable_ws_ping(mut self, config: PingConfig) -> Self {
441 self.ping_config = Some(config);
442 self
443 }
444
445 pub fn disable_ws_ping(mut self) -> Self {
447 self.ping_config = None;
448 self
449 }
450
451 pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
453 self.id_provider = Arc::new(id_provider);
454 self
455 }
456}
457
458#[derive(Debug)]
460pub struct Builder<HttpMiddleware, RpcMiddleware> {
461 server_cfg: ServerConfig,
462 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
463 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
464}
465
466impl Default for Builder<Identity, Identity> {
467 fn default() -> Self {
468 Builder {
469 server_cfg: ServerConfig::default(),
470 rpc_middleware: RpcServiceBuilder::new(),
471 http_middleware: tower::ServiceBuilder::new(),
472 }
473 }
474}
475
476impl Builder<Identity, Identity> {
477 pub fn new() -> Self {
479 Self::default()
480 }
481}
482
483impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
484 pub fn build(
486 self,
487 methods: impl Into<Methods>,
488 stop_handle: StopHandle,
489 ) -> TowerService<RpcMiddleware, HttpMiddleware> {
490 let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
491
492 let rpc_middleware = TowerServiceNoHttp {
493 rpc_middleware: self.rpc_middleware,
494 inner: ServiceData {
495 methods: methods.into(),
496 stop_handle,
497 conn_id,
498 conn_guard: self.conn_guard,
499 server_cfg: self.server_cfg,
500 },
501 on_session_close: None,
502 };
503
504 TowerService { rpc_middleware, http_middleware: self.http_middleware }
505 }
506
507 pub fn connection_id(mut self, id: u32) -> Self {
511 self.conn_id = Arc::new(AtomicU32::new(id));
512 self
513 }
514
515 pub fn max_connections(mut self, limit: u32) -> Self {
517 self.conn_guard = ConnectionGuard::new(limit as usize);
518 self
519 }
520
521 pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> TowerServiceBuilder<T, HttpMiddleware> {
523 TowerServiceBuilder {
524 server_cfg: self.server_cfg,
525 rpc_middleware,
526 http_middleware: self.http_middleware,
527 conn_id: self.conn_id,
528 conn_guard: self.conn_guard,
529 }
530 }
531
532 pub fn set_http_middleware<T>(
534 self,
535 http_middleware: tower::ServiceBuilder<T>,
536 ) -> TowerServiceBuilder<RpcMiddleware, T> {
537 TowerServiceBuilder {
538 server_cfg: self.server_cfg,
539 rpc_middleware: self.rpc_middleware,
540 http_middleware,
541 conn_id: self.conn_id,
542 conn_guard: self.conn_guard,
543 }
544 }
545}
546
547impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
548 pub fn max_request_body_size(mut self, size: u32) -> Self {
550 self.server_cfg.max_request_body_size = size;
551 self
552 }
553
554 pub fn max_response_body_size(mut self, size: u32) -> Self {
556 self.server_cfg.max_response_body_size = size;
557 self
558 }
559
560 pub fn max_connections(mut self, max: u32) -> Self {
562 self.server_cfg.max_connections = max;
563 self
564 }
565
566 pub fn set_batch_request_config(mut self, cfg: BatchRequestConfig) -> Self {
571 self.server_cfg.batch_requests_config = cfg;
572 self
573 }
574
575 pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
577 self.server_cfg.max_subscriptions_per_connection = max;
578 self
579 }
580
581 pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> Builder<HttpMiddleware, T> {
640 Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware }
641 }
642
643 pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
647 self.server_cfg.tokio_runtime = Some(rt);
648 self
649 }
650
651 pub fn enable_ws_ping(mut self, config: PingConfig) -> Self {
666 self.server_cfg.ping_config = Some(config);
667 self
668 }
669
670 pub fn disable_ws_ping(mut self) -> Self {
674 self.server_cfg.ping_config = None;
675 self
676 }
677
678 pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
699 self.server_cfg.id_provider = Arc::new(id_provider);
700 self
701 }
702
703 pub fn set_http_middleware<T>(self, http_middleware: tower::ServiceBuilder<T>) -> Builder<T, RpcMiddleware> {
726 Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware }
727 }
728
729 pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
733 self.server_cfg.tcp_no_delay = no_delay;
734 self
735 }
736
737 pub fn http_only(mut self) -> Self {
741 self.server_cfg.enable_http = true;
742 self.server_cfg.enable_ws = false;
743 self
744 }
745
746 pub fn ws_only(mut self) -> Self {
752 self.server_cfg.enable_http = false;
753 self.server_cfg.enable_ws = true;
754 self
755 }
756
757 pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
776 self.server_cfg.message_buffer_capacity = c;
777 self
778 }
779
780 pub fn to_service_builder(self) -> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
854 let max_conns = self.server_cfg.max_connections as usize;
855
856 TowerServiceBuilder {
857 server_cfg: self.server_cfg,
858 rpc_middleware: self.rpc_middleware,
859 http_middleware: self.http_middleware,
860 conn_id: Arc::new(AtomicU32::new(0)),
861 conn_guard: ConnectionGuard::new(max_conns),
862 }
863 }
864
865 pub async fn build(self, addrs: impl ToSocketAddrs) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
882 let listener = TcpListener::bind(addrs).await?;
883
884 Ok(Server {
885 listener,
886 server_cfg: self.server_cfg,
887 rpc_middleware: self.rpc_middleware,
888 http_middleware: self.http_middleware,
889 })
890 }
891
892 pub fn build_from_tcp(
916 self,
917 listener: impl Into<StdTcpListener>,
918 ) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
919 let listener = TcpListener::from_std(listener.into())?;
920
921 Ok(Server {
922 listener,
923 server_cfg: self.server_cfg,
924 rpc_middleware: self.rpc_middleware,
925 http_middleware: self.http_middleware,
926 })
927 }
928}
929
930#[derive(Debug, Clone)]
932struct ServiceData {
933 methods: Methods,
935 stop_handle: StopHandle,
937 conn_id: u32,
939 conn_guard: ConnectionGuard,
941 server_cfg: ServerConfig,
943}
944
945#[derive(Debug, Clone)]
950pub struct TowerService<RpcMiddleware, HttpMiddleware> {
951 rpc_middleware: TowerServiceNoHttp<RpcMiddleware>,
952 http_middleware: tower::ServiceBuilder<HttpMiddleware>,
953}
954
955impl<RpcMiddleware, HttpMiddleware> TowerService<RpcMiddleware, HttpMiddleware> {
956 pub fn on_session_closed(&mut self) -> SessionClosedFuture {
962 if let Some(n) = self.rpc_middleware.on_session_close.as_mut() {
963 n.closed()
965 } else {
966 let (session_close, fut) = session_close();
967 self.rpc_middleware.on_session_close = Some(session_close);
968 fut
969 }
970 }
971}
972
973impl<RequestBody, ResponseBody, RpcMiddleware, HttpMiddleware> Service<HttpRequest<RequestBody>> for TowerService<RpcMiddleware, HttpMiddleware>
974where
975 RpcMiddleware: for<'a> tower::Layer<RpcService> + Clone,
976 <RpcMiddleware as Layer<RpcService>>::Service: Send + Sync + 'static,
977 for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
978 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
979 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
980 Send + Service<HttpRequest<RequestBody>, Response = HttpResponse<ResponseBody>, Error = Box<(dyn StdError + Send + Sync + 'static)>>,
981 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest<RequestBody>>>::Future:
982 Send + 'static,
983 RequestBody: http_body::Body<Data = Bytes> + Send + 'static,
984 RequestBody::Error: Into<BoxError>,
985{
986 type Response = HttpResponse<ResponseBody>;
987 type Error = BoxError;
988 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
989
990 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
991 Poll::Ready(Ok(()))
992 }
993
994 fn call(&mut self, request: HttpRequest<RequestBody>) -> Self::Future {
995 Box::pin(self.http_middleware.service(self.rpc_middleware.clone()).call(request))
996 }
997}
998
999#[derive(Debug, Clone)]
1004pub struct TowerServiceNoHttp<L> {
1005 inner: ServiceData,
1006 rpc_middleware: RpcServiceBuilder<L>,
1007 on_session_close: Option<SessionClose>,
1008}
1009
1010impl<Body, RpcMiddleware> Service<HttpRequest<Body>> for TowerServiceNoHttp<RpcMiddleware>
1011where
1012 RpcMiddleware: for<'a> tower::Layer<RpcService>,
1013 <RpcMiddleware as Layer<RpcService>>::Service: Send + Sync + 'static,
1014 for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
1015 Body: http_body::Body<Data = Bytes> + Send + 'static,
1016 Body::Error: Into<BoxError>,
1017{
1018 type Response = HttpResponse;
1019
1020 type Error = BoxError;
1023
1024 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1025
1026 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
1027 Poll::Ready(Ok(()))
1028 }
1029
1030 fn call(&mut self, request: HttpRequest<Body>) -> Self::Future {
1031 let mut request = request.map(HttpBody::new);
1032
1033 let conn_guard = &self.inner.conn_guard;
1034 let stop_handle = self.inner.stop_handle.clone();
1035 let conn_id = self.inner.conn_id;
1036 let on_session_close = self.on_session_close.take();
1037
1038 tracing::trace!(target: LOG_TARGET, "{:?}", request);
1039
1040 let Some(conn_permit) = conn_guard.try_acquire() else {
1041 return async move { Ok(http::response::too_many_requests()) }.boxed();
1042 };
1043
1044 let conn = ConnectionState::new(stop_handle.clone(), conn_id, conn_permit);
1045
1046 let max_conns = conn_guard.max_connections();
1047 let curr_conns = max_conns - conn_guard.available_connections();
1048 tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns);
1049
1050 let req_ext = request.extensions_mut();
1051 req_ext.insert::<ConnectionGuard>(conn_guard.clone());
1052 req_ext.insert::<ConnectionId>(conn.conn_id.into());
1053
1054 let is_upgrade_request = is_upgrade_request(&request);
1055
1056 if self.inner.server_cfg.enable_ws && is_upgrade_request {
1057 let this = self.inner.clone();
1058
1059 let mut server = soketto::handshake::http::Server::new();
1060
1061 let response = match server.receive_request(&request) {
1062 Ok(response) => {
1063 let (tx, rx) = mpsc::channel::<String>(this.server_cfg.message_buffer_capacity as usize);
1064 let sink = MethodSink::new(tx);
1065
1066 let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
1070
1071 let cfg = RpcServiceCfg::CallsAndSubscriptions {
1072 bounded_subscriptions: BoundedSubscriptions::new(
1073 this.server_cfg.max_subscriptions_per_connection,
1074 ),
1075 id_provider: this.server_cfg.id_provider.clone(),
1076 sink: sink.clone(),
1077 _pending_calls: pending_calls,
1078 };
1079
1080 let rpc_service = RpcService::new(
1081 this.methods.clone(),
1082 this.server_cfg.max_response_body_size as usize,
1083 this.conn_id.into(),
1084 cfg,
1085 );
1086
1087 let rpc_service = self.rpc_middleware.service(rpc_service);
1088
1089 tokio::spawn(
1090 async move {
1091 let extensions = request.extensions().clone();
1092
1093 let upgraded = match hyper::upgrade::on(request).await {
1094 Ok(u) => u,
1095 Err(e) => {
1096 tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
1097 return;
1098 }
1099 };
1100
1101 let io = hyper_util::rt::TokioIo::new(upgraded);
1102
1103 let stream = BufReader::new(BufWriter::new(io.compat()));
1104 let mut ws_builder = server.into_builder(stream);
1105 ws_builder.set_max_message_size(this.server_cfg.max_request_body_size as usize);
1106 let (sender, receiver) = ws_builder.finish();
1107
1108 let params = BackgroundTaskParams {
1109 server_cfg: this.server_cfg,
1110 conn,
1111 ws_sender: sender,
1112 ws_receiver: receiver,
1113 rpc_service,
1114 sink,
1115 rx,
1116 pending_calls_completed,
1117 on_session_close,
1118 extensions,
1119 };
1120
1121 ws::background_task(params).await;
1122 }
1123 .in_current_span(),
1124 );
1125
1126 response.map(|()| HttpBody::empty())
1127 }
1128 Err(e) => {
1129 tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
1130 HttpResponse::new(HttpBody::from(format!("Could not upgrade connection: {e}")))
1131 }
1132 };
1133
1134 async { Ok(response) }.boxed()
1135 } else if self.inner.server_cfg.enable_http && !is_upgrade_request {
1136 let this = &self.inner;
1137 let max_response_size = this.server_cfg.max_response_body_size;
1138 let max_request_size = this.server_cfg.max_request_body_size;
1139 let methods = this.methods.clone();
1140 let batch_config = this.server_cfg.batch_requests_config;
1141
1142 let rpc_service = self.rpc_middleware.service(RpcService::new(
1143 methods,
1144 max_response_size as usize,
1145 this.conn_id.into(),
1146 RpcServiceCfg::OnlyCalls,
1147 ));
1148
1149 Box::pin(async move {
1150 let rp =
1151 http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size)
1152 .await;
1153 drop(conn);
1156 Ok(rp)
1157 })
1158 } else {
1159 Box::pin(async { Ok(http::response::denied()) })
1162 }
1163 }
1164}
1165
1166struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
1167 http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
1168 rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
1169 conn_guard: &'a ConnectionGuard,
1170 conn_id: u32,
1171 server_cfg: ServerConfig,
1172 stop_handle: StopHandle,
1173 socket: TcpStream,
1174 drop_on_completion: mpsc::Sender<()>,
1175 remote_addr: SocketAddr,
1176 methods: Methods,
1177}
1178
1179#[instrument(name = "connection", skip_all, fields(remote_addr = %params.remote_addr, conn_id = %params.conn_id), level = "INFO")]
1180fn process_connection<'a, RpcMiddleware, HttpMiddleware, Body>(params: ProcessConnection<HttpMiddleware, RpcMiddleware>)
1181where
1182 RpcMiddleware: 'static,
1183 HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
1184 <HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
1185 Send + 'static + Clone + Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError>,
1186 <<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future:
1187 Send + 'static,
1188 Body: http_body::Body<Data = Bytes> + Send + 'static,
1189 <Body as http_body::Body>::Error: Into<BoxError>,
1190 <Body as http_body::Body>::Data: Send,
1191{
1192 let ProcessConnection {
1193 http_middleware,
1194 rpc_middleware,
1195 conn_guard,
1196 conn_id,
1197 server_cfg,
1198 socket,
1199 stop_handle,
1200 drop_on_completion,
1201 methods,
1202 ..
1203 } = params;
1204
1205 if let Err(e) = socket.set_nodelay(server_cfg.tcp_no_delay) {
1206 tracing::warn!(target: LOG_TARGET, "Could not set NODELAY on socket: {:?}", e);
1207 return;
1208 }
1209
1210 let tower_service = TowerServiceNoHttp {
1211 inner: ServiceData {
1212 server_cfg,
1213 methods,
1214 stop_handle: stop_handle.clone(),
1215 conn_id,
1216 conn_guard: conn_guard.clone(),
1217 },
1218 rpc_middleware,
1219 on_session_close: None,
1220 };
1221
1222 let service = http_middleware.service(tower_service);
1223
1224 tokio::spawn(async {
1225 let service = crate::utils::TowerToHyperService::new(service);
1227 let io = TokioIo::new(socket);
1228 let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
1229
1230 let conn = builder.serve_connection_with_upgrades(io, service);
1231 let stopped = stop_handle.shutdown();
1232
1233 tokio::pin!(stopped, conn);
1234
1235 let res = match future::select(conn, stopped).await {
1236 Either::Left((conn, _)) => conn,
1237 Either::Right((_, mut conn)) => {
1238 conn.as_mut().graceful_shutdown();
1241 conn.await
1242 }
1243 };
1244
1245 if let Err(e) = res {
1246 tracing::debug!(target: LOG_TARGET, "HTTP serve connection failed {:?}", e);
1247 }
1248 drop(drop_on_completion)
1249 });
1250}
1251
1252enum AcceptConnection<S> {
1253 Shutdown,
1254 Established { socket: TcpStream, remote_addr: SocketAddr, stop: S },
1255 Err((std::io::Error, S)),
1256}
1257
1258async fn try_accept_conn<S>(listener: &TcpListener, stopped: S) -> AcceptConnection<S>
1259where
1260 S: Future + Unpin,
1261{
1262 let accept = listener.accept();
1263 tokio::pin!(accept);
1264
1265 match futures_util::future::select(accept, stopped).await {
1266 Either::Left((res, stop)) => match res {
1267 Ok((socket, remote_addr)) => AcceptConnection::Established { socket, remote_addr, stop },
1268 Err(e) => AcceptConnection::Err((e, stop)),
1269 },
1270 Either::Right(_) => AcceptConnection::Shutdown,
1271 }
1272}
1273
1274pub(crate) async fn handle_rpc_call<S>(
1275 body: &[u8],
1276 is_single: bool,
1277 batch_config: BatchRequestConfig,
1278 max_response_size: u32,
1279 rpc_service: &S,
1280 extensions: Extensions,
1281) -> Option<MethodResponse>
1282where
1283 for<'a> S: RpcServiceT<'a> + Send,
1284{
1285 if is_single {
1287 if let Ok(req) = deserialize::from_slice_with_extensions(body, extensions) {
1288 Some(rpc_service.call(req).await)
1289 } else if let Ok(_notif) = serde_json::from_slice::<Notif>(body) {
1290 None
1291 } else {
1292 let (id, code) = prepare_error(body);
1293 Some(MethodResponse::error(id, ErrorObject::from(code)))
1294 }
1295 }
1296 else {
1298 let max_len = match batch_config {
1299 BatchRequestConfig::Disabled => {
1300 let rp = MethodResponse::error(
1301 Id::Null,
1302 ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
1303 );
1304 return Some(rp);
1305 }
1306 BatchRequestConfig::Limit(limit) => limit as usize,
1307 BatchRequestConfig::Unlimited => usize::MAX,
1308 };
1309
1310 if let Ok(batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(body) {
1311 if batch.len() > max_len {
1312 return Some(MethodResponse::error(Id::Null, reject_too_big_batch_request(max_len)));
1313 }
1314
1315 let mut got_notif = false;
1316 let mut batch_response = BatchResponseBuilder::new_with_limit(max_response_size as usize);
1317
1318 for call in batch {
1319 if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) {
1320 let rp = rpc_service.call(req).await;
1321
1322 if let Err(too_large) = batch_response.append(&rp) {
1323 return Some(too_large);
1324 }
1325 } else if let Ok(_notif) = serde_json::from_str::<Notif>(call.get()) {
1326 got_notif = true;
1328 } else {
1329 let id = match serde_json::from_str::<InvalidRequest>(call.get()) {
1331 Ok(err) => err.id,
1332 Err(_) => Id::Null,
1333 };
1334
1335 if let Err(too_large) =
1336 batch_response.append(&MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)))
1337 {
1338 return Some(too_large);
1339 }
1340 }
1341 }
1342
1343 if got_notif && batch_response.is_empty() {
1344 None
1345 } else {
1346 let batch_rp = batch_response.finish();
1347 Some(MethodResponse::from_batch(batch_rp))
1348 }
1349 } else {
1350 Some(MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError)))
1351 }
1352 }
1353}