use std::{
net::SocketAddr,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use anyhow::{anyhow, bail, ensure, Context as _, Result};
use bytes::Bytes;
use futures_lite::Stream;
use futures_sink::Sink;
use futures_util::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use tokio::sync::mpsc;
use tokio_tungstenite_wasm::WebSocketStream;
use tokio_util::{
codec::{FramedRead, FramedWrite},
task::AbortOnDropHandle,
};
use tracing::{debug, info_span, trace, Instrument};
use crate::{
defaults::timeouts::relay::CLIENT_RECV_TIMEOUT,
key::{PublicKey, SecretKey},
relay::{
client::streams::{MaybeTlsStreamReader, MaybeTlsStreamWriter},
codec::{
write_frame, ClientInfo, DerpCodec, Frame, MAX_PACKET_SIZE,
PER_CLIENT_READ_QUEUE_DEPTH, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION,
},
},
};
impl PartialEq for Conn {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Eq for Conn {}
#[derive(Debug, Clone)]
pub struct Conn {
inner: Arc<ConnTasks>,
}
#[derive(Debug)]
pub struct ConnReceiver {
reader_channel: mpsc::Receiver<Result<ReceivedMessage>>,
}
impl ConnReceiver {
pub async fn recv(&mut self) -> Result<ReceivedMessage> {
let msg = self
.reader_channel
.recv()
.await
.ok_or(anyhow!("shut down"))??;
Ok(msg)
}
}
#[derive(derive_more::Debug)]
pub struct ConnTasks {
local_addr: Option<SocketAddr>,
writer_channel: mpsc::Sender<ConnWriterMessage>,
writer_task: AbortOnDropHandle<Result<()>>,
reader_task: AbortOnDropHandle<()>,
}
impl Conn {
pub async fn send(&self, dstkey: PublicKey, packet: Bytes) -> Result<()> {
trace!(%dstkey, len = packet.len(), "[RELAY] send");
self.inner
.writer_channel
.send(ConnWriterMessage::Packet((dstkey, packet)))
.await?;
Ok(())
}
pub async fn send_ping(&self, data: [u8; 8]) -> Result<()> {
self.inner
.writer_channel
.send(ConnWriterMessage::Ping(data))
.await?;
Ok(())
}
pub async fn send_pong(&self, data: [u8; 8]) -> Result<()> {
self.inner
.writer_channel
.send(ConnWriterMessage::Pong(data))
.await?;
Ok(())
}
pub async fn note_preferred(&self, preferred: bool) -> Result<()> {
self.inner
.writer_channel
.send(ConnWriterMessage::NotePreferred(preferred))
.await?;
Ok(())
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.inner.local_addr
}
pub fn is_closed(&self) -> bool {
self.inner.writer_task.is_finished()
}
pub async fn close(&self) {
if self.inner.writer_task.is_finished() && self.inner.reader_task.is_finished() {
return;
}
self.inner
.writer_channel
.send(ConnWriterMessage::Shutdown)
.await
.ok();
self.inner.reader_task.abort();
}
}
fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
match frame {
Frame::KeepAlive => {
Ok(ReceivedMessage::KeepAlive)
}
Frame::PeerGone { peer } => Ok(ReceivedMessage::PeerGone(peer)),
Frame::RecvPacket { src_key, content } => {
let packet = ReceivedMessage::ReceivedPacket {
source: src_key,
data: content,
};
Ok(packet)
}
Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)),
Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)),
Frame::Health { problem } => {
let problem = std::str::from_utf8(&problem)?.to_owned();
let problem = Some(problem);
Ok(ReceivedMessage::Health { problem })
}
Frame::Restarting {
reconnect_in,
try_for,
} => {
let reconnect_in = Duration::from_millis(reconnect_in as u64);
let try_for = Duration::from_millis(try_for as u64);
Ok(ReceivedMessage::ServerRestarting {
reconnect_in,
try_for,
})
}
_ => bail!("unexpected packet: {:?}", frame.typ()),
}
}
#[derive(Debug)]
enum ConnWriterMessage {
Packet((PublicKey, Bytes)),
Pong([u8; 8]),
Ping([u8; 8]),
NotePreferred(bool),
Shutdown,
}
struct ConnWriterTasks {
recv_msgs: mpsc::Receiver<ConnWriterMessage>,
writer: ConnWriter,
rate_limiter: Option<RateLimiter>,
}
impl ConnWriterTasks {
async fn run(mut self) -> Result<()> {
while let Some(msg) = self.recv_msgs.recv().await {
match msg {
ConnWriterMessage::Packet((key, bytes)) => {
send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
}
ConnWriterMessage::Pong(data) => {
write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
self.writer.flush().await?;
}
ConnWriterMessage::Ping(data) => {
write_frame(&mut self.writer, Frame::Ping { data }, None).await?;
self.writer.flush().await?;
}
ConnWriterMessage::NotePreferred(preferred) => {
write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?;
self.writer.flush().await?;
}
ConnWriterMessage::Shutdown => {
return Ok(());
}
}
}
bail!("channel unexpectedly closed");
}
}
pub struct ConnBuilder {
secret_key: SecretKey,
reader: ConnReader,
writer: ConnWriter,
local_addr: Option<SocketAddr>,
}
pub(crate) enum ConnReader {
Derp(FramedRead<MaybeTlsStreamReader, DerpCodec>),
Ws(SplitStream<WebSocketStream>),
}
pub(crate) enum ConnWriter {
Derp(FramedWrite<MaybeTlsStreamWriter, DerpCodec>),
Ws(SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>),
}
fn tung_wasm_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error {
match e {
tokio_tungstenite_wasm::Error::Io(io_err) => io_err,
_ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
}
}
impl Stream for ConnReader {
type Item = Result<Frame>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::Derp(ref mut ws) => Pin::new(ws).poll_next(cx),
Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
Poll::Ready(Some(Ok(tokio_tungstenite_wasm::Message::Binary(vec)))) => {
Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
}
Poll::Ready(Some(Ok(msg))) => {
tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
Poll::Pending
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
}
}
}
impl Sink<Frame> for ConnWriter {
type Error = std::io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut ws) => Pin::new(ws).poll_ready(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_wasm_to_io_err),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
match *self {
Self::Derp(ref mut ws) => Pin::new(ws).start_send(item),
Self::Ws(ref mut ws) => Pin::new(ws)
.start_send(tokio_tungstenite_wasm::Message::binary(
item.encode_for_ws_msg(),
))
.map_err(tung_wasm_to_io_err),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut ws) => Pin::new(ws).poll_flush(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_wasm_to_io_err),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut ws) => Pin::new(ws).poll_close(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_wasm_to_io_err),
}
}
}
impl ConnBuilder {
pub fn new(
secret_key: SecretKey,
local_addr: Option<SocketAddr>,
reader: ConnReader,
writer: ConnWriter,
) -> Self {
Self {
secret_key,
reader,
writer,
local_addr,
}
}
async fn server_handshake(&mut self) -> Result<Option<RateLimiter>> {
debug!("server_handshake: started");
let client_info = ClientInfo {
version: PROTOCOL_VERSION,
};
debug!("server_handshake: sending client_key: {:?}", &client_info);
crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info)
.await?;
let rate_limiter = RateLimiter::new(0, 0)?;
debug!("server_handshake: done");
Ok(rate_limiter)
}
pub async fn build(mut self) -> Result<(Conn, ConnReceiver)> {
let rate_limiter = self.server_handshake().await?;
let (writer_sender, writer_recv) = mpsc::channel(PER_CLIENT_SEND_QUEUE_DEPTH);
let writer_task = tokio::task::spawn(
ConnWriterTasks {
rate_limiter,
writer: self.writer,
recv_msgs: writer_recv,
}
.run()
.instrument(info_span!("conn.writer")),
);
let (reader_sender, reader_recv) = mpsc::channel(PER_CLIENT_READ_QUEUE_DEPTH);
let reader_task = tokio::task::spawn({
let writer_sender = writer_sender.clone();
async move {
loop {
let frame = tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.next()).await;
let res = match frame {
Ok(Some(Ok(frame))) => process_incoming_frame(frame),
Ok(Some(Err(err))) => {
Err(err)
}
Ok(None) => {
Err(anyhow::anyhow!("EOF: reader stream ended"))
}
Err(err) => {
Err(err.into())
}
};
if res.is_err() {
writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
break;
}
if reader_sender.send(res).await.is_err() {
writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
break;
}
}
}
.instrument(info_span!("conn.reader"))
});
let conn = Conn {
inner: Arc::new(ConnTasks {
local_addr: self.local_addr,
writer_channel: writer_sender,
writer_task: AbortOnDropHandle::new(writer_task),
reader_task: AbortOnDropHandle::new(reader_task),
}),
};
let conn_receiver = ConnReceiver {
reader_channel: reader_recv,
};
Ok((conn, conn_receiver))
}
}
#[derive(derive_more::Debug, Clone)]
pub enum ReceivedMessage {
ReceivedPacket {
source: PublicKey,
#[debug(skip)]
data: Bytes, },
PeerGone(PublicKey),
Ping([u8; 8]),
Pong([u8; 8]),
KeepAlive,
Health {
problem: Option<String>,
},
ServerRestarting {
reconnect_in: Duration,
try_for: Duration,
},
}
pub(crate) async fn send_packet<S: Sink<Frame, Error = std::io::Error> + Unpin>(
mut writer: S,
rate_limiter: &Option<RateLimiter>,
dst_key: PublicKey,
packet: Bytes,
) -> Result<()> {
ensure!(
packet.len() <= MAX_PACKET_SIZE,
"packet too big: {}",
packet.len()
);
let frame = Frame::SendPacket { dst_key, packet };
if let Some(rate_limiter) = rate_limiter {
if rate_limiter.check_n(frame.len()).is_err() {
tracing::warn!("dropping send: rate limit reached");
return Ok(());
}
}
writer.send(frame).await?;
writer.flush().await?;
Ok(())
}
pub(crate) struct RateLimiter {
inner: governor::RateLimiter<
governor::state::direct::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
governor::middleware::NoOpMiddleware,
>,
}
impl RateLimiter {
pub(crate) fn new(bytes_per_second: usize, bytes_burst: usize) -> Result<Option<Self>> {
if bytes_per_second == 0 || bytes_burst == 0 {
return Ok(None);
}
let bytes_per_second = NonZeroU32::new(u32::try_from(bytes_per_second)?)
.context("bytes_per_second not non-zero")?;
let bytes_burst =
NonZeroU32::new(u32::try_from(bytes_burst)?).context("bytes_burst not non-zero")?;
Ok(Some(Self {
inner: governor::RateLimiter::direct(
governor::Quota::per_second(bytes_per_second).allow_burst(bytes_burst),
),
}))
}
pub(crate) fn check_n(&self, n: usize) -> Result<()> {
let n = NonZeroU32::new(u32::try_from(n)?).context("n not non-zero")?;
match self.inner.check_n(n) {
Ok(_) => Ok(()),
Err(_) => bail!("batch cannot go through"),
}
}
}