1use parking_lot::Mutex;
4use std::collections::HashMap;
5use std::fmt;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use crate::core::futures::channel::mpsc;
10use crate::core::futures::{
11 self, future,
12 task::{Context, Poll},
13 Future, Sink as FuturesSink, TryFutureExt,
14};
15use crate::core::{self, BoxFuture};
16
17use crate::handler::{SubscribeRpcMethod, UnsubscribeRpcMethod};
18use crate::types::{PubSubMetadata, SinkResult, SubscriptionId, TransportError, TransportSender};
19
20lazy_static::lazy_static! {
21 static ref UNSUBSCRIBE_POOL: futures::executor::ThreadPool = futures::executor::ThreadPool::new()
22 .expect("Unable to spawn background pool for unsubscribe tasks.");
23}
24
25pub struct Session {
28 active_subscriptions: Mutex<HashMap<(SubscriptionId, String), Box<dyn Fn(SubscriptionId) + Send + 'static>>>,
29 transport: TransportSender,
30 on_drop: Mutex<Vec<Box<dyn FnMut() + Send>>>,
31}
32
33impl fmt::Debug for Session {
34 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
35 fmt.debug_struct("pubsub::Session")
36 .field("active_subscriptions", &self.active_subscriptions.lock().len())
37 .field("transport", &self.transport)
38 .finish()
39 }
40}
41
42impl Session {
43 pub fn new(sender: TransportSender) -> Self {
46 Session {
47 active_subscriptions: Default::default(),
48 transport: sender,
49 on_drop: Default::default(),
50 }
51 }
52
53 pub fn sender(&self) -> TransportSender {
55 self.transport.clone()
56 }
57
58 pub fn on_drop<F: FnOnce() + Send + 'static>(&self, on_drop: F) {
60 let mut func = Some(on_drop);
61 self.on_drop.lock().push(Box::new(move || {
62 if let Some(f) = func.take() {
63 f();
64 }
65 }));
66 }
67
68 fn add_subscription<F>(&self, name: &str, id: &SubscriptionId, remove: F)
70 where
71 F: Fn(SubscriptionId) + Send + 'static,
72 {
73 let ret = self
74 .active_subscriptions
75 .lock()
76 .insert((id.clone(), name.into()), Box::new(remove));
77 if let Some(remove) = ret {
78 warn!("SubscriptionId collision. Unsubscribing previous client.");
79 remove(id.clone());
80 }
81 }
82
83 fn remove_subscription(&self, name: &str, id: &SubscriptionId) -> bool {
85 self.active_subscriptions
86 .lock()
87 .remove(&(id.clone(), name.into()))
88 .is_some()
89 }
90}
91
92impl Drop for Session {
93 fn drop(&mut self) {
94 let mut active = self.active_subscriptions.lock();
95 for (id, remove) in active.drain() {
96 remove(id.0)
97 }
98
99 let mut on_drop = self.on_drop.lock();
100 for mut on_drop in on_drop.drain(..) {
101 on_drop();
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct Sink {
109 notification: String,
110 transport: TransportSender,
111}
112
113impl Sink {
114 pub fn notify(&self, val: core::Params) -> SinkResult {
116 let val = self.params_to_string(val);
117 self.transport.clone().unbounded_send(val)
118 }
119
120 fn params_to_string(&self, val: core::Params) -> String {
121 let notification = core::Notification {
122 jsonrpc: Some(core::Version::V2),
123 method: self.notification.clone(),
124 params: val,
125 };
126 core::to_string(¬ification).expect("Notification serialization never fails.")
127 }
128}
129
130impl FuturesSink<core::Params> for Sink {
131 type Error = TransportError;
132
133 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134 Pin::new(&mut self.transport).poll_ready(cx)
135 }
136
137 fn start_send(mut self: Pin<&mut Self>, item: core::Params) -> Result<(), Self::Error> {
138 let val = self.params_to_string(item);
139 Pin::new(&mut self.transport).start_send(val)
140 }
141
142 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 Pin::new(&mut self.transport).poll_flush(cx)
144 }
145
146 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147 Pin::new(&mut self.transport).poll_close(cx)
148 }
149}
150
151#[derive(Debug)]
154pub struct Subscriber {
155 notification: String,
156 transport: TransportSender,
157 sender: crate::oneshot::Sender<Result<SubscriptionId, core::Error>>,
158}
159
160impl Subscriber {
161 pub fn new_test<T: Into<String>>(
165 method: T,
166 ) -> (
167 Self,
168 crate::oneshot::Receiver<Result<SubscriptionId, core::Error>>,
169 mpsc::UnboundedReceiver<String>,
170 ) {
171 let (sender, id_receiver) = crate::oneshot::channel();
172 let (transport, transport_receiver) = mpsc::unbounded();
173
174 let subscriber = Subscriber {
175 notification: method.into(),
176 transport,
177 sender,
178 };
179
180 (subscriber, id_receiver, transport_receiver)
181 }
182
183 pub fn assign_id(self, id: SubscriptionId) -> Result<Sink, ()> {
187 let Self {
188 notification,
189 transport,
190 sender,
191 } = self;
192 sender
193 .send(Ok(id))
194 .map(|_| Sink {
195 notification,
196 transport,
197 })
198 .map_err(|_| ())
199 }
200
201 pub fn assign_id_async(self, id: SubscriptionId) -> impl Future<Output = Result<Sink, ()>> {
206 let Self {
207 notification,
208 transport,
209 sender,
210 } = self;
211 sender.send_and_wait(Ok(id)).map_ok(|_| Sink {
212 notification,
213 transport,
214 })
215 }
216
217 pub fn reject(self, error: core::Error) -> Result<(), ()> {
221 self.sender.send(Err(error)).map_err(|_| ())
222 }
223
224 pub fn reject_async(self, error: core::Error) -> impl Future<Output = Result<(), ()>> {
229 self.sender.send_and_wait(Err(error)).map_ok(|_| ()).map_err(|_| ())
230 }
231}
232
233pub fn new_subscription<M, F, G>(notification: &str, subscribe: F, unsubscribe: G) -> (Subscribe<F, G>, Unsubscribe<G>)
235where
236 M: PubSubMetadata,
237 F: SubscribeRpcMethod<M>,
238 G: UnsubscribeRpcMethod<M>,
239{
240 let unsubscribe = Arc::new(unsubscribe);
241 let subscribe = Subscribe {
242 notification: notification.to_owned(),
243 unsubscribe: unsubscribe.clone(),
244 subscribe,
245 };
246
247 let unsubscribe = Unsubscribe {
248 notification: notification.into(),
249 unsubscribe,
250 };
251
252 (subscribe, unsubscribe)
253}
254
255fn subscription_rejected() -> core::Error {
256 core::Error {
257 code: core::ErrorCode::ServerError(-32091),
258 message: "Subscription rejected".into(),
259 data: None,
260 }
261}
262
263fn subscriptions_unavailable() -> core::Error {
264 core::Error {
265 code: core::ErrorCode::ServerError(-32090),
266 message: "Subscriptions are not available on this transport.".into(),
267 data: None,
268 }
269}
270
271pub struct Subscribe<F, G> {
273 notification: String,
274 subscribe: F,
275 unsubscribe: Arc<G>,
276}
277
278impl<M, F, G> core::RpcMethod<M> for Subscribe<F, G>
279where
280 M: PubSubMetadata,
281 F: SubscribeRpcMethod<M>,
282 G: UnsubscribeRpcMethod<M>,
283{
284 fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Result<core::Value>> {
285 match meta.session() {
286 Some(session) => {
287 let (tx, rx) = crate::oneshot::channel();
288
289 let subscriber = Subscriber {
291 notification: self.notification.clone(),
292 transport: session.sender(),
293 sender: tx,
294 };
295 self.subscribe.call(params, meta, subscriber);
296
297 let unsub = self.unsubscribe.clone();
298 let notification = self.notification.clone();
299 let subscribe_future = rx.map_err(|_| subscription_rejected()).and_then(move |result| {
300 futures::future::ready(match result {
301 Ok(id) => {
302 session.add_subscription(¬ification, &id, move |id| {
303 let f = unsub.call(id, None);
307 UNSUBSCRIBE_POOL.spawn_ok(async move {
308 let _ = f.await;
309 });
310 });
311 Ok(id.into())
312 }
313 Err(e) => Err(e),
314 })
315 });
316 Box::pin(subscribe_future)
317 }
318 None => Box::pin(future::err(subscriptions_unavailable())),
319 }
320 }
321}
322
323pub struct Unsubscribe<G> {
325 notification: String,
326 unsubscribe: Arc<G>,
327}
328
329impl<M, G> core::RpcMethod<M> for Unsubscribe<G>
330where
331 M: PubSubMetadata,
332 G: UnsubscribeRpcMethod<M>,
333{
334 fn call(&self, params: core::Params, meta: M) -> BoxFuture<core::Result<core::Value>> {
335 let id = match params {
336 core::Params::Array(ref vec) if vec.len() == 1 => SubscriptionId::parse_value(&vec[0]),
337 _ => None,
338 };
339 match (meta.session(), id) {
340 (Some(session), Some(id)) => {
341 if session.remove_subscription(&self.notification, &id) {
342 Box::pin(self.unsubscribe.call(id, Some(meta)))
343 } else {
344 Box::pin(future::err(core::Error::invalid_params("Invalid subscription id.")))
345 }
346 }
347 (Some(_), None) => Box::pin(future::err(core::Error::invalid_params("Expected subscription id."))),
348 _ => Box::pin(future::err(subscriptions_unavailable())),
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use crate::core;
356 use crate::core::futures::channel::mpsc;
357 use crate::core::RpcMethod;
358 use crate::types::{PubSubMetadata, SubscriptionId};
359 use std::sync::atomic::{AtomicBool, Ordering};
360 use std::sync::Arc;
361
362 use super::{new_subscription, Session, Sink, Subscriber};
363
364 fn session() -> (Session, mpsc::UnboundedReceiver<String>) {
365 let (tx, rx) = mpsc::unbounded();
366 (Session::new(tx), rx)
367 }
368
369 #[test]
370 fn should_unregister_on_drop() {
371 let id = SubscriptionId::Number(1);
373 let called = Arc::new(AtomicBool::new(false));
374 let called2 = called.clone();
375 let session = session().0;
376 session.add_subscription("test", &id, move |id| {
377 assert_eq!(id, SubscriptionId::Number(1));
378 called2.store(true, Ordering::SeqCst);
379 });
380
381 drop(session);
383
384 assert_eq!(called.load(Ordering::SeqCst), true);
386 }
387
388 #[test]
389 fn should_remove_subscription() {
390 let id = SubscriptionId::Number(1);
392 let called = Arc::new(AtomicBool::new(false));
393 let called2 = called.clone();
394 let session = session().0;
395 session.add_subscription("test", &id, move |id| {
396 assert_eq!(id, SubscriptionId::Number(1));
397 called2.store(true, Ordering::SeqCst);
398 });
399
400 let removed = session.remove_subscription("test", &id);
402 drop(session);
403
404 assert_eq!(removed, true);
406 assert_eq!(called.load(Ordering::SeqCst), false);
407 }
408
409 #[test]
410 fn should_not_remove_subscription_if_invalid() {
411 let id = SubscriptionId::Number(1);
413 let called = Arc::new(AtomicBool::new(false));
414 let called2 = called.clone();
415 let other_session = session().0;
416 let session = session().0;
417 session.add_subscription("test", &id, move |id| {
418 assert_eq!(id, SubscriptionId::Number(1));
419 called2.store(true, Ordering::SeqCst);
420 });
421
422 let removed = other_session.remove_subscription("test", &id);
424 drop(session);
425
426 assert_eq!(removed, false);
428 assert_eq!(called.load(Ordering::SeqCst), true);
429 }
430
431 #[test]
432 fn should_unregister_in_case_of_collision() {
433 let id = SubscriptionId::Number(1);
435 let called = Arc::new(AtomicBool::new(false));
436 let called2 = called.clone();
437 let session = session().0;
438 session.add_subscription("test", &id, move |id| {
439 assert_eq!(id, SubscriptionId::Number(1));
440 called2.store(true, Ordering::SeqCst);
441 });
442
443 session.add_subscription("test", &id, |_| {});
445
446 assert_eq!(called.load(Ordering::SeqCst), true);
448 }
449
450 #[test]
451 fn should_send_notification_to_the_transport() {
452 let (tx, mut rx) = mpsc::unbounded();
454 let sink = Sink {
455 notification: "test".into(),
456 transport: tx,
457 };
458
459 sink.notify(core::Params::Array(vec![core::Value::Number(10.into())]))
461 .unwrap();
462
463 let val = rx.try_next().unwrap();
464 assert_eq!(val, Some(r#"{"jsonrpc":"2.0","method":"test","params":[10]}"#.into()));
466 }
467
468 #[test]
469 fn should_assign_id() {
470 let (transport, _) = mpsc::unbounded();
472 let (tx, rx) = crate::oneshot::channel();
473 let subscriber = Subscriber {
474 notification: "test".into(),
475 transport,
476 sender: tx,
477 };
478
479 let sink = subscriber.assign_id_async(SubscriptionId::Number(5));
481
482 futures::executor::block_on(async move {
484 let id = rx.await;
485 assert_eq!(id, Ok(Ok(SubscriptionId::Number(5))));
486 let sink = sink.await.unwrap();
487 assert_eq!(sink.notification, "test".to_owned());
488 })
489 }
490
491 #[test]
492 fn should_reject() {
493 let (transport, _) = mpsc::unbounded();
495 let (tx, rx) = crate::oneshot::channel();
496 let subscriber = Subscriber {
497 notification: "test".into(),
498 transport,
499 sender: tx,
500 };
501 let error = core::Error {
502 code: core::ErrorCode::InvalidRequest,
503 message: "Cannot start subscription now.".into(),
504 data: None,
505 };
506
507 let reject = subscriber.reject_async(error.clone());
509
510 futures::executor::block_on(async move {
512 assert_eq!(rx.await.unwrap(), Err(error));
513 reject.await.unwrap();
514 });
515 }
516
517 #[derive(Clone)]
518 struct Metadata(Arc<Session>);
519 impl core::Metadata for Metadata {}
520 impl PubSubMetadata for Metadata {
521 fn session(&self) -> Option<Arc<Session>> {
522 Some(self.0.clone())
523 }
524 }
525 impl Default for Metadata {
526 fn default() -> Self {
527 Self(Arc::new(session().0))
528 }
529 }
530
531 #[test]
532 fn should_subscribe() {
533 let (subscribe, _) = new_subscription(
535 "test".into(),
536 move |params, _meta, subscriber: Subscriber| {
537 assert_eq!(params, core::Params::None);
538 let _sink = subscriber.assign_id(SubscriptionId::Number(5)).unwrap();
539 },
540 |_id, _meta| async { Ok(core::Value::Bool(true)) },
541 );
542
543 let meta = Metadata::default();
545 let result = subscribe.call(core::Params::None, meta);
546
547 assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(5)));
549 }
550
551 #[test]
552 fn should_unsubscribe() {
553 const SUB_ID: u64 = 5;
555 let (subscribe, unsubscribe) = new_subscription(
556 "test".into(),
557 move |params, _meta, subscriber: Subscriber| {
558 assert_eq!(params, core::Params::None);
559 let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap();
560 },
561 |_id, _meta| async { Ok(core::Value::Bool(true)) },
562 );
563
564 let meta = Metadata::default();
566 futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap();
567 let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID)]), meta);
568
569 assert_eq!(futures::executor::block_on(result), Ok(serde_json::json!(true)));
571 }
572
573 #[test]
574 fn should_not_unsubscribe_if_invalid() {
575 const SUB_ID: u64 = 5;
577 let (subscribe, unsubscribe) = new_subscription(
578 "test".into(),
579 move |params, _meta, subscriber: Subscriber| {
580 assert_eq!(params, core::Params::None);
581 let _sink = subscriber.assign_id(SubscriptionId::Number(SUB_ID)).unwrap();
582 },
583 |_id, _meta| async { Ok(core::Value::Bool(true)) },
584 );
585
586 let meta = Metadata::default();
588 futures::executor::block_on(subscribe.call(core::Params::None, meta.clone())).unwrap();
589 let result = unsubscribe.call(core::Params::Array(vec![serde_json::json!(SUB_ID + 1)]), meta);
590
591 assert_eq!(
593 futures::executor::block_on(result),
594 Err(core::Error {
595 code: core::ErrorCode::InvalidParams,
596 message: "Invalid subscription id.".into(),
597 data: None,
598 })
599 );
600 }
601}