kube_client/api/
portforward.rs

1use std::{collections::HashMap, future::Future};
2
3use bytes::{Buf, Bytes};
4use futures::{
5    channel::{mpsc, oneshot},
6    future, FutureExt, SinkExt, StreamExt,
7};
8use thiserror::Error;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
10use tokio_tungstenite::{tungstenite as ws, WebSocketStream};
11use tokio_util::io::ReaderStream;
12
13/// Errors from Portforwarder.
14#[derive(Debug, Error)]
15pub enum Error {
16    /// Received invalid channel in WebSocket message.
17    #[error("received invalid channel {0}")]
18    InvalidChannel(usize),
19
20    /// Received initial frame with invalid size. The initial frame must be 3 bytes, including the channel prefix.
21    #[error("received initial frame with invalid size")]
22    InvalidInitialFrameSize,
23
24    /// Received initial frame with invalid port mapping.
25    /// The port included in the initial frame did not match the port number associated with the channel.
26    #[error("invalid port mapping in initial frame, got {actual}, expected {expected}")]
27    InvalidPortMapping { actual: u16, expected: u16 },
28
29    /// Failed to forward bytes from Pod.
30    #[error("failed to forward bytes from Pod: {0}")]
31    ForwardFromPod(#[source] futures::channel::mpsc::SendError),
32
33    /// Failed to forward bytes to Pod.
34    #[error("failed to forward bytes to Pod: {0}")]
35    ForwardToPod(#[source] futures::channel::mpsc::SendError),
36
37    /// Failed to write bytes from Pod.
38    #[error("failed to write bytes from Pod: {0}")]
39    WriteBytesFromPod(#[source] std::io::Error),
40
41    /// Failed to read bytes to send to Pod.
42    #[error("failed to read bytes to send to Pod: {0}")]
43    ReadBytesToSend(#[source] std::io::Error),
44
45    /// Received an error message from pod that is not a valid UTF-8.
46    #[error("received invalid error message from Pod: {0}")]
47    InvalidErrorMessage(#[source] std::string::FromUtf8Error),
48
49    /// Failed to forward an error message from pod.
50    #[error("failed to forward an error message {0:?}")]
51    ForwardErrorMessage(String),
52
53    /// Failed to send a WebSocket message to the server.
54    #[error("failed to send a WebSocket message: {0}")]
55    SendWebSocketMessage(#[source] ws::Error),
56
57    /// Failed to receive a WebSocket message from the server.
58    #[error("failed to receive a WebSocket message: {0}")]
59    ReceiveWebSocketMessage(#[source] ws::Error),
60
61    #[error("failed to complete the background task: {0}")]
62    Spawn(#[source] tokio::task::JoinError),
63
64    /// Failed to shutdown a pod writer channel.
65    #[error("failed to shutdown write to Pod channel: {0}")]
66    Shutdown(#[source] std::io::Error),
67}
68
69type ErrorReceiver = oneshot::Receiver<String>;
70type ErrorSender = oneshot::Sender<String>;
71
72// Internal message used by the futures to communicate with each other.
73enum Message {
74    FromPod(u8, Bytes),
75    ToPod(u8, Bytes),
76    FromPodClose,
77    ToPodClose(u8),
78}
79
80/// Manages port-forwarded streams.
81///
82/// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports.  Error
83/// channel for each port is only written by the server when there's an exception and
84/// the port cannot be used (didn't initialize or can't be used anymore).
85pub struct Portforwarder {
86    ports: HashMap<u16, DuplexStream>,
87    errors: HashMap<u16, ErrorReceiver>,
88    task: tokio::task::JoinHandle<Result<(), Error>>,
89}
90
91impl Portforwarder {
92    pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
93    where
94        S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
95    {
96        let mut ports = HashMap::with_capacity(port_nums.len());
97        let mut error_rxs = HashMap::with_capacity(port_nums.len());
98        let mut error_txs = Vec::with_capacity(port_nums.len());
99        let mut task_ios = Vec::with_capacity(port_nums.len());
100        for port in port_nums.iter() {
101            let (a, b) = tokio::io::duplex(1024 * 1024);
102            ports.insert(*port, a);
103            task_ios.push(b);
104
105            let (tx, rx) = oneshot::channel();
106            error_rxs.insert(*port, rx);
107            error_txs.push(Some(tx));
108        }
109        let task = tokio::spawn(start_message_loop(
110            stream,
111            port_nums.to_vec(),
112            task_ios,
113            error_txs,
114        ));
115
116        Portforwarder {
117            ports,
118            errors: error_rxs,
119            task,
120        }
121    }
122
123    /// Take a port stream by the port on the target resource.
124    ///
125    /// A value is returned at most once per port.
126    #[inline]
127    pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
128        self.ports.remove(&port)
129    }
130
131    /// Take a future that resolves with any error message or when the error sender is dropped.
132    /// When the future resolves, the port should be considered no longer usable.
133    ///
134    /// A value is returned at most once per port.
135    #[inline]
136    pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>>> {
137        self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
138    }
139
140    /// Abort the background task, causing port forwards to fail.
141    #[inline]
142    pub fn abort(&self) {
143        self.task.abort();
144    }
145
146    /// Waits for port forwarding task to complete.
147    pub async fn join(self) -> Result<(), Error> {
148        let Self {
149            mut ports,
150            mut errors,
151            task,
152        } = self;
153        // Start by terminating any streams that have not yet been taken
154        // since they would otherwise keep the connection open indefinitely
155        ports.clear();
156        errors.clear();
157        task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
158    }
159}
160
161async fn start_message_loop<S>(
162    stream: WebSocketStream<S>,
163    ports: Vec<u16>,
164    duplexes: Vec<DuplexStream>,
165    error_senders: Vec<Option<ErrorSender>>,
166) -> Result<(), Error>
167where
168    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
169{
170    let mut writers = Vec::new();
171    // Loops to run concurrently.
172    // We can spawn tasks to run `to_pod_loop` in parallel and flatten the errors, but the other 2 loops
173    // are over a single WebSocket connection and cannot process each port in parallel.
174    let mut loops = Vec::with_capacity(ports.len() + 2);
175    // Channel to communicate with the main loop
176    let (sender, receiver) = mpsc::channel::<Message>(1);
177    for (i, (r, w)) in duplexes.into_iter().map(tokio::io::split).enumerate() {
178        writers.push(w);
179        // Each port uses 2 channels. Duplex data channel and error.
180        let ch = 2 * (i as u8);
181        loops.push(to_pod_loop(ch, r, sender.clone()).boxed());
182    }
183
184    let (ws_sink, ws_stream) = stream.split();
185    loops.push(from_pod_loop(ws_stream, sender).boxed());
186    loops.push(forwarder_loop(&ports, receiver, ws_sink, writers, error_senders).boxed());
187
188    future::try_join_all(loops).await.map(|_| ())
189}
190
191async fn to_pod_loop(
192    ch: u8,
193    reader: tokio::io::ReadHalf<DuplexStream>,
194    mut sender: mpsc::Sender<Message>,
195) -> Result<(), Error> {
196    let mut read_stream = ReaderStream::new(reader);
197    while let Some(bytes) = read_stream
198        .next()
199        .await
200        .transpose()
201        .map_err(Error::ReadBytesToSend)?
202    {
203        if !bytes.is_empty() {
204            sender
205                .send(Message::ToPod(ch, bytes))
206                .await
207                .map_err(Error::ForwardToPod)?;
208        }
209    }
210    sender
211        .send(Message::ToPodClose(ch))
212        .await
213        .map_err(Error::ForwardToPod)?;
214    Ok(())
215}
216
217async fn from_pod_loop<S>(
218    mut ws_stream: futures::stream::SplitStream<WebSocketStream<S>>,
219    mut sender: mpsc::Sender<Message>,
220) -> Result<(), Error>
221where
222    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
223{
224    while let Some(msg) = ws_stream
225        .next()
226        .await
227        .transpose()
228        .map_err(Error::ReceiveWebSocketMessage)?
229    {
230        match msg {
231            ws::Message::Binary(mut bytes) if bytes.len() > 1 => {
232                let ch = bytes.split_to(1)[0];
233                sender
234                    .send(Message::FromPod(ch, bytes))
235                    .await
236                    .map_err(Error::ForwardFromPod)?;
237            }
238            message if message.is_close() => {
239                sender
240                    .send(Message::FromPodClose)
241                    .await
242                    .map_err(Error::ForwardFromPod)?;
243                break;
244            }
245            // REVIEW should we error on unexpected websocket message?
246            _ => {}
247        }
248    }
249    Ok(())
250}
251
252// Start a loop to handle messages received from other futures.
253// On `Message::ToPod(ch, bytes)`, a WebSocket message is sent with the channel prefix.
254// On `Message::FromPod(ch, bytes)` with an even `ch`, `bytes` are written to the port's sink.
255// On `Message::FromPod(ch, bytes)` with an odd `ch`, an error message is sent to the error channel of the port.
256async fn forwarder_loop<S>(
257    ports: &[u16],
258    mut receiver: mpsc::Receiver<Message>,
259    mut ws_sink: futures::stream::SplitSink<WebSocketStream<S>, ws::Message>,
260    mut writers: Vec<tokio::io::WriteHalf<DuplexStream>>,
261    mut error_senders: Vec<Option<ErrorSender>>,
262) -> Result<(), Error>
263where
264    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
265{
266    #[derive(Default, Clone)]
267    struct ChannelState {
268        // Keep track if the channel has received the initialization frame.
269        initialized: bool,
270        // Keep track if the channel has shutdown.
271        shutdown: bool,
272    }
273    let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
274    let mut closed_ports = 0;
275    let mut socket_shutdown = false;
276    while let Some(msg) = receiver.next().await {
277        match msg {
278            Message::FromPod(ch, mut bytes) => {
279                let ch = ch as usize;
280                let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
281
282                let port_index = ch / 2;
283                // Initialization
284                if !channel.initialized {
285                    // The initial message must be 3 bytes including the channel prefix.
286                    if bytes.len() != 2 {
287                        return Err(Error::InvalidInitialFrameSize);
288                    }
289
290                    let port = bytes.get_u16_le();
291                    if port != ports[port_index] {
292                        return Err(Error::InvalidPortMapping {
293                            actual: port,
294                            expected: ports[port_index],
295                        });
296                    }
297
298                    channel.initialized = true;
299                    continue;
300                }
301
302                // Odd channels are for errors for (n - 1)/2 th port
303                if ch % 2 != 0 {
304                    // A port sends at most one error message because it's considered unusable after this.
305                    if let Some(sender) = error_senders[port_index].take() {
306                        let s = String::from_utf8(bytes.into_iter().collect())
307                            .map_err(Error::InvalidErrorMessage)?;
308                        sender.send(s).map_err(Error::ForwardErrorMessage)?;
309                    }
310                } else if !channel.shutdown {
311                    writers[port_index]
312                        .write_all(&bytes)
313                        .await
314                        .map_err(Error::WriteBytesFromPod)?;
315                }
316            }
317
318            Message::ToPod(ch, bytes) => {
319                let mut bin = Vec::with_capacity(bytes.len() + 1);
320                bin.push(ch);
321                bin.extend(bytes);
322                ws_sink
323                    .send(ws::Message::binary(bin))
324                    .await
325                    .map_err(Error::SendWebSocketMessage)?;
326            }
327            Message::ToPodClose(ch) => {
328                let ch = ch as usize;
329                let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
330                let port_index = ch / 2;
331
332                if !channel.shutdown {
333                    writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
334                    channel.shutdown = true;
335
336                    closed_ports += 1;
337                }
338            }
339            Message::FromPodClose => {
340                for writer in &mut writers {
341                    writer.shutdown().await.map_err(Error::Shutdown)?;
342                }
343            }
344        }
345
346        if closed_ports == ports.len() && !socket_shutdown {
347            ws_sink
348                .send(ws::Message::Close(None))
349                .await
350                .map_err(Error::SendWebSocketMessage)?;
351            socket_shutdown = true;
352        }
353    }
354    Ok(())
355}