1use std::{collections::HashMap, future::Future};
2
3use bytes::{Buf, Bytes};
4use futures::{
5 channel::{mpsc, oneshot},
6 future, FutureExt, SinkExt, StreamExt,
7};
8use thiserror::Error;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
10use tokio_tungstenite::{tungstenite as ws, WebSocketStream};
11use tokio_util::io::ReaderStream;
12
13#[derive(Debug, Error)]
15pub enum Error {
16 #[error("received invalid channel {0}")]
18 InvalidChannel(usize),
19
20 #[error("received initial frame with invalid size")]
22 InvalidInitialFrameSize,
23
24 #[error("invalid port mapping in initial frame, got {actual}, expected {expected}")]
27 InvalidPortMapping { actual: u16, expected: u16 },
28
29 #[error("failed to forward bytes from Pod: {0}")]
31 ForwardFromPod(#[source] futures::channel::mpsc::SendError),
32
33 #[error("failed to forward bytes to Pod: {0}")]
35 ForwardToPod(#[source] futures::channel::mpsc::SendError),
36
37 #[error("failed to write bytes from Pod: {0}")]
39 WriteBytesFromPod(#[source] std::io::Error),
40
41 #[error("failed to read bytes to send to Pod: {0}")]
43 ReadBytesToSend(#[source] std::io::Error),
44
45 #[error("received invalid error message from Pod: {0}")]
47 InvalidErrorMessage(#[source] std::string::FromUtf8Error),
48
49 #[error("failed to forward an error message {0:?}")]
51 ForwardErrorMessage(String),
52
53 #[error("failed to send a WebSocket message: {0}")]
55 SendWebSocketMessage(#[source] ws::Error),
56
57 #[error("failed to receive a WebSocket message: {0}")]
59 ReceiveWebSocketMessage(#[source] ws::Error),
60
61 #[error("failed to complete the background task: {0}")]
62 Spawn(#[source] tokio::task::JoinError),
63
64 #[error("failed to shutdown write to Pod channel: {0}")]
66 Shutdown(#[source] std::io::Error),
67}
68
69type ErrorReceiver = oneshot::Receiver<String>;
70type ErrorSender = oneshot::Sender<String>;
71
72enum Message {
74 FromPod(u8, Bytes),
75 ToPod(u8, Bytes),
76 FromPodClose,
77 ToPodClose(u8),
78}
79
80pub struct Portforwarder {
86 ports: HashMap<u16, DuplexStream>,
87 errors: HashMap<u16, ErrorReceiver>,
88 task: tokio::task::JoinHandle<Result<(), Error>>,
89}
90
91impl Portforwarder {
92 pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
93 where
94 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
95 {
96 let mut ports = HashMap::with_capacity(port_nums.len());
97 let mut error_rxs = HashMap::with_capacity(port_nums.len());
98 let mut error_txs = Vec::with_capacity(port_nums.len());
99 let mut task_ios = Vec::with_capacity(port_nums.len());
100 for port in port_nums.iter() {
101 let (a, b) = tokio::io::duplex(1024 * 1024);
102 ports.insert(*port, a);
103 task_ios.push(b);
104
105 let (tx, rx) = oneshot::channel();
106 error_rxs.insert(*port, rx);
107 error_txs.push(Some(tx));
108 }
109 let task = tokio::spawn(start_message_loop(
110 stream,
111 port_nums.to_vec(),
112 task_ios,
113 error_txs,
114 ));
115
116 Portforwarder {
117 ports,
118 errors: error_rxs,
119 task,
120 }
121 }
122
123 #[inline]
127 pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
128 self.ports.remove(&port)
129 }
130
131 #[inline]
136 pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>>> {
137 self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
138 }
139
140 #[inline]
142 pub fn abort(&self) {
143 self.task.abort();
144 }
145
146 pub async fn join(self) -> Result<(), Error> {
148 let Self {
149 mut ports,
150 mut errors,
151 task,
152 } = self;
153 ports.clear();
156 errors.clear();
157 task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
158 }
159}
160
161async fn start_message_loop<S>(
162 stream: WebSocketStream<S>,
163 ports: Vec<u16>,
164 duplexes: Vec<DuplexStream>,
165 error_senders: Vec<Option<ErrorSender>>,
166) -> Result<(), Error>
167where
168 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
169{
170 let mut writers = Vec::new();
171 let mut loops = Vec::with_capacity(ports.len() + 2);
175 let (sender, receiver) = mpsc::channel::<Message>(1);
177 for (i, (r, w)) in duplexes.into_iter().map(tokio::io::split).enumerate() {
178 writers.push(w);
179 let ch = 2 * (i as u8);
181 loops.push(to_pod_loop(ch, r, sender.clone()).boxed());
182 }
183
184 let (ws_sink, ws_stream) = stream.split();
185 loops.push(from_pod_loop(ws_stream, sender).boxed());
186 loops.push(forwarder_loop(&ports, receiver, ws_sink, writers, error_senders).boxed());
187
188 future::try_join_all(loops).await.map(|_| ())
189}
190
191async fn to_pod_loop(
192 ch: u8,
193 reader: tokio::io::ReadHalf<DuplexStream>,
194 mut sender: mpsc::Sender<Message>,
195) -> Result<(), Error> {
196 let mut read_stream = ReaderStream::new(reader);
197 while let Some(bytes) = read_stream
198 .next()
199 .await
200 .transpose()
201 .map_err(Error::ReadBytesToSend)?
202 {
203 if !bytes.is_empty() {
204 sender
205 .send(Message::ToPod(ch, bytes))
206 .await
207 .map_err(Error::ForwardToPod)?;
208 }
209 }
210 sender
211 .send(Message::ToPodClose(ch))
212 .await
213 .map_err(Error::ForwardToPod)?;
214 Ok(())
215}
216
217async fn from_pod_loop<S>(
218 mut ws_stream: futures::stream::SplitStream<WebSocketStream<S>>,
219 mut sender: mpsc::Sender<Message>,
220) -> Result<(), Error>
221where
222 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
223{
224 while let Some(msg) = ws_stream
225 .next()
226 .await
227 .transpose()
228 .map_err(Error::ReceiveWebSocketMessage)?
229 {
230 match msg {
231 ws::Message::Binary(mut bytes) if bytes.len() > 1 => {
232 let ch = bytes.split_to(1)[0];
233 sender
234 .send(Message::FromPod(ch, bytes))
235 .await
236 .map_err(Error::ForwardFromPod)?;
237 }
238 message if message.is_close() => {
239 sender
240 .send(Message::FromPodClose)
241 .await
242 .map_err(Error::ForwardFromPod)?;
243 break;
244 }
245 _ => {}
247 }
248 }
249 Ok(())
250}
251
252async fn forwarder_loop<S>(
257 ports: &[u16],
258 mut receiver: mpsc::Receiver<Message>,
259 mut ws_sink: futures::stream::SplitSink<WebSocketStream<S>, ws::Message>,
260 mut writers: Vec<tokio::io::WriteHalf<DuplexStream>>,
261 mut error_senders: Vec<Option<ErrorSender>>,
262) -> Result<(), Error>
263where
264 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
265{
266 #[derive(Default, Clone)]
267 struct ChannelState {
268 initialized: bool,
270 shutdown: bool,
272 }
273 let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
274 let mut closed_ports = 0;
275 let mut socket_shutdown = false;
276 while let Some(msg) = receiver.next().await {
277 match msg {
278 Message::FromPod(ch, mut bytes) => {
279 let ch = ch as usize;
280 let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
281
282 let port_index = ch / 2;
283 if !channel.initialized {
285 if bytes.len() != 2 {
287 return Err(Error::InvalidInitialFrameSize);
288 }
289
290 let port = bytes.get_u16_le();
291 if port != ports[port_index] {
292 return Err(Error::InvalidPortMapping {
293 actual: port,
294 expected: ports[port_index],
295 });
296 }
297
298 channel.initialized = true;
299 continue;
300 }
301
302 if ch % 2 != 0 {
304 if let Some(sender) = error_senders[port_index].take() {
306 let s = String::from_utf8(bytes.into_iter().collect())
307 .map_err(Error::InvalidErrorMessage)?;
308 sender.send(s).map_err(Error::ForwardErrorMessage)?;
309 }
310 } else if !channel.shutdown {
311 writers[port_index]
312 .write_all(&bytes)
313 .await
314 .map_err(Error::WriteBytesFromPod)?;
315 }
316 }
317
318 Message::ToPod(ch, bytes) => {
319 let mut bin = Vec::with_capacity(bytes.len() + 1);
320 bin.push(ch);
321 bin.extend(bytes);
322 ws_sink
323 .send(ws::Message::binary(bin))
324 .await
325 .map_err(Error::SendWebSocketMessage)?;
326 }
327 Message::ToPodClose(ch) => {
328 let ch = ch as usize;
329 let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
330 let port_index = ch / 2;
331
332 if !channel.shutdown {
333 writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
334 channel.shutdown = true;
335
336 closed_ports += 1;
337 }
338 }
339 Message::FromPodClose => {
340 for writer in &mut writers {
341 writer.shutdown().await.map_err(Error::Shutdown)?;
342 }
343 }
344 }
345
346 if closed_ports == ports.len() && !socket_shutdown {
347 ws_sink
348 .send(ws::Message::Close(None))
349 .await
350 .map_err(Error::SendWebSocketMessage)?;
351 socket_shutdown = true;
352 }
353 }
354 Ok(())
355}