webrtc_ice/udp_mux/
mod.rs1use std::collections::HashMap;
2use std::io::ErrorKind;
3use std::net::SocketAddr;
4use std::sync::{Arc, Weak};
5
6use async_trait::async_trait;
7use tokio::sync::{watch, Mutex};
8use util::sync::RwLock;
9use util::{Conn, Error};
10
11mod udp_mux_conn;
12pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter};
13
14#[cfg(test)]
15mod udp_mux_test;
16
17mod socket_addr_ext;
18
19use stun::attributes::ATTR_USERNAME;
20use stun::message::{is_message as is_stun_message, Message as STUNMessage};
21
22use crate::candidate::RECEIVE_MTU;
23
24fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr {
28 match (target, socket_addr) {
29 (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => {
30 let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped();
31
32 SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port())
33 }
34 (_, _) => *target,
36 }
37}
38
39#[async_trait]
40pub trait UDPMux {
41 async fn close(&self) -> Result<(), Error>;
43
44 async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error>;
46
47 async fn remove_conn_by_ufrag(&self, ufrag: &str);
49}
50
51pub struct UDPMuxParams {
52 conn: Box<dyn Conn + Send + Sync>,
53}
54
55impl UDPMuxParams {
56 pub fn new<C>(conn: C) -> Self
57 where
58 C: Conn + Send + Sync + 'static,
59 {
60 Self {
61 conn: Box::new(conn),
62 }
63 }
64}
65
66pub struct UDPMuxDefault {
67 params: UDPMuxParams,
70
71 conns: Mutex<HashMap<String, UDPMuxConn>>,
73
74 address_map: RwLock<HashMap<SocketAddr, UDPMuxConn>>,
76
77 closed_watch_tx: Mutex<Option<watch::Sender<()>>>,
79
80 closed_watch_rx: watch::Receiver<()>,
82}
83
84impl UDPMuxDefault {
85 pub fn new(params: UDPMuxParams) -> Arc<Self> {
86 let (closed_watch_tx, closed_watch_rx) = watch::channel(());
87
88 let mux = Arc::new(Self {
89 params,
90 conns: Mutex::default(),
91 address_map: RwLock::default(),
92 closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
93 closed_watch_rx: closed_watch_rx.clone(),
94 });
95
96 let cloned_mux = Arc::clone(&mux);
97 cloned_mux.start_conn_worker(closed_watch_rx);
98
99 mux
100 }
101
102 pub async fn is_closed(&self) -> bool {
103 self.closed_watch_tx.lock().await.is_none()
104 }
105
106 fn create_muxed_conn(self: &Arc<Self>, ufrag: &str) -> Result<UDPMuxConn, Error> {
108 let local_addr = self.params.conn.local_addr()?;
109
110 let params = UDPMuxConnParams {
111 local_addr,
112 key: ufrag.into(),
113 udp_mux: Arc::downgrade(self) as Weak<dyn UDPMuxWriter + Send + Sync>,
114 };
115
116 Ok(UDPMuxConn::new(params))
117 }
118
119 async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option<UDPMuxConn> {
120 let (result, message) = {
121 let mut m = STUNMessage::new();
122
123 (m.unmarshal_binary(buffer), m)
124 };
125
126 match result {
127 Err(err) => {
128 log::warn!("Failed to handle decode ICE from {}: {}", addr, err);
129 None
130 }
131 Ok(_) => {
132 let (attr, found) = message.attributes.get(ATTR_USERNAME);
133 if !found {
134 log::warn!("No username attribute in STUN message from {}", &addr);
135 return None;
136 }
137
138 let s = match String::from_utf8(attr.value) {
139 Err(err) => {
142 log::warn!(
143 "Failed to decode USERNAME from STUN message as UTF-8: {}",
144 err
145 );
146 return None;
147 }
148 Ok(s) => s,
149 };
150
151 let conns = self.conns.lock().await;
152 let conn = s
153 .split(':')
154 .next()
155 .and_then(|ufrag| conns.get(ufrag))
156 .cloned();
157
158 conn
159 }
160 }
161 }
162
163 fn start_conn_worker(self: Arc<Self>, mut closed_watch_rx: watch::Receiver<()>) {
164 tokio::spawn(async move {
165 let mut buffer = [0u8; RECEIVE_MTU];
166
167 loop {
168 let loop_self = Arc::clone(&self);
169 let conn = &loop_self.params.conn;
170
171 tokio::select! {
172 res = conn.recv_from(&mut buffer) => {
173 match res {
174 Ok((len, addr)) => {
175 let conn = {
177 let address_map = loop_self
178 .address_map
179 .read();
180
181 address_map.get(&addr).cloned()
182 };
183
184 let conn = match conn {
185 None if is_stun_message(&buffer) => {
188 loop_self.conn_from_stun_message(&buffer, &addr).await
189 }
190 s @ Some(_) => s,
191 _ => None,
192 };
193
194 match conn {
195 None => {
196 log::trace!("Dropping packet from {}", &addr);
197 }
198 Some(conn) => {
199 if let Err(err) = conn.write_packet(&buffer[..len], addr).await {
200 log::error!("Failed to write packet: {}", err);
201 }
202 }
203 }
204 }
205 Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue,
206 Err(err) => {
207 log::error!("Could not read udp packet: {}", err);
208 break;
209 }
210 }
211 }
212 _ = closed_watch_rx.changed() => {
213 return;
214 }
215 }
216 }
217 });
218 }
219}
220
221#[async_trait]
222impl UDPMux for UDPMuxDefault {
223 async fn close(&self) -> Result<(), Error> {
224 if self.is_closed().await {
225 return Err(Error::ErrAlreadyClosed);
226 }
227
228 let mut closed_tx = self.closed_watch_tx.lock().await;
229
230 if let Some(tx) = closed_tx.take() {
231 let _ = tx.send(());
232 drop(closed_tx);
233
234 let old_conns = {
235 let mut conns = self.conns.lock().await;
236
237 std::mem::take(&mut (*conns))
238 };
239
240 for (_, conn) in old_conns {
242 conn.close();
243 }
244
245 {
246 let mut address_map = self.address_map.write();
247
248 let _ = std::mem::take(&mut (*address_map));
251 }
252 }
253
254 Ok(())
255 }
256
257 async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
258 if self.is_closed().await {
259 return Err(Error::ErrUseClosedNetworkConn);
260 }
261
262 {
263 let mut conns = self.conns.lock().await;
264 if let Some(conn) = conns.get(ufrag) {
265 return Ok(Arc::new(conn.clone()) as Arc<dyn Conn + Send + Sync>);
268 }
269
270 let muxed_conn = self.create_muxed_conn(ufrag)?;
271 let mut close_rx = muxed_conn.close_rx();
272 let cloned_self = Arc::clone(&self);
273 let cloned_ufrag = ufrag.to_string();
274 tokio::spawn(async move {
275 let _ = close_rx.changed().await;
276
277 cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await;
279 });
280
281 conns.insert(ufrag.into(), muxed_conn.clone());
282
283 Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>)
284 }
285 }
286
287 async fn remove_conn_by_ufrag(&self, ufrag: &str) {
288 let removed_conn = {
292 let mut conns = self.conns.lock().await;
293 conns.remove(ufrag)
294 };
295
296 if let Some(conn) = removed_conn {
297 let mut address_map = self.address_map.write();
298
299 for address in conn.get_addresses() {
300 address_map.remove(&address);
301 }
302 }
303 }
304}
305
306#[async_trait]
307impl UDPMuxWriter for UDPMuxDefault {
308 async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
309 if self.is_closed().await {
310 return;
311 }
312
313 let key = conn.key();
314 {
315 let mut addresses = self.address_map.write();
316
317 addresses
318 .entry(addr)
319 .and_modify(|e| {
320 if e.key() != key {
321 e.remove_address(&addr);
322 *e = conn.clone();
323 }
324 })
325 .or_insert_with(|| conn.clone());
326 }
327
328 log::debug!("Registered {} for {}", addr, key);
329 }
330
331 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
332 self.params
333 .conn
334 .send_to(buf, *target)
335 .await
336 .map_err(Into::into)
337 }
338}