1use std::collections::hash_map::Entry;
28use std::fmt::{self, Debug};
29use std::future::Future;
30use std::ops::{Deref, DerefMut};
31use std::sync::Arc;
32
33use crate::error::RegisterMethodError;
34use crate::id_providers::RandomIntegerIdProvider;
35use crate::server::helpers::MethodSink;
36use crate::server::method_response::MethodResponse;
37use crate::server::subscription::{
38 sub_message_to_json, BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink,
39 SubNotifResultOrError, Subscribers, Subscription, SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit,
40 SubscriptionState,
41};
42use crate::server::{ResponsePayload, LOG_TARGET};
43use crate::traits::ToRpcParams;
44use futures_util::{future::BoxFuture, FutureExt};
45use http::Extensions;
46use jsonrpsee_types::error::{ErrorCode, ErrorObject};
47use jsonrpsee_types::{
48 ErrorObjectOwned, Id, Params, Request, Response, ResponseSuccess, SubscriptionId as RpcSubscriptionId,
49};
50use rustc_hash::FxHashMap;
51use serde::de::DeserializeOwned;
52use tokio::sync::{mpsc, oneshot};
53
54use super::IntoResponse;
55
56pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, MaxResponseSize, Extensions) -> MethodResponse>;
61pub type AsyncMethod<'a> = Arc<
63 dyn Send
64 + Sync
65 + Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Extensions) -> BoxFuture<'a, MethodResponse>,
66>;
67
68pub type SubscriptionMethod<'a> =
70 Arc<dyn Send + Sync + Fn(Id, Params, MethodSink, SubscriptionState, Extensions) -> BoxFuture<'a, MethodResponse>>;
71type UnsubscriptionMethod =
73 Arc<dyn Send + Sync + Fn(Id, Params, ConnectionId, MaxResponseSize, Extensions) -> MethodResponse>;
74
75#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, serde::Deserialize, serde::Serialize)]
77pub struct ConnectionId(pub usize);
78
79impl From<u32> for ConnectionId {
80 fn from(id: u32) -> Self {
81 Self(id as usize)
82 }
83}
84
85impl From<usize> for ConnectionId {
86 fn from(id: usize) -> Self {
87 Self(id)
88 }
89}
90
91pub type MaxResponseSize = usize;
93
94pub type RawRpcResponse = (String, mpsc::Receiver<String>);
99
100#[derive(thiserror::Error, Debug)]
102pub enum MethodsError {
103 #[error(transparent)]
105 Parse(#[from] serde_json::Error),
106 #[error(transparent)]
108 JsonRpc(#[from] ErrorObjectOwned),
109 #[error("Invalid subscription ID: `{0}`")]
110 InvalidSubscriptionId(String),
112}
113
114#[derive(Debug)]
119pub enum CallOrSubscription {
120 Subscription(MethodResponse),
123 Call(MethodResponse),
125}
126
127impl CallOrSubscription {
128 pub fn as_response(&self) -> &MethodResponse {
130 match &self {
131 Self::Subscription(r) => r,
132 Self::Call(r) => r,
133 }
134 }
135
136 pub fn into_response(self) -> MethodResponse {
138 match self {
139 Self::Subscription(r) => r,
140 Self::Call(r) => r,
141 }
142 }
143}
144
145#[derive(Clone)]
147pub enum MethodCallback {
148 Sync(SyncMethod),
150 Async(AsyncMethod<'static>),
152 Subscription(SubscriptionMethod<'static>),
154 Unsubscription(UnsubscriptionMethod),
156}
157
158#[derive(Debug, Copy, Clone)]
160pub enum MethodKind {
161 Subscription,
163 Unsubscription,
165 MethodCall,
167 NotFound,
169}
170
171impl std::fmt::Display for MethodKind {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 let s = match self {
174 Self::Subscription => "subscription",
175 Self::MethodCall => "method call",
176 Self::NotFound => "method not found",
177 Self::Unsubscription => "unsubscription",
178 };
179
180 write!(f, "{s}")
181 }
182}
183
184pub enum MethodResult<T> {
186 Sync(T),
188 Async(BoxFuture<'static, T>),
190}
191
192impl<T: Debug> Debug for MethodResult<T> {
193 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
194 match self {
195 MethodResult::Sync(result) => result.fmt(f),
196 MethodResult::Async(_) => f.write_str("<future>"),
197 }
198 }
199}
200
201impl Debug for MethodCallback {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 match self {
204 Self::Async(_) => write!(f, "Async"),
205 Self::Sync(_) => write!(f, "Sync"),
206 Self::Subscription(_) => write!(f, "Subscription"),
207 Self::Unsubscription(_) => write!(f, "Unsubscription"),
208 }
209 }
210}
211
212#[derive(Default, Debug, Clone)]
214pub struct Methods {
215 callbacks: Arc<FxHashMap<&'static str, MethodCallback>>,
216 extensions: Extensions,
217}
218
219impl Methods {
220 pub fn new() -> Self {
222 Self::default()
223 }
224
225 pub fn verify_method_name(&mut self, name: &'static str) -> Result<(), RegisterMethodError> {
227 if self.callbacks.contains_key(name) {
228 return Err(RegisterMethodError::AlreadyRegistered(name.into()));
229 }
230
231 Ok(())
232 }
233
234 pub fn verify_and_insert(
237 &mut self,
238 name: &'static str,
239 callback: MethodCallback,
240 ) -> Result<&mut MethodCallback, RegisterMethodError> {
241 match self.mut_callbacks().entry(name) {
242 Entry::Occupied(_) => Err(RegisterMethodError::AlreadyRegistered(name.into())),
243 Entry::Vacant(vacant) => Ok(vacant.insert(callback)),
244 }
245 }
246
247 fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> {
249 Arc::make_mut(&mut self.callbacks)
250 }
251
252 pub fn merge(&mut self, other: impl Into<Methods>) -> Result<(), RegisterMethodError> {
255 let mut other = other.into();
256
257 for name in other.callbacks.keys() {
258 self.verify_method_name(name)?;
259 }
260
261 let callbacks = self.mut_callbacks();
262
263 for (name, callback) in other.mut_callbacks().drain() {
264 callbacks.insert(name, callback);
265 }
266
267 Ok(())
268 }
269
270 pub fn method(&self, method_name: &str) -> Option<&MethodCallback> {
272 self.callbacks.get(method_name)
273 }
274
275 pub fn method_with_name(&self, method_name: &str) -> Option<(&'static str, &MethodCallback)> {
278 self.callbacks.get_key_value(method_name).map(|(k, v)| (*k, v))
279 }
280
281 pub async fn call<Params: ToRpcParams, T: DeserializeOwned + Clone>(
305 &self,
306 method: &str,
307 params: Params,
308 ) -> Result<T, MethodsError> {
309 let params = params.to_rpc_params()?;
310 let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0));
311 tracing::trace!(target: LOG_TARGET, "[Methods::call] Method: {:?}, params: {:?}", method, params);
312 let (rp, _) = self.inner_call(req, 1, mock_subscription_permit()).await;
313
314 let rp = serde_json::from_str::<Response<T>>(&rp)?;
315 ResponseSuccess::try_from(rp).map(|s| s.result).map_err(|e| MethodsError::JsonRpc(e.into_owned()))
316 }
317
318 pub async fn raw_json_request(
351 &self,
352 request: &str,
353 buf_size: usize,
354 ) -> Result<(String, mpsc::Receiver<String>), serde_json::Error> {
355 tracing::trace!("[Methods::raw_json_request] Request: {:?}", request);
356 let req: Request = serde_json::from_str(request)?;
357 let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
358
359 Ok((resp, rx))
360 }
361
362 async fn inner_call(
364 &self,
365 req: Request<'_>,
366 buf_size: usize,
367 subscription_permit: SubscriptionPermit,
368 ) -> RawRpcResponse {
369 let (tx, mut rx) = mpsc::channel(buf_size);
370 let Request { id, method, params, .. } = req;
373 let params = Params::new(params.as_ref().map(|params| params.as_ref().get()));
374 let max_response_size = usize::MAX;
375 let conn_id = ConnectionId(0);
376 let mut ext = self.extensions.clone();
377 ext.insert(conn_id);
378
379 let response = match self.method(&method) {
380 None => MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)),
381 Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, ext),
382 Some(MethodCallback::Async(cb)) => {
383 (cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, ext).await
384 }
385 Some(MethodCallback::Subscription(cb)) => {
386 let conn_state =
387 SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit };
388 let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, ext).await;
389
390 let _ = rx.recv().await.expect("Every call must at least produce one response; qed");
395
396 res
397 }
398 Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, ext),
399 };
400
401 let is_success = response.is_success();
402 let (rp, notif) = response.into_parts();
403
404 if let Some(n) = notif {
405 n.notify(is_success);
406 }
407
408 tracing::trace!(target: LOG_TARGET, "[Methods::inner_call] Method: {}, response: {}", method, rp);
409
410 (rp, rx)
411 }
412
413 pub async fn subscribe_unbounded(
441 &self,
442 sub_method: &str,
443 params: impl ToRpcParams,
444 ) -> Result<Subscription, MethodsError> {
445 self.subscribe(sub_method, params, u32::MAX as usize).await
446 }
447
448 pub async fn subscribe(
452 &self,
453 sub_method: &str,
454 params: impl ToRpcParams,
455 buf_size: usize,
456 ) -> Result<Subscription, MethodsError> {
457 let params = params.to_rpc_params()?;
458 let req = Request::new(sub_method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0));
459
460 tracing::trace!(target: LOG_TARGET, "[Methods::subscribe] Method: {}, params: {:?}", sub_method, params);
461
462 let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;
463
464 let as_success: ResponseSuccess<serde_json::Value> = serde_json::from_str::<Response<_>>(&resp)?.try_into()?;
466
467 let sub_id = as_success.result.try_into().map_err(|_| MethodsError::InvalidSubscriptionId(resp.clone()))?;
468
469 Ok(Subscription { sub_id, rx })
470 }
471
472 pub fn method_names(&self) -> impl Iterator<Item = &'static str> + '_ {
474 self.callbacks.keys().copied()
475 }
476
477 pub fn extensions(&mut self) -> &Extensions {
479 &self.extensions
480 }
481
482 pub fn extensions_mut(&mut self) -> &mut Extensions {
511 &mut self.extensions
512 }
513}
514
515impl<Context> Deref for RpcModule<Context> {
516 type Target = Methods;
517
518 fn deref(&self) -> &Methods {
519 &self.methods
520 }
521}
522
523impl<Context> DerefMut for RpcModule<Context> {
524 fn deref_mut(&mut self) -> &mut Methods {
525 &mut self.methods
526 }
527}
528
529#[derive(Debug, Clone)]
533pub struct RpcModule<Context> {
534 ctx: Arc<Context>,
535 methods: Methods,
536}
537
538impl<Context> RpcModule<Context> {
539 pub fn new(ctx: Context) -> Self {
541 Self::from_arc(Arc::new(ctx))
542 }
543
544 pub fn from_arc(ctx: Arc<Context>) -> Self {
548 Self { ctx, methods: Default::default() }
549 }
550
551 pub fn remove_context(self) -> RpcModule<()> {
553 let mut module = RpcModule::new(());
554 module.methods = self.methods;
555 module
556 }
557}
558
559impl<Context> From<RpcModule<Context>> for Methods {
560 fn from(module: RpcModule<Context>) -> Methods {
561 module.methods
562 }
563}
564
565impl<Context: Send + Sync + 'static> RpcModule<Context> {
566 pub fn register_method<R, F>(
577 &mut self,
578 method_name: &'static str,
579 callback: F,
580 ) -> Result<&mut MethodCallback, RegisterMethodError>
581 where
582 Context: Send + Sync + 'static,
583 R: IntoResponse + 'static,
584 F: Fn(Params, &Context, &Extensions) -> R + Send + Sync + 'static,
585 {
586 let ctx = self.ctx.clone();
587 self.methods.verify_and_insert(
588 method_name,
589 MethodCallback::Sync(Arc::new(move |id, params, max_response_size, extensions| {
590 let rp = callback(params, &*ctx, &extensions).into_response();
591 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
592 })),
593 )
594 }
595
596 pub fn remove_method(&mut self, method_name: &'static str) -> Option<MethodCallback> {
601 self.methods.mut_callbacks().remove(method_name)
602 }
603
604 pub fn register_async_method<R, Fun, Fut>(
617 &mut self,
618 method_name: &'static str,
619 callback: Fun,
620 ) -> Result<&mut MethodCallback, RegisterMethodError>
621 where
622 R: IntoResponse + 'static,
623 Fut: Future<Output = R> + Send,
624 Fun: (Fn(Params<'static>, Arc<Context>, Extensions) -> Fut) + Clone + Send + Sync + 'static,
625 {
626 let ctx = self.ctx.clone();
627 self.methods.verify_and_insert(
628 method_name,
629 MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
630 let ctx = ctx.clone();
631 let callback = callback.clone();
632
633 let future = async move {
636 let rp = callback(params, ctx, extensions.clone()).await.into_response();
637 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions)
638 };
639 future.boxed()
640 })),
641 )
642 }
643
644 pub fn register_blocking_method<R, F>(
648 &mut self,
649 method_name: &'static str,
650 callback: F,
651 ) -> Result<&mut MethodCallback, RegisterMethodError>
652 where
653 Context: Send + Sync + 'static,
654 R: IntoResponse + 'static,
655 F: Fn(Params, Arc<Context>, Extensions) -> R + Clone + Send + Sync + 'static,
656 {
657 let ctx = self.ctx.clone();
658 let callback = self.methods.verify_and_insert(
659 method_name,
660 MethodCallback::Async(Arc::new(move |id, params, _, max_response_size, extensions| {
661 let ctx = ctx.clone();
662 let callback = callback.clone();
663
664 let extensions2 = extensions.clone();
667
668 tokio::task::spawn_blocking(move || {
669 let rp = callback(params, ctx, extensions2.clone()).into_response();
670 MethodResponse::response(id, rp, max_response_size).with_extensions(extensions2)
671 })
672 .map(|result| match result {
673 Ok(r) => r,
674 Err(err) => {
675 tracing::error!(target: LOG_TARGET, "Join error for blocking RPC method: {:?}", err);
676 MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::InternalError))
677 .with_extensions(extensions)
678 }
679 })
680 .boxed()
681 })),
682 )?;
683
684 Ok(callback)
685 }
686
687 pub fn register_subscription<R, F, Fut>(
782 &mut self,
783 subscribe_method_name: &'static str,
784 notif_method_name: &'static str,
785 unsubscribe_method_name: &'static str,
786 callback: F,
787 ) -> Result<&mut MethodCallback, RegisterMethodError>
788 where
789 Context: Send + Sync + 'static,
790 F: (Fn(Params<'static>, PendingSubscriptionSink, Arc<Context>, Extensions) -> Fut)
791 + Send
792 + Sync
793 + Clone
794 + 'static,
795 Fut: Future<Output = R> + Send + 'static,
796 R: IntoSubscriptionCloseResponse + Send,
797 {
798 let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
799 let ctx = self.ctx.clone();
800
801 let callback = {
803 self.methods.verify_and_insert(
804 subscribe_method_name,
805 MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
806 let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
807
808 let (tx, rx) = oneshot::channel();
810 let (accepted_tx, accepted_rx) = oneshot::channel();
811
812 let sub_id = uniq_sub.sub_id.clone();
813 let method = notif_method_name;
814
815 let sink = PendingSubscriptionSink {
816 inner: method_sink.clone(),
817 method: notif_method_name,
818 subscribers: subscribers.clone(),
819 uniq_sub,
820 id: id.clone().into_owned(),
821 subscribe: tx,
822 permit: conn.subscription_permit,
823 };
824
825 let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone());
833
834 tokio::spawn(async move {
835 let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await {
837 Ok((r, _)) => r.into_response(),
838 Err(_) => return,
840 };
841
842 match response {
843 SubscriptionCloseResponse::Notif(msg) => {
844 let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &sub_id, method);
845 let _ = method_sink.send(json).await;
846 }
847 SubscriptionCloseResponse::NotifErr(msg) => {
848 let json = sub_message_to_json(msg, SubNotifResultOrError::Error, &sub_id, method);
849 let _ = method_sink.send(json).await;
850 }
851 SubscriptionCloseResponse::None => (),
852 }
853 });
854
855 let id = id.clone().into_owned();
856
857 Box::pin(async move {
858 let rp = match rx.await {
859 Ok(rp) => {
860 if rp.is_success() {
863 let _ = accepted_tx.send(());
864 }
865 rp
866 }
867 Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
868 };
869
870 rp.with_extensions(extensions)
871 })
872 })),
873 )?
874 };
875
876 Ok(callback)
877 }
878
879 pub fn register_subscription_raw<R, F>(
925 &mut self,
926 subscribe_method_name: &'static str,
927 notif_method_name: &'static str,
928 unsubscribe_method_name: &'static str,
929 callback: F,
930 ) -> Result<&mut MethodCallback, RegisterMethodError>
931 where
932 Context: Send + Sync + 'static,
933 F: (Fn(Params, PendingSubscriptionSink, Arc<Context>, &Extensions) -> R) + Send + Sync + Clone + 'static,
934 R: IntoSubscriptionCloseResponse,
935 {
936 let subscribers = self.verify_and_register_unsubscribe(subscribe_method_name, unsubscribe_method_name)?;
937 let ctx = self.ctx.clone();
938
939 let callback = {
941 self.methods.verify_and_insert(
942 subscribe_method_name,
943 MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| {
944 let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() };
945
946 let (tx, rx) = oneshot::channel();
948
949 let sink = PendingSubscriptionSink {
950 inner: method_sink.clone(),
951 method: notif_method_name,
952 subscribers: subscribers.clone(),
953 uniq_sub,
954 id: id.clone().into_owned(),
955 subscribe: tx,
956 permit: conn.subscription_permit,
957 };
958
959 callback(params, sink, ctx.clone(), &extensions);
960
961 let id = id.clone().into_owned();
962
963 Box::pin(async move {
964 let rp = match rx.await {
965 Ok(rp) => rp,
966 Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
967 };
968
969 rp.with_extensions(extensions)
970 })
971 })),
972 )?
973 };
974
975 Ok(callback)
976 }
977
978 fn verify_and_register_unsubscribe(
981 &mut self,
982 subscribe_method_name: &'static str,
983 unsubscribe_method_name: &'static str,
984 ) -> Result<Subscribers, RegisterMethodError> {
985 if subscribe_method_name == unsubscribe_method_name {
986 return Err(RegisterMethodError::SubscriptionNameConflict(subscribe_method_name.into()));
987 }
988
989 self.methods.verify_method_name(subscribe_method_name)?;
990 self.methods.verify_method_name(unsubscribe_method_name)?;
991
992 let subscribers = Subscribers::default();
993
994 {
996 let subscribers = subscribers.clone();
997 self.methods.mut_callbacks().insert(
998 unsubscribe_method_name,
999 MethodCallback::Unsubscription(Arc::new(move |id, params, conn_id, max_response_size, extensions| {
1000 let sub_id = match params.one::<RpcSubscriptionId>() {
1001 Ok(sub_id) => sub_id,
1002 Err(_) => {
1003 tracing::warn!(
1004 target: LOG_TARGET,
1005 "Unsubscribe call `{}` failed: couldn't parse subscription id={:?} request id={:?}",
1006 unsubscribe_method_name,
1007 params,
1008 id
1009 );
1010
1011 return MethodResponse::response(id, ResponsePayload::success(false), max_response_size)
1012 .with_extensions(extensions);
1013 }
1014 };
1015
1016 let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() };
1017 let result = subscribers.lock().remove(&key).is_some();
1018
1019 if !result {
1020 tracing::debug!(
1021 target: LOG_TARGET,
1022 "Unsubscribe call `{}` subscription key={:?} not an active subscription",
1023 unsubscribe_method_name,
1024 key,
1025 );
1026 }
1027
1028 MethodResponse::response(id, ResponsePayload::success(result), max_response_size)
1029 })),
1030 );
1031 }
1032
1033 Ok(subscribers)
1034 }
1035
1036 pub fn register_alias(
1038 &mut self,
1039 alias: &'static str,
1040 existing_method: &'static str,
1041 ) -> Result<(), RegisterMethodError> {
1042 self.methods.verify_method_name(alias)?;
1043
1044 let callback = match self.methods.callbacks.get(existing_method) {
1045 Some(callback) => callback.clone(),
1046 None => return Err(RegisterMethodError::MethodNotFound(existing_method.into())),
1047 };
1048
1049 self.methods.mut_callbacks().insert(alias, callback);
1050
1051 Ok(())
1052 }
1053}
1054
1055fn mock_subscription_permit() -> SubscriptionPermit {
1056 BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed")
1057}