1use std::future::Future;
2
3use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;
4
5use futures::{
6 channel::{mpsc, oneshot},
7 FutureExt, SinkExt, StreamExt,
8};
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11use tokio::{
12 io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream},
13 select,
14};
15use tokio_tungstenite::{
16 tungstenite::{self as ws},
17 WebSocketStream,
18};
19
20use super::AttachParams;
21
22type StatusReceiver = oneshot::Receiver<Status>;
23type StatusSender = oneshot::Sender<Status>;
24
25type TerminalSizeReceiver = mpsc::Receiver<TerminalSize>;
26type TerminalSizeSender = mpsc::Sender<TerminalSize>;
27
28#[derive(Debug, Serialize, Deserialize)]
30#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
31#[serde(rename_all = "PascalCase")]
32pub struct TerminalSize {
33 pub width: u16,
35 pub height: u16,
37}
38
39#[derive(Debug, Error)]
41pub enum Error {
42 #[error("failed to read from stdin: {0}")]
44 ReadStdin(#[source] std::io::Error),
45
46 #[error("failed to send a stdin data: {0}")]
48 SendStdin(#[source] ws::Error),
49
50 #[error("failed to write to stdout: {0}")]
52 WriteStdout(#[source] std::io::Error),
53
54 #[error("failed to write to stderr: {0}")]
56 WriteStderr(#[source] std::io::Error),
57
58 #[error("failed to receive a WebSocket message: {0}")]
60 ReceiveWebSocketMessage(#[source] ws::Error),
61
62 #[error("failed to complete the background task: {0}")]
64 Spawn(#[source] tokio::task::JoinError),
65
66 #[error("failed to send a WebSocket close message: {0}")]
68 SendClose(#[source] ws::Error),
69
70 #[error("failed to deserialize status object: {0}")]
72 DeserializeStatus(#[source] serde_json::Error),
73
74 #[error("failed to send status object")]
76 SendStatus,
77
78 #[error("failed to serialize TerminalSize object: {0}")]
80 SerializeTerminalSize(#[source] serde_json::Error),
81
82 #[error("failed to send terminal size message")]
84 SendTerminalSize(#[source] ws::Error),
85
86 #[error("failed to set terminal size, tty need to be true to resize the terminal")]
88 TtyNeedToBeTrue,
89}
90
91const MAX_BUF_SIZE: usize = 1024;
92
93#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
102pub struct AttachedProcess {
103 has_stdin: bool,
104 has_stdout: bool,
105 has_stderr: bool,
106 stdin_writer: Option<DuplexStream>,
107 stdout_reader: Option<DuplexStream>,
108 stderr_reader: Option<DuplexStream>,
109 status_rx: Option<StatusReceiver>,
110 terminal_resize_tx: Option<TerminalSizeSender>,
111 task: tokio::task::JoinHandle<Result<(), Error>>,
112}
113
114impl AttachedProcess {
115 pub(crate) fn new<S>(stream: WebSocketStream<S>, ap: &AttachParams) -> Self
116 where
117 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
118 {
119 let (stdin_writer, stdin_reader) = tokio::io::duplex(ap.max_stdin_buf_size.unwrap_or(MAX_BUF_SIZE));
122 let (stdout_writer, stdout_reader) = if ap.stdout {
123 let (w, r) = tokio::io::duplex(ap.max_stdout_buf_size.unwrap_or(MAX_BUF_SIZE));
124 (Some(w), Some(r))
125 } else {
126 (None, None)
127 };
128 let (stderr_writer, stderr_reader) = if ap.stderr {
129 let (w, r) = tokio::io::duplex(ap.max_stderr_buf_size.unwrap_or(MAX_BUF_SIZE));
130 (Some(w), Some(r))
131 } else {
132 (None, None)
133 };
134 let (status_tx, status_rx) = oneshot::channel();
135 let (terminal_resize_tx, terminal_resize_rx) = if ap.tty {
136 let (w, r) = mpsc::channel(10);
137 (Some(w), Some(r))
138 } else {
139 (None, None)
140 };
141
142 let task = tokio::spawn(start_message_loop(
143 stream,
144 stdin_reader,
145 stdout_writer,
146 stderr_writer,
147 status_tx,
148 terminal_resize_rx,
149 ));
150
151 AttachedProcess {
152 has_stdin: ap.stdin,
153 has_stdout: ap.stdout,
154 has_stderr: ap.stderr,
155 task,
156 stdin_writer: Some(stdin_writer),
157 stdout_reader,
158 stderr_reader,
159 terminal_resize_tx,
160 status_rx: Some(status_rx),
161 }
162 }
163
164 pub fn stdin(&mut self) -> Option<impl AsyncWrite + Unpin> {
177 if !self.has_stdin {
178 return None;
179 }
180 self.stdin_writer.take()
181 }
182
183 pub fn stdout(&mut self) -> Option<impl AsyncRead + Unpin> {
197 if !self.has_stdout {
198 return None;
199 }
200 self.stdout_reader.take()
201 }
202
203 pub fn stderr(&mut self) -> Option<impl AsyncRead + Unpin> {
217 if !self.has_stderr {
218 return None;
219 }
220 self.stderr_reader.take()
221 }
222
223 #[inline]
225 pub fn abort(&self) {
226 self.task.abort();
227 }
228
229 pub async fn join(self) -> Result<(), Error> {
231 self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
232 }
233
234 pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>>> {
238 self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
239 }
240
241 pub fn terminal_size(&mut self) -> Option<TerminalSizeSender> {
258 self.terminal_resize_tx.take()
259 }
260}
261
262const STDIN_CHANNEL: u8 = 0;
264const STDOUT_CHANNEL: u8 = 1;
265const STDERR_CHANNEL: u8 = 2;
266const STATUS_CHANNEL: u8 = 3;
268const RESIZE_CHANNEL: u8 = 4;
270
271async fn start_message_loop<S>(
272 stream: WebSocketStream<S>,
273 stdin: impl AsyncRead + Unpin,
274 mut stdout: Option<impl AsyncWrite + Unpin>,
275 mut stderr: Option<impl AsyncWrite + Unpin>,
276 status_tx: StatusSender,
277 mut terminal_size_rx: Option<TerminalSizeReceiver>,
278) -> Result<(), Error>
279where
280 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
281{
282 let mut stdin_stream = tokio_util::io::ReaderStream::new(stdin);
283 let (mut server_send, raw_server_recv) = stream.split();
284 let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
286 let mut have_terminal_size_rx = terminal_size_rx.is_some();
287
288 loop {
289 let terminal_size_next = async {
290 match terminal_size_rx.as_mut() {
291 Some(tmp) => Some(tmp.next().await),
292 None => None,
293 }
294 };
295 select! {
296 server_message = server_recv.next() => {
297 match server_message {
298 Some(Ok(Message::Stdout(bin))) => {
299 if let Some(stdout) = stdout.as_mut() {
300 stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
301 }
302 },
303 Some(Ok(Message::Stderr(bin))) => {
304 if let Some(stderr) = stderr.as_mut() {
305 stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
306 }
307 },
308 Some(Ok(Message::Status(bin))) => {
309 let status = serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
310 status_tx.send(status).map_err(|_| Error::SendStatus)?;
311 break
312 },
313 Some(Err(err)) => {
314 return Err(Error::ReceiveWebSocketMessage(err));
315 },
316 None => {
317 break
319 },
320 }
321 },
322 stdin_message = stdin_stream.next() => {
323 match stdin_message {
324 Some(Ok(bytes)) => {
325 if !bytes.is_empty() {
326 let mut vec = Vec::with_capacity(bytes.len() + 1);
327 vec.push(STDIN_CHANNEL);
328 vec.extend_from_slice(&bytes[..]);
329 server_send
330 .send(ws::Message::binary(vec))
331 .await
332 .map_err(Error::SendStdin)?;
333 }
334 },
335 Some(Err(err)) => {
336 return Err(Error::ReadStdin(err));
337 }
338 None => {
339 server_send.close().await.map_err(Error::SendClose)?;
342 break;
343 }
344 }
345 },
346 Some(terminal_size_message) = terminal_size_next, if have_terminal_size_rx => {
347 match terminal_size_message {
348 Some(new_size) => {
349 let new_size = serde_json::to_vec(&new_size).map_err(Error::SerializeTerminalSize)?;
350 let mut vec = Vec::with_capacity(new_size.len() + 1);
351 vec.push(RESIZE_CHANNEL);
352 vec.extend_from_slice(&new_size[..]);
353 server_send.send(ws::Message::Binary(vec.into())).await.map_err(Error::SendTerminalSize)?;
354 },
355 None => {
356 have_terminal_size_rx = false;
357 }
358 }
359 },
360 }
361 }
362
363 Ok(())
364}
365
366enum Message {
368 Stdout(Vec<u8>),
370 Stderr(Vec<u8>),
372 Status(Vec<u8>),
374}
375
376async fn filter_message(wsm: Result<ws::Message, ws::Error>) -> Option<Result<Message, ws::Error>> {
378 match wsm {
379 Ok(ws::Message::Binary(bin)) if bin.len() > 1 => match bin[0] {
382 STDOUT_CHANNEL => Some(Ok(Message::Stdout(bin.into()))),
383 STDERR_CHANNEL => Some(Ok(Message::Stderr(bin.into()))),
384 STATUS_CHANNEL => Some(Ok(Message::Status(bin.into()))),
385 _ => None,
387 },
388 Ok(_) => None,
392 Err(err) => Some(Err(err)),
395 }
396}