compio_net/
unix.rs

1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{SockAddr, Socket as Socket2, Type};
7
8use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
9
10/// A Unix socket server, listening for connections.
11///
12/// You can accept a new connection by using the [`UnixListener::accept`]
13/// method.
14///
15/// # Examples
16///
17/// ```
18/// use compio_io::{AsyncReadExt, AsyncWriteExt};
19/// use compio_net::{UnixListener, UnixStream};
20/// use tempfile::tempdir;
21///
22/// let dir = tempdir().unwrap();
23/// let sock_file = dir.path().join("unix-server.sock");
24///
25/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
26/// let listener = UnixListener::bind(&sock_file).await.unwrap();
27///
28/// let (mut tx, (mut rx, _)) =
29///     futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
30///
31/// tx.write_all("test").await.0.unwrap();
32///
33/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
34///
35/// assert_eq!(buf, b"test");
36/// # });
37/// ```
38#[derive(Debug, Clone)]
39pub struct UnixListener {
40    inner: Socket,
41}
42
43impl UnixListener {
44    /// Creates a new [`UnixListener`], which will be bound to the specified
45    /// file path. The file path cannot yet exist, and will be cleaned up
46    /// upon dropping [`UnixListener`]
47    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
48        Self::bind_addr(&SockAddr::unix(path)?).await
49    }
50
51    /// Creates a new [`UnixListener`] with [`SockAddr`], which will be bound to
52    /// the specified file path. The file path cannot yet exist, and will be
53    /// cleaned up upon dropping [`UnixListener`]
54    pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
55        if !addr.is_unix() {
56            return Err(io::Error::new(
57                io::ErrorKind::InvalidInput,
58                "addr is not unix socket address",
59            ));
60        }
61
62        let socket = Socket::bind(addr, Type::STREAM, None).await?;
63        socket.listen(1024)?;
64        Ok(UnixListener { inner: socket })
65    }
66
67    #[cfg(unix)]
68    /// Creates new UnixListener from a [`std::os::unix::net::UnixListener`].
69    pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
70        Ok(Self {
71            inner: Socket::from_socket2(Socket2::from(stream))?,
72        })
73    }
74
75    /// Close the socket. If the returned future is dropped before polling, the
76    /// socket won't be closed.
77    pub fn close(self) -> impl Future<Output = io::Result<()>> {
78        self.inner.close()
79    }
80
81    /// Accepts a new incoming connection from this listener.
82    ///
83    /// This function will yield once a new Unix domain socket connection
84    /// is established. When established, the corresponding [`UnixStream`] and
85    /// will be returned.
86    pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
87        let (socket, addr) = self.inner.accept().await?;
88        let stream = UnixStream { inner: socket };
89        Ok((stream, addr))
90    }
91
92    /// Returns the local address that this listener is bound to.
93    pub fn local_addr(&self) -> io::Result<SockAddr> {
94        self.inner.local_addr()
95    }
96}
97
98impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
99
100/// A Unix stream between two local sockets on Windows & WSL.
101///
102/// A Unix stream can either be created by connecting to an endpoint, via the
103/// `connect` method, or by accepting a connection from a listener.
104///
105/// # Examples
106///
107/// ```no_run
108/// use compio_io::AsyncWrite;
109/// use compio_net::UnixStream;
110///
111/// # compio_runtime::Runtime::new().unwrap().block_on(async {
112/// // Connect to a peer
113/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
114///
115/// // Write some data.
116/// stream.write("hello world!").await.unwrap();
117/// # })
118/// ```
119#[derive(Debug, Clone)]
120pub struct UnixStream {
121    inner: Socket,
122}
123
124impl UnixStream {
125    /// Opens a Unix connection to the specified file path. There must be a
126    /// [`UnixListener`] or equivalent listening on the corresponding Unix
127    /// domain socket to successfully connect and return a `UnixStream`.
128    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
129        Self::connect_addr(&SockAddr::unix(path)?).await
130    }
131
132    /// Opens a Unix connection to the specified address. There must be a
133    /// [`UnixListener`] or equivalent listening on the corresponding Unix
134    /// domain socket to successfully connect and return a `UnixStream`.
135    pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
136        if !addr.is_unix() {
137            return Err(io::Error::new(
138                io::ErrorKind::InvalidInput,
139                "addr is not unix socket address",
140            ));
141        }
142
143        #[cfg(windows)]
144        let socket = {
145            let new_addr = empty_unix_socket();
146            Socket::bind(&new_addr, Type::STREAM, None).await?
147        };
148        #[cfg(unix)]
149        let socket = {
150            use socket2::Domain;
151            Socket::new(Domain::UNIX, Type::STREAM, None).await?
152        };
153        socket.connect_async(addr).await?;
154        let unix_stream = UnixStream { inner: socket };
155        Ok(unix_stream)
156    }
157
158    #[cfg(unix)]
159    /// Creates new UnixStream from a [`std::os::unix::net::UnixStream`].
160    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
161        Ok(Self {
162            inner: Socket::from_socket2(Socket2::from(stream))?,
163        })
164    }
165
166    /// Close the socket. If the returned future is dropped before polling, the
167    /// socket won't be closed.
168    pub fn close(self) -> impl Future<Output = io::Result<()>> {
169        self.inner.close()
170    }
171
172    /// Returns the socket path of the remote peer of this connection.
173    pub fn peer_addr(&self) -> io::Result<SockAddr> {
174        #[allow(unused_mut)]
175        let mut addr = self.inner.peer_addr()?;
176        #[cfg(windows)]
177        {
178            fix_unix_socket_length(&mut addr);
179        }
180        Ok(addr)
181    }
182
183    /// Returns the socket path of the local half of this connection.
184    pub fn local_addr(&self) -> io::Result<SockAddr> {
185        self.inner.local_addr()
186    }
187
188    /// Splits a [`UnixStream`] into a read half and a write half, which can be
189    /// used to read and write the stream concurrently.
190    ///
191    /// This method is more efficient than
192    /// [`into_split`](UnixStream::into_split), but the halves cannot
193    /// be moved into independently spawned tasks.
194    pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
195        crate::split(self)
196    }
197
198    /// Splits a [`UnixStream`] into a read half and a write half, which can be
199    /// used to read and write the stream concurrently.
200    ///
201    /// Unlike [`split`](UnixStream::split), the owned halves can be moved to
202    /// separate tasks, however this comes at the cost of a heap allocation.
203    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
204        crate::into_split(self)
205    }
206
207    /// Create [`PollFd`] from inner socket.
208    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
209        self.inner.to_poll_fd()
210    }
211
212    /// Create [`PollFd`] from inner socket.
213    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
214        self.inner.into_poll_fd()
215    }
216}
217
218impl AsyncRead for UnixStream {
219    #[inline]
220    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
221        (&*self).read(buf).await
222    }
223
224    #[inline]
225    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
226        (&*self).read_vectored(buf).await
227    }
228}
229
230impl AsyncRead for &UnixStream {
231    #[inline]
232    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
233        self.inner.recv(buf).await
234    }
235
236    #[inline]
237    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
238        self.inner.recv_vectored(buf).await
239    }
240}
241
242impl AsyncWrite for UnixStream {
243    #[inline]
244    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
245        (&*self).write(buf).await
246    }
247
248    #[inline]
249    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
250        (&*self).write_vectored(buf).await
251    }
252
253    #[inline]
254    async fn flush(&mut self) -> io::Result<()> {
255        (&*self).flush().await
256    }
257
258    #[inline]
259    async fn shutdown(&mut self) -> io::Result<()> {
260        (&*self).shutdown().await
261    }
262}
263
264impl AsyncWrite for &UnixStream {
265    #[inline]
266    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
267        self.inner.send(buf).await
268    }
269
270    #[inline]
271    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
272        self.inner.send_vectored(buf).await
273    }
274
275    #[inline]
276    async fn flush(&mut self) -> io::Result<()> {
277        Ok(())
278    }
279
280    #[inline]
281    async fn shutdown(&mut self) -> io::Result<()> {
282        self.inner.shutdown().await
283    }
284}
285
286impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
287
288#[cfg(windows)]
289#[inline]
290fn empty_unix_socket() -> SockAddr {
291    use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
292
293    // SAFETY: the length is correct
294    unsafe {
295        SockAddr::try_init(|addr, len| {
296            let addr: *mut SOCKADDR_UN = addr.cast();
297            std::ptr::write(
298                addr,
299                SOCKADDR_UN {
300                    sun_family: AF_UNIX,
301                    sun_path: [0; 108],
302                },
303            );
304            std::ptr::write(len, 3);
305            Ok(())
306        })
307    }
308    // it is always Ok
309    .unwrap()
310    .1
311}
312
313// The peer addr returned after ConnectEx is buggy. It contains bytes that
314// should not belong to the address. Luckily a unix path should not contain `\0`
315// until the end. We can determine the path ending by that.
316#[cfg(windows)]
317#[inline]
318fn fix_unix_socket_length(addr: &mut SockAddr) {
319    use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
320
321    // SAFETY: cannot construct non-unix socket address in safe way.
322    let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
323    let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
324        Ok(str) => str.to_bytes_with_nul().len() + 2,
325        Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
326    };
327    unsafe {
328        addr.set_length(addr_len as _);
329    }
330}