1use {
2 crate::{
3 http_sender::RpcErrorObject,
4 rpc_config::{
5 RpcAccountInfoConfig, RpcBlockSubscribeConfig, RpcBlockSubscribeFilter,
6 RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
7 RpcTransactionLogsFilter,
8 },
9 rpc_filter::maybe_map_filters,
10 rpc_response::{
11 Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
12 RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
13 },
14 },
15 futures_util::{
16 future::{ready, BoxFuture, FutureExt},
17 sink::SinkExt,
18 stream::{BoxStream, StreamExt},
19 },
20 log::*,
21 serde::de::DeserializeOwned,
22 serde_json::{json, Map, Value},
23 safecoin_account_decoder::UiAccount,
24 solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature},
25 std::collections::BTreeMap,
26 thiserror::Error,
27 tokio::{
28 net::TcpStream,
29 sync::{mpsc, oneshot, RwLock},
30 task::JoinHandle,
31 time::{sleep, Duration},
32 },
33 tokio_stream::wrappers::UnboundedReceiverStream,
34 tokio_tungstenite::{
35 connect_async,
36 tungstenite::{
37 protocol::frame::{coding::CloseCode, CloseFrame},
38 Message,
39 },
40 MaybeTlsStream, WebSocketStream,
41 },
42 url::Url,
43};
44
45pub type PubsubClientResult<T = ()> = Result<T, PubsubClientError>;
46
47#[derive(Debug, Error)]
48pub enum PubsubClientError {
49 #[error("url parse error")]
50 UrlParseError(#[from] url::ParseError),
51
52 #[error("unable to connect to server")]
53 ConnectionError(tokio_tungstenite::tungstenite::Error),
54
55 #[error("websocket error")]
56 WsError(#[from] tokio_tungstenite::tungstenite::Error),
57
58 #[error("connection closed (({0})")]
59 ConnectionClosed(String),
60
61 #[error("json parse error")]
62 JsonParseError(#[from] serde_json::error::Error),
63
64 #[error("subscribe failed: {reason}")]
65 SubscribeFailed { reason: String, message: String },
66
67 #[error("request failed: {reason}")]
68 RequestFailed { reason: String, message: String },
69}
70
71type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
72type SubscribeResponseMsg =
73 Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
74type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
75type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
76type RequestMsg = (
77 String,
78 Value,
79 oneshot::Sender<Result<Value, PubsubClientError>>,
80);
81
82#[derive(Debug)]
83pub struct PubsubClient {
84 subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>,
85 request_tx: mpsc::UnboundedSender<RequestMsg>,
86 shutdown_tx: oneshot::Sender<()>,
87 node_version: RwLock<Option<semver::Version>>,
88 ws: JoinHandle<PubsubClientResult>,
89}
90
91impl PubsubClient {
92 pub async fn new(url: &str) -> PubsubClientResult<Self> {
93 let url = Url::parse(url)?;
94 let (ws, _response) = connect_async(url)
95 .await
96 .map_err(PubsubClientError::ConnectionError)?;
97
98 let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel();
99 let (request_tx, request_rx) = mpsc::unbounded_channel();
100 let (shutdown_tx, shutdown_rx) = oneshot::channel();
101
102 Ok(Self {
103 subscribe_tx,
104 request_tx,
105 shutdown_tx,
106 node_version: RwLock::new(None),
107 ws: tokio::spawn(PubsubClient::run_ws(
108 ws,
109 subscribe_rx,
110 request_rx,
111 shutdown_rx,
112 )),
113 })
114 }
115
116 pub async fn shutdown(self) -> PubsubClientResult {
117 let _ = self.shutdown_tx.send(());
118 self.ws.await.unwrap() }
120
121 async fn get_node_version(&self) -> PubsubClientResult<semver::Version> {
122 let r_node_version = self.node_version.read().await;
123 if let Some(version) = &*r_node_version {
124 Ok(version.clone())
125 } else {
126 drop(r_node_version);
127 let mut w_node_version = self.node_version.write().await;
128 let node_version = self.get_version().await?;
129 *w_node_version = Some(node_version.clone());
130 Ok(node_version)
131 }
132 }
133
134 async fn get_version(&self) -> PubsubClientResult<semver::Version> {
135 let (response_tx, response_rx) = oneshot::channel();
136 self.request_tx
137 .send(("getVersion".to_string(), Value::Null, response_tx))
138 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
139 let result = response_rx
140 .await
141 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
142 let node_version: RpcVersionInfo = serde_json::from_value(result)?;
143 let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| {
144 PubsubClientError::RequestFailed {
145 reason: format!("failed to parse cluster version: {}", e),
146 message: "getVersion".to_string(),
147 }
148 })?;
149 Ok(node_version)
150 }
151
152 async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
153 where
154 T: DeserializeOwned + Send + 'a,
155 {
156 let (response_tx, response_rx) = oneshot::channel();
157 self.subscribe_tx
158 .send((operation.to_string(), params, response_tx))
159 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
160
161 let (notifications, unsubscribe) = response_rx
162 .await
163 .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
164 Ok((
165 UnboundedReceiverStream::new(notifications)
166 .filter_map(|value| ready(serde_json::from_value::<T>(value).ok()))
167 .boxed(),
168 unsubscribe,
169 ))
170 }
171
172 pub async fn account_subscribe(
173 &self,
174 pubkey: &Pubkey,
175 config: Option<RpcAccountInfoConfig>,
176 ) -> SubscribeResult<'_, RpcResponse<UiAccount>> {
177 let params = json!([pubkey.to_string(), config]);
178 self.subscribe("account", params).await
179 }
180
181 pub async fn block_subscribe(
182 &self,
183 filter: RpcBlockSubscribeFilter,
184 config: Option<RpcBlockSubscribeConfig>,
185 ) -> SubscribeResult<'_, RpcResponse<RpcBlockUpdate>> {
186 self.subscribe("block", json!([filter, config])).await
187 }
188
189 pub async fn logs_subscribe(
190 &self,
191 filter: RpcTransactionLogsFilter,
192 config: RpcTransactionLogsConfig,
193 ) -> SubscribeResult<'_, RpcResponse<RpcLogsResponse>> {
194 self.subscribe("logs", json!([filter, config])).await
195 }
196
197 pub async fn program_subscribe(
198 &self,
199 pubkey: &Pubkey,
200 mut config: Option<RpcProgramAccountsConfig>,
201 ) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
202 if let Some(ref mut config) = config {
203 if let Some(ref mut filters) = config.filters {
204 let node_version = self.get_node_version().await.ok();
205 maybe_map_filters(node_version, filters).map_err(|e| {
208 PubsubClientError::RequestFailed {
209 reason: e,
210 message: "maybe_map_filters".to_string(),
211 }
212 })?;
213 }
214 }
215
216 let params = json!([pubkey.to_string(), config]);
217 self.subscribe("program", params).await
218 }
219
220 pub async fn vote_subscribe(&self) -> SubscribeResult<'_, RpcVote> {
221 self.subscribe("vote", json!([])).await
222 }
223
224 pub async fn root_subscribe(&self) -> SubscribeResult<'_, Slot> {
225 self.subscribe("root", json!([])).await
226 }
227
228 pub async fn signature_subscribe(
229 &self,
230 signature: &Signature,
231 config: Option<RpcSignatureSubscribeConfig>,
232 ) -> SubscribeResult<'_, RpcResponse<RpcSignatureResult>> {
233 let params = json!([signature.to_string(), config]);
234 self.subscribe("signature", params).await
235 }
236
237 pub async fn slot_subscribe(&self) -> SubscribeResult<'_, SlotInfo> {
238 self.subscribe("slot", json!([])).await
239 }
240
241 pub async fn slot_updates_subscribe(&self) -> SubscribeResult<'_, SlotUpdate> {
242 self.subscribe("slotsUpdates", json!([])).await
243 }
244
245 async fn run_ws(
246 mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
247 mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
248 mut request_rx: mpsc::UnboundedReceiver<RequestMsg>,
249 mut shutdown_rx: oneshot::Receiver<()>,
250 ) -> PubsubClientResult {
251 let mut request_id: u64 = 0;
252
253 let mut requests_subscribe = BTreeMap::new();
254 let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new();
255 let mut other_requests = BTreeMap::new();
256 let mut subscriptions = BTreeMap::new();
257 let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel();
258
259 loop {
260 tokio::select! {
261 _ = (&mut shutdown_rx) => {
263 let frame = CloseFrame { code: CloseCode::Normal, reason: "".into() };
264 ws.send(Message::Close(Some(frame))).await?;
265 ws.flush().await?;
266 break;
267 },
268 () = sleep(Duration::from_secs(10)) => {
270 ws.send(Message::Ping(Vec::new())).await?;
271 },
272 Some((operation, params, response_tx)) = subscribe_rx.recv() => {
274 request_id += 1;
275 let method = format!("{}Subscribe", operation);
276 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
277 ws.send(Message::Text(text)).await?;
278 requests_subscribe.insert(request_id, (operation, response_tx));
279 },
280 Some((operation, sid, response_tx)) = unsubscribe_rx.recv() => {
282 subscriptions.remove(&sid);
283 request_id += 1;
284 let method = format!("{}Unsubscribe", operation);
285 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":[sid]}).to_string();
286 ws.send(Message::Text(text)).await?;
287 requests_unsubscribe.insert(request_id, response_tx);
288 },
289 Some((method, params, response_tx)) = request_rx.recv() => {
291 request_id += 1;
292 let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
293 ws.send(Message::Text(text)).await?;
294 other_requests.insert(request_id, response_tx);
295 }
296 next_msg = ws.next() => {
298 let msg = match next_msg {
299 Some(msg) => msg?,
300 None => break,
301 };
302 trace!("ws.next(): {:?}", &msg);
303
304 let text = match msg {
306 Message::Text(text) => text,
307 Message::Binary(_data) => continue, Message::Ping(data) => {
309 ws.send(Message::Pong(data)).await?;
310 continue
311 },
312 Message::Pong(_data) => continue,
313 Message::Close(_frame) => break,
314 Message::Frame(_frame) => continue,
315 };
316
317
318 let mut json: Map<String, Value> = serde_json::from_str(&text)?;
319
320 if let Some(id) = json.get("id") {
323 let id = id.as_u64().ok_or_else(|| {
324 PubsubClientError::SubscribeFailed { reason: "invalid `id` field".into(), message: text.clone() }
325 })?;
326
327 let err = json.get("error").map(|error_object| {
328 match serde_json::from_value::<RpcErrorObject>(error_object.clone()) {
329 Ok(rpc_error_object) => {
330 format!("{} ({})", rpc_error_object.message, rpc_error_object.code)
331 }
332 Err(err) => format!(
333 "Failed to deserialize RPC error response: {} [{}]",
334 serde_json::to_string(error_object).unwrap(),
335 err
336 )
337 }
338 });
339
340 if let Some(response_tx) = other_requests.remove(&id) {
341 match err {
342 Some(reason) => {
343 let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()}));
344 },
345 None => {
346 let json_result = json.get("result").ok_or_else(|| {
347 PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() }
348 })?;
349 if response_tx.send(Ok(json_result.clone())).is_err() {
350 break;
351 }
352 }
353 }
354 } else if let Some(response_tx) = requests_unsubscribe.remove(&id) {
355 let _ = response_tx.send(()); } else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) {
357 match err {
358 Some(reason) => {
359 let _ = response_tx.send(Err(PubsubClientError::SubscribeFailed { reason, message: text.clone()}));
360 },
361 None => {
362 let sid = json.get("result").and_then(Value::as_u64).ok_or_else(|| {
364 PubsubClientError::SubscribeFailed { reason: "invalid `result` field".into(), message: text.clone() }
365 })?;
366
367 let (notifications_tx, notifications_rx) = mpsc::unbounded_channel();
369 let unsubscribe_tx = unsubscribe_tx.clone();
370 let unsubscribe = Box::new(move || async move {
371 let (response_tx, response_rx) = oneshot::channel();
372 if unsubscribe_tx.send((operation, sid, response_tx)).is_ok() {
374 let _ = response_rx.await; }
376 }.boxed());
377
378 if response_tx.send(Ok((notifications_rx, unsubscribe))).is_err() {
379 break;
380 }
381 subscriptions.insert(sid, notifications_tx);
382 }
383 }
384 } else {
385 error!("Unknown request id: {}", id);
386 break;
387 }
388 continue;
389 }
390
391 if let Some(Value::Object(params)) = json.get_mut("params") {
394 if let Some(sid) = params.get("subscription").and_then(Value::as_u64) {
395 let mut unsubscribe_required = false;
396
397 if let Some(notifications_tx) = subscriptions.get(&sid) {
398 if let Some(result) = params.remove("result") {
399 if notifications_tx.send(result).is_err() {
400 unsubscribe_required = true;
401 }
402 }
403 } else {
404 unsubscribe_required = true;
405 }
406
407 if unsubscribe_required {
408 if let Some(Value::String(method)) = json.remove("method") {
409 if let Some(operation) = method.strip_suffix("Notification") {
410 let (response_tx, _response_rx) = oneshot::channel();
411 let _ = unsubscribe_tx.send((operation.to_string(), sid, response_tx));
412 }
413 }
414 }
415 }
416 }
417 }
418 }
419 }
420
421 Ok(())
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 }