wasmtime_wasi/
write_stream.rs1use crate::{OutputStream, Pollable, StreamError};
2use anyhow::anyhow;
3use bytes::Bytes;
4use std::sync::{Arc, Mutex};
5
6#[derive(Debug)]
7struct WorkerState {
8 alive: bool,
9 items: std::collections::VecDeque<Bytes>,
10 write_budget: usize,
11 flush_pending: bool,
12 error: Option<anyhow::Error>,
13}
14
15impl WorkerState {
16 fn check_error(&mut self) -> Result<(), StreamError> {
17 if let Some(e) = self.error.take() {
18 return Err(StreamError::LastOperationFailed(e));
19 }
20 if !self.alive {
21 return Err(StreamError::Closed);
22 }
23 Ok(())
24 }
25}
26
27struct Worker {
28 state: Mutex<WorkerState>,
29 new_work: tokio::sync::Notify,
30 write_ready_changed: tokio::sync::Notify,
31}
32
33enum Job {
34 Flush,
35 Write(Bytes),
36}
37
38impl Worker {
39 fn new(write_budget: usize) -> Self {
40 Self {
41 state: Mutex::new(WorkerState {
42 alive: true,
43 items: std::collections::VecDeque::new(),
44 write_budget,
45 flush_pending: false,
46 error: None,
47 }),
48 new_work: tokio::sync::Notify::new(),
49 write_ready_changed: tokio::sync::Notify::new(),
50 }
51 }
52 async fn ready(&self) {
53 loop {
54 {
55 let state = self.state();
56 if state.error.is_some()
57 || !state.alive
58 || (!state.flush_pending && state.write_budget > 0)
59 {
60 return;
61 }
62 }
63 self.write_ready_changed.notified().await;
64 }
65 }
66 fn check_write(&self) -> Result<usize, StreamError> {
67 let mut state = self.state();
68 if let Err(e) = state.check_error() {
69 return Err(e);
70 }
71
72 if state.flush_pending || state.write_budget == 0 {
73 return Ok(0);
74 }
75
76 Ok(state.write_budget)
77 }
78 fn state(&self) -> std::sync::MutexGuard<WorkerState> {
79 self.state.lock().unwrap()
80 }
81 fn pop(&self) -> Option<Job> {
82 let mut state = self.state();
83 if state.items.is_empty() {
84 if state.flush_pending {
85 return Some(Job::Flush);
86 }
87 } else if let Some(bytes) = state.items.pop_front() {
88 return Some(Job::Write(bytes));
89 }
90
91 None
92 }
93 fn report_error(&self, e: std::io::Error) {
94 {
95 let mut state = self.state();
96 state.alive = false;
97 state.error = Some(e.into());
98 state.flush_pending = false;
99 }
100 self.write_ready_changed.notify_one();
101 }
102 async fn work<T: tokio::io::AsyncWrite + Send + Unpin + 'static>(&self, mut writer: T) {
103 use tokio::io::AsyncWriteExt;
104 loop {
105 while let Some(job) = self.pop() {
106 match job {
107 Job::Flush => {
108 if let Err(e) = writer.flush().await {
109 self.report_error(e);
110 return;
111 }
112
113 tracing::debug!("worker marking flush complete");
114 self.state().flush_pending = false;
115 }
116
117 Job::Write(mut bytes) => {
118 tracing::debug!("worker writing: {bytes:?}");
119 let len = bytes.len();
120 match writer.write_all_buf(&mut bytes).await {
121 Err(e) => {
122 self.report_error(e);
123 return;
124 }
125 Ok(_) => {
126 self.state().write_budget += len;
127 }
128 }
129 }
130 }
131
132 self.write_ready_changed.notify_one();
133 }
134 self.new_work.notified().await;
135 }
136 }
137}
138
139pub struct AsyncWriteStream {
141 worker: Arc<Worker>,
142 join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
143}
144
145impl AsyncWriteStream {
146 pub fn new<T: tokio::io::AsyncWrite + Send + Unpin + 'static>(
149 write_budget: usize,
150 writer: T,
151 ) -> Self {
152 let worker = Arc::new(Worker::new(write_budget));
153
154 let w = Arc::clone(&worker);
155 let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
156
157 AsyncWriteStream {
158 worker,
159 join_handle: Some(join_handle),
160 }
161 }
162}
163
164#[async_trait::async_trait]
165impl OutputStream for AsyncWriteStream {
166 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
167 let mut state = self.worker.state();
168 state.check_error()?;
169 if state.flush_pending {
170 return Err(StreamError::Trap(anyhow!(
171 "write not permitted while flush pending"
172 )));
173 }
174 match state.write_budget.checked_sub(bytes.len()) {
175 Some(remaining_budget) => {
176 state.write_budget = remaining_budget;
177 state.items.push_back(bytes);
178 }
179 None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
180 }
181 drop(state);
182 self.worker.new_work.notify_one();
183 Ok(())
184 }
185 fn flush(&mut self) -> Result<(), StreamError> {
186 let mut state = self.worker.state();
187 state.check_error()?;
188
189 state.flush_pending = true;
190 self.worker.new_work.notify_one();
191
192 Ok(())
193 }
194
195 fn check_write(&mut self) -> Result<usize, StreamError> {
196 self.worker.check_write()
197 }
198
199 async fn cancel(&mut self) {
200 match self.join_handle.take() {
201 Some(task) => _ = task.cancel().await,
202 None => {}
203 }
204 }
205}
206#[async_trait::async_trait]
207impl Pollable for AsyncWriteStream {
208 async fn ready(&mut self) {
209 self.worker.ready().await;
210 }
211}