wasmtime_wasi/
write_stream.rs

1use crate::{OutputStream, Pollable, StreamError};
2use anyhow::anyhow;
3use bytes::Bytes;
4use std::sync::{Arc, Mutex};
5
6#[derive(Debug)]
7struct WorkerState {
8    alive: bool,
9    items: std::collections::VecDeque<Bytes>,
10    write_budget: usize,
11    flush_pending: bool,
12    error: Option<anyhow::Error>,
13}
14
15impl WorkerState {
16    fn check_error(&mut self) -> Result<(), StreamError> {
17        if let Some(e) = self.error.take() {
18            return Err(StreamError::LastOperationFailed(e));
19        }
20        if !self.alive {
21            return Err(StreamError::Closed);
22        }
23        Ok(())
24    }
25}
26
27struct Worker {
28    state: Mutex<WorkerState>,
29    new_work: tokio::sync::Notify,
30    write_ready_changed: tokio::sync::Notify,
31}
32
33enum Job {
34    Flush,
35    Write(Bytes),
36}
37
38impl Worker {
39    fn new(write_budget: usize) -> Self {
40        Self {
41            state: Mutex::new(WorkerState {
42                alive: true,
43                items: std::collections::VecDeque::new(),
44                write_budget,
45                flush_pending: false,
46                error: None,
47            }),
48            new_work: tokio::sync::Notify::new(),
49            write_ready_changed: tokio::sync::Notify::new(),
50        }
51    }
52    async fn ready(&self) {
53        loop {
54            {
55                let state = self.state();
56                if state.error.is_some()
57                    || !state.alive
58                    || (!state.flush_pending && state.write_budget > 0)
59                {
60                    return;
61                }
62            }
63            self.write_ready_changed.notified().await;
64        }
65    }
66    fn check_write(&self) -> Result<usize, StreamError> {
67        let mut state = self.state();
68        if let Err(e) = state.check_error() {
69            return Err(e);
70        }
71
72        if state.flush_pending || state.write_budget == 0 {
73            return Ok(0);
74        }
75
76        Ok(state.write_budget)
77    }
78    fn state(&self) -> std::sync::MutexGuard<WorkerState> {
79        self.state.lock().unwrap()
80    }
81    fn pop(&self) -> Option<Job> {
82        let mut state = self.state();
83        if state.items.is_empty() {
84            if state.flush_pending {
85                return Some(Job::Flush);
86            }
87        } else if let Some(bytes) = state.items.pop_front() {
88            return Some(Job::Write(bytes));
89        }
90
91        None
92    }
93    fn report_error(&self, e: std::io::Error) {
94        {
95            let mut state = self.state();
96            state.alive = false;
97            state.error = Some(e.into());
98            state.flush_pending = false;
99        }
100        self.write_ready_changed.notify_one();
101    }
102    async fn work<T: tokio::io::AsyncWrite + Send + Unpin + 'static>(&self, mut writer: T) {
103        use tokio::io::AsyncWriteExt;
104        loop {
105            while let Some(job) = self.pop() {
106                match job {
107                    Job::Flush => {
108                        if let Err(e) = writer.flush().await {
109                            self.report_error(e);
110                            return;
111                        }
112
113                        tracing::debug!("worker marking flush complete");
114                        self.state().flush_pending = false;
115                    }
116
117                    Job::Write(mut bytes) => {
118                        tracing::debug!("worker writing: {bytes:?}");
119                        let len = bytes.len();
120                        match writer.write_all_buf(&mut bytes).await {
121                            Err(e) => {
122                                self.report_error(e);
123                                return;
124                            }
125                            Ok(_) => {
126                                self.state().write_budget += len;
127                            }
128                        }
129                    }
130                }
131
132                self.write_ready_changed.notify_one();
133            }
134            self.new_work.notified().await;
135        }
136    }
137}
138
139/// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl
140pub struct AsyncWriteStream {
141    worker: Arc<Worker>,
142    join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
143}
144
145impl AsyncWriteStream {
146    /// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl
147    /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`].
148    pub fn new<T: tokio::io::AsyncWrite + Send + Unpin + 'static>(
149        write_budget: usize,
150        writer: T,
151    ) -> Self {
152        let worker = Arc::new(Worker::new(write_budget));
153
154        let w = Arc::clone(&worker);
155        let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
156
157        AsyncWriteStream {
158            worker,
159            join_handle: Some(join_handle),
160        }
161    }
162}
163
164#[async_trait::async_trait]
165impl OutputStream for AsyncWriteStream {
166    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
167        let mut state = self.worker.state();
168        state.check_error()?;
169        if state.flush_pending {
170            return Err(StreamError::Trap(anyhow!(
171                "write not permitted while flush pending"
172            )));
173        }
174        match state.write_budget.checked_sub(bytes.len()) {
175            Some(remaining_budget) => {
176                state.write_budget = remaining_budget;
177                state.items.push_back(bytes);
178            }
179            None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
180        }
181        drop(state);
182        self.worker.new_work.notify_one();
183        Ok(())
184    }
185    fn flush(&mut self) -> Result<(), StreamError> {
186        let mut state = self.worker.state();
187        state.check_error()?;
188
189        state.flush_pending = true;
190        self.worker.new_work.notify_one();
191
192        Ok(())
193    }
194
195    fn check_write(&mut self) -> Result<usize, StreamError> {
196        self.worker.check_write()
197    }
198
199    async fn cancel(&mut self) {
200        match self.join_handle.take() {
201            Some(task) => _ = task.cancel().await,
202            None => {}
203        }
204    }
205}
206#[async_trait::async_trait]
207impl Pollable for AsyncWriteStream {
208    async fn ready(&mut self) {
209        self.worker.ready().await;
210    }
211}