1use crate::{
2 file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
3 Error, ErrorExt,
4};
5#[cfg(windows)]
6use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
7use io_lifetimes::AsSocketlike;
8#[cfg(unix)]
9use io_lifetimes::{AsFd, BorrowedFd};
10#[cfg(windows)]
11use io_lifetimes::{AsSocket, BorrowedSocket};
12use std::any::Any;
13use std::io;
14#[cfg(unix)]
15use system_interface::fs::GetSetFdFlags;
16use system_interface::io::IoExt;
17use system_interface::io::IsReadWrite;
18use system_interface::io::ReadReady;
19
20pub enum Socket {
21 TcpListener(cap_std::net::TcpListener),
22 TcpStream(cap_std::net::TcpStream),
23 #[cfg(unix)]
24 UnixStream(cap_std::os::unix::net::UnixStream),
25 #[cfg(unix)]
26 UnixListener(cap_std::os::unix::net::UnixListener),
27}
28
29impl From<cap_std::net::TcpListener> for Socket {
30 fn from(listener: cap_std::net::TcpListener) -> Self {
31 Self::TcpListener(listener)
32 }
33}
34
35impl From<cap_std::net::TcpStream> for Socket {
36 fn from(stream: cap_std::net::TcpStream) -> Self {
37 Self::TcpStream(stream)
38 }
39}
40
41#[cfg(unix)]
42impl From<cap_std::os::unix::net::UnixListener> for Socket {
43 fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
44 Self::UnixListener(listener)
45 }
46}
47
48#[cfg(unix)]
49impl From<cap_std::os::unix::net::UnixStream> for Socket {
50 fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
51 Self::UnixStream(stream)
52 }
53}
54
55#[cfg(unix)]
56impl From<Socket> for Box<dyn WasiFile> {
57 fn from(listener: Socket) -> Self {
58 match listener {
59 Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
60 Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),
61 Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
62 Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),
63 }
64 }
65}
66
67#[cfg(windows)]
68impl From<Socket> for Box<dyn WasiFile> {
69 fn from(listener: Socket) -> Self {
70 match listener {
71 Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
72 Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
73 }
74 }
75}
76
77macro_rules! wasi_listen_write_impl {
78 ($ty:ty, $stream:ty) => {
79 #[wiggle::async_trait]
80 impl WasiFile for $ty {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84 #[cfg(unix)]
85 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
86 Some(self.0.as_fd())
87 }
88 #[cfg(windows)]
89 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
90 Some(self.0.as_raw_handle_or_socket())
91 }
92 async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
93 let (stream, _) = self.0.accept()?;
94 let mut stream = <$stream>::from_cap_std(stream);
95 stream.set_fdflags(fdflags).await?;
96 Ok(Box::new(stream))
97 }
98 async fn get_filetype(&self) -> Result<FileType, Error> {
99 Ok(FileType::SocketStream)
100 }
101 #[cfg(unix)]
102 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
103 let fdflags = get_fd_flags(&self.0)?;
104 Ok(fdflags)
105 }
106 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
107 if fdflags == crate::file::FdFlags::NONBLOCK {
108 self.0.set_nonblocking(true)?;
109 } else if fdflags.is_empty() {
110 self.0.set_nonblocking(false)?;
111 } else {
112 return Err(
113 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
114 );
115 }
116 Ok(())
117 }
118 fn num_ready_bytes(&self) -> Result<u64, Error> {
119 Ok(1)
120 }
121 }
122
123 #[cfg(windows)]
124 impl AsSocket for $ty {
125 #[inline]
126 fn as_socket(&self) -> BorrowedSocket<'_> {
127 self.0.as_socket()
128 }
129 }
130
131 #[cfg(windows)]
132 impl AsRawHandleOrSocket for $ty {
133 #[inline]
134 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
135 self.0.as_raw_handle_or_socket()
136 }
137 }
138
139 #[cfg(unix)]
140 impl AsFd for $ty {
141 fn as_fd(&self) -> BorrowedFd<'_> {
142 self.0.as_fd()
143 }
144 }
145 };
146}
147
148pub struct TcpListener(cap_std::net::TcpListener);
149
150impl TcpListener {
151 pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
152 TcpListener(cap_std)
153 }
154}
155wasi_listen_write_impl!(TcpListener, TcpStream);
156
157#[cfg(unix)]
158pub struct UnixListener(cap_std::os::unix::net::UnixListener);
159
160#[cfg(unix)]
161impl UnixListener {
162 pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
163 UnixListener(cap_std)
164 }
165}
166
167#[cfg(unix)]
168wasi_listen_write_impl!(UnixListener, UnixStream);
169
170macro_rules! wasi_stream_write_impl {
171 ($ty:ty, $std_ty:ty) => {
172 #[wiggle::async_trait]
173 impl WasiFile for $ty {
174 fn as_any(&self) -> &dyn Any {
175 self
176 }
177 #[cfg(unix)]
178 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
179 Some(self.0.as_fd())
180 }
181 #[cfg(windows)]
182 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
183 Some(self.0.as_raw_handle_or_socket())
184 }
185 async fn get_filetype(&self) -> Result<FileType, Error> {
186 Ok(FileType::SocketStream)
187 }
188 #[cfg(unix)]
189 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
190 let fdflags = get_fd_flags(&self.0)?;
191 Ok(fdflags)
192 }
193 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
194 if fdflags == crate::file::FdFlags::NONBLOCK {
195 self.0.set_nonblocking(true)?;
196 } else if fdflags.is_empty() {
197 self.0.set_nonblocking(false)?;
198 } else {
199 return Err(
200 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
201 );
202 }
203 Ok(())
204 }
205 async fn read_vectored<'a>(
206 &self,
207 bufs: &mut [io::IoSliceMut<'a>],
208 ) -> Result<u64, Error> {
209 use std::io::Read;
210 let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
211 Ok(n.try_into()?)
212 }
213 async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
214 use std::io::Write;
215 let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
216 Ok(n.try_into()?)
217 }
218 async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
219 let n = self.0.peek(buf)?;
220 Ok(n.try_into()?)
221 }
222 fn num_ready_bytes(&self) -> Result<u64, Error> {
223 let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
224 Ok(val)
225 }
226 async fn readable(&self) -> Result<(), Error> {
227 let (readable, _writeable) = is_read_write(&self.0)?;
228 if readable {
229 Ok(())
230 } else {
231 Err(Error::io())
232 }
233 }
234 async fn writable(&self) -> Result<(), Error> {
235 let (_readable, writeable) = is_read_write(&self.0)?;
236 if writeable {
237 Ok(())
238 } else {
239 Err(Error::io())
240 }
241 }
242
243 async fn sock_recv<'a>(
244 &self,
245 ri_data: &mut [std::io::IoSliceMut<'a>],
246 ri_flags: RiFlags,
247 ) -> Result<(u64, RoFlags), Error> {
248 if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
249 return Err(Error::not_supported());
250 }
251
252 if ri_flags.contains(RiFlags::RECV_PEEK) {
253 if let Some(first) = ri_data.iter_mut().next() {
254 let n = self.0.peek(first)?;
255 return Ok((n as u64, RoFlags::empty()));
256 } else {
257 return Ok((0, RoFlags::empty()));
258 }
259 }
260
261 if ri_flags.contains(RiFlags::RECV_WAITALL) {
262 let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
263 self.0.read_exact_vectored(ri_data)?;
264 return Ok((n as u64, RoFlags::empty()));
265 }
266
267 let n = self.0.read_vectored(ri_data)?;
268 Ok((n as u64, RoFlags::empty()))
269 }
270
271 async fn sock_send<'a>(
272 &self,
273 si_data: &[std::io::IoSlice<'a>],
274 si_flags: SiFlags,
275 ) -> Result<u64, Error> {
276 if si_flags != SiFlags::empty() {
277 return Err(Error::not_supported());
278 }
279
280 let n = self.0.write_vectored(si_data)?;
281 Ok(n as u64)
282 }
283
284 async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
285 let how = if how == SdFlags::RD | SdFlags::WR {
286 cap_std::net::Shutdown::Both
287 } else if how == SdFlags::RD {
288 cap_std::net::Shutdown::Read
289 } else if how == SdFlags::WR {
290 cap_std::net::Shutdown::Write
291 } else {
292 return Err(Error::invalid_argument());
293 };
294 self.0.shutdown(how)?;
295 Ok(())
296 }
297 }
298 #[cfg(unix)]
299 impl AsFd for $ty {
300 fn as_fd(&self) -> BorrowedFd<'_> {
301 self.0.as_fd()
302 }
303 }
304
305 #[cfg(windows)]
306 impl AsSocket for $ty {
307 fn as_socket(&self) -> BorrowedSocket<'_> {
309 self.0.as_socket()
310 }
311 }
312
313 #[cfg(windows)]
314 impl AsRawHandleOrSocket for TcpStream {
315 #[inline]
316 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
317 self.0.as_raw_handle_or_socket()
318 }
319 }
320 };
321}
322
323pub struct TcpStream(cap_std::net::TcpStream);
324
325impl TcpStream {
326 pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
327 TcpStream(socket)
328 }
329}
330
331wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
332
333#[cfg(unix)]
334pub struct UnixStream(cap_std::os::unix::net::UnixStream);
335
336#[cfg(unix)]
337impl UnixStream {
338 pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
339 UnixStream(socket)
340 }
341}
342
343#[cfg(unix)]
344wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
345
346pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
347 use cap_fs_ext::FileTypeExt;
348 if ft.is_block_device() {
349 FileType::SocketDgram
350 } else {
351 FileType::SocketStream
352 }
353}
354
355pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {
359 #[cfg(not(windows))]
362 {
363 let mut out = crate::file::FdFlags::empty();
364 if f.get_fd_flags()?
365 .contains(system_interface::fs::FdFlags::NONBLOCK)
366 {
367 out |= crate::file::FdFlags::NONBLOCK;
368 }
369 Ok(out)
370 }
371
372 #[cfg(windows)]
376 match rustix::net::recv(f, &mut [], rustix::net::RecvFlags::empty()) {
377 Ok(_) => Ok(crate::file::FdFlags::empty()),
378 Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),
379 Err(e) => Err(e.into()),
380 }
381}
382
383pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
387 #[cfg(not(windows))]
389 {
390 f.is_read_write()
391 }
392
393 #[cfg(windows)]
395 {
396 f.as_socketlike_view::<std::net::TcpStream>()
397 .is_read_write()
398 }
399}