compio_net/
split.rs

1use std::{error::Error, fmt, io};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::AsRawFd;
5use compio_io::{AsyncRead, AsyncWrite};
6
7pub(crate) fn split<T>(stream: &T) -> (ReadHalf<T>, WriteHalf<T>)
8where
9    for<'a> &'a T: AsyncRead + AsyncWrite,
10{
11    (ReadHalf(stream), WriteHalf(stream))
12}
13
14/// Borrowed read half.
15#[derive(Debug)]
16pub struct ReadHalf<'a, T>(&'a T);
17
18impl<T> AsyncRead for ReadHalf<'_, T>
19where
20    for<'a> &'a T: AsyncRead,
21{
22    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
23        self.0.read(buf).await
24    }
25
26    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
27        self.0.read_vectored(buf).await
28    }
29}
30
31/// Borrowed write half.
32#[derive(Debug)]
33pub struct WriteHalf<'a, T>(&'a T);
34
35impl<T> AsyncWrite for WriteHalf<'_, T>
36where
37    for<'a> &'a T: AsyncWrite,
38{
39    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
40        self.0.write(buf).await
41    }
42
43    async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
44        self.0.write_vectored(buf).await
45    }
46
47    async fn flush(&mut self) -> io::Result<()> {
48        self.0.flush().await
49    }
50
51    async fn shutdown(&mut self) -> io::Result<()> {
52        self.0.shutdown().await
53    }
54}
55
56pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>)
57where
58    for<'a> &'a T: AsyncRead + AsyncWrite,
59    T: Clone,
60{
61    (OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream))
62}
63
64/// Owned read half.
65#[derive(Debug)]
66pub struct OwnedReadHalf<T>(T);
67
68impl<T: AsRawFd> OwnedReadHalf<T> {
69    /// Attempts to put the two halves of a `TcpStream` back together and
70    /// recover the original socket. Succeeds only if the two halves
71    /// originated from the same call to `into_split`.
72    pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
73        if self.0.as_raw_fd() == w.0.as_raw_fd() {
74            drop(w);
75            Ok(self.0)
76        } else {
77            Err(ReuniteError(self, w))
78        }
79    }
80}
81
82impl<T> AsyncRead for OwnedReadHalf<T>
83where
84    for<'a> &'a T: AsyncRead,
85{
86    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
87        (&self.0).read(buf).await
88    }
89
90    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
91        (&self.0).read_vectored(buf).await
92    }
93}
94
95/// Owned write half.
96#[derive(Debug)]
97pub struct OwnedWriteHalf<T>(T);
98
99impl<T> AsyncWrite for OwnedWriteHalf<T>
100where
101    for<'a> &'a T: AsyncWrite,
102{
103    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
104        (&self.0).write(buf).await
105    }
106
107    async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
108        (&self.0).write_vectored(buf).await
109    }
110
111    async fn flush(&mut self) -> io::Result<()> {
112        (&self.0).flush().await
113    }
114
115    async fn shutdown(&mut self) -> io::Result<()> {
116        (&self.0).shutdown().await
117    }
118}
119
120/// Error indicating that two halves were not from the same socket, and thus
121/// could not be reunited.
122#[derive(Debug)]
123pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);
124
125impl<T> fmt::Display for ReuniteError<T> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(
128            f,
129            "tried to reunite halves that are not from the same socket"
130        )
131    }
132}
133
134impl<T: fmt::Debug> Error for ReuniteError<T> {}