1use std::future::Future;
2
3use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;
4
5use futures::{
6 channel::{mpsc, oneshot},
7 FutureExt, SinkExt, StreamExt,
8};
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11use tokio::{
12 io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream},
13 select,
14};
15use tokio_tungstenite::tungstenite as ws;
16
17use crate::client::Connection;
18
19use super::AttachParams;
20
21type StatusReceiver = oneshot::Receiver<Status>;
22type StatusSender = oneshot::Sender<Status>;
23
24type TerminalSizeReceiver = mpsc::Receiver<TerminalSize>;
25type TerminalSizeSender = mpsc::Sender<TerminalSize>;
26
27#[derive(Debug, Serialize, Deserialize)]
29#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
30#[serde(rename_all = "PascalCase")]
31pub struct TerminalSize {
32 pub width: u16,
34 pub height: u16,
36}
37
38#[derive(Debug, Error)]
40pub enum Error {
41 #[error("failed to read from stdin: {0}")]
43 ReadStdin(#[source] std::io::Error),
44
45 #[error("failed to send a stdin data: {0}")]
47 SendStdin(#[source] ws::Error),
48
49 #[error("failed to write to stdout: {0}")]
51 WriteStdout(#[source] std::io::Error),
52
53 #[error("failed to write to stderr: {0}")]
55 WriteStderr(#[source] std::io::Error),
56
57 #[error("failed to receive a WebSocket message: {0}")]
59 ReceiveWebSocketMessage(#[source] ws::Error),
60
61 #[error("failed to complete the background task: {0}")]
63 Spawn(#[source] tokio::task::JoinError),
64
65 #[error("failed to send a WebSocket close message: {0}")]
67 SendClose(#[source] ws::Error),
68
69 #[error("failed to deserialize status object: {0}")]
71 DeserializeStatus(#[source] serde_json::Error),
72
73 #[error("failed to send status object")]
75 SendStatus,
76
77 #[error("failed to serialize TerminalSize object: {0}")]
79 SerializeTerminalSize(#[source] serde_json::Error),
80
81 #[error("failed to send terminal size message")]
83 SendTerminalSize(#[source] ws::Error),
84
85 #[error("failed to set terminal size, tty need to be true to resize the terminal")]
87 TtyNeedToBeTrue,
88}
89
90const MAX_BUF_SIZE: usize = 1024;
91
92#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
101pub struct AttachedProcess {
102 has_stdin: bool,
103 has_stdout: bool,
104 has_stderr: bool,
105 stdin_writer: Option<DuplexStream>,
106 stdout_reader: Option<DuplexStream>,
107 stderr_reader: Option<DuplexStream>,
108 status_rx: Option<StatusReceiver>,
109 terminal_resize_tx: Option<TerminalSizeSender>,
110 task: tokio::task::JoinHandle<Result<(), Error>>,
111}
112
113impl AttachedProcess {
114 pub(crate) fn new(connection: Connection, ap: &AttachParams) -> Self {
115 let (stdin_writer, stdin_reader) = tokio::io::duplex(ap.max_stdin_buf_size.unwrap_or(MAX_BUF_SIZE));
118 let (stdout_writer, stdout_reader) = if ap.stdout {
119 let (w, r) = tokio::io::duplex(ap.max_stdout_buf_size.unwrap_or(MAX_BUF_SIZE));
120 (Some(w), Some(r))
121 } else {
122 (None, None)
123 };
124 let (stderr_writer, stderr_reader) = if ap.stderr {
125 let (w, r) = tokio::io::duplex(ap.max_stderr_buf_size.unwrap_or(MAX_BUF_SIZE));
126 (Some(w), Some(r))
127 } else {
128 (None, None)
129 };
130 let (status_tx, status_rx) = oneshot::channel();
131 let (terminal_resize_tx, terminal_resize_rx) = if ap.tty {
132 let (w, r) = mpsc::channel(10);
133 (Some(w), Some(r))
134 } else {
135 (None, None)
136 };
137
138 let task = tokio::spawn(start_message_loop(
139 connection,
140 stdin_reader,
141 stdout_writer,
142 stderr_writer,
143 status_tx,
144 terminal_resize_rx,
145 ));
146
147 AttachedProcess {
148 has_stdin: ap.stdin,
149 has_stdout: ap.stdout,
150 has_stderr: ap.stderr,
151 task,
152 stdin_writer: Some(stdin_writer),
153 stdout_reader,
154 stderr_reader,
155 terminal_resize_tx,
156 status_rx: Some(status_rx),
157 }
158 }
159
160 pub fn stdin(&mut self) -> Option<impl AsyncWrite + Unpin> {
173 if !self.has_stdin {
174 return None;
175 }
176 self.stdin_writer.take()
177 }
178
179 pub fn stdout(&mut self) -> Option<impl AsyncRead + Unpin> {
193 if !self.has_stdout {
194 return None;
195 }
196 self.stdout_reader.take()
197 }
198
199 pub fn stderr(&mut self) -> Option<impl AsyncRead + Unpin> {
213 if !self.has_stderr {
214 return None;
215 }
216 self.stderr_reader.take()
217 }
218
219 #[inline]
221 pub fn abort(&self) {
222 self.task.abort();
223 }
224
225 pub async fn join(self) -> Result<(), Error> {
227 self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
228 }
229
230 pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>>> {
234 self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
235 }
236
237 pub fn terminal_size(&mut self) -> Option<TerminalSizeSender> {
254 self.terminal_resize_tx.take()
255 }
256}
257
258const STDIN_CHANNEL: u8 = 0;
260const STDOUT_CHANNEL: u8 = 1;
261const STDERR_CHANNEL: u8 = 2;
262const STATUS_CHANNEL: u8 = 3;
264const RESIZE_CHANNEL: u8 = 4;
266const CLOSE_CHANNEL: u8 = 255;
268
269async fn start_message_loop(
270 connection: Connection,
271 stdin: impl AsyncRead + Unpin,
272 mut stdout: Option<impl AsyncWrite + Unpin>,
273 mut stderr: Option<impl AsyncWrite + Unpin>,
274 status_tx: StatusSender,
275 mut terminal_size_rx: Option<TerminalSizeReceiver>,
276) -> Result<(), Error> {
277 let supports_stream_close = connection.supports_stream_close();
278 let stream = connection.into_stream();
279 let mut stdin_stream = tokio_util::io::ReaderStream::new(stdin);
280 let (mut server_send, raw_server_recv) = stream.split();
281 let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
283 let mut have_terminal_size_rx = terminal_size_rx.is_some();
284
285 let mut stdin_is_open = true;
287
288 loop {
289 let terminal_size_next = async {
290 match terminal_size_rx.as_mut() {
291 Some(tmp) => Some(tmp.next().await),
292 None => None,
293 }
294 };
295 select! {
296 server_message = server_recv.next() => {
297 match server_message {
298 Some(Ok(Message::Stdout(bin))) => {
299 if let Some(stdout) = stdout.as_mut() {
300 stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
301 }
302 },
303 Some(Ok(Message::Stderr(bin))) => {
304 if let Some(stderr) = stderr.as_mut() {
305 stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
306 }
307 },
308 Some(Ok(Message::Status(bin))) => {
309 let status = serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
310 status_tx.send(status).map_err(|_| Error::SendStatus)?;
311 break
312 },
313 Some(Err(err)) => {
314 return Err(Error::ReceiveWebSocketMessage(err));
315 },
316 None => {
317 break
319 },
320 }
321 },
322 stdin_message = stdin_stream.next(), if stdin_is_open => {
323 match stdin_message {
324 Some(Ok(bytes)) => {
325 if !bytes.is_empty() {
326 let mut vec = Vec::with_capacity(bytes.len() + 1);
327 vec.push(STDIN_CHANNEL);
328 vec.extend_from_slice(&bytes[..]);
329 server_send
330 .send(ws::Message::binary(vec))
331 .await
332 .map_err(Error::SendStdin)?;
333 }
334 },
335 Some(Err(err)) => {
336 return Err(Error::ReadStdin(err));
337 }
338 None => {
339 if supports_stream_close {
342 let vec = vec![CLOSE_CHANNEL, STDIN_CHANNEL];
345 server_send
346 .send(ws::Message::binary(vec))
347 .await
348 .map_err(Error::SendStdin)?;
349 } else {
350 server_send.close().await.map_err(Error::SendClose)?;
354 }
355
356 stdin_is_open = false;
358 }
359 }
360 },
361 Some(terminal_size_message) = terminal_size_next, if have_terminal_size_rx => {
362 match terminal_size_message {
363 Some(new_size) => {
364 let new_size = serde_json::to_vec(&new_size).map_err(Error::SerializeTerminalSize)?;
365 let mut vec = Vec::with_capacity(new_size.len() + 1);
366 vec.push(RESIZE_CHANNEL);
367 vec.extend_from_slice(&new_size[..]);
368 server_send.send(ws::Message::Binary(vec.into())).await.map_err(Error::SendTerminalSize)?;
369 },
370 None => {
371 have_terminal_size_rx = false;
372 }
373 }
374 },
375 }
376 }
377
378 Ok(())
379}
380
381enum Message {
383 Stdout(Vec<u8>),
385 Stderr(Vec<u8>),
387 Status(Vec<u8>),
389}
390
391async fn filter_message(wsm: Result<ws::Message, ws::Error>) -> Option<Result<Message, ws::Error>> {
393 match wsm {
394 Ok(ws::Message::Binary(bin)) if bin.len() > 1 => match bin[0] {
397 STDOUT_CHANNEL => Some(Ok(Message::Stdout(bin.into()))),
398 STDERR_CHANNEL => Some(Ok(Message::Stderr(bin.into()))),
399 STATUS_CHANNEL => Some(Ok(Message::Status(bin.into()))),
400 _ => None,
402 },
403 Ok(_) => None,
407 Err(err) => Some(Err(err)),
410 }
411}