1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{SockAddr, Socket as Socket2, Type};
7
8use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
9
10#[derive(Debug, Clone)]
39pub struct UnixListener {
40 inner: Socket,
41}
42
43impl UnixListener {
44 pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
48 Self::bind_addr(&SockAddr::unix(path)?).await
49 }
50
51 pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
55 if !addr.is_unix() {
56 return Err(io::Error::new(
57 io::ErrorKind::InvalidInput,
58 "addr is not unix socket address",
59 ));
60 }
61
62 let socket = Socket::bind(addr, Type::STREAM, None).await?;
63 socket.listen(1024)?;
64 Ok(UnixListener { inner: socket })
65 }
66
67 pub fn close(self) -> impl Future<Output = io::Result<()>> {
70 self.inner.close()
71 }
72
73 pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
79 let (socket, addr) = self.inner.accept().await?;
80 let stream = UnixStream { inner: socket };
81 Ok((stream, addr))
82 }
83
84 pub fn local_addr(&self) -> io::Result<SockAddr> {
86 self.inner.local_addr()
87 }
88}
89
90impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
91
92#[derive(Debug, Clone)]
112pub struct UnixStream {
113 inner: Socket,
114}
115
116impl UnixStream {
117 pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
121 Self::connect_addr(&SockAddr::unix(path)?).await
122 }
123
124 pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
128 if !addr.is_unix() {
129 return Err(io::Error::new(
130 io::ErrorKind::InvalidInput,
131 "addr is not unix socket address",
132 ));
133 }
134
135 #[cfg(windows)]
136 let socket = {
137 let new_addr = empty_unix_socket();
138 Socket::bind(&new_addr, Type::STREAM, None).await?
139 };
140 #[cfg(unix)]
141 let socket = {
142 use socket2::Domain;
143 Socket::new(Domain::UNIX, Type::STREAM, None).await?
144 };
145 socket.connect_async(addr).await?;
146 let unix_stream = UnixStream { inner: socket };
147 Ok(unix_stream)
148 }
149
150 #[cfg(unix)]
151 pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
153 Ok(Self {
154 inner: Socket::from_socket2(Socket2::from(stream))?,
155 })
156 }
157
158 pub fn close(self) -> impl Future<Output = io::Result<()>> {
161 self.inner.close()
162 }
163
164 pub fn peer_addr(&self) -> io::Result<SockAddr> {
166 #[allow(unused_mut)]
167 let mut addr = self.inner.peer_addr()?;
168 #[cfg(windows)]
169 {
170 fix_unix_socket_length(&mut addr);
171 }
172 Ok(addr)
173 }
174
175 pub fn local_addr(&self) -> io::Result<SockAddr> {
177 self.inner.local_addr()
178 }
179
180 pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
187 crate::split(self)
188 }
189
190 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
196 crate::into_split(self)
197 }
198
199 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
201 self.inner.to_poll_fd()
202 }
203
204 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
206 self.inner.into_poll_fd()
207 }
208}
209
210impl AsyncRead for UnixStream {
211 #[inline]
212 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
213 (&*self).read(buf).await
214 }
215
216 #[inline]
217 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
218 (&*self).read_vectored(buf).await
219 }
220}
221
222impl AsyncRead for &UnixStream {
223 #[inline]
224 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
225 self.inner.recv(buf).await
226 }
227
228 #[inline]
229 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
230 self.inner.recv_vectored(buf).await
231 }
232}
233
234impl AsyncWrite for UnixStream {
235 #[inline]
236 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
237 (&*self).write(buf).await
238 }
239
240 #[inline]
241 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
242 (&*self).write_vectored(buf).await
243 }
244
245 #[inline]
246 async fn flush(&mut self) -> io::Result<()> {
247 (&*self).flush().await
248 }
249
250 #[inline]
251 async fn shutdown(&mut self) -> io::Result<()> {
252 (&*self).shutdown().await
253 }
254}
255
256impl AsyncWrite for &UnixStream {
257 #[inline]
258 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
259 self.inner.send(buf).await
260 }
261
262 #[inline]
263 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
264 self.inner.send_vectored(buf).await
265 }
266
267 #[inline]
268 async fn flush(&mut self) -> io::Result<()> {
269 Ok(())
270 }
271
272 #[inline]
273 async fn shutdown(&mut self) -> io::Result<()> {
274 self.inner.shutdown().await
275 }
276}
277
278impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
279
280#[cfg(windows)]
281#[inline]
282fn empty_unix_socket() -> SockAddr {
283 use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
284
285 unsafe {
287 SockAddr::try_init(|addr, len| {
288 let addr: *mut SOCKADDR_UN = addr.cast();
289 std::ptr::write(addr, SOCKADDR_UN {
290 sun_family: AF_UNIX,
291 sun_path: [0; 108],
292 });
293 std::ptr::write(len, 3);
294 Ok(())
295 })
296 }
297 .unwrap()
299 .1
300}
301
302#[cfg(windows)]
306#[inline]
307fn fix_unix_socket_length(addr: &mut SockAddr) {
308 use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
309
310 let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
312 let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
313 Ok(str) => str.to_bytes_with_nul().len() + 2,
314 Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
315 };
316 unsafe {
317 addr.set_length(addr_len as _);
318 }
319}