1#[cfg(test)]
2mod nat_test;
3
4use std::collections::{HashMap, HashSet};
5use std::net::IpAddr;
6use std::ops::Add;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9use std::time::SystemTime;
10
11use portable_atomic::AtomicU16;
12use tokio::sync::Mutex;
13use tokio::time::Duration;
14
15use crate::error::*;
16use crate::vnet::chunk::Chunk;
17use crate::vnet::net::UDP_STR;
18
19const DEFAULT_NAT_MAPPING_LIFE_TIME: Duration = Duration::from_secs(30);
20
21#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
28pub enum EndpointDependencyType {
29 #[default]
31 EndpointIndependent,
32 EndpointAddrDependent,
34 EndpointAddrPortDependent,
36}
37
38#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
40pub enum NatMode {
41 #[default]
43 Normal,
44 Nat1To1,
49}
50
51#[derive(Default, Debug, Copy, Clone)]
53pub struct NatType {
54 pub mode: NatMode,
55 pub mapping_behavior: EndpointDependencyType,
56 pub filtering_behavior: EndpointDependencyType,
57 pub hair_pining: bool, pub port_preservation: bool, pub mapping_life_time: Duration,
60}
61
62#[derive(Default, Debug, Clone)]
63pub(crate) struct NatConfig {
64 pub(crate) name: String,
65 pub(crate) nat_type: NatType,
66 pub(crate) mapped_ips: Vec<IpAddr>, pub(crate) local_ips: Vec<IpAddr>, }
69
70#[derive(Debug, Clone)]
71pub(crate) struct Mapping {
72 proto: String, local: String, mapped: String, bound: String, filters: Arc<Mutex<HashSet<String>>>, expires: Arc<Mutex<SystemTime>>, }
79
80impl Default for Mapping {
81 fn default() -> Self {
82 Mapping {
83 proto: String::new(), local: String::new(), mapped: String::new(), bound: String::new(), filters: Arc::new(Mutex::new(HashSet::new())), expires: Arc::new(Mutex::new(SystemTime::now())), }
90 }
91}
92
93#[derive(Default, Debug, Clone)]
94pub(crate) struct NetworkAddressTranslator {
95 pub(crate) name: String,
96 pub(crate) nat_type: NatType,
97 pub(crate) mapped_ips: Vec<IpAddr>, pub(crate) local_ips: Vec<IpAddr>, pub(crate) outbound_map: Arc<Mutex<HashMap<String, Arc<Mapping>>>>, pub(crate) inbound_map: Arc<Mutex<HashMap<String, Arc<Mapping>>>>, pub(crate) udp_port_counter: Arc<AtomicU16>,
102}
103
104impl NetworkAddressTranslator {
105 pub(crate) fn new(config: NatConfig) -> Result<Self> {
106 let mut nat_type = config.nat_type;
107
108 if nat_type.mode == NatMode::Nat1To1 {
109 nat_type.mapping_behavior = EndpointDependencyType::EndpointIndependent;
111 nat_type.filtering_behavior = EndpointDependencyType::EndpointIndependent;
112 nat_type.port_preservation = true;
113 nat_type.mapping_life_time = Duration::from_secs(0);
114
115 if config.mapped_ips.is_empty() {
116 return Err(Error::ErrNatRequiresMapping);
117 }
118 if config.mapped_ips.len() != config.local_ips.len() {
119 return Err(Error::ErrMismatchLengthIp);
120 }
121 } else {
122 nat_type.mode = NatMode::Normal;
124 if nat_type.mapping_life_time == Duration::from_secs(0) {
125 nat_type.mapping_life_time = DEFAULT_NAT_MAPPING_LIFE_TIME;
126 }
127 }
128
129 Ok(NetworkAddressTranslator {
130 name: config.name,
131 nat_type,
132 mapped_ips: config.mapped_ips,
133 local_ips: config.local_ips,
134 outbound_map: Arc::new(Mutex::new(HashMap::new())),
135 inbound_map: Arc::new(Mutex::new(HashMap::new())),
136 udp_port_counter: Arc::new(AtomicU16::new(0)),
137 })
138 }
139
140 pub(crate) fn get_paired_mapped_ip(&self, loc_ip: &IpAddr) -> Option<&IpAddr> {
141 for (i, ip) in self.local_ips.iter().enumerate() {
142 if ip == loc_ip {
143 return self.mapped_ips.get(i);
144 }
145 }
146 None
147 }
148
149 pub(crate) fn get_paired_local_ip(&self, mapped_ip: &IpAddr) -> Option<&IpAddr> {
150 for (i, ip) in self.mapped_ips.iter().enumerate() {
151 if ip == mapped_ip {
152 return self.local_ips.get(i);
153 }
154 }
155 None
156 }
157
158 pub(crate) async fn translate_outbound(
159 &self,
160 from: &(dyn Chunk + Send + Sync),
161 ) -> Result<Option<Box<dyn Chunk + Send + Sync>>> {
162 let mut to = from.clone_to();
163
164 if from.network() == UDP_STR {
165 if self.nat_type.mode == NatMode::Nat1To1 {
166 let src_addr = from.source_addr();
168 if let Some(src_ip) = self.get_paired_mapped_ip(&src_addr.ip()) {
169 to.set_source_addr(&format!("{}:{}", src_ip, src_addr.port()))?;
170 } else {
171 log::debug!(
172 "[{}] drop outbound chunk {} with not route",
173 self.name,
174 from
175 );
176 return Ok(None); }
178 } else {
179 let bound = match self.nat_type.mapping_behavior {
181 EndpointDependencyType::EndpointIndependent => "".to_owned(),
182 EndpointDependencyType::EndpointAddrDependent => {
183 from.get_destination_ip().to_string()
184 }
185 EndpointDependencyType::EndpointAddrPortDependent => {
186 from.destination_addr().to_string()
187 }
188 };
189
190 let filter_key = match self.nat_type.filtering_behavior {
191 EndpointDependencyType::EndpointIndependent => "".to_owned(),
192 EndpointDependencyType::EndpointAddrDependent => {
193 from.get_destination_ip().to_string()
194 }
195 EndpointDependencyType::EndpointAddrPortDependent => {
196 from.destination_addr().to_string()
197 }
198 };
199
200 let o_key = format!("udp:{}:{}", from.source_addr(), bound);
201 let name = self.name.clone();
202
203 let m_mapped = if let Some(m) = self.find_outbound_mapping(&o_key).await {
204 let mut filters = m.filters.lock().await;
205 if !filters.contains(&filter_key) {
206 log::debug!(
207 "[{}] permit access from {} to {}",
208 name,
209 filter_key,
210 m.mapped
211 );
212 filters.insert(filter_key);
213 }
214 m.mapped.clone()
215 } else {
216 let udp_port_counter = self.udp_port_counter.load(Ordering::SeqCst);
218 let mapped_port = 0xC000 + udp_port_counter;
219 if udp_port_counter == 0xFFFF - 0xC000 {
220 self.udp_port_counter.store(0, Ordering::SeqCst);
221 } else {
222 self.udp_port_counter.fetch_add(1, Ordering::SeqCst);
223 }
224
225 let m = if let Some(mapped_ips_first) = self.mapped_ips.first() {
226 Mapping {
227 proto: "udp".to_owned(),
228 local: from.source_addr().to_string(),
229 bound,
230 mapped: format!("{mapped_ips_first}:{mapped_port}"),
231 filters: Arc::new(Mutex::new(HashSet::new())),
232 expires: Arc::new(Mutex::new(
233 SystemTime::now().add(self.nat_type.mapping_life_time),
234 )),
235 }
236 } else {
237 return Err(Error::ErrNatRequiresMapping);
238 };
239
240 {
241 let mut outbound_map = self.outbound_map.lock().await;
242 outbound_map.insert(o_key.clone(), Arc::new(m.clone()));
243 }
244
245 let i_key = format!("udp:{}", m.mapped);
246
247 log::debug!(
248 "[{}] created a new NAT binding oKey={} i_key={}",
249 self.name,
250 o_key,
251 i_key
252 );
253 log::debug!(
254 "[{}] permit access from {} to {}",
255 self.name,
256 filter_key,
257 m.mapped
258 );
259
260 {
261 let mut filters = m.filters.lock().await;
262 filters.insert(filter_key);
263 }
264
265 let m_mapped = m.mapped.clone();
266 {
267 let mut inbound_map = self.inbound_map.lock().await;
268 inbound_map.insert(i_key, Arc::new(m));
269 }
270 m_mapped
271 };
272
273 to.set_source_addr(&m_mapped)?;
274 }
275
276 log::debug!(
277 "[{}] translate outbound chunk from {} to {}",
278 self.name,
279 from,
280 to
281 );
282
283 return Ok(Some(to));
284 }
285
286 Err(Error::ErrNonUdpTranslationNotSupported)
287 }
288
289 pub(crate) async fn translate_inbound(
290 &self,
291 from: &(dyn Chunk + Send + Sync),
292 ) -> Result<Option<Box<dyn Chunk + Send + Sync>>> {
293 let mut to = from.clone_to();
294
295 if from.network() == UDP_STR {
296 if self.nat_type.mode == NatMode::Nat1To1 {
297 let dst_addr = from.destination_addr();
299 if let Some(dst_ip) = self.get_paired_local_ip(&dst_addr.ip()) {
300 let dst_port = from.destination_addr().port();
301 to.set_destination_addr(&format!("{dst_ip}:{dst_port}"))?;
302 } else {
303 return Err(Error::Other(format!(
304 "drop {from} as {:?}",
305 Error::ErrNoAssociatedLocalAddress
306 )));
307 }
308 } else {
309 let filter_key = match self.nat_type.filtering_behavior {
311 EndpointDependencyType::EndpointIndependent => "".to_owned(),
312 EndpointDependencyType::EndpointAddrDependent => {
313 from.get_source_ip().to_string()
314 }
315 EndpointDependencyType::EndpointAddrPortDependent => {
316 from.source_addr().to_string()
317 }
318 };
319
320 let i_key = format!("udp:{}", from.destination_addr());
321 if let Some(m) = self.find_inbound_mapping(&i_key).await {
322 {
323 let filters = m.filters.lock().await;
324 if !filters.contains(&filter_key) {
325 return Err(Error::Other(format!(
326 "drop {} as the remote {} {:?}",
327 from,
328 filter_key,
329 Error::ErrHasNoPermission
330 )));
331 }
332 }
333
334 to.set_destination_addr(&m.local)?;
343 } else {
344 return Err(Error::Other(format!(
345 "drop {} as {:?}",
346 from,
347 Error::ErrNoNatBindingFound
348 )));
349 }
350 }
351
352 log::debug!(
353 "[{}] translate inbound chunk from {} to {}",
354 self.name,
355 from,
356 to
357 );
358
359 return Ok(Some(to));
360 }
361
362 Err(Error::ErrNonUdpTranslationNotSupported)
363 }
364
365 pub(crate) async fn find_outbound_mapping(&self, o_key: &str) -> Option<Arc<Mapping>> {
367 let mapping_life_time = self.nat_type.mapping_life_time;
368 let mut expired = false;
369 let (in_key, out_key) = {
370 let outbound_map = self.outbound_map.lock().await;
371 if let Some(m) = outbound_map.get(o_key) {
372 let now = SystemTime::now();
373
374 {
375 let mut expires = m.expires.lock().await;
376 if now.duration_since(*expires).is_ok() {
378 expired = true;
379 } else {
380 *expires = now.add(mapping_life_time);
381 }
382 }
383 (
384 NetworkAddressTranslator::get_inbound_map_key(m),
385 NetworkAddressTranslator::get_outbound_map_key(m),
386 )
387 } else {
388 (String::new(), String::new())
389 }
390 };
391
392 if expired {
393 {
394 let mut inbound_map = self.inbound_map.lock().await;
395 inbound_map.remove(&in_key);
396 }
397 {
398 let mut outbound_map = self.outbound_map.lock().await;
399 outbound_map.remove(&out_key);
400 }
401 }
402
403 let outbound_map = self.outbound_map.lock().await;
404 outbound_map.get(o_key).cloned()
405 }
406
407 pub(crate) async fn find_inbound_mapping(&self, i_key: &str) -> Option<Arc<Mapping>> {
409 let mut expired = false;
410 let (in_key, out_key) = {
411 let inbound_map = self.inbound_map.lock().await;
412 if let Some(m) = inbound_map.get(i_key) {
413 let now = SystemTime::now();
414
415 {
416 let expires = m.expires.lock().await;
417 if now.duration_since(*expires).is_ok() {
419 expired = true;
420 }
421 }
422 (
423 NetworkAddressTranslator::get_inbound_map_key(m),
424 NetworkAddressTranslator::get_outbound_map_key(m),
425 )
426 } else {
427 (String::new(), String::new())
428 }
429 };
430
431 if expired {
432 {
433 let mut inbound_map = self.inbound_map.lock().await;
434 inbound_map.remove(&in_key);
435 }
436 {
437 let mut outbound_map = self.outbound_map.lock().await;
438 outbound_map.remove(&out_key);
439 }
440 }
441
442 let inbound_map = self.inbound_map.lock().await;
443 inbound_map.get(i_key).cloned()
444 }
445
446 fn get_outbound_map_key(m: &Mapping) -> String {
448 format!("{}:{}:{}", m.proto, m.local, m.bound)
449 }
450
451 fn get_inbound_map_key(m: &Mapping) -> String {
452 format!("{}:{}", m.proto, m.mapped)
453 }
454
455 async fn inbound_map_len(&self) -> usize {
456 let inbound_map = self.inbound_map.lock().await;
457 inbound_map.len()
458 }
459
460 async fn outbound_map_len(&self) -> usize {
461 let outbound_map = self.outbound_map.lock().await;
462 outbound_map.len()
463 }
464}