jsonrpc_pubsub/
subscription.rs

1//! Subscription primitives.
2
3use parking_lot::Mutex;
4use std::collections::HashMap;
5use std::fmt;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use crate::core::futures::channel::mpsc;
10use crate::core::futures::{
11	self, future,
12	task::{Context, Poll},
13	Future, Sink as FuturesSink, TryFutureExt,
14};
15use crate::core::{self, BoxFuture};
16
17use crate::handler::{SubscribeRpcMethod, UnsubscribeRpcMethod};
18use crate::types::{PubSubMetadata, SinkResult, SubscriptionId, TransportError, TransportSender};
19
20lazy_static::lazy_static! {
21	static ref UNSUBSCRIBE_POOL: futures::executor::ThreadPool = futures::executor::ThreadPool::new()
22		.expect("Unable to spawn background pool for unsubscribe tasks.");
23}
24
25/// RPC client session
26/// Keeps track of active subscriptions and unsubscribes from them upon dropping.
27pub struct Session {
28	active_subscriptions: Mutex<HashMap<(SubscriptionId, String), Box<dyn Fn(SubscriptionId) + Send + 'static>>>,
29	transport: TransportSender,
30	on_drop: Mutex<Vec<Box<dyn FnMut() + Send>>>,
31}
32
33impl fmt::Debug for Session {
34	fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
35		fmt.debug_struct("pubsub::Session")
36			.field("active_subscriptions", &self.active_subscriptions.lock().len())
37			.field("transport", &self.transport)
38			.finish()
39	}
40}
41
42impl Session {
43	/// Creates new session given transport raw send capabilities.
44	/// Session should be created as part of metadata, `sender` should be returned by transport.
45	pub fn new(sender: TransportSender) -> Self {
46		Session {
47			active_subscriptions: Default::default(),
48			transport: sender,
49			on_drop: Default::default(),
50		}
51	}
52
53	/// Returns transport write stream
54	pub fn sender(&self) -> TransportSender {
55		self.transport.clone()
56	}
57
58	/// Adds a function to call when session is dropped.
59	pub fn on_drop<F: FnOnce() + Send + 'static>(&self, on_drop: F) {
60		let mut func = Some(on_drop);
61		self.on_drop.lock().push(Box::new(move || {
62			if let Some(f) = func.take() {
63				f();
64			}
65		}));
66	}
67
68	/// Adds new active subscription
69	fn add_subscription<F>(&self, name: &str, id: &SubscriptionId, remove: F)
70	where
71		F: Fn(SubscriptionId) + Send + 'static,
72	{
73		let ret = self
74			.active_subscriptions
75			.lock()
76			.insert((id.clone(), name.into()), Box::new(remove));
77		if let Some(remove) = ret {
78			warn!("SubscriptionId collision. Unsubscribing previous client.");
79			remove(id.clone());
80		}
81	}
82
83	/// Removes existing subscription.
84	fn remove_subscription(&self, name: &str, id: &SubscriptionId) -> bool {
85		self.active_subscriptions
86			.lock()
87			.remove(&(id.clone(), name.into()))
88			.is_some()
89	}
90}
91
92impl Drop for Session {
93	fn drop(&mut self) {
94		let mut active = self.active_subscriptions.lock();
95		for (id, remove) in active.drain() {
96			remove(id.0)
97		}
98
99		let mut on_drop = self.on_drop.lock();
100		for mut on_drop in on_drop.drain(..) {
101			on_drop();
102		}
103	}
104}
105
106/// A handle to send notifications directly to subscribed client.
107#[derive(Debug, Clone)]
108pub struct Sink {
109	notification: String,
110	transport: TransportSender,
111}
112
113impl Sink {
114	/// Sends a notification to a client.
115	pub fn notify(&self, val: core::Params) -> SinkResult {
116		let val = self.params_to_string(val);
117		self.transport.clone().unbounded_send(val)
118	}
119
120	fn params_to_string(&self, val: core::Params) -> String {
121		let notification = core::Notification {
122			jsonrpc: Some(core::Version::V2),
123			method: self.notification.clone(),
124			params: val,
125		};
126		core::to_string(&notification).expect("Notification serialization never fails.")
127	}
128}
129
130impl FuturesSink<core::Params> for Sink {
131	type Error = TransportError;
132
133	fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134		Pin::new(&mut self.transport).poll_ready(cx)
135	}
136
137	fn start_send(mut self: Pin<&mut Self>, item: core::Params) -> Result<(), Self::Error> {
138		let val = self.params_to_string(item);
139		Pin::new(&mut self.transport).start_send(val)
140	}
141
142	fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143		Pin::new(&mut self.transport).poll_flush(cx)
144	}
145
146	fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147		Pin::new(&mut self.transport).poll_close(cx)
148	}
149}
150
151/// Represents a subscribing client.
152/// Subscription handlers can either reject this subscription request or assign an unique id.
153#[derive(Debug)]
154pub struct Subscriber {
155	notification: String,
156	transport: TransportSender,
157	sender: crate::oneshot::Sender<Result<SubscriptionId, core::Error>>,
158}
159
160impl Subscriber {
161	/// Creates new subscriber.
162	///
163	/// Should only be used for tests.
164	pub fn new_test<T: Into<String>>(
165		method: T,
166	) -> (
167		Self,
168		crate::oneshot::Receiver<Result<SubscriptionId, core::Error>>,
169		mpsc::UnboundedReceiver<String>,
170	) {
171		let (sender, id_receiver) = crate::oneshot::channel();
172		let (transport, transport_receiver) = mpsc::unbounded();
173
174		let subscriber = Subscriber {
175			notification: method.into(),
176			transport,
177			sender,
178		};
179
180		(subscriber, id_receiver, transport_receiver)
181	}
182
183	/// Consumes `Subscriber` and assigns unique id to a requestor.
184	///
185	/// Returns `Err` if request has already terminated.
186	pub fn assign_id(self, id: SubscriptionId) -> Result<Sink, ()> {
187		let Self {
188			notification,
189			transport,
190			sender,
191		} = self;
192		sender
193			.send(Ok(id))
194			.map(|_| Sink {
195				notification,
196				transport,
197			})
198			.map_err(|_| ())
199	}
200
201	/// Consumes `Subscriber` and assigns unique id to a requestor.
202	///
203	/// The returned `Future` resolves when the subscriber receives subscription id.
204	/// Resolves to `Err` if request has already terminated.
205	pub fn assign_id_async(self, id: SubscriptionId) -> impl Future<Output = Result<Sink, ()>> {
206		let Self {
207			notification,
208			transport,
209			sender,
210		} = self;
211		sender.send_and_wait(Ok(id)).map_ok(|_| Sink {
212			notification,
213			transport,
214		})
215	}
216
217	/// Rejects this subscription request with given error.
218	///
219	/// Returns `Err` if request has already terminated.
220	pub fn reject(self, error: core::Error) -> Result<(), ()> {
221		self.sender.send(Err(error)).map_err(|_| ())
222	}
223
224	/// Rejects this subscription request with given error.
225	///
226	/// The returned `Future` resolves when the rejection is sent to the client.
227	/// Resolves to `Err` if request has already terminated.
228	pub fn reject_async(self, error: core::Error) -> impl Future<Output = Result<(), ()>> {
229		self.sender.send_and_wait(Err(error)).map_ok(|_| ()).map_err(|_| ())
230	}
231}
232
233/// Creates new subscribe and unsubscribe RPC methods
234pub fn new_subscription<M, F, G>(notification: &str, subscribe: F, unsubscribe: G) -> (Subscribe<F, G>, Unsubscribe<G>)
235where
236	M: PubSubMetadata,
237	F: SubscribeRpcMethod<M>,
238	G: UnsubscribeRpcMethod<M>,
239{
240	let unsubscribe = Arc::new(unsubscribe);
241	let subscribe = Subscribe {
242		notification: notification.to_owned(),
243		unsubscribe: unsubscribe.clone(),
244		subscribe,
245	};
246
247	let unsubscribe = Unsubscribe {
248		notification: notification.into(),
249		unsubscribe,
250	};
251
252	(subscribe, unsubscribe)
253}
254
255fn subscription_rejected() -> core::Error {
256	core::Error {
257		code: core::ErrorCode::ServerError(-32091),
258		message: "Subscription rejected".into(),
259		data: None,
260	}
261}
262
263fn subscriptions_unavailable() -> core::Error {
264	core::Error {
265		code: core::ErrorCode::ServerError(-32090),
266		message: "Subscriptions are not available on this transport.".into(),
267		data: None,
268	}
269}
270
271/// Subscribe RPC implementation.
272pub struct Subscribe<F, G> {
273	notification: String,
274	subscribe: F,
275	unsubscribe: Arc<G>,
276}
277
278impl<M, F, G> core::RpcMethod<M> for Subscribe<F, G>
279where
280	M: PubSubMetadata,
281	F: SubscribeRpcMethod<M>,
282	G: UnsubscribeRpcMethod<M>,
283{
284	fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Result<core::Value>> {
285		match meta.session() {
286			Some(session) => {
287				let (tx, rx) = crate::oneshot::channel();
288
289				// Register the subscription
290				let subscriber = Subscriber {
291					notification: self.notification.clone(),
292					transport: session.sender(),
293					sender: tx,
294				};
295				self.subscribe.call(params, meta, subscriber);
296
297				let unsub = self.unsubscribe.clone();
298				let notification = self.notification.clone();
299				let subscribe_future = rx.map_err(|_| subscription_rejected()).and_then(move |result| {
300					futures::future::ready(match result {
301						Ok(id) => {
302							session.add_subscription(&notification, &id, move |id| {
303								// TODO [#570] [ToDr] We currently run unsubscribe tasks on a shared thread pool.
304								// In the future we should use some kind of `::spawn` method
305								// that spawns a task on an existing executor or pass the spawner handle here.
306								let f = unsub.call(id, None);
307								UNSUBSCRIBE_POOL.spawn_ok(async move {
308									let _ = f.await;
309								});
310							});
311							Ok(id.into())
312						}
313						Err(e) => Err(e),
314					})
315				});
316				Box::pin(subscribe_future)
317			}
318			None => Box::pin(future::err(subscriptions_unavailable())),
319		}
320	}
321}
322
323/// Unsubscribe RPC implementation.
324pub struct Unsubscribe<G> {
325	notification: String,
326	unsubscribe: Arc<G>,
327}
328
329impl<M, G> core::RpcMethod<M> for Unsubscribe<G>
330where
331	M: PubSubMetadata,
332	G: UnsubscribeRpcMethod<M>,
333{
334	fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Result<core::Value>> {
335		let id = match params {
336			core::Params::Array(ref vec) if vec.len() == 1 => SubscriptionId::parse_value(&vec[0]),
337			_ => None,
338		};
339		match (meta.session(), id) {
340			(Some(session), Some(id)) => {
341				if session.remove_subscription(&self.notification, &id) {
342					Box::pin(self.unsubscribe.call(id, Some(meta)))
343				} else {
344					Box::pin(future::err(core::Error::invalid_params("Invalid subscription id.")))
345				}
346			}
347			(Some(_), None) => Box::pin(future::err(core::Error::invalid_params("Expected subscription id."))),
348			_ => Box::pin(future::err(subscriptions_unavailable())),
349		}
350	}
351}
352
353#[cfg(test)]
354mod tests {
355	use crate::core;
356	use crate::core::futures::channel::mpsc;
357	use crate::core::RpcMethod;
358	use crate::types::{PubSubMetadata, SubscriptionId};
359	use std::sync::atomic::{AtomicBool, Ordering};
360	use std::sync::Arc;
361
362	use super::{new_subscription, Session, Sink, Subscriber};
363
364	fn session() -> (Session, mpsc::UnboundedReceiver<String>) {
365		let (tx, rx) = mpsc::unbounded();
366		(Session::new(tx), rx)
367	}
368
369	#[test]
370	fn should_unregister_on_drop() {
371		// given
372		let id = SubscriptionId::Number(1);
373		let called = Arc::new(AtomicBool::new(false));
374		let called2 = called.clone();
375		let session = session().0;
376		session.add_subscription("test", &id, move |id| {
377			assert_eq!(id, SubscriptionId::Number(1));
378			called2.store(true, Ordering::SeqCst);
379		});
380
381		// when
382		drop(session);
383
384		// then
385		assert_eq!(called.load(Ordering::SeqCst), true);
386	}
387
388	#[test]
389	fn should_remove_subscription() {
390		// given
391		let id = SubscriptionId::Number(1);
392		let called = Arc::new(AtomicBool::new(false));
393		let called2 = called.clone();
394		let session = session().0;
395		session.add_subscription("test", &id, move |id| {
396			assert_eq!(id, SubscriptionId::Number(1));
397			called2.store(true, Ordering::SeqCst);
398		});
399
400		// when
401		let removed = session.remove_subscription("test", &id);
402		drop(session);
403
404		// then
405		assert_eq!(removed, true);
406		assert_eq!(called.load(Ordering::SeqCst), false);
407	}
408
409	#[test]
410	fn should_not_remove_subscription_if_invalid() {
411		// given
412		let id = SubscriptionId::Number(1);
413		let called = Arc::new(AtomicBool::new(false));
414		let called2 = called.clone();
415		let other_session = session().0;
416		let session = session().0;
417		session.add_subscription("test", &id, move |id| {
418			assert_eq!(id, SubscriptionId::Number(1));
419			called2.store(true, Ordering::SeqCst);
420		});
421
422		// when
423		let removed = other_session.remove_subscription("test", &id);
424		drop(session);
425
426		// then
427		assert_eq!(removed, false);
428		assert_eq!(called.load(Ordering::SeqCst), true);
429	}
430
431	#[test]
432	fn should_unregister_in_case_of_collision() {
433		// given
434		let id = SubscriptionId::Number(1);
435		let called = Arc::new(AtomicBool::new(false));
436		let called2 = called.clone();
437		let session = session().0;
438		session.add_subscription("test", &id, move |id| {
439			assert_eq!(id, SubscriptionId::Number(1));
440			called2.store(true, Ordering::SeqCst);
441		});
442
443		// when
444		session.add_subscription("test", &id, |_| {});
445
446		// then
447		assert_eq!(called.load(Ordering::SeqCst), true);
448	}
449
450	#[test]
451	fn should_send_notification_to_the_transport() {
452		// given
453		let (tx, mut rx) = mpsc::unbounded();
454		let sink = Sink {
455			notification: "test".into(),
456			transport: tx,
457		};
458
459		// when
460		sink.notify(core::Params::Array(vec![core::Value::Number(10.into())]))
461			.unwrap();
462
463		let val = rx.try_next().unwrap();
464		// then
465		assert_eq!(val, Some(r#"{"jsonrpc":"2.0","method":"test","params":[10]}"#.into()));
466	}
467
468	#[test]
469	fn should_assign_id() {
470		// given
471		let (transport, _) = mpsc::unbounded();
472		let (tx, rx) = crate::oneshot::channel();
473		let subscriber = Subscriber {
474			notification: "test".into(),
475			transport,
476			sender: tx,
477		};
478
479		// when
480		let sink = subscriber.assign_id_async(SubscriptionId::Number(5));
481
482		// then
483		futures::executor::block_on(async move {
484			let id = rx.await;
485			assert_eq!(id, Ok(Ok(SubscriptionId::Number(5))));
486			let sink = sink.await.unwrap();
487			assert_eq!(sink.notification, "test".to_owned());
488		})
489	}
490
491	#[test]
492	fn should_reject() {
493		// given
494		let (transport, _) = mpsc::unbounded();
495		let (tx, rx) = crate::oneshot::channel();
496		let subscriber = Subscriber {
497			notification: "test".into(),
498			transport,
499			sender: tx,
500		};
501		let error = core::Error {
502			code: core::ErrorCode::InvalidRequest,
503			message: "Cannot start subscription now.".into(),
504			data: None,
505		};
506
507		// when
508		let reject = subscriber.reject_async(error.clone());
509
510		// then
511		futures::executor::block_on(async move {
512			assert_eq!(rx.await.unwrap(), Err(error));
513			reject.await.unwrap();
514		});
515	}
516
517	#[derive(Clone)]
518	struct Metadata(Arc<Session>);
519	impl core::Metadata for Metadata {}
520	impl PubSubMetadata for Metadata {
521		fn session(&self) -> Option<Arc<Session>> {
522			Some(self.0.clone())
523		}
524	}
525	impl Default for Metadata {
526		fn default() -> Self {
527			Self(Arc::new(session().0))
528		}
529	}
530
531	#[test]
532	fn should_subscribe() {
533		// given
534		let (subscribe, _) = new_subscription(
535			"test".into(),
536			move |params, _meta, subscriber: Subscriber| {
537				assert_eq!(params, core::Params::None);
538				let _sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap();
539			},
540			|_id, _meta| async { Ok(core::Value::Bool(true)) },
541		);
542
543		// when
544		let meta = Metadata::default();
545		let result = subscribe.call(core::Params::None, meta);
546
547		// then
548		assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(5)));
549	}
550
551	#[test]
552	fn should_unsubscribe() {
553		// given
554		const SUB_ID: u64 = 5;
555		let (subscribe, unsubscribe) = new_subscription(
556			"test".into(),
557			move |params, _meta, subscriber: Subscriber| {
558				assert_eq!(params, core::Params::None);
559				let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap();
560			},
561			|_id, _meta| async { Ok(core::Value::Bool(true)) },
562		);
563
564		// when
565		let meta = Metadata::default();
566		futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap();
567		let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID)]), meta);
568
569		// then
570		assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(true)));
571	}
572
573	#[test]
574	fn should_not_unsubscribe_if_invalid() {
575		// given
576		const SUB_ID: u64 = 5;
577		let (subscribe, unsubscribe) = new_subscription(
578			"test".into(),
579			move |params, _meta, subscriber: Subscriber| {
580				assert_eq!(params, core::Params::None);
581				let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap();
582			},
583			|_id, _meta| async { Ok(core::Value::Bool(true)) },
584		);
585
586		// when
587		let meta = Metadata::default();
588		futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap();
589		let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID + 1)]), meta);
590
591		// then
592		assert_eq!(
593			futures::executor::block_on(result),
594			Err(core::Error {
595				code: core::ErrorCode::InvalidParams,
596				message: "Invalid subscription id.".into(),
597				data: None,
598			})
599		);
600	}
601}