1use anyhow::anyhow;
11use bytes::Bytes;
12use std::sync::{Arc, Mutex};
13use tokio::sync::mpsc;
14use wasmtime_wasi_io::{
15 poll::Pollable,
16 streams::{InputStream, OutputStream, StreamError},
17};
18
19pub use crate::write_stream::AsyncWriteStream;
20
21#[derive(Debug, Clone)]
22pub struct MemoryInputPipe {
23 buffer: Arc<Mutex<Bytes>>,
24}
25
26impl MemoryInputPipe {
27 pub fn new(bytes: impl Into<Bytes>) -> Self {
28 Self {
29 buffer: Arc::new(Mutex::new(bytes.into())),
30 }
31 }
32
33 pub fn is_empty(&self) -> bool {
34 self.buffer.lock().unwrap().is_empty()
35 }
36}
37
38#[async_trait::async_trait]
39impl InputStream for MemoryInputPipe {
40 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
41 let mut buffer = self.buffer.lock().unwrap();
42 if buffer.is_empty() {
43 return Err(StreamError::Closed);
44 }
45
46 let size = size.min(buffer.len());
47 let read = buffer.split_to(size);
48 Ok(read)
49 }
50}
51
52#[async_trait::async_trait]
53impl Pollable for MemoryInputPipe {
54 async fn ready(&mut self) {}
55}
56
57#[derive(Debug, Clone)]
58pub struct MemoryOutputPipe {
59 capacity: usize,
60 buffer: Arc<Mutex<bytes::BytesMut>>,
61}
62
63impl MemoryOutputPipe {
64 pub fn new(capacity: usize) -> Self {
65 MemoryOutputPipe {
66 capacity,
67 buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())),
68 }
69 }
70
71 pub fn contents(&self) -> bytes::Bytes {
72 self.buffer.lock().unwrap().clone().freeze()
73 }
74
75 pub fn try_into_inner(self) -> Option<bytes::BytesMut> {
76 std::sync::Arc::into_inner(self.buffer).map(|m| m.into_inner().unwrap())
77 }
78}
79
80#[async_trait::async_trait]
81impl OutputStream for MemoryOutputPipe {
82 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
83 let mut buf = self.buffer.lock().unwrap();
84 if bytes.len() > self.capacity - buf.len() {
85 return Err(StreamError::Trap(anyhow!(
86 "write beyond capacity of MemoryOutputPipe"
87 )));
88 }
89 buf.extend_from_slice(bytes.as_ref());
90 Ok(())
92 }
93 fn flush(&mut self) -> Result<(), StreamError> {
94 Ok(())
96 }
97 fn check_write(&mut self) -> Result<usize, StreamError> {
98 let consumed = self.buffer.lock().unwrap().len();
99 if consumed < self.capacity {
100 Ok(self.capacity - consumed)
101 } else {
102 Err(StreamError::Closed)
104 }
105 }
106}
107
108#[async_trait::async_trait]
109impl Pollable for MemoryOutputPipe {
110 async fn ready(&mut self) {}
111}
112
113pub struct AsyncReadStream {
115 closed: bool,
116 buffer: Option<Result<Bytes, StreamError>>,
117 receiver: mpsc::Receiver<Result<Bytes, StreamError>>,
118 join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
119}
120
121impl AsyncReadStream {
122 pub fn new<T: tokio::io::AsyncRead + Send + Unpin + 'static>(mut reader: T) -> Self {
125 let (sender, receiver) = mpsc::channel(1);
126 let join_handle = crate::runtime::spawn(async move {
127 loop {
128 use tokio::io::AsyncReadExt;
129 let mut buf = bytes::BytesMut::with_capacity(4096);
130 let sent = match reader.read_buf(&mut buf).await {
131 Ok(nbytes) if nbytes == 0 => sender.send(Err(StreamError::Closed)).await,
132 Ok(_) => sender.send(Ok(buf.freeze())).await,
133 Err(e) => {
134 sender
135 .send(Err(StreamError::LastOperationFailed(e.into())))
136 .await
137 }
138 };
139 if sent.is_err() {
140 break;
142 }
143 }
144 });
145 AsyncReadStream {
146 closed: false,
147 buffer: None,
148 receiver,
149 join_handle: Some(join_handle),
150 }
151 }
152}
153
154#[async_trait::async_trait]
155impl InputStream for AsyncReadStream {
156 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
157 use mpsc::error::TryRecvError;
158
159 match self.buffer.take() {
160 Some(Ok(mut bytes)) => {
161 let len = bytes.len().min(size);
163 let rest = bytes.split_off(len);
164 if !rest.is_empty() {
165 self.buffer = Some(Ok(rest));
166 }
167 return Ok(bytes);
168 }
169 Some(Err(e)) => {
170 self.closed = true;
171 return Err(e);
172 }
173 None => {}
174 }
175
176 match self.receiver.try_recv() {
177 Ok(Ok(mut bytes)) => {
178 let len = bytes.len().min(size);
179 let rest = bytes.split_off(len);
180 if !rest.is_empty() {
181 self.buffer = Some(Ok(rest));
182 }
183
184 Ok(bytes)
185 }
186 Ok(Err(e)) => {
187 self.closed = true;
188 Err(e)
189 }
190 Err(TryRecvError::Empty) => Ok(Bytes::new()),
191 Err(TryRecvError::Disconnected) => Err(StreamError::Trap(anyhow!(
192 "AsyncReadStream sender died - should be impossible"
193 ))),
194 }
195 }
196
197 async fn cancel(&mut self) {
198 match self.join_handle.take() {
199 Some(task) => _ = task.cancel().await,
200 None => {}
201 }
202 }
203}
204#[async_trait::async_trait]
205impl Pollable for AsyncReadStream {
206 async fn ready(&mut self) {
207 if self.buffer.is_some() || self.closed {
208 return;
209 }
210 match self.receiver.recv().await {
211 Some(res) => self.buffer = Some(res),
212 None => {
213 panic!("no more sender for an open AsyncReadStream - should be impossible")
214 }
215 }
216 }
217}
218
219#[derive(Copy, Clone)]
221pub struct SinkOutputStream;
222
223#[async_trait::async_trait]
224impl OutputStream for SinkOutputStream {
225 fn write(&mut self, _buf: Bytes) -> Result<(), StreamError> {
226 Ok(())
227 }
228 fn flush(&mut self) -> Result<(), StreamError> {
229 Ok(())
231 }
232
233 fn check_write(&mut self) -> Result<usize, StreamError> {
234 Ok(usize::MAX)
236 }
237}
238
239#[async_trait::async_trait]
240impl Pollable for SinkOutputStream {
241 async fn ready(&mut self) {}
242}
243
244#[derive(Copy, Clone)]
246pub struct ClosedInputStream;
247
248#[async_trait::async_trait]
249impl InputStream for ClosedInputStream {
250 fn read(&mut self, _size: usize) -> Result<Bytes, StreamError> {
251 Err(StreamError::Closed)
252 }
253}
254
255#[async_trait::async_trait]
256impl Pollable for ClosedInputStream {
257 async fn ready(&mut self) {}
258}
259
260#[derive(Copy, Clone)]
262pub struct ClosedOutputStream;
263
264#[async_trait::async_trait]
265impl OutputStream for ClosedOutputStream {
266 fn write(&mut self, _: Bytes) -> Result<(), StreamError> {
267 Err(StreamError::Closed)
268 }
269 fn flush(&mut self) -> Result<(), StreamError> {
270 Err(StreamError::Closed)
271 }
272
273 fn check_write(&mut self) -> Result<usize, StreamError> {
274 Err(StreamError::Closed)
275 }
276}
277
278#[async_trait::async_trait]
279impl Pollable for ClosedOutputStream {
280 async fn ready(&mut self) {}
281}
282
283#[cfg(test)]
284mod test {
285 use super::*;
286 use std::time::Duration;
287 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
288
289 #[cfg(not(target_arch = "x86_64"))]
291 const TEST_ITERATIONS: usize = 10;
292
293 #[cfg(target_arch = "x86_64")]
294 const TEST_ITERATIONS: usize = 100;
295
296 async fn resolves_immediately<F, O>(fut: F) -> O
297 where
298 F: futures::Future<Output = O>,
299 {
300 tokio::time::timeout(Duration::from_secs(2), fut)
304 .await
305 .expect("operation timed out")
306 }
307
308 async fn never_resolves<F: futures::Future>(fut: F) {
309 tokio::time::timeout(Duration::from_millis(10), fut)
313 .await
314 .err()
315 .expect("operation should time out");
316 }
317
318 pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) {
319 let (a, b) = tokio::io::duplex(size);
320 let (_read_half, write_half) = tokio::io::split(a);
321 let (read_half, _write_half) = tokio::io::split(b);
322 (read_half, write_half)
323 }
324
325 #[test_log::test(tokio::test(flavor = "multi_thread"))]
326 async fn empty_read_stream() {
327 let mut reader = AsyncReadStream::new(tokio::io::empty());
328
329 match reader.read(10) {
332 Err(StreamError::Closed) => {}
334
335 Ok(bs) => {
337 assert!(bs.is_empty());
338 resolves_immediately(reader.ready()).await;
339 assert!(matches!(reader.read(0), Err(StreamError::Closed)));
340 }
341 res => panic!("unexpected: {res:?}"),
342 }
343 }
344
345 #[test_log::test(tokio::test(flavor = "multi_thread"))]
346 async fn infinite_read_stream() {
347 let mut reader = AsyncReadStream::new(tokio::io::repeat(0));
348
349 let bs = reader.read(10).unwrap();
350 if bs.is_empty() {
351 resolves_immediately(reader.ready()).await;
353 let bs = reader.read(10).unwrap();
355 assert_eq!(bs.len(), 10);
356 } else {
357 assert_eq!(bs.len(), 10);
358 }
359
360 let bs = reader.read(10).unwrap();
362 assert_eq!(bs.len(), 10);
363
364 let bs = reader.read(0).unwrap();
366 assert_eq!(bs.len(), 0);
367 }
368
369 async fn finite_async_reader(contents: &[u8]) -> impl AsyncRead + Send + 'static + use<> {
370 let (r, mut w) = simplex(contents.len());
371 w.write_all(contents).await.unwrap();
372 r
373 }
374
375 #[test_log::test(tokio::test(flavor = "multi_thread"))]
376 async fn finite_read_stream() {
377 let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await);
378
379 let bs = reader.read(123).unwrap();
380 if bs.is_empty() {
381 resolves_immediately(reader.ready()).await;
383 let bs = reader.read(123).unwrap();
385 assert_eq!(bs.len(), 123);
386 } else {
387 assert_eq!(bs.len(), 123);
388 }
389
390 match reader.read(0) {
393 Err(StreamError::Closed) => {} Ok(bs) => {
395 assert!(bs.is_empty());
396 resolves_immediately(reader.ready()).await;
398 assert!(matches!(reader.read(0), Err(StreamError::Closed)));
400 }
401 res => panic!("unexpected: {res:?}"),
402 }
403 }
404
405 #[test_log::test(tokio::test(flavor = "multi_thread"))]
406 async fn multiple_chunks_read_stream() {
409 let (r, mut w) = simplex(1024);
410 let mut reader = AsyncReadStream::new(r);
411
412 w.write_all(&[123]).await.unwrap();
413
414 let bs = reader.read(1).unwrap();
415 if bs.is_empty() {
416 resolves_immediately(reader.ready()).await;
418 let bs = reader.read(1).unwrap();
420 assert_eq!(*bs, [123u8]);
421 } else {
422 assert_eq!(*bs, [123u8]);
423 }
424
425 let bs = reader.read(1).unwrap();
427 assert!(bs.is_empty());
428
429 never_resolves(reader.ready()).await;
431
432 let bs = reader.read(1).unwrap();
434 assert!(bs.is_empty());
435
436 w.write_all(&[45]).await.unwrap();
438
439 resolves_immediately(reader.ready()).await;
442
443 let bs = reader.read(1).unwrap();
445 assert_eq!(*bs, [45u8]);
446
447 let bs = reader.read(1).unwrap();
449 assert!(bs.is_empty());
450
451 never_resolves(reader.ready()).await;
453
454 let bs = reader.read(1).unwrap();
456 assert!(bs.is_empty());
457
458 drop(w);
460
461 resolves_immediately(reader.ready()).await;
464
465 assert!(matches!(reader.read(1), Err(StreamError::Closed)));
467 }
468
469 #[test_log::test(tokio::test(flavor = "multi_thread"))]
470 async fn backpressure_read_stream() {
474 let (r, mut w) = simplex(16 * 1024); let mut reader = AsyncReadStream::new(r);
476
477 let writer_task = tokio::task::spawn(async move {
478 w.write_all(&[123; 8192]).await.unwrap();
480 w
481 });
482
483 resolves_immediately(reader.ready()).await;
484
485 let bs = reader.read(4097).unwrap();
488 assert_eq!(bs.len(), 4096);
489
490 resolves_immediately(reader.ready()).await;
492
493 let bs = reader.read(4097).unwrap();
496 assert_eq!(bs.len(), 4096);
497
498 let w = resolves_immediately(writer_task).await;
500
501 drop(w);
503
504 resolves_immediately(reader.ready()).await;
506
507 assert!(matches!(reader.read(4097), Err(StreamError::Closed)));
509 }
510
511 #[test_log::test(test_log::test(tokio::test(flavor = "multi_thread")))]
512 async fn sink_write_stream() {
513 let mut writer = AsyncWriteStream::new(2048, tokio::io::sink());
514 let chunk = Bytes::from_static(&[0; 1024]);
515
516 let readiness = resolves_immediately(writer.write_ready())
517 .await
518 .expect("write_ready does not trap");
519 assert_eq!(readiness, 2048);
520 writer.write(chunk.clone()).expect("write does not error");
522
523 let readiness = resolves_immediately(writer.write_ready())
525 .await
526 .expect("write_ready does not trap");
527 assert!(
528 readiness == 1024 || readiness == 2048,
529 "readiness should be 1024 or 2048, got {readiness}"
530 );
531
532 if readiness == 1024 {
533 writer.write(chunk.clone()).expect("write does not error");
534
535 let readiness = resolves_immediately(writer.write_ready())
536 .await
537 .expect("write_ready does not trap");
538 assert!(
539 readiness == 1024 || readiness == 2048,
540 "readiness should be 1024 or 2048, got {readiness}"
541 );
542 }
543 }
544
545 #[test_log::test(tokio::test(flavor = "multi_thread"))]
546 async fn closed_write_stream() {
547 for n in 0..TEST_ITERATIONS {
549 closed_write_stream_(n).await
550 }
551 }
552 #[tracing::instrument]
553 async fn closed_write_stream_(n: usize) {
554 let (reader, writer) = simplex(1);
555 let mut writer = AsyncWriteStream::new(1024, writer);
556
557 drop(reader);
559
560 let mut should_be_closed = false;
563
564 let chunk = Bytes::from_static(&[0; 1]);
566 writer
567 .write(chunk.clone())
568 .expect("first write should succeed");
569
570 let mut write_ready_res = None;
572 if n % 2 == 0 {
573 let r = resolves_immediately(writer.write_ready()).await;
574 match r {
576 Ok(1023) => {}
578 Err(StreamError::LastOperationFailed(_)) => {
580 tracing::debug!("discovered stream failure in first write_ready");
581 should_be_closed = true;
582 }
583 r => panic!("unexpected write_ready: {r:?}"),
584 }
585 write_ready_res = Some(r);
586 }
587
588 let flush_res = writer.flush();
591 match flush_res {
592 Err(StreamError::LastOperationFailed(_)) => {
594 tracing::debug!("discovered stream failure trying to flush");
595 assert!(!should_be_closed);
596 should_be_closed = true;
597 }
598 Err(StreamError::Closed) => {
600 assert!(
601 should_be_closed,
602 "expected a LastOperationFailed before we see Closed. {write_ready_res:?}"
603 );
604 }
605 Ok(()) => {}
607 Err(e) => panic!("unexpected flush error: {e:?} {write_ready_res:?}"),
608 }
609
610 match resolves_immediately(writer.write_ready()).await {
613 Err(StreamError::LastOperationFailed(_)) => {
615 tracing::debug!("discovered stream failure trying to flush");
616 assert!(!should_be_closed);
617 }
618 Err(StreamError::Closed) => {
620 assert!(should_be_closed);
621 }
622 r => {
623 panic!("stream should be reported closed by the end of write_ready after flush, got {r:?}. {write_ready_res:?} {flush_res:?}")
624 }
625 }
626 }
627
628 #[test_log::test(tokio::test(flavor = "multi_thread"))]
629 async fn multiple_chunks_write_stream() {
630 for n in 0..TEST_ITERATIONS {
632 multiple_chunks_write_stream_aux(n).await
633 }
634 }
635 #[tracing::instrument]
636 async fn multiple_chunks_write_stream_aux(_: usize) {
637 use std::ops::Deref;
638
639 let (mut reader, writer) = simplex(1024);
640 let mut writer = AsyncWriteStream::new(1024, writer);
641
642 let chunk = Bytes::from_static(&[123; 1]);
644
645 let permit = resolves_immediately(writer.write_ready())
646 .await
647 .expect("write should be ready");
648 assert_eq!(permit, 1024);
649
650 writer.write(chunk.clone()).expect("write does not trap");
651
652 let permit = resolves_immediately(writer.write_ready())
655 .await
656 .expect("write should be ready");
657 assert!(matches!(permit, 1023 | 1024));
658
659 let mut read_buf = vec![0; chunk.len()];
660 let read_len = reader.read_exact(&mut read_buf).await.unwrap();
661 assert_eq!(read_len, chunk.len());
662 assert_eq!(read_buf.as_slice(), chunk.deref());
663
664 let chunk2 = Bytes::from_static(&[45; 1]);
666
667 writer.flush().expect("channel is still alive");
669
670 let permit = resolves_immediately(writer.write_ready())
671 .await
672 .expect("write should be ready");
673 assert_eq!(permit, 1024);
674
675 writer.write(chunk2.clone()).expect("write does not trap");
676
677 let permit = resolves_immediately(writer.write_ready())
680 .await
681 .expect("write should be ready");
682 assert!(matches!(permit, 1023 | 1024));
683
684 let mut read2_buf = vec![0; chunk2.len()];
685 let read2_len = reader.read_exact(&mut read2_buf).await.unwrap();
686 assert_eq!(read2_len, chunk2.len());
687 assert_eq!(read2_buf.as_slice(), chunk2.deref());
688
689 writer.flush().expect("channel is still alive");
691
692 let permit = resolves_immediately(writer.write_ready())
693 .await
694 .expect("write should be ready");
695 assert_eq!(permit, 1024);
696 }
697
698 #[test_log::test(tokio::test(flavor = "multi_thread"))]
699 async fn backpressure_write_stream() {
700 for n in 0..TEST_ITERATIONS {
702 backpressure_write_stream_aux(n).await
703 }
704 }
705 #[tracing::instrument]
706 async fn backpressure_write_stream_aux(_: usize) {
707 use futures::future::poll_immediate;
708
709 let (mut reader, writer) = simplex(1024);
712 let mut writer = AsyncWriteStream::new(1024, writer);
713
714 let chunk = Bytes::from_static(&[0; 1024]);
715
716 let permit = resolves_immediately(writer.write_ready())
717 .await
718 .expect("write should be ready");
719 assert_eq!(permit, 1024);
720
721 writer.write(chunk.clone()).expect("write succeeds");
722
723 let permit = poll_immediate(writer.write_ready()).await;
726 assert!(matches!(permit, None | Some(Ok(1024))));
727
728 let permit = resolves_immediately(writer.write_ready())
731 .await
732 .expect("write should be ready");
733 assert_eq!(permit, 1024);
734
735 writer.write(chunk.clone()).expect("write does not trap");
738
739 writer
741 .write(chunk.clone())
742 .err()
743 .expect("unpermitted write does trap");
744
745 never_resolves(writer.write_ready()).await;
748
749 let mut buf = [0; 2048];
752 reader.read_exact(&mut buf).await.unwrap();
753
754 never_resolves(reader.read(&mut buf)).await;
756
757 let permit = resolves_immediately(writer.write_ready())
759 .await
760 .expect("ready is ok");
761 assert_eq!(permit, 1024);
762
763 writer.write(chunk.clone()).expect("write does not trap");
765 }
766
767 #[test_log::test(tokio::test(flavor = "multi_thread"))]
768 async fn backpressure_write_stream_with_flush() {
769 for n in 0..TEST_ITERATIONS {
770 backpressure_write_stream_with_flush_aux(n).await;
771 }
772 }
773
774 async fn backpressure_write_stream_with_flush_aux(_: usize) {
775 let (mut reader, writer) = simplex(1024);
778 let mut writer = AsyncWriteStream::new(1024, writer);
779
780 let chunk = Bytes::from_static(&[0; 1024]);
781
782 let permit = resolves_immediately(writer.write_ready())
783 .await
784 .expect("write should be ready");
785 assert_eq!(permit, 1024);
786
787 writer.write(chunk.clone()).expect("write succeeds");
788
789 writer.flush().expect("flush succeeds");
790
791 let permit = resolves_immediately(writer.write_ready())
794 .await
795 .expect("write_ready succeeds");
796 assert_eq!(permit, 1024);
797
798 writer.write(chunk.clone()).expect("write does not trap");
800
801 writer.flush().expect("flush succeeds");
803
804 writer
806 .write(chunk.clone())
807 .err()
808 .expect("unpermitted write does trap");
809
810 never_resolves(writer.write_ready()).await;
813
814 let mut buf = [0; 2048];
817 reader.read_exact(&mut buf).await.unwrap();
818
819 never_resolves(reader.read(&mut buf)).await;
821
822 let permit = resolves_immediately(writer.write_ready())
824 .await
825 .expect("ready is ok");
826 assert_eq!(permit, 1024);
827
828 writer.write(chunk.clone()).expect("write does not trap");
830
831 writer.flush().expect("flush succeeds");
832
833 let permit = resolves_immediately(writer.write_ready())
834 .await
835 .expect("ready is ok");
836 assert_eq!(permit, 1024);
837 }
838}