compio_io/
split.rs

1use std::sync::Arc;
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use futures_util::lock::Mutex;
5
6use crate::{AsyncRead, AsyncWrite, IoResult};
7
8/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
9/// [`AsyncRead`] and [`AsyncWrite`] handles.
10pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
11    let stream = Arc::new(Mutex::new(stream));
12    (ReadHalf(stream.clone()), WriteHalf(stream))
13}
14
15/// The readable half of a value returned from [`split`].
16#[derive(Debug)]
17pub struct ReadHalf<T>(Arc<Mutex<T>>);
18
19impl<T: Unpin> ReadHalf<T> {
20    /// Reunites with a previously split [`WriteHalf`].
21    ///
22    /// # Panics
23    ///
24    /// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from
25    /// the same [`split`] operation this method will panic.
26    /// This can be checked ahead of time by comparing the stored pointer
27    /// of the two halves.
28    #[track_caller]
29    pub fn unsplit(self, w: WriteHalf<T>) -> T {
30        if Arc::ptr_eq(&self.0, &w.0) {
31            drop(w);
32            let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
33            inner.into_inner()
34        } else {
35            #[cold]
36            fn panic_unrelated() -> ! {
37                panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
38            }
39
40            panic_unrelated()
41        }
42    }
43}
44
45impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
46    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
47        self.0.lock().await.read(buf).await
48    }
49
50    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
51        self.0.lock().await.read_vectored(buf).await
52    }
53}
54
55/// The writable half of a value returned from [`split`].
56#[derive(Debug)]
57pub struct WriteHalf<T>(Arc<Mutex<T>>);
58
59impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
60    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
61        self.0.lock().await.write(buf).await
62    }
63
64    async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
65        self.0.lock().await.write_vectored(buf).await
66    }
67
68    async fn flush(&mut self) -> IoResult<()> {
69        self.0.lock().await.flush().await
70    }
71
72    async fn shutdown(&mut self) -> IoResult<()> {
73        self.0.lock().await.shutdown().await
74    }
75}