use crate::preview2::{HostInputStream, HostOutputStream, StreamState};
use anyhow::Error;
use bytes::Bytes;
#[derive(Debug)]
pub struct MemoryInputPipe {
buffer: std::io::Cursor<Bytes>,
}
impl MemoryInputPipe {
pub fn new(bytes: Bytes) -> Self {
Self {
buffer: std::io::Cursor::new(bytes),
}
}
pub fn is_empty(&self) -> bool {
self.buffer.get_ref().len() as u64 == self.buffer.position()
}
}
#[async_trait::async_trait]
impl HostInputStream for MemoryInputPipe {
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
if self.is_empty() {
return Ok((Bytes::new(), StreamState::Closed));
}
let mut dest = bytes::BytesMut::zeroed(size);
let nbytes = std::io::Read::read(&mut self.buffer, dest.as_mut())?;
dest.truncate(nbytes);
let state = if self.is_empty() {
StreamState::Closed
} else {
StreamState::Open
};
Ok((dest.freeze(), state))
}
async fn ready(&mut self) -> Result<(), Error> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MemoryOutputPipe {
buffer: std::sync::Arc<std::sync::Mutex<bytes::BytesMut>>,
}
impl MemoryOutputPipe {
pub fn new() -> Self {
MemoryOutputPipe {
buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())),
}
}
pub fn contents(&self) -> bytes::Bytes {
self.buffer.lock().unwrap().clone().freeze()
}
pub fn try_into_inner(self) -> Option<bytes::BytesMut> {
std::sync::Arc::into_inner(self.buffer).map(|m| m.into_inner().unwrap())
}
}
#[async_trait::async_trait]
impl HostOutputStream for MemoryOutputPipe {
fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> {
let mut buf = self.buffer.lock().unwrap();
buf.extend_from_slice(bytes.as_ref());
Ok((bytes.len(), StreamState::Open))
}
async fn ready(&mut self) -> Result<(), Error> {
Ok(())
}
}
pub fn pipe(size: usize) -> (AsyncReadStream, AsyncWriteStream) {
let (a, b) = tokio::io::duplex(size);
let (_read_half, write_half) = tokio::io::split(a);
let (read_half, _write_half) = tokio::io::split(b);
(
AsyncReadStream::new(read_half),
AsyncWriteStream::new(write_half),
)
}
pub struct AsyncReadStream {
state: StreamState,
buffer: Option<Result<Bytes, std::io::Error>>,
receiver: tokio::sync::mpsc::Receiver<Result<(Bytes, StreamState), std::io::Error>>,
pub(crate) join_handle: tokio::task::JoinHandle<()>,
}
impl AsyncReadStream {
pub fn new<T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static>(mut reader: T) -> Self {
let (sender, receiver) = tokio::sync::mpsc::channel(1);
let join_handle = crate::preview2::spawn(async move {
loop {
use tokio::io::AsyncReadExt;
let mut buf = bytes::BytesMut::with_capacity(4096);
let sent = match reader.read_buf(&mut buf).await {
Ok(nbytes) if nbytes == 0 => {
sender.send(Ok((Bytes::new(), StreamState::Closed))).await
}
Ok(_) => sender.send(Ok((buf.freeze(), StreamState::Open))).await,
Err(e) => sender.send(Err(e)).await,
};
if sent.is_err() {
break;
}
}
});
AsyncReadStream {
state: StreamState::Open,
buffer: None,
receiver,
join_handle,
}
}
}
impl Drop for AsyncReadStream {
fn drop(&mut self) {
self.join_handle.abort()
}
}
#[async_trait::async_trait]
impl HostInputStream for AsyncReadStream {
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
use tokio::sync::mpsc::error::TryRecvError;
match self.buffer.take() {
Some(Ok(mut bytes)) => {
let len = bytes.len().min(size);
let rest = bytes.split_off(len);
let return_state = if !rest.is_empty() {
self.buffer = Some(Ok(rest));
StreamState::Open
} else {
self.state
};
return Ok((bytes, return_state));
}
Some(Err(e)) => return Err(e.into()),
None => {}
}
match self.receiver.try_recv() {
Ok(Ok((mut bytes, state))) => {
self.state = state;
let len = bytes.len().min(size);
let rest = bytes.split_off(len);
let return_state = if !rest.is_empty() {
self.buffer = Some(Ok(rest));
StreamState::Open
} else {
self.state
};
Ok((bytes, return_state))
}
Ok(Err(e)) => Err(e.into()),
Err(TryRecvError::Empty) => Ok((Bytes::new(), self.state)),
Err(TryRecvError::Disconnected) => Err(anyhow::anyhow!(
"AsyncReadStream sender died - should be impossible"
)),
}
}
async fn ready(&mut self) -> Result<(), Error> {
if self.buffer.is_some() || self.state == StreamState::Closed {
return Ok(());
}
match self.receiver.recv().await {
Some(Ok((bytes, state))) => {
if state == StreamState::Closed {
self.state = state;
}
self.buffer = Some(Ok(bytes));
}
Some(Err(e)) => self.buffer = Some(Err(e)),
None => {
return Err(anyhow::anyhow!(
"no more sender for an open AsyncReadStream - should be impossible"
))
}
}
Ok(())
}
}
#[derive(Debug)]
enum WriteState {
Ready,
Pending,
Err(std::io::Error),
}
pub struct AsyncWriteStream {
state: Option<WriteState>,
sender: tokio::sync::mpsc::Sender<Bytes>,
result_receiver: tokio::sync::mpsc::Receiver<Result<StreamState, std::io::Error>>,
join_handle: tokio::task::JoinHandle<()>,
}
impl AsyncWriteStream {
pub fn new<T: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static>(mut writer: T) -> Self {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<Bytes>(1);
let (result_sender, result_receiver) = tokio::sync::mpsc::channel(1);
let join_handle = crate::preview2::spawn(async move {
'outer: loop {
use tokio::io::AsyncWriteExt;
match receiver.recv().await {
Some(mut bytes) => {
while !bytes.is_empty() {
match writer.write_buf(&mut bytes).await {
Ok(0) => {
let _ = result_sender.send(Ok(StreamState::Closed)).await;
break 'outer;
}
Ok(_) => {
if bytes.is_empty() {
match result_sender.send(Ok(StreamState::Open)).await {
Ok(_) => break,
Err(_) => break 'outer,
}
}
continue;
}
Err(e) => {
let _ = result_sender.send(Err(e)).await;
break 'outer;
}
}
}
}
None => break 'outer,
}
}
});
AsyncWriteStream {
state: Some(WriteState::Ready),
sender,
result_receiver,
join_handle,
}
}
fn send(&mut self, bytes: Bytes) -> anyhow::Result<(usize, StreamState)> {
use tokio::sync::mpsc::error::TrySendError;
debug_assert!(matches!(self.state, Some(WriteState::Ready)));
let len = bytes.len();
match self.sender.try_send(bytes) {
Ok(_) => {
self.state = Some(WriteState::Pending);
Ok((len, StreamState::Open))
}
Err(TrySendError::Full(_)) => {
unreachable!("task shouldnt be full when writestate is ready")
}
Err(TrySendError::Closed(_)) => unreachable!("task shouldn't die while not closed"),
}
}
}
impl Drop for AsyncWriteStream {
fn drop(&mut self) {
self.join_handle.abort()
}
}
#[async_trait::async_trait]
impl HostOutputStream for AsyncWriteStream {
fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> {
use tokio::sync::mpsc::error::TryRecvError;
match self.state {
Some(WriteState::Ready) => self.send(bytes),
Some(WriteState::Pending) => match self.result_receiver.try_recv() {
Ok(Ok(StreamState::Open)) => {
self.state = Some(WriteState::Ready);
self.send(bytes)
}
Ok(Ok(StreamState::Closed)) => {
self.state = None;
Ok((0, StreamState::Closed))
}
Ok(Err(e)) => {
self.state = None;
Err(e.into())
}
Err(TryRecvError::Empty) => {
self.state = Some(WriteState::Pending);
Ok((0, StreamState::Open))
}
Err(TryRecvError::Disconnected) => {
unreachable!("task shouldn't die while pending")
}
},
Some(WriteState::Err(_)) => {
if let Some(WriteState::Err(e)) = self.state.take() {
Err(e.into())
} else {
unreachable!("self.state shown to be Some(Err(e)) in match clause")
}
}
None => Ok((0, StreamState::Closed)),
}
}
async fn ready(&mut self) -> Result<(), Error> {
match &self.state {
Some(WriteState::Pending) => match self.result_receiver.recv().await {
Some(Ok(StreamState::Open)) => {
self.state = Some(WriteState::Ready);
}
Some(Ok(StreamState::Closed)) => {
self.state = None;
}
Some(Err(e)) => {
self.state = Some(WriteState::Err(e));
}
None => unreachable!("task shouldn't die while pending"),
},
Some(WriteState::Ready | WriteState::Err(_)) | None => {}
}
Ok(())
}
}
pub struct SinkOutputStream;
#[async_trait::async_trait]
impl HostOutputStream for SinkOutputStream {
fn write(&mut self, buf: Bytes) -> Result<(usize, StreamState), Error> {
Ok((buf.len(), StreamState::Open))
}
async fn ready(&mut self) -> Result<(), Error> {
Ok(())
}
}
pub struct ClosedInputStream;
#[async_trait::async_trait]
impl HostInputStream for ClosedInputStream {
fn read(&mut self, _size: usize) -> Result<(Bytes, StreamState), Error> {
Ok((Bytes::new(), StreamState::Closed))
}
async fn ready(&mut self) -> Result<(), Error> {
Ok(())
}
}
pub struct ClosedOutputStream;
#[async_trait::async_trait]
impl HostOutputStream for ClosedOutputStream {
fn write(&mut self, _: Bytes) -> Result<(usize, StreamState), Error> {
Ok((0, StreamState::Closed))
}
async fn ready(&mut self) -> Result<(), Error> {
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const REASONABLE_DURATION: std::time::Duration = std::time::Duration::from_millis(100);
pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) {
let (a, b) = tokio::io::duplex(size);
let (_read_half, write_half) = tokio::io::split(a);
let (read_half, _write_half) = tokio::io::split(b);
(read_half, write_half)
}
#[tokio::test(flavor = "multi_thread")]
async fn empty_read_stream() {
let mut reader = AsyncReadStream::new(tokio::io::empty());
let (bs, state) = reader.read(10).unwrap();
assert!(bs.is_empty());
match state {
StreamState::Closed => {}
StreamState::Open => {
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(0).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Closed);
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn infinite_read_stream() {
let mut reader = AsyncReadStream::new(tokio::io::repeat(0));
let (bs, state) = reader.read(10).unwrap();
assert_eq!(state, StreamState::Open);
if bs.is_empty() {
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(10).unwrap();
assert_eq!(bs.len(), 10);
assert_eq!(state, StreamState::Open);
} else {
assert_eq!(bs.len(), 10);
}
let (bs, state) = reader.read(10).unwrap();
assert_eq!(state, StreamState::Open);
assert_eq!(bs.len(), 10);
let (bs, state) = reader.read(0).unwrap();
assert_eq!(state, StreamState::Open);
assert_eq!(bs.len(), 0);
}
async fn finite_async_reader(contents: &[u8]) -> impl AsyncRead + Send + Sync + 'static {
let (r, mut w) = simplex(contents.len());
w.write_all(contents).await.unwrap();
r
}
#[tokio::test(flavor = "multi_thread")]
async fn finite_read_stream() {
let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await);
let (bs, state) = reader.read(123).unwrap();
assert_eq!(state, StreamState::Open);
if bs.is_empty() {
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(123).unwrap();
assert_eq!(bs.len(), 123);
assert_eq!(state, StreamState::Open);
} else {
assert_eq!(bs.len(), 123);
}
let (bs, state) = reader.read(0).unwrap();
assert!(bs.is_empty());
match state {
StreamState::Closed => {} StreamState::Open => {
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(0).unwrap();
assert_eq!(bs.len(), 0);
assert_eq!(state, StreamState::Closed);
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn multiple_chunks_read_stream() {
let (r, mut w) = simplex(1024);
let mut reader = AsyncReadStream::new(r);
w.write_all(&[123]).await.unwrap();
let (bs, state) = reader.read(1).unwrap();
assert_eq!(state, StreamState::Open);
if bs.is_empty() {
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(1).unwrap();
assert_eq!(*bs, [123u8]);
assert_eq!(state, StreamState::Open);
} else {
assert_eq!(*bs, [123u8]);
}
let (bs, state) = reader.read(1).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.err()
.expect("the reader should time out");
let (bs, state) = reader.read(1).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Open);
w.write_all(&[45]).await.unwrap();
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("the ready is ok");
let (bs, state) = reader.read(1).unwrap();
assert_eq!(*bs, [45u8]);
assert_eq!(state, StreamState::Open);
let (bs, state) = reader.read(1).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.err()
.expect("the reader should time out");
let (bs, state) = reader.read(1).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Open);
drop(w);
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("the ready is ok");
let (bs, state) = reader.read(1).unwrap();
assert!(bs.is_empty());
assert_eq!(state, StreamState::Closed);
}
#[tokio::test(flavor = "multi_thread")]
async fn backpressure_read_stream() {
let (r, mut w) = simplex(16 * 1024); let mut reader = AsyncReadStream::new(r);
let writer_task = tokio::task::spawn(async move {
w.write_all(&[123; 8192]).await.unwrap();
w
});
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(4097).unwrap();
assert_eq!(bs.len(), 4096);
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(4097).unwrap();
assert_eq!(bs.len(), 4096);
assert_eq!(state, StreamState::Open);
let w = tokio::time::timeout(REASONABLE_DURATION, writer_task)
.await
.expect("the join should be ready instantly");
drop(w);
tokio::time::timeout(REASONABLE_DURATION, reader.ready())
.await
.expect("the reader should be ready instantly")
.expect("ready is ok");
let (bs, state) = reader.read(4097).unwrap();
assert_eq!(bs.len(), 0);
assert_eq!(state, StreamState::Closed);
}
#[tokio::test(flavor = "multi_thread")]
async fn sink_write_stream() {
let mut writer = AsyncWriteStream::new(tokio::io::sink());
let chunk = Bytes::from_static(&[0; 1024]);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(state, StreamState::Open);
if !(len == 0 || len == chunk.len()) {
unreachable!()
}
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
}
#[tokio::test(flavor = "multi_thread")]
async fn closed_write_stream() {
let (reader, writer) = simplex(1024);
drop(reader);
let mut writer = AsyncWriteStream::new(writer);
let chunk = Bytes::from_static(&[0; 1]);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let err = writer.write(chunk.clone()).err().unwrap();
assert_eq!(
err.downcast_ref::<std::io::Error>().unwrap().kind(),
std::io::ErrorKind::BrokenPipe
);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, 0);
assert_eq!(state, StreamState::Closed);
}
#[tokio::test(flavor = "multi_thread")]
async fn multiple_chunks_write_stream() {
use std::ops::Deref;
let (mut reader, writer) = simplex(1024);
let mut writer = AsyncWriteStream::new(writer);
let chunk = Bytes::from_static(&[123; 1]);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let mut read_buf = vec![0; chunk.len()];
let read_len = reader.read_exact(&mut read_buf).await.unwrap();
assert_eq!(read_len, chunk.len());
assert_eq!(read_buf.as_slice(), chunk.deref());
let chunk2 = Bytes::from_static(&[45; 1]);
let (len, state) = writer.write(chunk2.clone()).unwrap();
assert_eq!(len, chunk2.len());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let mut read2_buf = vec![0; chunk2.len()];
let read2_len = reader.read_exact(&mut read2_buf).await.unwrap();
assert_eq!(read2_len, chunk2.len());
assert_eq!(read2_buf.as_slice(), chunk2.deref());
}
#[tokio::test(flavor = "multi_thread")]
async fn backpressure_write_stream() {
let (mut reader, writer) = simplex(1024);
let mut writer = AsyncWriteStream::new(writer);
let chunk = Bytes::from_static(&[0; 1024]);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, 0);
assert_eq!(state, StreamState::Open);
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.err()
.expect("the writer should be not become ready");
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, 0);
assert_eq!(state, StreamState::Open);
let mut buf = [0; 2048];
reader.read_exact(&mut buf).await.unwrap();
tokio::time::timeout(REASONABLE_DURATION, reader.read(&mut buf))
.await
.err()
.expect("nothing more buffered in the system");
tokio::time::timeout(REASONABLE_DURATION, writer.ready())
.await
.expect("the writer should be ready instantly")
.expect("ready is ok");
let (len, state) = writer.write(chunk.clone()).unwrap();
assert_eq!(len, chunk.len());
assert_eq!(state, StreamState::Open);
}
}