alloy_transport_ws/native.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
use crate::WsBackend;
use alloy_pubsub::PubSubConnect;
use alloy_transport::{utils::Spawnable, Authorization, TransportErrorKind, TransportResult};
use futures::{SinkExt, StreamExt};
use serde_json::value::RawValue;
use std::time::Duration;
pub use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::{
tungstenite::{self, client::IntoClientRequest, Message},
MaybeTlsStream, WebSocketStream,
};
#[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::sleep;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::sleep;
type TungsteniteStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
const KEEPALIVE: u64 = 10;
/// Simple connection details for a websocket connection.
#[derive(Clone, Debug)]
pub struct WsConnect {
/// The URL to connect to.
pub url: String,
/// The authorization header to use.
pub auth: Option<Authorization>,
/// The websocket config.
pub config: Option<WebSocketConfig>,
}
impl WsConnect {
/// Creates a new websocket connection configuration.
pub fn new<S: Into<String>>(url: S) -> Self {
Self { url: url.into(), auth: None, config: None }
}
/// Sets the authorization header.
pub fn with_auth(mut self, auth: Authorization) -> Self {
self.auth = Some(auth);
self
}
/// Sets the websocket config.
pub const fn with_config(mut self, config: WebSocketConfig) -> Self {
self.config = Some(config);
self
}
}
impl IntoClientRequest for WsConnect {
fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
let mut request: http::Request<()> = self.url.into_client_request()?;
if let Some(auth) = self.auth {
let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
auth_value.set_sensitive(true);
request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
}
request.into_client_request()
}
}
impl PubSubConnect for WsConnect {
fn is_local(&self) -> bool {
alloy_transport::utils::guess_local_url(&self.url)
}
async fn connect(&self) -> TransportResult<alloy_pubsub::ConnectionHandle> {
let request = self.clone().into_client_request();
let req = request.map_err(TransportErrorKind::custom)?;
let (socket, _) = tokio_tungstenite::connect_async_with_config(req, self.config, false)
.await
.map_err(TransportErrorKind::custom)?;
let (handle, interface) = alloy_pubsub::ConnectionHandle::new();
let backend = WsBackend { socket, interface };
backend.spawn();
Ok(handle)
}
}
impl WsBackend<TungsteniteStream> {
/// Handle a message from the server.
#[allow(clippy::result_unit_err)]
pub fn handle(&mut self, msg: Message) -> Result<(), ()> {
match msg {
Message::Text(text) => self.handle_text(&text),
Message::Close(frame) => {
if frame.is_some() {
error!(?frame, "Received close frame with data");
} else {
error!("WS server has gone away");
}
Err(())
}
Message::Binary(_) => {
error!("Received binary message, expected text");
Err(())
}
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => Ok(()),
}
}
/// Send a message to the server.
pub async fn send(&mut self, msg: Box<RawValue>) -> Result<(), tungstenite::Error> {
self.socket.send(Message::Text(msg.get().to_owned())).await
}
/// Spawn a new backend task.
pub fn spawn(mut self) {
let fut = async move {
let mut errored = false;
let keepalive = sleep(Duration::from_secs(KEEPALIVE));
tokio::pin!(keepalive);
loop {
// We bias the loop as follows
// 1. New dispatch to server.
// 2. Keepalive.
// 3. Response or notification from server.
// This ensures that keepalive is sent only if no other messages
// have been sent in the last 10 seconds. And prioritizes new
// dispatches over responses from the server. This will fail if
// the client saturates the task with dispatches, but that's
// probably not a big deal.
tokio::select! {
biased;
// we've received a new dispatch, so we send it via
// websocket. We handle new work before processing any
// responses from the server.
inst = self.interface.recv_from_frontend() => {
match inst {
Some(msg) => {
// Reset the keepalive timer.
keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
if let Err(err) = self.send(msg).await {
error!(%err, "WS connection error");
errored = true;
break
}
},
// dispatcher has gone away, or shutdown was received
None => {
break
},
}
},
// Send a ping to the server, if no other messages have been
// sent in the last 10 seconds.
_ = &mut keepalive => {
// Reset the keepalive timer.
keepalive.set(sleep(Duration::from_secs(KEEPALIVE)));
if let Err(err) = self.socket.send(Message::Ping(vec![])).await {
error!(%err, "WS connection error");
errored = true;
break
}
}
resp = self.socket.next() => {
match resp {
Some(Ok(item)) => {
errored = self.handle(item).is_err();
if errored { break }
},
Some(Err(err)) => {
error!(%err, "WS connection error");
errored = true;
break
}
None => {
error!("WS server has gone away");
errored = true;
break
},
}
}
}
}
if errored {
self.interface.close_with_error();
}
};
fut.spawn_task()
}
}