safecoin_client/nonblocking/
pubsub_client.rs

1use {
2    crate::{
3        http_sender::RpcErrorObject,
4        rpc_config::{
5            RpcAccountInfoConfig, RpcBlockSubscribeConfig, RpcBlockSubscribeFilter,
6            RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
7            RpcTransactionLogsFilter,
8        },
9        rpc_filter::maybe_map_filters,
10        rpc_response::{
11            Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
12            RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
13        },
14    },
15    futures_util::{
16        future::{ready, BoxFuture, FutureExt},
17        sink::SinkExt,
18        stream::{BoxStream, StreamExt},
19    },
20    log::*,
21    serde::de::DeserializeOwned,
22    serde_json::{json, Map, Value},
23    safecoin_account_decoder::UiAccount,
24    solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature},
25    std::collections::BTreeMap,
26    thiserror::Error,
27    tokio::{
28        net::TcpStream,
29        sync::{mpsc, oneshot, RwLock},
30        task::JoinHandle,
31        time::{sleep, Duration},
32    },
33    tokio_stream::wrappers::UnboundedReceiverStream,
34    tokio_tungstenite::{
35        connect_async,
36        tungstenite::{
37            protocol::frame::{coding::CloseCode, CloseFrame},
38            Message,
39        },
40        MaybeTlsStream, WebSocketStream,
41    },
42    url::Url,
43};
44
45pub type PubsubClientResult<T = ()> = Result<T, PubsubClientError>;
46
47#[derive(Debug, Error)]
48pub enum PubsubClientError {
49    #[error("url parse error")]
50    UrlParseError(#[from] url::ParseError),
51
52    #[error("unable to connect to server")]
53    ConnectionError(tokio_tungstenite::tungstenite::Error),
54
55    #[error("websocket error")]
56    WsError(#[from] tokio_tungstenite::tungstenite::Error),
57
58    #[error("connection closed (({0})")]
59    ConnectionClosed(String),
60
61    #[error("json parse error")]
62    JsonParseError(#[from] serde_json::error::Error),
63
64    #[error("subscribe failed: {reason}")]
65    SubscribeFailed { reason: String, message: String },
66
67    #[error("request failed: {reason}")]
68    RequestFailed { reason: String, message: String },
69}
70
71type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
72type SubscribeResponseMsg =
73    Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
74type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
75type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
76type RequestMsg = (
77    String,
78    Value,
79    oneshot::Sender<Result<Value, PubsubClientError>>,
80);
81
82#[derive(Debug)]
83pub struct PubsubClient {
84    subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>,
85    request_tx: mpsc::UnboundedSender<RequestMsg>,
86    shutdown_tx: oneshot::Sender<()>,
87    node_version: RwLock<Option<semver::Version>>,
88    ws: JoinHandle<PubsubClientResult>,
89}
90
91impl PubsubClient {
92    pub async fn new(url: &str) -> PubsubClientResult<Self> {
93        let url = Url::parse(url)?;
94        let (ws, _response) = connect_async(url)
95            .await
96            .map_err(PubsubClientError::ConnectionError)?;
97
98        let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel();
99        let (request_tx, request_rx) = mpsc::unbounded_channel();
100        let (shutdown_tx, shutdown_rx) = oneshot::channel();
101
102        Ok(Self {
103            subscribe_tx,
104            request_tx,
105            shutdown_tx,
106            node_version: RwLock::new(None),
107            ws: tokio::spawn(PubsubClient::run_ws(
108                ws,
109                subscribe_rx,
110                request_rx,
111                shutdown_rx,
112            )),
113        })
114    }
115
116    pub async fn shutdown(self) -> PubsubClientResult {
117        let _ = self.shutdown_tx.send(());
118        self.ws.await.unwrap() // WS future should not be cancelled or panicked
119    }
120
121    async fn get_node_version(&self) -> PubsubClientResult<semver::Version> {
122        let r_node_version = self.node_version.read().await;
123        if let Some(version) = &*r_node_version {
124            Ok(version.clone())
125        } else {
126            drop(r_node_version);
127            let mut w_node_version = self.node_version.write().await;
128            let node_version = self.get_version().await?;
129            *w_node_version = Some(node_version.clone());
130            Ok(node_version)
131        }
132    }
133
134    async fn get_version(&self) -> PubsubClientResult<semver::Version> {
135        let (response_tx, response_rx) = oneshot::channel();
136        self.request_tx
137            .send(("getVersion".to_string(), Value::Null, response_tx))
138            .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
139        let result = response_rx
140            .await
141            .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
142        let node_version: RpcVersionInfo = serde_json::from_value(result)?;
143        let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| {
144            PubsubClientError::RequestFailed {
145                reason: format!("failed to parse cluster version: {}", e),
146                message: "getVersion".to_string(),
147            }
148        })?;
149        Ok(node_version)
150    }
151
152    async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
153    where
154        T: DeserializeOwned + Send + 'a,
155    {
156        let (response_tx, response_rx) = oneshot::channel();
157        self.subscribe_tx
158            .send((operation.to_string(), params, response_tx))
159            .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
160
161        let (notifications, unsubscribe) = response_rx
162            .await
163            .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
164        Ok((
165            UnboundedReceiverStream::new(notifications)
166                .filter_map(|value| ready(serde_json::from_value::<T>(value).ok()))
167                .boxed(),
168            unsubscribe,
169        ))
170    }
171
172    pub async fn account_subscribe(
173        &self,
174        pubkey: &Pubkey,
175        config: Option<RpcAccountInfoConfig>,
176    ) -> SubscribeResult<'_, RpcResponse<UiAccount>> {
177        let params = json!([pubkey.to_string(), config]);
178        self.subscribe("account", params).await
179    }
180
181    pub async fn block_subscribe(
182        &self,
183        filter: RpcBlockSubscribeFilter,
184        config: Option<RpcBlockSubscribeConfig>,
185    ) -> SubscribeResult<'_, RpcResponse<RpcBlockUpdate>> {
186        self.subscribe("block", json!([filter, config])).await
187    }
188
189    pub async fn logs_subscribe(
190        &self,
191        filter: RpcTransactionLogsFilter,
192        config: RpcTransactionLogsConfig,
193    ) -> SubscribeResult<'_, RpcResponse<RpcLogsResponse>> {
194        self.subscribe("logs", json!([filter, config])).await
195    }
196
197    pub async fn program_subscribe(
198        &self,
199        pubkey: &Pubkey,
200        mut config: Option<RpcProgramAccountsConfig>,
201    ) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
202        if let Some(ref mut config) = config {
203            if let Some(ref mut filters) = config.filters {
204                let node_version = self.get_node_version().await.ok();
205                // If node does not support the pubsub `getVersion` method, assume version is old
206                // and filters should be mapped (node_version.is_none()).
207                maybe_map_filters(node_version, filters).map_err(|e| {
208                    PubsubClientError::RequestFailed {
209                        reason: e,
210                        message: "maybe_map_filters".to_string(),
211                    }
212                })?;
213            }
214        }
215
216        let params = json!([pubkey.to_string(), config]);
217        self.subscribe("program", params).await
218    }
219
220    pub async fn vote_subscribe(&self) -> SubscribeResult<'_, RpcVote> {
221        self.subscribe("vote", json!([])).await
222    }
223
224    pub async fn root_subscribe(&self) -> SubscribeResult<'_, Slot> {
225        self.subscribe("root", json!([])).await
226    }
227
228    pub async fn signature_subscribe(
229        &self,
230        signature: &Signature,
231        config: Option<RpcSignatureSubscribeConfig>,
232    ) -> SubscribeResult<'_, RpcResponse<RpcSignatureResult>> {
233        let params = json!([signature.to_string(), config]);
234        self.subscribe("signature", params).await
235    }
236
237    pub async fn slot_subscribe(&self) -> SubscribeResult<'_, SlotInfo> {
238        self.subscribe("slot", json!([])).await
239    }
240
241    pub async fn slot_updates_subscribe(&self) -> SubscribeResult<'_, SlotUpdate> {
242        self.subscribe("slotsUpdates", json!([])).await
243    }
244
245    async fn run_ws(
246        mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
247        mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
248        mut request_rx: mpsc::UnboundedReceiver<RequestMsg>,
249        mut shutdown_rx: oneshot::Receiver<()>,
250    ) -> PubsubClientResult {
251        let mut request_id: u64 = 0;
252
253        let mut requests_subscribe = BTreeMap::new();
254        let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new();
255        let mut other_requests = BTreeMap::new();
256        let mut subscriptions = BTreeMap::new();
257        let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel();
258
259        loop {
260            tokio::select! {
261                // Send close on shutdown signal
262                _ = (&mut shutdown_rx) => {
263                    let frame = CloseFrame { code: CloseCode::Normal, reason: "".into() };
264                    ws.send(Message::Close(Some(frame))).await?;
265                    ws.flush().await?;
266                    break;
267                },
268                // Send `Message::Ping` each 10s if no any other communication
269                () = sleep(Duration::from_secs(10)) => {
270                    ws.send(Message::Ping(Vec::new())).await?;
271                },
272                // Read message for subscribe
273                Some((operation, params, response_tx)) = subscribe_rx.recv() => {
274                    request_id += 1;
275                    let method = format!("{}Subscribe", operation);
276                    let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
277                    ws.send(Message::Text(text)).await?;
278                    requests_subscribe.insert(request_id, (operation, response_tx));
279                },
280                // Read message for unsubscribe
281                Some((operation, sid, response_tx)) = unsubscribe_rx.recv() => {
282                    subscriptions.remove(&sid);
283                    request_id += 1;
284                    let method = format!("{}Unsubscribe", operation);
285                    let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":[sid]}).to_string();
286                    ws.send(Message::Text(text)).await?;
287                    requests_unsubscribe.insert(request_id, response_tx);
288                },
289                // Read message for other requests
290                Some((method, params, response_tx)) = request_rx.recv() => {
291                    request_id += 1;
292                    let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
293                    ws.send(Message::Text(text)).await?;
294                    other_requests.insert(request_id, response_tx);
295                }
296                // Read incoming WebSocket message
297                next_msg = ws.next() => {
298                    let msg = match next_msg {
299                        Some(msg) => msg?,
300                        None => break,
301                    };
302                    trace!("ws.next(): {:?}", &msg);
303
304                    // Get text from the message
305                    let text = match msg {
306                        Message::Text(text) => text,
307                        Message::Binary(_data) => continue, // Ignore
308                        Message::Ping(data) => {
309                            ws.send(Message::Pong(data)).await?;
310                            continue
311                        },
312                        Message::Pong(_data) => continue,
313                        Message::Close(_frame) => break,
314                        Message::Frame(_frame) => continue,
315                    };
316
317
318                    let mut json: Map<String, Value> = serde_json::from_str(&text)?;
319
320                    // Subscribe/Unsubscribe response, example:
321                    // `{"jsonrpc":"2.0","result":5308752,"id":1}`
322                    if let Some(id) = json.get("id") {
323                        let id = id.as_u64().ok_or_else(|| {
324                            PubsubClientError::SubscribeFailed { reason: "invalid `id` field".into(), message: text.clone() }
325                        })?;
326
327                        let err = json.get("error").map(|error_object| {
328                            match serde_json::from_value::<RpcErrorObject>(error_object.clone()) {
329                                Ok(rpc_error_object) => {
330                                    format!("{} ({})",  rpc_error_object.message, rpc_error_object.code)
331                                }
332                                Err(err) => format!(
333                                    "Failed to deserialize RPC error response: {} [{}]",
334                                    serde_json::to_string(error_object).unwrap(),
335                                    err
336                                )
337                            }
338                        });
339
340                        if let Some(response_tx) = other_requests.remove(&id) {
341                            match err {
342                                Some(reason) => {
343                                    let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()}));
344                                },
345                                None => {
346                                    let json_result = json.get("result").ok_or_else(|| {
347                                        PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() }
348                                    })?;
349                                    if response_tx.send(Ok(json_result.clone())).is_err() {
350                                        break;
351                                    }
352                                }
353                            }
354                        } else if let Some(response_tx) = requests_unsubscribe.remove(&id) {
355                            let _ = response_tx.send(()); // do not care if receiver is closed
356                        } else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) {
357                            match err {
358                                Some(reason) => {
359                                    let _ = response_tx.send(Err(PubsubClientError::SubscribeFailed { reason, message: text.clone()}));
360                                },
361                                None => {
362                                    // Subscribe Id
363                                    let sid = json.get("result").and_then(Value::as_u64).ok_or_else(|| {
364                                        PubsubClientError::SubscribeFailed { reason: "invalid `result` field".into(), message: text.clone() }
365                                    })?;
366
367                                    // Create notifications channel and unsubscribe function
368                                    let (notifications_tx, notifications_rx) = mpsc::unbounded_channel();
369                                    let unsubscribe_tx = unsubscribe_tx.clone();
370                                    let unsubscribe = Box::new(move || async move {
371                                        let (response_tx, response_rx) = oneshot::channel();
372                                        // do nothing if ws already closed
373                                        if unsubscribe_tx.send((operation, sid, response_tx)).is_ok() {
374                                            let _ = response_rx.await; // channel can be closed only if ws is closed
375                                        }
376                                    }.boxed());
377
378                                    if response_tx.send(Ok((notifications_rx, unsubscribe))).is_err() {
379                                        break;
380                                    }
381                                    subscriptions.insert(sid, notifications_tx);
382                                }
383                            }
384                        } else {
385                            error!("Unknown request id: {}", id);
386                            break;
387                        }
388                        continue;
389                    }
390
391                    // Notification, example:
392                    // `{"jsonrpc":"2.0","method":"logsNotification","params":{"result":{...},"subscription":3114862}}`
393                    if let Some(Value::Object(params)) = json.get_mut("params") {
394                        if let Some(sid) = params.get("subscription").and_then(Value::as_u64) {
395                            let mut unsubscribe_required = false;
396
397                            if let Some(notifications_tx) = subscriptions.get(&sid) {
398                                if let Some(result) = params.remove("result") {
399                                    if notifications_tx.send(result).is_err() {
400                                        unsubscribe_required = true;
401                                    }
402                                }
403                            } else {
404                                unsubscribe_required = true;
405                            }
406
407                            if unsubscribe_required {
408                                if let Some(Value::String(method)) = json.remove("method") {
409                                    if let Some(operation) = method.strip_suffix("Notification") {
410                                        let (response_tx, _response_rx) = oneshot::channel();
411                                        let _ = unsubscribe_tx.send((operation.to_string(), sid, response_tx));
412                                    }
413                                }
414                            }
415                        }
416                    }
417                }
418            }
419        }
420
421        Ok(())
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    // see client-test/test/client.rs
428}