1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{SockAddr, Socket as Socket2, Type};
7
8use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
9
10#[derive(Debug, Clone)]
39pub struct UnixListener {
40 inner: Socket,
41}
42
43impl UnixListener {
44 pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
48 Self::bind_addr(&SockAddr::unix(path)?).await
49 }
50
51 pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
55 if !addr.is_unix() {
56 return Err(io::Error::new(
57 io::ErrorKind::InvalidInput,
58 "addr is not unix socket address",
59 ));
60 }
61
62 let socket = Socket::bind(addr, Type::STREAM, None).await?;
63 socket.listen(1024)?;
64 Ok(UnixListener { inner: socket })
65 }
66
67 #[cfg(unix)]
68 pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
70 Ok(Self {
71 inner: Socket::from_socket2(Socket2::from(stream))?,
72 })
73 }
74
75 pub fn close(self) -> impl Future<Output = io::Result<()>> {
78 self.inner.close()
79 }
80
81 pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
87 let (socket, addr) = self.inner.accept().await?;
88 let stream = UnixStream { inner: socket };
89 Ok((stream, addr))
90 }
91
92 pub fn local_addr(&self) -> io::Result<SockAddr> {
94 self.inner.local_addr()
95 }
96}
97
98impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
99
100#[derive(Debug, Clone)]
120pub struct UnixStream {
121 inner: Socket,
122}
123
124impl UnixStream {
125 pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
129 Self::connect_addr(&SockAddr::unix(path)?).await
130 }
131
132 pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
136 if !addr.is_unix() {
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidInput,
139 "addr is not unix socket address",
140 ));
141 }
142
143 #[cfg(windows)]
144 let socket = {
145 let new_addr = empty_unix_socket();
146 Socket::bind(&new_addr, Type::STREAM, None).await?
147 };
148 #[cfg(unix)]
149 let socket = {
150 use socket2::Domain;
151 Socket::new(Domain::UNIX, Type::STREAM, None).await?
152 };
153 socket.connect_async(addr).await?;
154 let unix_stream = UnixStream { inner: socket };
155 Ok(unix_stream)
156 }
157
158 #[cfg(unix)]
159 pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
161 Ok(Self {
162 inner: Socket::from_socket2(Socket2::from(stream))?,
163 })
164 }
165
166 pub fn close(self) -> impl Future<Output = io::Result<()>> {
169 self.inner.close()
170 }
171
172 pub fn peer_addr(&self) -> io::Result<SockAddr> {
174 #[allow(unused_mut)]
175 let mut addr = self.inner.peer_addr()?;
176 #[cfg(windows)]
177 {
178 fix_unix_socket_length(&mut addr);
179 }
180 Ok(addr)
181 }
182
183 pub fn local_addr(&self) -> io::Result<SockAddr> {
185 self.inner.local_addr()
186 }
187
188 pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
195 crate::split(self)
196 }
197
198 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
204 crate::into_split(self)
205 }
206
207 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
209 self.inner.to_poll_fd()
210 }
211
212 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
214 self.inner.into_poll_fd()
215 }
216}
217
218impl AsyncRead for UnixStream {
219 #[inline]
220 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
221 (&*self).read(buf).await
222 }
223
224 #[inline]
225 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
226 (&*self).read_vectored(buf).await
227 }
228}
229
230impl AsyncRead for &UnixStream {
231 #[inline]
232 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
233 self.inner.recv(buf).await
234 }
235
236 #[inline]
237 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
238 self.inner.recv_vectored(buf).await
239 }
240}
241
242impl AsyncWrite for UnixStream {
243 #[inline]
244 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
245 (&*self).write(buf).await
246 }
247
248 #[inline]
249 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
250 (&*self).write_vectored(buf).await
251 }
252
253 #[inline]
254 async fn flush(&mut self) -> io::Result<()> {
255 (&*self).flush().await
256 }
257
258 #[inline]
259 async fn shutdown(&mut self) -> io::Result<()> {
260 (&*self).shutdown().await
261 }
262}
263
264impl AsyncWrite for &UnixStream {
265 #[inline]
266 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
267 self.inner.send(buf).await
268 }
269
270 #[inline]
271 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
272 self.inner.send_vectored(buf).await
273 }
274
275 #[inline]
276 async fn flush(&mut self) -> io::Result<()> {
277 Ok(())
278 }
279
280 #[inline]
281 async fn shutdown(&mut self) -> io::Result<()> {
282 self.inner.shutdown().await
283 }
284}
285
286impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
287
288#[cfg(windows)]
289#[inline]
290fn empty_unix_socket() -> SockAddr {
291 use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
292
293 unsafe {
295 SockAddr::try_init(|addr, len| {
296 let addr: *mut SOCKADDR_UN = addr.cast();
297 std::ptr::write(
298 addr,
299 SOCKADDR_UN {
300 sun_family: AF_UNIX,
301 sun_path: [0; 108],
302 },
303 );
304 std::ptr::write(len, 3);
305 Ok(())
306 })
307 }
308 .unwrap()
310 .1
311}
312
313#[cfg(windows)]
317#[inline]
318fn fix_unix_socket_length(addr: &mut SockAddr) {
319 use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
320
321 let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
323 let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
324 Ok(str) => str.to_bytes_with_nul().len() + 2,
325 Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
326 };
327 unsafe {
328 addr.set_length(addr_len as _);
329 }
330}