webrtc_ice/udp_mux/
mod.rs

1use std::collections::HashMap;
2use std::io::ErrorKind;
3use std::net::SocketAddr;
4use std::sync::{Arc, Weak};
5
6use async_trait::async_trait;
7use tokio::sync::{watch, Mutex};
8use util::sync::RwLock;
9use util::{Conn, Error};
10
11mod udp_mux_conn;
12pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter};
13
14#[cfg(test)]
15mod udp_mux_test;
16
17mod socket_addr_ext;
18
19use stun::attributes::ATTR_USERNAME;
20use stun::message::{is_message as is_stun_message, Message as STUNMessage};
21
22use crate::candidate::RECEIVE_MTU;
23
24/// Normalize a target socket addr for sending over a given local socket addr. This is useful when
25/// a dual stack socket is used, in which case an IPv4 target needs to be mapped to an IPv6
26/// address.
27fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr {
28    match (target, socket_addr) {
29        (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => {
30            let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped();
31
32            SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port())
33        }
34        // This will fail later if target is IPv6 and socket is IPv4, we ignore it here
35        (_, _) => *target,
36    }
37}
38
39#[async_trait]
40pub trait UDPMux {
41    /// Close the muxing.
42    async fn close(&self) -> Result<(), Error>;
43
44    /// Get the underlying connection for a given ufrag.
45    async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error>;
46
47    /// Remove the underlying connection for a given ufrag.
48    async fn remove_conn_by_ufrag(&self, ufrag: &str);
49}
50
51pub struct UDPMuxParams {
52    conn: Box<dyn Conn + Send + Sync>,
53}
54
55impl UDPMuxParams {
56    pub fn new<C>(conn: C) -> Self
57    where
58        C: Conn + Send + Sync + 'static,
59    {
60        Self {
61            conn: Box::new(conn),
62        }
63    }
64}
65
66pub struct UDPMuxDefault {
67    /// The params this instance is configured with.
68    /// Contains the underlying UDP socket in use
69    params: UDPMuxParams,
70
71    /// Maps from ufrag to the underlying connection.
72    conns: Mutex<HashMap<String, UDPMuxConn>>,
73
74    /// Maps from ip address to the underlying connection.
75    address_map: RwLock<HashMap<SocketAddr, UDPMuxConn>>,
76
77    // Close sender
78    closed_watch_tx: Mutex<Option<watch::Sender<()>>>,
79
80    /// Close receiver
81    closed_watch_rx: watch::Receiver<()>,
82}
83
84impl UDPMuxDefault {
85    pub fn new(params: UDPMuxParams) -> Arc<Self> {
86        let (closed_watch_tx, closed_watch_rx) = watch::channel(());
87
88        let mux = Arc::new(Self {
89            params,
90            conns: Mutex::default(),
91            address_map: RwLock::default(),
92            closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
93            closed_watch_rx: closed_watch_rx.clone(),
94        });
95
96        let cloned_mux = Arc::clone(&mux);
97        cloned_mux.start_conn_worker(closed_watch_rx);
98
99        mux
100    }
101
102    pub async fn is_closed(&self) -> bool {
103        self.closed_watch_tx.lock().await.is_none()
104    }
105
106    /// Create a muxed connection for a given ufrag.
107    fn create_muxed_conn(self: &Arc<Self>, ufrag: &str) -> Result<UDPMuxConn, Error> {
108        let local_addr = self.params.conn.local_addr()?;
109
110        let params = UDPMuxConnParams {
111            local_addr,
112            key: ufrag.into(),
113            udp_mux: Arc::downgrade(self) as Weak<dyn UDPMuxWriter + Send + Sync>,
114        };
115
116        Ok(UDPMuxConn::new(params))
117    }
118
119    async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option<UDPMuxConn> {
120        let (result, message) = {
121            let mut m = STUNMessage::new();
122
123            (m.unmarshal_binary(buffer), m)
124        };
125
126        match result {
127            Err(err) => {
128                log::warn!("Failed to handle decode ICE from {}: {}", addr, err);
129                None
130            }
131            Ok(_) => {
132                let (attr, found) = message.attributes.get(ATTR_USERNAME);
133                if !found {
134                    log::warn!("No username attribute in STUN message from {}", &addr);
135                    return None;
136                }
137
138                let s = match String::from_utf8(attr.value) {
139                    // Per the RFC this shouldn't happen
140                    // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3
141                    Err(err) => {
142                        log::warn!(
143                            "Failed to decode USERNAME from STUN message as UTF-8: {}",
144                            err
145                        );
146                        return None;
147                    }
148                    Ok(s) => s,
149                };
150
151                let conns = self.conns.lock().await;
152                let conn = s
153                    .split(':')
154                    .next()
155                    .and_then(|ufrag| conns.get(ufrag))
156                    .cloned();
157
158                conn
159            }
160        }
161    }
162
163    fn start_conn_worker(self: Arc<Self>, mut closed_watch_rx: watch::Receiver<()>) {
164        tokio::spawn(async move {
165            let mut buffer = [0u8; RECEIVE_MTU];
166
167            loop {
168                let loop_self = Arc::clone(&self);
169                let conn = &loop_self.params.conn;
170
171                tokio::select! {
172                    res = conn.recv_from(&mut buffer) => {
173                        match res {
174                            Ok((len, addr)) => {
175                                // Find connection based on previously having seen this source address
176                                let conn = {
177                                    let address_map = loop_self
178                                        .address_map
179                                        .read();
180
181                                    address_map.get(&addr).cloned()
182                                };
183
184                                let conn = match conn {
185                                    // If we couldn't find the connection based on source address, see if
186                                    // this is a STUN message and if so if we can find the connection based on ufrag.
187                                    None if is_stun_message(&buffer) => {
188                                        loop_self.conn_from_stun_message(&buffer, &addr).await
189                                    }
190                                    s @ Some(_) => s,
191                                    _ => None,
192                                };
193
194                                match conn {
195                                    None => {
196                                        log::trace!("Dropping packet from {}", &addr);
197                                    }
198                                    Some(conn) => {
199                                        if let Err(err) = conn.write_packet(&buffer[..len], addr).await {
200                                            log::error!("Failed to write packet: {}", err);
201                                        }
202                                    }
203                                }
204                            }
205                            Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue,
206                            Err(err) => {
207                                log::error!("Could not read udp packet: {}", err);
208                                break;
209                            }
210                        }
211                    }
212                    _ = closed_watch_rx.changed() => {
213                        return;
214                    }
215                }
216            }
217        });
218    }
219}
220
221#[async_trait]
222impl UDPMux for UDPMuxDefault {
223    async fn close(&self) -> Result<(), Error> {
224        if self.is_closed().await {
225            return Err(Error::ErrAlreadyClosed);
226        }
227
228        let mut closed_tx = self.closed_watch_tx.lock().await;
229
230        if let Some(tx) = closed_tx.take() {
231            let _ = tx.send(());
232            drop(closed_tx);
233
234            let old_conns = {
235                let mut conns = self.conns.lock().await;
236
237                std::mem::take(&mut (*conns))
238            };
239
240            // NOTE: We don't wait for these closure to complete
241            for (_, conn) in old_conns {
242                conn.close();
243            }
244
245            {
246                let mut address_map = self.address_map.write();
247
248                // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to
249                // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides.
250                let _ = std::mem::take(&mut (*address_map));
251            }
252        }
253
254        Ok(())
255    }
256
257    async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
258        if self.is_closed().await {
259            return Err(Error::ErrUseClosedNetworkConn);
260        }
261
262        {
263            let mut conns = self.conns.lock().await;
264            if let Some(conn) = conns.get(ufrag) {
265                // UDPMuxConn uses `Arc` internally so it's cheap to clone, but because
266                // we implement `Conn` we need to further wrap it in an `Arc` here.
267                return Ok(Arc::new(conn.clone()) as Arc<dyn Conn + Send + Sync>);
268            }
269
270            let muxed_conn = self.create_muxed_conn(ufrag)?;
271            let mut close_rx = muxed_conn.close_rx();
272            let cloned_self = Arc::clone(&self);
273            let cloned_ufrag = ufrag.to_string();
274            tokio::spawn(async move {
275                let _ = close_rx.changed().await;
276
277                // Arc needed
278                cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await;
279            });
280
281            conns.insert(ufrag.into(), muxed_conn.clone());
282
283            Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>)
284        }
285    }
286
287    async fn remove_conn_by_ufrag(&self, ufrag: &str) {
288        // Pion's ice implementation has both `RemoveConnByFrag` and `RemoveConn`, but since `conns`
289        // is keyed on `ufrag` their implementation is equivalent.
290
291        let removed_conn = {
292            let mut conns = self.conns.lock().await;
293            conns.remove(ufrag)
294        };
295
296        if let Some(conn) = removed_conn {
297            let mut address_map = self.address_map.write();
298
299            for address in conn.get_addresses() {
300                address_map.remove(&address);
301            }
302        }
303    }
304}
305
306#[async_trait]
307impl UDPMuxWriter for UDPMuxDefault {
308    async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
309        if self.is_closed().await {
310            return;
311        }
312
313        let key = conn.key();
314        {
315            let mut addresses = self.address_map.write();
316
317            addresses
318                .entry(addr)
319                .and_modify(|e| {
320                    if e.key() != key {
321                        e.remove_address(&addr);
322                        *e = conn.clone();
323                    }
324                })
325                .or_insert_with(|| conn.clone());
326        }
327
328        log::debug!("Registered {} for {}", addr, key);
329    }
330
331    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
332        self.params
333            .conn
334            .send_to(buf, *target)
335            .await
336            .map_err(Into::into)
337    }
338}