iroh_net/relay/server/
streams.rsuse std::{
pin::Pin,
task::{Context, Poll},
};
use anyhow::Result;
use futures_lite::Stream;
use futures_sink::Sink;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::WebSocketStream;
use tokio_util::codec::Framed;
use crate::relay::codec::{DerpCodec, Frame};
#[derive(Debug)]
pub(crate) enum RelayIo {
Derp(Framed<MaybeTlsStream, DerpCodec>),
Ws(WebSocketStream<MaybeTlsStream>),
}
fn tung_to_io_err(e: tungstenite::Error) -> std::io::Error {
match e {
tungstenite::Error::Io(io_err) => io_err,
_ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
}
}
impl Sink<Frame> for RelayIo {
type Error = std::io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut framed) => Pin::new(framed).poll_ready(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_to_io_err),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
match *self {
Self::Derp(ref mut framed) => Pin::new(framed).start_send(item),
Self::Ws(ref mut ws) => Pin::new(ws)
.start_send(tungstenite::Message::Binary(item.encode_for_ws_msg()))
.map_err(tung_to_io_err),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut framed) => Pin::new(framed).poll_flush(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_to_io_err),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
Self::Derp(ref mut framed) => Pin::new(framed).poll_close(cx),
Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_to_io_err),
}
}
}
impl Stream for RelayIo {
type Item = anyhow::Result<Frame>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::Derp(ref mut framed) => Pin::new(framed).poll_next(cx),
Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
Poll::Ready(Some(Ok(tungstenite::Message::Binary(vec)))) => {
Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
}
Poll::Ready(Some(Ok(msg))) => {
tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
Poll::Pending
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
}
}
}
#[derive(Debug)]
pub enum MaybeTlsStream {
Plain(tokio::net::TcpStream),
Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
#[cfg(test)]
Test(tokio::io::DuplexStream),
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match &mut *self {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(test)]
MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
match &mut *self {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(test)]
MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
match &mut *self {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_shutdown(cx),
#[cfg(test)]
MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_shutdown(cx),
}
}
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match &mut *self {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(test)]
MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match &mut *self {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
#[cfg(test)]
MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs),
}
}
}