wasmtime_wasi/stdio/
worker_thread_stdin.rs

1//! Handling for standard in using a worker task.
2//!
3//! Standard input is a global singleton resource for the entire program which
4//! needs special care. Currently this implementation adheres to a few
5//! constraints which make this nontrivial to implement.
6//!
7//! * Any number of guest wasm programs can read stdin. While this doesn't make
8//!   a ton of sense semantically they shouldn't block forever. Instead it's a
9//!   race to see who actually reads which parts of stdin.
10//!
11//! * Data from stdin isn't actually read unless requested. This is done to try
12//!   to be a good neighbor to others running in the process. Under the
13//!   assumption that most programs have one "thing" which reads stdin the
14//!   actual consumption of bytes is delayed until the wasm guest is dynamically
15//!   chosen to be that "thing". Before that data from stdin is not consumed to
16//!   avoid taking it from other components in the process.
17//!
18//! * Tokio's documentation indicates that "interactive stdin" is best done with
19//!   a helper thread to avoid blocking shutdown of the event loop. That's
20//!   respected here where all stdin reading happens on a blocking helper thread
21//!   that, at this time, is never shut down.
22//!
23//! This module is one that's likely to change over time though as new systems
24//! are encountered along with preexisting bugs.
25
26use crate::stdio::StdinStream;
27use bytes::{Bytes, BytesMut};
28use std::io::{IsTerminal, Read};
29use std::mem;
30use std::sync::{Condvar, Mutex, OnceLock};
31use tokio::sync::Notify;
32use wasmtime_wasi_io::{
33    poll::Pollable,
34    streams::{InputStream, StreamError},
35};
36
37#[derive(Default)]
38struct GlobalStdin {
39    state: Mutex<StdinState>,
40    read_requested: Condvar,
41    read_completed: Notify,
42}
43
44#[derive(Default, Debug)]
45enum StdinState {
46    #[default]
47    ReadNotRequested,
48    ReadRequested,
49    Data(BytesMut),
50    Error(std::io::Error),
51    Closed,
52}
53
54impl GlobalStdin {
55    fn get() -> &'static GlobalStdin {
56        static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
57        STDIN.get_or_init(|| create())
58    }
59}
60
61fn create() -> GlobalStdin {
62    std::thread::spawn(|| {
63        let state = GlobalStdin::get();
64        loop {
65            // Wait for a read to be requested, but don't hold the lock across
66            // the blocking read.
67            let mut lock = state.state.lock().unwrap();
68            lock = state
69                .read_requested
70                .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
71                .unwrap();
72            drop(lock);
73
74            let mut bytes = BytesMut::zeroed(1024);
75            let (new_state, done) = match std::io::stdin().read(&mut bytes) {
76                Ok(0) => (StdinState::Closed, true),
77                Ok(nbytes) => {
78                    bytes.truncate(nbytes);
79                    (StdinState::Data(bytes), false)
80                }
81                Err(e) => (StdinState::Error(e), true),
82            };
83
84            // After the blocking read completes the state should not have been
85            // tampered with.
86            debug_assert!(matches!(
87                *state.state.lock().unwrap(),
88                StdinState::ReadRequested
89            ));
90            *state.state.lock().unwrap() = new_state;
91            state.read_completed.notify_waiters();
92            if done {
93                break;
94            }
95        }
96    });
97
98    GlobalStdin::default()
99}
100
101/// Only public interface is the [`InputStream`] impl.
102#[derive(Clone)]
103pub struct Stdin;
104
105/// Returns a stream that represents the host's standard input.
106///
107/// Suitable for passing to
108/// [`WasiCtxBuilder::stdin`](crate::WasiCtxBuilder::stdin).
109pub fn stdin() -> Stdin {
110    Stdin
111}
112
113impl StdinStream for Stdin {
114    fn stream(&self) -> Box<dyn InputStream> {
115        Box::new(Stdin)
116    }
117
118    fn isatty(&self) -> bool {
119        std::io::stdin().is_terminal()
120    }
121}
122
123#[async_trait::async_trait]
124impl InputStream for Stdin {
125    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
126        let g = GlobalStdin::get();
127        let mut locked = g.state.lock().unwrap();
128        match mem::replace(&mut *locked, StdinState::ReadRequested) {
129            StdinState::ReadNotRequested => {
130                g.read_requested.notify_one();
131                Ok(Bytes::new())
132            }
133            StdinState::ReadRequested => Ok(Bytes::new()),
134            StdinState::Data(mut data) => {
135                let size = data.len().min(size);
136                let bytes = data.split_to(size);
137                *locked = if data.is_empty() {
138                    StdinState::ReadNotRequested
139                } else {
140                    StdinState::Data(data)
141                };
142                Ok(bytes.freeze())
143            }
144            StdinState::Error(e) => {
145                *locked = StdinState::Closed;
146                Err(StreamError::LastOperationFailed(e.into()))
147            }
148            StdinState::Closed => {
149                *locked = StdinState::Closed;
150                Err(StreamError::Closed)
151            }
152        }
153    }
154}
155
156#[async_trait::async_trait]
157impl Pollable for Stdin {
158    async fn ready(&mut self) {
159        let g = GlobalStdin::get();
160
161        // Scope the synchronous `state.lock()` to this block which does not
162        // `.await` inside of it.
163        let notified = {
164            let mut locked = g.state.lock().unwrap();
165            match *locked {
166                // If a read isn't requested yet
167                StdinState::ReadNotRequested => {
168                    g.read_requested.notify_one();
169                    *locked = StdinState::ReadRequested;
170                    g.read_completed.notified()
171                }
172                StdinState::ReadRequested => g.read_completed.notified(),
173                StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
174            }
175        };
176
177        notified.await;
178    }
179}