wasi_cap_std_sync/
net.rs

1#[cfg(windows)]
2use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
3use io_lifetimes::AsSocketlike;
4#[cfg(unix)]
5use io_lifetimes::{AsFd, BorrowedFd};
6#[cfg(windows)]
7use io_lifetimes::{AsSocket, BorrowedSocket};
8use std::any::Any;
9use std::convert::TryInto;
10use std::io;
11#[cfg(unix)]
12use system_interface::fs::GetSetFdFlags;
13use system_interface::io::IoExt;
14use system_interface::io::IsReadWrite;
15use system_interface::io::ReadReady;
16use wasi_common::{
17    file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
18    Error, ErrorExt,
19};
20
21pub enum Socket {
22    TcpListener(cap_std::net::TcpListener),
23    TcpStream(cap_std::net::TcpStream),
24    #[cfg(unix)]
25    UnixStream(cap_std::os::unix::net::UnixStream),
26    #[cfg(unix)]
27    UnixListener(cap_std::os::unix::net::UnixListener),
28}
29
30impl From<cap_std::net::TcpListener> for Socket {
31    fn from(listener: cap_std::net::TcpListener) -> Self {
32        Self::TcpListener(listener)
33    }
34}
35
36impl From<cap_std::net::TcpStream> for Socket {
37    fn from(stream: cap_std::net::TcpStream) -> Self {
38        Self::TcpStream(stream)
39    }
40}
41
42#[cfg(unix)]
43impl From<cap_std::os::unix::net::UnixListener> for Socket {
44    fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
45        Self::UnixListener(listener)
46    }
47}
48
49#[cfg(unix)]
50impl From<cap_std::os::unix::net::UnixStream> for Socket {
51    fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
52        Self::UnixStream(stream)
53    }
54}
55
56#[cfg(unix)]
57impl From<Socket> for Box<dyn WasiFile> {
58    fn from(listener: Socket) -> Self {
59        match listener {
60            Socket::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)),
61            Socket::UnixListener(l) => Box::new(crate::net::UnixListener::from_cap_std(l)),
62            Socket::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)),
63            Socket::UnixStream(l) => Box::new(crate::net::UnixStream::from_cap_std(l)),
64        }
65    }
66}
67
68#[cfg(windows)]
69impl From<Socket> for Box<dyn WasiFile> {
70    fn from(listener: Socket) -> Self {
71        match listener {
72            Socket::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)),
73            Socket::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)),
74        }
75    }
76}
77
78macro_rules! wasi_listen_write_impl {
79    ($ty:ty, $stream:ty) => {
80        #[async_trait::async_trait]
81        impl WasiFile for $ty {
82            fn as_any(&self) -> &dyn Any {
83                self
84            }
85            #[cfg(unix)]
86            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
87                Some(self.0.as_fd())
88            }
89            #[cfg(windows)]
90            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
91                Some(self.0.as_raw_handle_or_socket())
92            }
93            async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
94                let (stream, _) = self.0.accept()?;
95                let mut stream = <$stream>::from_cap_std(stream);
96                stream.set_fdflags(fdflags).await?;
97                Ok(Box::new(stream))
98            }
99            async fn get_filetype(&self) -> Result<FileType, Error> {
100                Ok(FileType::SocketStream)
101            }
102            #[cfg(unix)]
103            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
104                let fdflags = get_fd_flags(&self.0)?;
105                Ok(fdflags)
106            }
107            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
108                if fdflags == wasi_common::file::FdFlags::NONBLOCK {
109                    self.0.set_nonblocking(true)?;
110                } else if fdflags.is_empty() {
111                    self.0.set_nonblocking(false)?;
112                } else {
113                    return Err(
114                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
115                    );
116                }
117                Ok(())
118            }
119            fn num_ready_bytes(&self) -> Result<u64, Error> {
120                Ok(1)
121            }
122        }
123
124        #[cfg(windows)]
125        impl AsSocket for $ty {
126            #[inline]
127            fn as_socket(&self) -> BorrowedSocket<'_> {
128                self.0.as_socket()
129            }
130        }
131
132        #[cfg(windows)]
133        impl AsRawHandleOrSocket for $ty {
134            #[inline]
135            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
136                self.0.as_raw_handle_or_socket()
137            }
138        }
139
140        #[cfg(unix)]
141        impl AsFd for $ty {
142            fn as_fd(&self) -> BorrowedFd<'_> {
143                self.0.as_fd()
144            }
145        }
146    };
147}
148
149pub struct TcpListener(cap_std::net::TcpListener);
150
151impl TcpListener {
152    pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
153        TcpListener(cap_std)
154    }
155}
156wasi_listen_write_impl!(TcpListener, TcpStream);
157
158#[cfg(unix)]
159pub struct UnixListener(cap_std::os::unix::net::UnixListener);
160
161#[cfg(unix)]
162impl UnixListener {
163    pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
164        UnixListener(cap_std)
165    }
166}
167
168#[cfg(unix)]
169wasi_listen_write_impl!(UnixListener, UnixStream);
170
171macro_rules! wasi_stream_write_impl {
172    ($ty:ty, $std_ty:ty) => {
173        #[async_trait::async_trait]
174        impl WasiFile for $ty {
175            fn as_any(&self) -> &dyn Any {
176                self
177            }
178            #[cfg(unix)]
179            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
180                Some(self.0.as_fd())
181            }
182            #[cfg(windows)]
183            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
184                Some(self.0.as_raw_handle_or_socket())
185            }
186            async fn get_filetype(&self) -> Result<FileType, Error> {
187                Ok(FileType::SocketStream)
188            }
189            #[cfg(unix)]
190            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
191                let fdflags = get_fd_flags(&self.0)?;
192                Ok(fdflags)
193            }
194            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
195                if fdflags == wasi_common::file::FdFlags::NONBLOCK {
196                    self.0.set_nonblocking(true)?;
197                } else if fdflags.is_empty() {
198                    self.0.set_nonblocking(false)?;
199                } else {
200                    return Err(
201                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
202                    );
203                }
204                Ok(())
205            }
206            async fn read_vectored<'a>(
207                &self,
208                bufs: &mut [io::IoSliceMut<'a>],
209            ) -> Result<u64, Error> {
210                use std::io::Read;
211                let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
212                Ok(n.try_into()?)
213            }
214            async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
215                use std::io::Write;
216                let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
217                Ok(n.try_into()?)
218            }
219            async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
220                let n = self.0.peek(buf)?;
221                Ok(n.try_into()?)
222            }
223            fn num_ready_bytes(&self) -> Result<u64, Error> {
224                let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
225                Ok(val)
226            }
227            async fn readable(&self) -> Result<(), Error> {
228                let (readable, _writeable) = is_read_write(&self.0)?;
229                if readable {
230                    Ok(())
231                } else {
232                    Err(Error::io())
233                }
234            }
235            async fn writable(&self) -> Result<(), Error> {
236                let (_readable, writeable) = is_read_write(&self.0)?;
237                if writeable {
238                    Ok(())
239                } else {
240                    Err(Error::io())
241                }
242            }
243
244            async fn sock_recv<'a>(
245                &self,
246                ri_data: &mut [std::io::IoSliceMut<'a>],
247                ri_flags: RiFlags,
248            ) -> Result<(u64, RoFlags), Error> {
249                if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
250                    return Err(Error::not_supported());
251                }
252
253                if ri_flags.contains(RiFlags::RECV_PEEK) {
254                    if let Some(first) = ri_data.iter_mut().next() {
255                        let n = self.0.peek(first)?;
256                        return Ok((n as u64, RoFlags::empty()));
257                    } else {
258                        return Ok((0, RoFlags::empty()));
259                    }
260                }
261
262                if ri_flags.contains(RiFlags::RECV_WAITALL) {
263                    let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
264                    self.0.read_exact_vectored(ri_data)?;
265                    return Ok((n as u64, RoFlags::empty()));
266                }
267
268                let n = self.0.read_vectored(ri_data)?;
269                Ok((n as u64, RoFlags::empty()))
270            }
271
272            async fn sock_send<'a>(
273                &self,
274                si_data: &[std::io::IoSlice<'a>],
275                si_flags: SiFlags,
276            ) -> Result<u64, Error> {
277                if si_flags != SiFlags::empty() {
278                    return Err(Error::not_supported());
279                }
280
281                let n = self.0.write_vectored(si_data)?;
282                Ok(n as u64)
283            }
284
285            async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
286                let how = if how == SdFlags::RD | SdFlags::WR {
287                    cap_std::net::Shutdown::Both
288                } else if how == SdFlags::RD {
289                    cap_std::net::Shutdown::Read
290                } else if how == SdFlags::WR {
291                    cap_std::net::Shutdown::Write
292                } else {
293                    return Err(Error::invalid_argument());
294                };
295                self.0.shutdown(how)?;
296                Ok(())
297            }
298        }
299        #[cfg(unix)]
300        impl AsFd for $ty {
301            fn as_fd(&self) -> BorrowedFd<'_> {
302                self.0.as_fd()
303            }
304        }
305
306        #[cfg(windows)]
307        impl AsSocket for $ty {
308            /// Borrows the socket.
309            fn as_socket(&self) -> BorrowedSocket<'_> {
310                self.0.as_socket()
311            }
312        }
313
314        #[cfg(windows)]
315        impl AsRawHandleOrSocket for TcpStream {
316            #[inline]
317            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
318                self.0.as_raw_handle_or_socket()
319            }
320        }
321    };
322}
323
324pub struct TcpStream(cap_std::net::TcpStream);
325
326impl TcpStream {
327    pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
328        TcpStream(socket)
329    }
330}
331
332wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
333
334#[cfg(unix)]
335pub struct UnixStream(cap_std::os::unix::net::UnixStream);
336
337#[cfg(unix)]
338impl UnixStream {
339    pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
340        UnixStream(socket)
341    }
342}
343
344#[cfg(unix)]
345wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
346
347pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
348    use cap_fs_ext::FileTypeExt;
349    if ft.is_block_device() {
350        FileType::SocketDgram
351    } else {
352        FileType::SocketStream
353    }
354}
355
356/// Return the file-descriptor flags for a given file-like object.
357///
358/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
359pub fn get_fd_flags<Socketlike: AsSocketlike>(
360    f: Socketlike,
361) -> io::Result<wasi_common::file::FdFlags> {
362    // On Unix-family platforms, we can use the same system call that we'd use
363    // for files on sockets here.
364    #[cfg(not(windows))]
365    {
366        let mut out = wasi_common::file::FdFlags::empty();
367        if f.get_fd_flags()?
368            .contains(system_interface::fs::FdFlags::NONBLOCK)
369        {
370            out |= wasi_common::file::FdFlags::NONBLOCK;
371        }
372        Ok(out)
373    }
374
375    // On Windows, sockets are different, and there is no direct way to
376    // query for the non-blocking flag. We can get a sufficient approximation
377    // by testing whether a zero-length `recv` appears to block.
378    #[cfg(windows)]
379    match rustix::net::recv(f, &mut [], rustix::net::RecvFlags::empty()) {
380        Ok(_) => Ok(wasi_common::file::FdFlags::empty()),
381        Err(rustix::io::Errno::WOULDBLOCK) => Ok(wasi_common::file::FdFlags::NONBLOCK),
382        Err(e) => Err(e.into()),
383    }
384}
385
386/// Return the file-descriptor flags for a given file-like object.
387///
388/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
389pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
390    // On Unix-family platforms, we have an `IsReadWrite` impl.
391    #[cfg(not(windows))]
392    {
393        f.is_read_write()
394    }
395
396    // On Windows, we only have a `TcpStream` impl, so make a view first.
397    #[cfg(windows)]
398    {
399        f.as_socketlike_view::<std::net::TcpStream>()
400            .is_read_write()
401    }
402}