use std::{collections::HashMap, future::Future};
use bytes::{Buf, Bytes};
use futures::{
channel::{mpsc, oneshot},
future, FutureExt, SinkExt, StreamExt,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
use tokio_tungstenite::{tungstenite as ws, WebSocketStream};
use tokio_util::io::ReaderStream;
#[derive(Debug, Error)]
pub enum Error {
#[error("received invalid channel {0}")]
InvalidChannel(usize),
#[error("received initial frame with invalid size")]
InvalidInitialFrameSize,
#[error("invalid port mapping in initial frame, got {actual}, expected {expected}")]
InvalidPortMapping { actual: u16, expected: u16 },
#[error("failed to forward bytes from Pod: {0}")]
ForwardFromPod(#[source] futures::channel::mpsc::SendError),
#[error("failed to forward bytes to Pod: {0}")]
ForwardToPod(#[source] futures::channel::mpsc::SendError),
#[error("failed to write bytes from Pod: {0}")]
WriteBytesFromPod(#[source] std::io::Error),
#[error("failed to read bytes to send to Pod: {0}")]
ReadBytesToSend(#[source] std::io::Error),
#[error("received invalid error message from Pod: {0}")]
InvalidErrorMessage(#[source] std::string::FromUtf8Error),
#[error("failed to forward an error message {0:?}")]
ForwardErrorMessage(String),
#[error("failed to send a WebSocket message: {0}")]
SendWebSocketMessage(#[source] ws::Error),
#[error("failed to receive a WebSocket message: {0}")]
ReceiveWebSocketMessage(#[source] ws::Error),
#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),
#[error("failed to shutdown write to Pod channel: {0}")]
Shutdown(#[source] std::io::Error),
}
type ErrorReceiver = oneshot::Receiver<String>;
type ErrorSender = oneshot::Sender<String>;
enum Message {
FromPod(u8, Bytes),
ToPod(u8, Bytes),
FromPodClose,
ToPodClose(u8),
}
pub struct Portforwarder {
ports: HashMap<u16, DuplexStream>,
errors: HashMap<u16, ErrorReceiver>,
task: tokio::task::JoinHandle<Result<(), Error>>,
}
impl Portforwarder {
pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
let mut ports = HashMap::with_capacity(port_nums.len());
let mut error_rxs = HashMap::with_capacity(port_nums.len());
let mut error_txs = Vec::with_capacity(port_nums.len());
let mut task_ios = Vec::with_capacity(port_nums.len());
for port in port_nums.iter() {
let (a, b) = tokio::io::duplex(1024 * 1024);
ports.insert(*port, a);
task_ios.push(b);
let (tx, rx) = oneshot::channel();
error_rxs.insert(*port, rx);
error_txs.push(Some(tx));
}
let task = tokio::spawn(start_message_loop(
stream,
port_nums.to_vec(),
task_ios,
error_txs,
));
Portforwarder {
ports,
errors: error_rxs,
task,
}
}
#[inline]
pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
self.ports.remove(&port)
}
#[inline]
pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>>> {
self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
}
#[inline]
pub fn abort(&self) {
self.task.abort();
}
pub async fn join(self) -> Result<(), Error> {
let Self {
mut ports,
mut errors,
task,
} = self;
ports.clear();
errors.clear();
task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}
async fn start_message_loop<S>(
stream: WebSocketStream<S>,
ports: Vec<u16>,
duplexes: Vec<DuplexStream>,
error_senders: Vec<Option<ErrorSender>>,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
let mut writers = Vec::new();
let mut loops = Vec::with_capacity(ports.len() + 2);
let (sender, receiver) = mpsc::channel::<Message>(1);
for (i, (r, w)) in duplexes.into_iter().map(tokio::io::split).enumerate() {
writers.push(w);
let ch = 2 * (i as u8);
loops.push(to_pod_loop(ch, r, sender.clone()).boxed());
}
let (ws_sink, ws_stream) = stream.split();
loops.push(from_pod_loop(ws_stream, sender).boxed());
loops.push(forwarder_loop(&ports, receiver, ws_sink, writers, error_senders).boxed());
future::try_join_all(loops).await.map(|_| ())
}
async fn to_pod_loop(
ch: u8,
reader: tokio::io::ReadHalf<DuplexStream>,
mut sender: mpsc::Sender<Message>,
) -> Result<(), Error> {
let mut read_stream = ReaderStream::new(reader);
while let Some(bytes) = read_stream
.next()
.await
.transpose()
.map_err(Error::ReadBytesToSend)?
{
if !bytes.is_empty() {
sender
.send(Message::ToPod(ch, bytes))
.await
.map_err(Error::ForwardToPod)?;
}
}
sender
.send(Message::ToPodClose(ch))
.await
.map_err(Error::ForwardToPod)?;
Ok(())
}
async fn from_pod_loop<S>(
mut ws_stream: futures::stream::SplitStream<WebSocketStream<S>>,
mut sender: mpsc::Sender<Message>,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
while let Some(msg) = ws_stream
.next()
.await
.transpose()
.map_err(Error::ReceiveWebSocketMessage)?
{
match msg {
ws::Message::Binary(bin) if bin.len() > 1 => {
let mut bytes = Bytes::from(bin);
let ch = bytes.split_to(1)[0];
sender
.send(Message::FromPod(ch, bytes))
.await
.map_err(Error::ForwardFromPod)?;
}
message if message.is_close() => {
sender
.send(Message::FromPodClose)
.await
.map_err(Error::ForwardFromPod)?;
break;
}
_ => {}
}
}
Ok(())
}
async fn forwarder_loop<S>(
ports: &[u16],
mut receiver: mpsc::Receiver<Message>,
mut ws_sink: futures::stream::SplitSink<WebSocketStream<S>, ws::Message>,
mut writers: Vec<tokio::io::WriteHalf<DuplexStream>>,
mut error_senders: Vec<Option<ErrorSender>>,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
#[derive(Default, Clone)]
struct ChannelState {
initialized: bool,
shutdown: bool,
}
let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
let mut closed_ports = 0;
let mut socket_shutdown = false;
while let Some(msg) = receiver.next().await {
match msg {
Message::FromPod(ch, mut bytes) => {
let ch = ch as usize;
let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
let port_index = ch / 2;
if !channel.initialized {
if bytes.len() != 2 {
return Err(Error::InvalidInitialFrameSize);
}
let port = bytes.get_u16_le();
if port != ports[port_index] {
return Err(Error::InvalidPortMapping {
actual: port,
expected: ports[port_index],
});
}
channel.initialized = true;
continue;
}
if ch % 2 != 0 {
if let Some(sender) = error_senders[port_index].take() {
let s = String::from_utf8(bytes.into_iter().collect())
.map_err(Error::InvalidErrorMessage)?;
sender.send(s).map_err(Error::ForwardErrorMessage)?;
}
} else if !channel.shutdown {
writers[port_index]
.write_all(&bytes)
.await
.map_err(Error::WriteBytesFromPod)?;
}
}
Message::ToPod(ch, bytes) => {
let mut bin = Vec::with_capacity(bytes.len() + 1);
bin.push(ch);
bin.extend(bytes);
ws_sink
.send(ws::Message::binary(bin))
.await
.map_err(Error::SendWebSocketMessage)?;
}
Message::ToPodClose(ch) => {
let ch = ch as usize;
let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
let port_index = ch / 2;
if !channel.shutdown {
writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
channel.shutdown = true;
closed_ports += 1;
}
}
Message::FromPodClose => {
for writer in &mut writers {
writer.shutdown().await.map_err(Error::Shutdown)?;
}
}
}
if closed_ports == ports.len() && !socket_shutdown {
ws_sink
.send(ws::Message::Close(None))
.await
.map_err(Error::SendWebSocketMessage)?;
socket_shutdown = true;
}
}
Ok(())
}