1#[cfg(windows)]
2use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
3use io_lifetimes::AsSocketlike;
4#[cfg(unix)]
5use io_lifetimes::{AsFd, BorrowedFd};
6#[cfg(windows)]
7use io_lifetimes::{AsSocket, BorrowedSocket};
8use std::any::Any;
9use std::convert::TryInto;
10use std::io;
11#[cfg(unix)]
12use system_interface::fs::GetSetFdFlags;
13use system_interface::io::IoExt;
14use system_interface::io::IsReadWrite;
15use system_interface::io::ReadReady;
16use wasi_common::{
17 file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
18 Error, ErrorExt,
19};
20
21pub enum Socket {
22 TcpListener(cap_std::net::TcpListener),
23 TcpStream(cap_std::net::TcpStream),
24 #[cfg(unix)]
25 UnixStream(cap_std::os::unix::net::UnixStream),
26 #[cfg(unix)]
27 UnixListener(cap_std::os::unix::net::UnixListener),
28}
29
30impl From<cap_std::net::TcpListener> for Socket {
31 fn from(listener: cap_std::net::TcpListener) -> Self {
32 Self::TcpListener(listener)
33 }
34}
35
36impl From<cap_std::net::TcpStream> for Socket {
37 fn from(stream: cap_std::net::TcpStream) -> Self {
38 Self::TcpStream(stream)
39 }
40}
41
42#[cfg(unix)]
43impl From<cap_std::os::unix::net::UnixListener> for Socket {
44 fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
45 Self::UnixListener(listener)
46 }
47}
48
49#[cfg(unix)]
50impl From<cap_std::os::unix::net::UnixStream> for Socket {
51 fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
52 Self::UnixStream(stream)
53 }
54}
55
56#[cfg(unix)]
57impl From<Socket> for Box<dyn WasiFile> {
58 fn from(listener: Socket) -> Self {
59 match listener {
60 Socket::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)),
61 Socket::UnixListener(l) => Box::new(crate::net::UnixListener::from_cap_std(l)),
62 Socket::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)),
63 Socket::UnixStream(l) => Box::new(crate::net::UnixStream::from_cap_std(l)),
64 }
65 }
66}
67
68#[cfg(windows)]
69impl From<Socket> for Box<dyn WasiFile> {
70 fn from(listener: Socket) -> Self {
71 match listener {
72 Socket::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)),
73 Socket::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)),
74 }
75 }
76}
77
78macro_rules! wasi_listen_write_impl {
79 ($ty:ty, $stream:ty) => {
80 #[async_trait::async_trait]
81 impl WasiFile for $ty {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85 #[cfg(unix)]
86 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
87 Some(self.0.as_fd())
88 }
89 #[cfg(windows)]
90 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
91 Some(self.0.as_raw_handle_or_socket())
92 }
93 async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
94 let (stream, _) = self.0.accept()?;
95 let mut stream = <$stream>::from_cap_std(stream);
96 stream.set_fdflags(fdflags).await?;
97 Ok(Box::new(stream))
98 }
99 async fn get_filetype(&self) -> Result<FileType, Error> {
100 Ok(FileType::SocketStream)
101 }
102 #[cfg(unix)]
103 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
104 let fdflags = get_fd_flags(&self.0)?;
105 Ok(fdflags)
106 }
107 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
108 if fdflags == wasi_common::file::FdFlags::NONBLOCK {
109 self.0.set_nonblocking(true)?;
110 } else if fdflags.is_empty() {
111 self.0.set_nonblocking(false)?;
112 } else {
113 return Err(
114 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
115 );
116 }
117 Ok(())
118 }
119 fn num_ready_bytes(&self) -> Result<u64, Error> {
120 Ok(1)
121 }
122 }
123
124 #[cfg(windows)]
125 impl AsSocket for $ty {
126 #[inline]
127 fn as_socket(&self) -> BorrowedSocket<'_> {
128 self.0.as_socket()
129 }
130 }
131
132 #[cfg(windows)]
133 impl AsRawHandleOrSocket for $ty {
134 #[inline]
135 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
136 self.0.as_raw_handle_or_socket()
137 }
138 }
139
140 #[cfg(unix)]
141 impl AsFd for $ty {
142 fn as_fd(&self) -> BorrowedFd<'_> {
143 self.0.as_fd()
144 }
145 }
146 };
147}
148
149pub struct TcpListener(cap_std::net::TcpListener);
150
151impl TcpListener {
152 pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
153 TcpListener(cap_std)
154 }
155}
156wasi_listen_write_impl!(TcpListener, TcpStream);
157
158#[cfg(unix)]
159pub struct UnixListener(cap_std::os::unix::net::UnixListener);
160
161#[cfg(unix)]
162impl UnixListener {
163 pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
164 UnixListener(cap_std)
165 }
166}
167
168#[cfg(unix)]
169wasi_listen_write_impl!(UnixListener, UnixStream);
170
171macro_rules! wasi_stream_write_impl {
172 ($ty:ty, $std_ty:ty) => {
173 #[async_trait::async_trait]
174 impl WasiFile for $ty {
175 fn as_any(&self) -> &dyn Any {
176 self
177 }
178 #[cfg(unix)]
179 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
180 Some(self.0.as_fd())
181 }
182 #[cfg(windows)]
183 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
184 Some(self.0.as_raw_handle_or_socket())
185 }
186 async fn get_filetype(&self) -> Result<FileType, Error> {
187 Ok(FileType::SocketStream)
188 }
189 #[cfg(unix)]
190 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
191 let fdflags = get_fd_flags(&self.0)?;
192 Ok(fdflags)
193 }
194 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
195 if fdflags == wasi_common::file::FdFlags::NONBLOCK {
196 self.0.set_nonblocking(true)?;
197 } else if fdflags.is_empty() {
198 self.0.set_nonblocking(false)?;
199 } else {
200 return Err(
201 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
202 );
203 }
204 Ok(())
205 }
206 async fn read_vectored<'a>(
207 &self,
208 bufs: &mut [io::IoSliceMut<'a>],
209 ) -> Result<u64, Error> {
210 use std::io::Read;
211 let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
212 Ok(n.try_into()?)
213 }
214 async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
215 use std::io::Write;
216 let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
217 Ok(n.try_into()?)
218 }
219 async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
220 let n = self.0.peek(buf)?;
221 Ok(n.try_into()?)
222 }
223 fn num_ready_bytes(&self) -> Result<u64, Error> {
224 let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
225 Ok(val)
226 }
227 async fn readable(&self) -> Result<(), Error> {
228 let (readable, _writeable) = is_read_write(&self.0)?;
229 if readable {
230 Ok(())
231 } else {
232 Err(Error::io())
233 }
234 }
235 async fn writable(&self) -> Result<(), Error> {
236 let (_readable, writeable) = is_read_write(&self.0)?;
237 if writeable {
238 Ok(())
239 } else {
240 Err(Error::io())
241 }
242 }
243
244 async fn sock_recv<'a>(
245 &self,
246 ri_data: &mut [std::io::IoSliceMut<'a>],
247 ri_flags: RiFlags,
248 ) -> Result<(u64, RoFlags), Error> {
249 if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
250 return Err(Error::not_supported());
251 }
252
253 if ri_flags.contains(RiFlags::RECV_PEEK) {
254 if let Some(first) = ri_data.iter_mut().next() {
255 let n = self.0.peek(first)?;
256 return Ok((n as u64, RoFlags::empty()));
257 } else {
258 return Ok((0, RoFlags::empty()));
259 }
260 }
261
262 if ri_flags.contains(RiFlags::RECV_WAITALL) {
263 let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
264 self.0.read_exact_vectored(ri_data)?;
265 return Ok((n as u64, RoFlags::empty()));
266 }
267
268 let n = self.0.read_vectored(ri_data)?;
269 Ok((n as u64, RoFlags::empty()))
270 }
271
272 async fn sock_send<'a>(
273 &self,
274 si_data: &[std::io::IoSlice<'a>],
275 si_flags: SiFlags,
276 ) -> Result<u64, Error> {
277 if si_flags != SiFlags::empty() {
278 return Err(Error::not_supported());
279 }
280
281 let n = self.0.write_vectored(si_data)?;
282 Ok(n as u64)
283 }
284
285 async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
286 let how = if how == SdFlags::RD | SdFlags::WR {
287 cap_std::net::Shutdown::Both
288 } else if how == SdFlags::RD {
289 cap_std::net::Shutdown::Read
290 } else if how == SdFlags::WR {
291 cap_std::net::Shutdown::Write
292 } else {
293 return Err(Error::invalid_argument());
294 };
295 self.0.shutdown(how)?;
296 Ok(())
297 }
298 }
299 #[cfg(unix)]
300 impl AsFd for $ty {
301 fn as_fd(&self) -> BorrowedFd<'_> {
302 self.0.as_fd()
303 }
304 }
305
306 #[cfg(windows)]
307 impl AsSocket for $ty {
308 fn as_socket(&self) -> BorrowedSocket<'_> {
310 self.0.as_socket()
311 }
312 }
313
314 #[cfg(windows)]
315 impl AsRawHandleOrSocket for TcpStream {
316 #[inline]
317 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
318 self.0.as_raw_handle_or_socket()
319 }
320 }
321 };
322}
323
324pub struct TcpStream(cap_std::net::TcpStream);
325
326impl TcpStream {
327 pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
328 TcpStream(socket)
329 }
330}
331
332wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
333
334#[cfg(unix)]
335pub struct UnixStream(cap_std::os::unix::net::UnixStream);
336
337#[cfg(unix)]
338impl UnixStream {
339 pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
340 UnixStream(socket)
341 }
342}
343
344#[cfg(unix)]
345wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
346
347pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
348 use cap_fs_ext::FileTypeExt;
349 if ft.is_block_device() {
350 FileType::SocketDgram
351 } else {
352 FileType::SocketStream
353 }
354}
355
356pub fn get_fd_flags<Socketlike: AsSocketlike>(
360 f: Socketlike,
361) -> io::Result<wasi_common::file::FdFlags> {
362 #[cfg(not(windows))]
365 {
366 let mut out = wasi_common::file::FdFlags::empty();
367 if f.get_fd_flags()?
368 .contains(system_interface::fs::FdFlags::NONBLOCK)
369 {
370 out |= wasi_common::file::FdFlags::NONBLOCK;
371 }
372 Ok(out)
373 }
374
375 #[cfg(windows)]
379 match rustix::net::recv(f, &mut [], rustix::net::RecvFlags::empty()) {
380 Ok(_) => Ok(wasi_common::file::FdFlags::empty()),
381 Err(rustix::io::Errno::WOULDBLOCK) => Ok(wasi_common::file::FdFlags::NONBLOCK),
382 Err(e) => Err(e.into()),
383 }
384}
385
386pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
390 #[cfg(not(windows))]
392 {
393 f.is_read_write()
394 }
395
396 #[cfg(windows)]
398 {
399 f.as_socketlike_view::<std::net::TcpStream>()
400 .is_read_write()
401 }
402}