wasi_common/sync/
net.rs

1use crate::{
2    file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
3    Error, ErrorExt,
4};
5#[cfg(windows)]
6use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
7use io_lifetimes::AsSocketlike;
8#[cfg(unix)]
9use io_lifetimes::{AsFd, BorrowedFd};
10#[cfg(windows)]
11use io_lifetimes::{AsSocket, BorrowedSocket};
12use std::any::Any;
13use std::io;
14#[cfg(unix)]
15use system_interface::fs::GetSetFdFlags;
16use system_interface::io::IoExt;
17use system_interface::io::IsReadWrite;
18use system_interface::io::ReadReady;
19
20pub enum Socket {
21    TcpListener(cap_std::net::TcpListener),
22    TcpStream(cap_std::net::TcpStream),
23    #[cfg(unix)]
24    UnixStream(cap_std::os::unix::net::UnixStream),
25    #[cfg(unix)]
26    UnixListener(cap_std::os::unix::net::UnixListener),
27}
28
29impl From<cap_std::net::TcpListener> for Socket {
30    fn from(listener: cap_std::net::TcpListener) -> Self {
31        Self::TcpListener(listener)
32    }
33}
34
35impl From<cap_std::net::TcpStream> for Socket {
36    fn from(stream: cap_std::net::TcpStream) -> Self {
37        Self::TcpStream(stream)
38    }
39}
40
41#[cfg(unix)]
42impl From<cap_std::os::unix::net::UnixListener> for Socket {
43    fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
44        Self::UnixListener(listener)
45    }
46}
47
48#[cfg(unix)]
49impl From<cap_std::os::unix::net::UnixStream> for Socket {
50    fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
51        Self::UnixStream(stream)
52    }
53}
54
55#[cfg(unix)]
56impl From<Socket> for Box<dyn WasiFile> {
57    fn from(listener: Socket) -> Self {
58        match listener {
59            Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
60            Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),
61            Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
62            Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),
63        }
64    }
65}
66
67#[cfg(windows)]
68impl From<Socket> for Box<dyn WasiFile> {
69    fn from(listener: Socket) -> Self {
70        match listener {
71            Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
72            Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
73        }
74    }
75}
76
77macro_rules! wasi_listen_write_impl {
78    ($ty:ty, $stream:ty) => {
79        #[wiggle::async_trait]
80        impl WasiFile for $ty {
81            fn as_any(&self) -> &dyn Any {
82                self
83            }
84            #[cfg(unix)]
85            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
86                Some(self.0.as_fd())
87            }
88            #[cfg(windows)]
89            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
90                Some(self.0.as_raw_handle_or_socket())
91            }
92            async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
93                let (stream, _) = self.0.accept()?;
94                let mut stream = <$stream>::from_cap_std(stream);
95                stream.set_fdflags(fdflags).await?;
96                Ok(Box::new(stream))
97            }
98            async fn get_filetype(&self) -> Result<FileType, Error> {
99                Ok(FileType::SocketStream)
100            }
101            #[cfg(unix)]
102            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
103                let fdflags = get_fd_flags(&self.0)?;
104                Ok(fdflags)
105            }
106            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
107                if fdflags == crate::file::FdFlags::NONBLOCK {
108                    self.0.set_nonblocking(true)?;
109                } else if fdflags.is_empty() {
110                    self.0.set_nonblocking(false)?;
111                } else {
112                    return Err(
113                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
114                    );
115                }
116                Ok(())
117            }
118            fn num_ready_bytes(&self) -> Result<u64, Error> {
119                Ok(1)
120            }
121        }
122
123        #[cfg(windows)]
124        impl AsSocket for $ty {
125            #[inline]
126            fn as_socket(&self) -> BorrowedSocket<'_> {
127                self.0.as_socket()
128            }
129        }
130
131        #[cfg(windows)]
132        impl AsRawHandleOrSocket for $ty {
133            #[inline]
134            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
135                self.0.as_raw_handle_or_socket()
136            }
137        }
138
139        #[cfg(unix)]
140        impl AsFd for $ty {
141            fn as_fd(&self) -> BorrowedFd<'_> {
142                self.0.as_fd()
143            }
144        }
145    };
146}
147
148pub struct TcpListener(cap_std::net::TcpListener);
149
150impl TcpListener {
151    pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
152        TcpListener(cap_std)
153    }
154}
155wasi_listen_write_impl!(TcpListener, TcpStream);
156
157#[cfg(unix)]
158pub struct UnixListener(cap_std::os::unix::net::UnixListener);
159
160#[cfg(unix)]
161impl UnixListener {
162    pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
163        UnixListener(cap_std)
164    }
165}
166
167#[cfg(unix)]
168wasi_listen_write_impl!(UnixListener, UnixStream);
169
170macro_rules! wasi_stream_write_impl {
171    ($ty:ty, $std_ty:ty) => {
172        #[wiggle::async_trait]
173        impl WasiFile for $ty {
174            fn as_any(&self) -> &dyn Any {
175                self
176            }
177            #[cfg(unix)]
178            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
179                Some(self.0.as_fd())
180            }
181            #[cfg(windows)]
182            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
183                Some(self.0.as_raw_handle_or_socket())
184            }
185            async fn get_filetype(&self) -> Result<FileType, Error> {
186                Ok(FileType::SocketStream)
187            }
188            #[cfg(unix)]
189            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
190                let fdflags = get_fd_flags(&self.0)?;
191                Ok(fdflags)
192            }
193            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
194                if fdflags == crate::file::FdFlags::NONBLOCK {
195                    self.0.set_nonblocking(true)?;
196                } else if fdflags.is_empty() {
197                    self.0.set_nonblocking(false)?;
198                } else {
199                    return Err(
200                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
201                    );
202                }
203                Ok(())
204            }
205            async fn read_vectored<'a>(
206                &self,
207                bufs: &mut [io::IoSliceMut<'a>],
208            ) -> Result<u64, Error> {
209                use std::io::Read;
210                let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
211                Ok(n.try_into()?)
212            }
213            async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
214                use std::io::Write;
215                let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
216                Ok(n.try_into()?)
217            }
218            async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
219                let n = self.0.peek(buf)?;
220                Ok(n.try_into()?)
221            }
222            fn num_ready_bytes(&self) -> Result<u64, Error> {
223                let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
224                Ok(val)
225            }
226            async fn readable(&self) -> Result<(), Error> {
227                let (readable, _writeable) = is_read_write(&self.0)?;
228                if readable {
229                    Ok(())
230                } else {
231                    Err(Error::io())
232                }
233            }
234            async fn writable(&self) -> Result<(), Error> {
235                let (_readable, writeable) = is_read_write(&self.0)?;
236                if writeable {
237                    Ok(())
238                } else {
239                    Err(Error::io())
240                }
241            }
242
243            async fn sock_recv<'a>(
244                &self,
245                ri_data: &mut [std::io::IoSliceMut<'a>],
246                ri_flags: RiFlags,
247            ) -> Result<(u64, RoFlags), Error> {
248                if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
249                    return Err(Error::not_supported());
250                }
251
252                if ri_flags.contains(RiFlags::RECV_PEEK) {
253                    if let Some(first) = ri_data.iter_mut().next() {
254                        let n = self.0.peek(first)?;
255                        return Ok((n as u64, RoFlags::empty()));
256                    } else {
257                        return Ok((0, RoFlags::empty()));
258                    }
259                }
260
261                if ri_flags.contains(RiFlags::RECV_WAITALL) {
262                    let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
263                    self.0.read_exact_vectored(ri_data)?;
264                    return Ok((n as u64, RoFlags::empty()));
265                }
266
267                let n = self.0.read_vectored(ri_data)?;
268                Ok((n as u64, RoFlags::empty()))
269            }
270
271            async fn sock_send<'a>(
272                &self,
273                si_data: &[std::io::IoSlice<'a>],
274                si_flags: SiFlags,
275            ) -> Result<u64, Error> {
276                if si_flags != SiFlags::empty() {
277                    return Err(Error::not_supported());
278                }
279
280                let n = self.0.write_vectored(si_data)?;
281                Ok(n as u64)
282            }
283
284            async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
285                let how = if how == SdFlags::RD | SdFlags::WR {
286                    cap_std::net::Shutdown::Both
287                } else if how == SdFlags::RD {
288                    cap_std::net::Shutdown::Read
289                } else if how == SdFlags::WR {
290                    cap_std::net::Shutdown::Write
291                } else {
292                    return Err(Error::invalid_argument());
293                };
294                self.0.shutdown(how)?;
295                Ok(())
296            }
297        }
298        #[cfg(unix)]
299        impl AsFd for $ty {
300            fn as_fd(&self) -> BorrowedFd<'_> {
301                self.0.as_fd()
302            }
303        }
304
305        #[cfg(windows)]
306        impl AsSocket for $ty {
307            /// Borrows the socket.
308            fn as_socket(&self) -> BorrowedSocket<'_> {
309                self.0.as_socket()
310            }
311        }
312
313        #[cfg(windows)]
314        impl AsRawHandleOrSocket for TcpStream {
315            #[inline]
316            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
317                self.0.as_raw_handle_or_socket()
318            }
319        }
320    };
321}
322
323pub struct TcpStream(cap_std::net::TcpStream);
324
325impl TcpStream {
326    pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
327        TcpStream(socket)
328    }
329}
330
331wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
332
333#[cfg(unix)]
334pub struct UnixStream(cap_std::os::unix::net::UnixStream);
335
336#[cfg(unix)]
337impl UnixStream {
338    pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
339        UnixStream(socket)
340    }
341}
342
343#[cfg(unix)]
344wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
345
346pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
347    use cap_fs_ext::FileTypeExt;
348    if ft.is_block_device() {
349        FileType::SocketDgram
350    } else {
351        FileType::SocketStream
352    }
353}
354
355/// Return the file-descriptor flags for a given file-like object.
356///
357/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
358pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {
359    // On Unix-family platforms, we can use the same system call that we'd use
360    // for files on sockets here.
361    #[cfg(not(windows))]
362    {
363        let mut out = crate::file::FdFlags::empty();
364        if f.get_fd_flags()?
365            .contains(system_interface::fs::FdFlags::NONBLOCK)
366        {
367            out |= crate::file::FdFlags::NONBLOCK;
368        }
369        Ok(out)
370    }
371
372    // On Windows, sockets are different, and there is no direct way to
373    // query for the non-blocking flag. We can get a sufficient approximation
374    // by testing whether a zero-length `recv` appears to block.
375    #[cfg(windows)]
376    match rustix::net::recv(f, &mut [], rustix::net::RecvFlags::empty()) {
377        Ok(_) => Ok(crate::file::FdFlags::empty()),
378        Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),
379        Err(e) => Err(e.into()),
380    }
381}
382
383/// Return the file-descriptor flags for a given file-like object.
384///
385/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
386pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
387    // On Unix-family platforms, we have an `IsReadWrite` impl.
388    #[cfg(not(windows))]
389    {
390        f.is_read_write()
391    }
392
393    // On Windows, we only have a `TcpStream` impl, so make a view first.
394    #[cfg(windows)]
395    {
396        f.as_socketlike_view::<std::net::TcpStream>()
397            .is_read_write()
398    }
399}