sc_rpc_server/middleware/
mod.rs1use std::{
22 num::NonZeroU32,
23 time::{Duration, Instant},
24};
25
26use futures::future::{BoxFuture, FutureExt};
27use governor::{clock::Clock, Jitter};
28use jsonrpsee::{
29 server::middleware::rpc::RpcServiceT,
30 types::{ErrorObject, Id, Request},
31 MethodResponse,
32};
33
34mod metrics;
35mod node_health;
36mod rate_limit;
37
38pub use metrics::*;
39pub use node_health::*;
40pub use rate_limit::*;
41
42const MAX_JITTER: Duration = Duration::from_millis(50);
43const MAX_RETRIES: usize = 10;
44
45#[derive(Debug, Clone, Default)]
47pub struct MiddlewareLayer {
48 rate_limit: Option<RateLimit>,
49 metrics: Option<Metrics>,
50}
51
52impl MiddlewareLayer {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self {
60 Self { rate_limit: Some(RateLimit::per_minute(n)), metrics: self.metrics }
61 }
62
63 pub fn with_metrics(self, metrics: Metrics) -> Self {
65 Self { rate_limit: self.rate_limit, metrics: Some(metrics) }
66 }
67
68 pub fn ws_connect(&self) {
70 self.metrics.as_ref().map(|m| m.ws_connect());
71 }
72
73 pub fn ws_disconnect(&self, now: Instant) {
75 self.metrics.as_ref().map(|m| m.ws_disconnect(now));
76 }
77}
78
79impl<S> tower::Layer<S> for MiddlewareLayer {
80 type Service = Middleware<S>;
81
82 fn layer(&self, service: S) -> Self::Service {
83 Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() }
84 }
85}
86
87pub struct Middleware<S> {
95 service: S,
96 rate_limit: Option<RateLimit>,
97 metrics: Option<Metrics>,
98}
99
100impl<'a, S> RpcServiceT<'a> for Middleware<S>
101where
102 S: Send + Sync + RpcServiceT<'a> + Clone + 'static,
103{
104 type Future = BoxFuture<'a, MethodResponse>;
105
106 fn call(&self, req: Request<'a>) -> Self::Future {
107 let now = Instant::now();
108
109 self.metrics.as_ref().map(|m| m.on_call(&req));
110
111 let service = self.service.clone();
112 let rate_limit = self.rate_limit.clone();
113 let metrics = self.metrics.clone();
114
115 async move {
116 let mut is_rate_limited = false;
117
118 if let Some(limit) = rate_limit.as_ref() {
119 let mut attempts = 0;
120 let jitter = Jitter::up_to(MAX_JITTER);
121
122 loop {
123 if attempts >= MAX_RETRIES {
124 return reject_too_many_calls(req.id);
125 }
126
127 if let Err(rejected) = limit.inner.check() {
128 tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now()))
129 .await;
130 } else {
131 break;
132 }
133
134 is_rate_limited = true;
135 attempts += 1;
136 }
137 }
138
139 let rp = service.call(req.clone()).await;
140 metrics.as_ref().map(|m| m.on_response(&req, &rp, is_rate_limited, now));
141
142 rp
143 }
144 .boxed()
145 }
146}
147
148fn reject_too_many_calls(id: Id) -> MethodResponse {
149 MethodResponse::error(id, ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>))
150}