mod web_context;
use bytes::BytesMut;
use futures::task::AtomicWaker;
use futures::{future::Ready, io, prelude::*};
use js_sys::Array;
use libp2p_core::transport::DialOpts;
use libp2p_core::{
multiaddr::{Multiaddr, Protocol},
transport::{ListenerId, TransportError, TransportEvent},
};
use send_wrapper::SendWrapper;
use std::cmp::min;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::{pin::Pin, task::Context, task::Poll};
use wasm_bindgen::prelude::*;
use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
use crate::web_context::WebContext;
#[derive(Default)]
pub struct Transport {
_private: (),
}
const MAX_BUFFER: usize = 1024 * 1024;
impl libp2p_core::Transport for Transport {
type Output = Connection;
type Error = Error;
type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn listen_on(
&mut self,
_: ListenerId,
addr: Multiaddr,
) -> Result<(), TransportError<Self::Error>> {
Err(TransportError::MultiaddrNotSupported(addr))
}
fn remove_listener(&mut self, _id: ListenerId) -> bool {
false
}
fn dial(
&mut self,
addr: Multiaddr,
dial_opts: DialOpts,
) -> Result<Self::Dial, TransportError<Self::Error>> {
if dial_opts.role.is_listener() {
return Err(TransportError::MultiaddrNotSupported(addr));
}
let url = extract_websocket_url(&addr)
.ok_or_else(|| TransportError::MultiaddrNotSupported(addr))?;
Ok(async move {
let socket = match WebSocket::new(&url) {
Ok(ws) => ws,
Err(_) => return Err(Error::invalid_websocket_url(&url)),
};
Ok(Connection::new(socket))
}
.boxed())
}
fn poll(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
Poll::Pending
}
}
fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
let mut protocols = addr.iter();
let host_port = match (protocols.next(), protocols.next()) {
(Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
format!("{ip}:{port}")
}
(Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
format!("[{ip}]:{port}")
}
(Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
format!("{}:{}", &h, port)
}
_ => return None,
};
let (scheme, wspath) = match (protocols.next(), protocols.next()) {
(Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
(Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
(Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
_ => return None,
};
Some(format!("{scheme}://{host_port}{wspath}"))
}
#[derive(thiserror::Error, Debug)]
#[error("{msg}")]
pub struct Error {
msg: String,
}
impl Error {
fn invalid_websocket_url(url: &str) -> Self {
Self {
msg: format!("Invalid websocket url: {url}"),
}
}
}
pub struct Connection {
inner: SendWrapper<Inner>,
}
struct Inner {
socket: WebSocket,
new_data_waker: Rc<AtomicWaker>,
read_buffer: Rc<Mutex<BytesMut>>,
open_waker: Rc<AtomicWaker>,
write_waker: Rc<AtomicWaker>,
close_waker: Rc<AtomicWaker>,
errored: Rc<AtomicBool>,
_on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
_on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
_on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
_on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
_on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
buffered_amount_low_interval: i32,
}
impl Inner {
fn ready_state(&self) -> ReadyState {
match self.socket.ready_state() {
0 => ReadyState::Connecting,
1 => ReadyState::Open,
2 => ReadyState::Closing,
3 => ReadyState::Closed,
unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
}
}
fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
match self.ready_state() {
ReadyState::Connecting => {
self.open_waker.register(cx.waker());
Poll::Pending
}
ReadyState::Open => Poll::Ready(Ok(())),
ReadyState::Closed | ReadyState::Closing => {
Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
}
}
fn error_barrier(&self) -> io::Result<()> {
if self.errored.load(Ordering::SeqCst) {
return Err(io::ErrorKind::BrokenPipe.into());
}
Ok(())
}
}
#[derive(PartialEq)]
enum ReadyState {
Connecting,
Open,
Closing,
Closed,
}
impl Connection {
fn new(socket: WebSocket) -> Self {
socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
let open_waker = Rc::new(AtomicWaker::new());
let onopen_closure = Closure::<dyn FnMut(_)>::new({
let open_waker = open_waker.clone();
move |_| {
open_waker.wake();
}
});
socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
let close_waker = Rc::new(AtomicWaker::new());
let onclose_closure = Closure::<dyn FnMut(_)>::new({
let close_waker = close_waker.clone();
move |_| {
close_waker.wake();
}
});
socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
let errored = Rc::new(AtomicBool::new(false));
let onerror_closure = Closure::<dyn FnMut(_)>::new({
let errored = errored.clone();
move |_| {
errored.store(true, Ordering::SeqCst);
}
});
socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
let new_data_waker = Rc::new(AtomicWaker::new());
let onmessage_closure = Closure::<dyn FnMut(_)>::new({
let read_buffer = read_buffer.clone();
let new_data_waker = new_data_waker.clone();
let errored = errored.clone();
move |e: MessageEvent| {
let data = js_sys::Uint8Array::new(&e.data());
let mut read_buffer = read_buffer.lock().unwrap();
if read_buffer.len() + data.length() as usize > MAX_BUFFER {
tracing::warn!("Remote is overloading us with messages, closing connection");
errored.store(true, Ordering::SeqCst);
return;
}
read_buffer.extend_from_slice(&data.to_vec());
new_data_waker.wake();
}
});
socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
let write_waker = Rc::new(AtomicWaker::new());
let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
let write_waker = write_waker.clone();
let socket = socket.clone();
move |_| {
if socket.buffered_amount() == 0 {
write_waker.wake();
}
}
});
let buffered_amount_low_interval = WebContext::new()
.expect("to have a window or worker context")
.set_interval_with_callback_and_timeout_and_arguments(
on_buffered_amount_low_closure.as_ref().unchecked_ref(),
100, &Array::new(),
)
.expect("to be able to set an interval");
Self {
inner: SendWrapper::new(Inner {
socket,
new_data_waker,
read_buffer,
open_waker,
write_waker,
close_waker,
errored,
_on_open_closure: Rc::new(onopen_closure),
_on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
_on_close_closure: Rc::new(onclose_closure),
_on_error_closure: Rc::new(onerror_closure),
_on_message_closure: Rc::new(onmessage_closure),
buffered_amount_low_interval,
}),
}
}
fn buffered_amount(&self) -> usize {
self.inner.socket.buffered_amount() as usize
}
}
impl AsyncRead for Connection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.get_mut();
this.inner.error_barrier()?;
futures::ready!(this.inner.poll_open(cx))?;
let mut read_buffer = this.inner.read_buffer.lock().unwrap();
if read_buffer.is_empty() {
this.inner.new_data_waker.register(cx.waker());
return Poll::Pending;
}
let split_index = min(buf.len(), read_buffer.len());
let bytes_to_return = read_buffer.split_to(split_index);
let len = bytes_to_return.len();
buf[..len].copy_from_slice(&bytes_to_return);
Poll::Ready(Ok(len))
}
}
impl AsyncWrite for Connection {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.inner.error_barrier()?;
futures::ready!(this.inner.poll_open(cx))?;
debug_assert!(this.buffered_amount() <= MAX_BUFFER);
let remaining_space = MAX_BUFFER - this.buffered_amount();
if remaining_space == 0 {
this.inner.write_waker.register(cx.waker());
return Poll::Pending;
}
let bytes_to_send = min(buf.len(), remaining_space);
if this
.inner
.socket
.send_with_u8_array(&buf[..bytes_to_send])
.is_err()
{
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
Poll::Ready(Ok(bytes_to_send))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.buffered_amount() == 0 {
return Poll::Ready(Ok(()));
}
self.inner.error_barrier()?;
self.inner.write_waker.register(cx.waker());
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
const REGULAR_CLOSE: u16 = 1000; if self.inner.ready_state() == ReadyState::Closed {
return Poll::Ready(Ok(()));
}
self.inner.error_barrier()?;
if self.inner.ready_state() != ReadyState::Closing {
let _ = self
.inner
.socket
.close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
}
self.inner.close_waker.register(cx.waker());
Poll::Pending
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.inner.socket.set_onclose(None);
self.inner.socket.set_onerror(None);
self.inner.socket.set_onopen(None);
self.inner.socket.set_onmessage(None);
const REGULAR_CLOSE: u16 = 1000; if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
let _ = self
.inner
.socket
.close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
}
WebContext::new()
.expect("to have a window or worker context")
.clear_interval_with_handle(self.inner.buffered_amount_low_interval);
}
}
#[cfg(test)]
mod tests {
use super::*;
use libp2p_identity::PeerId;
#[test]
fn extract_url() {
let peer_id = PeerId::random();
let addr = "/dns4/example.com/tcp/2222/tls/ws"
.parse::<Multiaddr>()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://example.com:2222/");
let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
.parse()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://example.com:2222/");
let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
.parse::<Multiaddr>()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://127.0.0.1:2222/");
let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://[::1]:2222/");
let addr = "/dns4/example.com/tcp/2222/wss"
.parse::<Multiaddr>()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://example.com:2222/");
let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
.parse()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://example.com:2222/");
let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://127.0.0.1:2222/");
let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "wss://[::1]:2222/");
let addr = "/dns4/example.com/tcp/2222/ws"
.parse::<Multiaddr>()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "ws://example.com:2222/");
let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
.parse()
.unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "ws://example.com:2222/");
let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "ws://127.0.0.1:2222/");
let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "ws://[::1]:2222/");
let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
let url = extract_websocket_url(&addr).unwrap();
assert_eq!(url, "ws://127.0.0.1:2222/");
let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
.parse::<Multiaddr>()
.unwrap();
assert!(extract_websocket_url(&addr).is_none());
let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
assert!(extract_websocket_url(&addr).is_none());
}
}