1#[cfg(test)]
2mod util_test;
3
4use std::collections::HashSet;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::Arc;
7
8use stun::agent::*;
9use stun::attributes::*;
10use stun::integrity::*;
11use stun::message::*;
12use stun::textattrs::*;
13use stun::xoraddr::*;
14use tokio::time::Duration;
15use util::vnet::net::*;
16use util::Conn;
17
18use crate::agent::agent_config::{InterfaceFilterFn, IpFilterFn};
19use crate::error::*;
20use crate::network_type::*;
21
22pub fn create_addr(_network: NetworkType, ip: IpAddr, port: u16) -> SocketAddr {
23 SocketAddr::new(ip, port)
29}
30
31pub fn assert_inbound_username(m: &Message, expected_username: &str) -> Result<()> {
32 let mut username = Username::new(ATTR_USERNAME, String::new());
33 username.get_from(m)?;
34
35 if username.to_string() != expected_username {
36 return Err(Error::Other(format!(
37 "{:?} expected({}) actual({})",
38 Error::ErrMismatchUsername,
39 expected_username,
40 username,
41 )));
42 }
43
44 Ok(())
45}
46
47pub fn assert_inbound_message_integrity(m: &mut Message, key: &[u8]) -> Result<()> {
48 let message_integrity_attr = MessageIntegrity(key.to_vec());
49 Ok(message_integrity_attr.check(m)?)
50}
51
52pub async fn get_xormapped_addr(
56 conn: &Arc<dyn Conn + Send + Sync>,
57 server_addr: SocketAddr,
58 deadline: Duration,
59) -> Result<XorMappedAddress> {
60 let resp = stun_request(conn, server_addr, deadline).await?;
61 let mut addr = XorMappedAddress::default();
62 addr.get_from(&resp)?;
63 Ok(addr)
64}
65
66const MAX_MESSAGE_SIZE: usize = 1280;
67
68pub async fn stun_request(
69 conn: &Arc<dyn Conn + Send + Sync>,
70 server_addr: SocketAddr,
71 deadline: Duration,
72) -> Result<Message> {
73 let mut request = Message::new();
74 request.build(&[Box::new(BINDING_REQUEST), Box::new(TransactionId::new())])?;
75
76 conn.send_to(&request.raw, server_addr).await?;
77 let mut bs = vec![0_u8; MAX_MESSAGE_SIZE];
78 let (n, _) = if deadline > Duration::from_secs(0) {
79 match tokio::time::timeout(deadline, conn.recv_from(&mut bs)).await {
80 Ok(result) => match result {
81 Ok((n, addr)) => (n, addr),
82 Err(err) => return Err(Error::Other(err.to_string())),
83 },
84 Err(err) => return Err(Error::Other(err.to_string())),
85 }
86 } else {
87 conn.recv_from(&mut bs).await?
88 };
89
90 let mut res = Message::new();
91 res.raw = bs[..n].to_vec();
92 res.decode()?;
93
94 Ok(res)
95}
96
97pub async fn local_interfaces(
98 vnet: &Arc<Net>,
99 interface_filter: &Option<InterfaceFilterFn>,
100 ip_filter: &Option<IpFilterFn>,
101 network_types: &[NetworkType],
102 include_loopback: bool,
103) -> HashSet<IpAddr> {
104 let mut ips = HashSet::new();
105 let interfaces = vnet.get_interfaces().await;
106
107 let (mut ipv4requested, mut ipv6requested) = (false, false);
108 for typ in network_types {
109 if typ.is_ipv4() {
110 ipv4requested = true;
111 }
112 if typ.is_ipv6() {
113 ipv6requested = true;
114 }
115 }
116
117 for iface in interfaces {
118 if let Some(filter) = interface_filter {
119 if !filter(iface.name()) {
120 continue;
121 }
122 }
123
124 for ipnet in iface.addrs() {
125 let ipaddr = ipnet.addr();
126
127 if (!ipaddr.is_loopback() || include_loopback)
128 && ((ipv4requested && ipaddr.is_ipv4()) || (ipv6requested && ipaddr.is_ipv6()))
129 && ip_filter
130 .as_ref()
131 .map(|filter| filter(ipaddr))
132 .unwrap_or(true)
133 {
134 ips.insert(ipaddr);
135 }
136 }
137 }
138
139 ips
140}
141
142pub async fn listen_udp_in_port_range(
143 vnet: &Arc<Net>,
144 port_max: u16,
145 port_min: u16,
146 laddr: SocketAddr,
147) -> Result<Arc<dyn Conn + Send + Sync>> {
148 if laddr.port() != 0 || (port_min == 0 && port_max == 0) {
149 return Ok(vnet.bind(laddr).await?);
150 }
151 let i = if port_min == 0 { 1 } else { port_min };
152 let j = if port_max == 0 { 0xFFFF } else { port_max };
153 if i > j {
154 return Err(Error::ErrPort);
155 }
156
157 let port_start = rand::random::<u16>() % (j - i + 1) + i;
158 let mut port_current = port_start;
159 loop {
160 let laddr = SocketAddr::new(laddr.ip(), port_current);
161 match vnet.bind(laddr).await {
162 Ok(c) => return Ok(c),
163 Err(err) => log::debug!("failed to listen {}: {}", laddr, err),
164 };
165
166 port_current += 1;
167 if port_current > j {
168 port_current = i;
169 }
170 if port_current == port_start {
171 break;
172 }
173 }
174
175 Err(Error::ErrPort)
176}