wasmtime_wasi/
network.rs

1use crate::bindings::sockets::network::{ErrorCode, Ipv4Address, Ipv6Address};
2use crate::TrappableError;
3use std::future::Future;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::sync::Arc;
7
8pub struct Network {
9    pub socket_addr_check: SocketAddrCheck,
10    pub allow_ip_name_lookup: bool,
11}
12
13impl Network {
14    pub async fn check_socket_addr(
15        &self,
16        addr: SocketAddr,
17        reason: SocketAddrUse,
18    ) -> std::io::Result<()> {
19        self.socket_addr_check.check(addr, reason).await
20    }
21}
22
23/// A check that will be called for each socket address that is used of whether the address is permitted.
24#[derive(Clone)]
25pub struct SocketAddrCheck(
26    pub(crate)  Arc<
27        dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
28            + Send
29            + Sync,
30    >,
31);
32
33impl SocketAddrCheck {
34    pub async fn check(&self, addr: SocketAddr, reason: SocketAddrUse) -> std::io::Result<()> {
35        if (self.0)(addr, reason).await {
36            Ok(())
37        } else {
38            Err(std::io::Error::new(
39                std::io::ErrorKind::PermissionDenied,
40                "An address was not permitted by the socket address check.",
41            ))
42        }
43    }
44}
45
46impl Default for SocketAddrCheck {
47    fn default() -> Self {
48        Self(Arc::new(|_, _| Box::pin(async { false })))
49    }
50}
51
52/// The reason what a socket address is being used for.
53#[derive(Clone, Copy, Debug)]
54pub enum SocketAddrUse {
55    /// Binding TCP socket
56    TcpBind,
57    /// Connecting TCP socket
58    TcpConnect,
59    /// Binding UDP socket
60    UdpBind,
61    /// Connecting UDP socket
62    UdpConnect,
63    /// Sending datagram on non-connected UDP socket
64    UdpOutgoingDatagram,
65}
66
67pub type SocketResult<T> = Result<T, SocketError>;
68
69pub type SocketError = TrappableError<ErrorCode>;
70
71impl From<wasmtime::component::ResourceTableError> for SocketError {
72    fn from(error: wasmtime::component::ResourceTableError) -> Self {
73        Self::trap(error)
74    }
75}
76
77impl From<std::io::Error> for SocketError {
78    fn from(error: std::io::Error) -> Self {
79        ErrorCode::from(error).into()
80    }
81}
82
83impl From<rustix::io::Errno> for SocketError {
84    fn from(error: rustix::io::Errno) -> Self {
85        ErrorCode::from(error).into()
86    }
87}
88
89#[derive(Copy, Clone)]
90pub enum SocketAddressFamily {
91    Ipv4,
92    Ipv6,
93}
94
95pub(crate) fn to_ipv4_addr(addr: Ipv4Address) -> std::net::Ipv4Addr {
96    let (x0, x1, x2, x3) = addr;
97    std::net::Ipv4Addr::new(x0, x1, x2, x3)
98}
99
100pub(crate) fn from_ipv4_addr(addr: std::net::Ipv4Addr) -> Ipv4Address {
101    let [x0, x1, x2, x3] = addr.octets();
102    (x0, x1, x2, x3)
103}
104
105pub(crate) fn to_ipv6_addr(addr: Ipv6Address) -> std::net::Ipv6Addr {
106    let (x0, x1, x2, x3, x4, x5, x6, x7) = addr;
107    std::net::Ipv6Addr::new(x0, x1, x2, x3, x4, x5, x6, x7)
108}
109
110pub(crate) fn from_ipv6_addr(addr: std::net::Ipv6Addr) -> Ipv6Address {
111    let [x0, x1, x2, x3, x4, x5, x6, x7] = addr.segments();
112    (x0, x1, x2, x3, x4, x5, x6, x7)
113}