compio_net/
unix.rs

1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{SockAddr, Socket as Socket2, Type};
7
8use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
9
10/// A Unix socket server, listening for connections.
11///
12/// You can accept a new connection by using the [`UnixListener::accept`]
13/// method.
14///
15/// # Examples
16///
17/// ```
18/// use compio_io::{AsyncReadExt, AsyncWriteExt};
19/// use compio_net::{UnixListener, UnixStream};
20/// use tempfile::tempdir;
21///
22/// let dir = tempdir().unwrap();
23/// let sock_file = dir.path().join("unix-server.sock");
24///
25/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
26/// let listener = UnixListener::bind(&sock_file).await.unwrap();
27///
28/// let (mut tx, (mut rx, _)) =
29///     futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
30///
31/// tx.write_all("test").await.0.unwrap();
32///
33/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
34///
35/// assert_eq!(buf, b"test");
36/// # });
37/// ```
38#[derive(Debug, Clone)]
39pub struct UnixListener {
40    inner: Socket,
41}
42
43impl UnixListener {
44    /// Creates a new [`UnixListener`], which will be bound to the specified
45    /// file path. The file path cannot yet exist, and will be cleaned up
46    /// upon dropping [`UnixListener`]
47    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
48        Self::bind_addr(&SockAddr::unix(path)?).await
49    }
50
51    /// Creates a new [`UnixListener`] with [`SockAddr`], which will be bound to
52    /// the specified file path. The file path cannot yet exist, and will be
53    /// cleaned up upon dropping [`UnixListener`]
54    pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
55        if !addr.is_unix() {
56            return Err(io::Error::new(
57                io::ErrorKind::InvalidInput,
58                "addr is not unix socket address",
59            ));
60        }
61
62        let socket = Socket::bind(addr, Type::STREAM, None).await?;
63        socket.listen(1024)?;
64        Ok(UnixListener { inner: socket })
65    }
66
67    /// Close the socket. If the returned future is dropped before polling, the
68    /// socket won't be closed.
69    pub fn close(self) -> impl Future<Output = io::Result<()>> {
70        self.inner.close()
71    }
72
73    /// Accepts a new incoming connection from this listener.
74    ///
75    /// This function will yield once a new Unix domain socket connection
76    /// is established. When established, the corresponding [`UnixStream`] and
77    /// will be returned.
78    pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
79        let (socket, addr) = self.inner.accept().await?;
80        let stream = UnixStream { inner: socket };
81        Ok((stream, addr))
82    }
83
84    /// Returns the local address that this listener is bound to.
85    pub fn local_addr(&self) -> io::Result<SockAddr> {
86        self.inner.local_addr()
87    }
88}
89
90impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
91
92/// A Unix stream between two local sockets on Windows & WSL.
93///
94/// A Unix stream can either be created by connecting to an endpoint, via the
95/// `connect` method, or by accepting a connection from a listener.
96///
97/// # Examples
98///
99/// ```no_run
100/// use compio_io::AsyncWrite;
101/// use compio_net::UnixStream;
102///
103/// # compio_runtime::Runtime::new().unwrap().block_on(async {
104/// // Connect to a peer
105/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
106///
107/// // Write some data.
108/// stream.write("hello world!").await.unwrap();
109/// # })
110/// ```
111#[derive(Debug, Clone)]
112pub struct UnixStream {
113    inner: Socket,
114}
115
116impl UnixStream {
117    /// Opens a Unix connection to the specified file path. There must be a
118    /// [`UnixListener`] or equivalent listening on the corresponding Unix
119    /// domain socket to successfully connect and return a `UnixStream`.
120    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
121        Self::connect_addr(&SockAddr::unix(path)?).await
122    }
123
124    /// Opens a Unix connection to the specified address. There must be a
125    /// [`UnixListener`] or equivalent listening on the corresponding Unix
126    /// domain socket to successfully connect and return a `UnixStream`.
127    pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
128        if !addr.is_unix() {
129            return Err(io::Error::new(
130                io::ErrorKind::InvalidInput,
131                "addr is not unix socket address",
132            ));
133        }
134
135        #[cfg(windows)]
136        let socket = {
137            let new_addr = empty_unix_socket();
138            Socket::bind(&new_addr, Type::STREAM, None).await?
139        };
140        #[cfg(unix)]
141        let socket = {
142            use socket2::Domain;
143            Socket::new(Domain::UNIX, Type::STREAM, None).await?
144        };
145        socket.connect_async(addr).await?;
146        let unix_stream = UnixStream { inner: socket };
147        Ok(unix_stream)
148    }
149
150    #[cfg(unix)]
151    /// Creates new UnixStream from a std::os::unix::net::UnixStream.
152    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
153        Ok(Self {
154            inner: Socket::from_socket2(Socket2::from(stream))?,
155        })
156    }
157
158    /// Close the socket. If the returned future is dropped before polling, the
159    /// socket won't be closed.
160    pub fn close(self) -> impl Future<Output = io::Result<()>> {
161        self.inner.close()
162    }
163
164    /// Returns the socket path of the remote peer of this connection.
165    pub fn peer_addr(&self) -> io::Result<SockAddr> {
166        #[allow(unused_mut)]
167        let mut addr = self.inner.peer_addr()?;
168        #[cfg(windows)]
169        {
170            fix_unix_socket_length(&mut addr);
171        }
172        Ok(addr)
173    }
174
175    /// Returns the socket path of the local half of this connection.
176    pub fn local_addr(&self) -> io::Result<SockAddr> {
177        self.inner.local_addr()
178    }
179
180    /// Splits a [`UnixStream`] into a read half and a write half, which can be
181    /// used to read and write the stream concurrently.
182    ///
183    /// This method is more efficient than
184    /// [`into_split`](UnixStream::into_split), but the halves cannot
185    /// be moved into independently spawned tasks.
186    pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
187        crate::split(self)
188    }
189
190    /// Splits a [`UnixStream`] into a read half and a write half, which can be
191    /// used to read and write the stream concurrently.
192    ///
193    /// Unlike [`split`](UnixStream::split), the owned halves can be moved to
194    /// separate tasks, however this comes at the cost of a heap allocation.
195    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
196        crate::into_split(self)
197    }
198
199    /// Create [`PollFd`] from inner socket.
200    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
201        self.inner.to_poll_fd()
202    }
203
204    /// Create [`PollFd`] from inner socket.
205    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
206        self.inner.into_poll_fd()
207    }
208}
209
210impl AsyncRead for UnixStream {
211    #[inline]
212    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
213        (&*self).read(buf).await
214    }
215
216    #[inline]
217    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
218        (&*self).read_vectored(buf).await
219    }
220}
221
222impl AsyncRead for &UnixStream {
223    #[inline]
224    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
225        self.inner.recv(buf).await
226    }
227
228    #[inline]
229    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
230        self.inner.recv_vectored(buf).await
231    }
232}
233
234impl AsyncWrite for UnixStream {
235    #[inline]
236    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
237        (&*self).write(buf).await
238    }
239
240    #[inline]
241    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
242        (&*self).write_vectored(buf).await
243    }
244
245    #[inline]
246    async fn flush(&mut self) -> io::Result<()> {
247        (&*self).flush().await
248    }
249
250    #[inline]
251    async fn shutdown(&mut self) -> io::Result<()> {
252        (&*self).shutdown().await
253    }
254}
255
256impl AsyncWrite for &UnixStream {
257    #[inline]
258    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
259        self.inner.send(buf).await
260    }
261
262    #[inline]
263    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
264        self.inner.send_vectored(buf).await
265    }
266
267    #[inline]
268    async fn flush(&mut self) -> io::Result<()> {
269        Ok(())
270    }
271
272    #[inline]
273    async fn shutdown(&mut self) -> io::Result<()> {
274        self.inner.shutdown().await
275    }
276}
277
278impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
279
280#[cfg(windows)]
281#[inline]
282fn empty_unix_socket() -> SockAddr {
283    use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
284
285    // SAFETY: the length is correct
286    unsafe {
287        SockAddr::try_init(|addr, len| {
288            let addr: *mut SOCKADDR_UN = addr.cast();
289            std::ptr::write(addr, SOCKADDR_UN {
290                sun_family: AF_UNIX,
291                sun_path: [0; 108],
292            });
293            std::ptr::write(len, 3);
294            Ok(())
295        })
296    }
297    // it is always Ok
298    .unwrap()
299    .1
300}
301
302// The peer addr returned after ConnectEx is buggy. It contains bytes that
303// should not belong to the address. Luckily a unix path should not contain `\0`
304// until the end. We can determine the path ending by that.
305#[cfg(windows)]
306#[inline]
307fn fix_unix_socket_length(addr: &mut SockAddr) {
308    use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
309
310    // SAFETY: cannot construct non-unix socket address in safe way.
311    let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
312    let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
313        Ok(str) => str.to_bytes_with_nul().len() + 2,
314        Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
315    };
316    unsafe {
317        addr.set_length(addr_len as _);
318    }
319}