webrtc_ice/udp_mux/
udp_mux_conn.rs

1use std::collections::HashSet;
2use std::convert::TryInto;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::{Arc, Weak};
6
7use async_trait::async_trait;
8use tokio::sync::watch;
9use util::sync::Mutex;
10use util::{Buffer, Conn, Error};
11
12use super::socket_addr_ext::{SocketAddrExt, MAX_ADDR_SIZE};
13use super::{normalize_socket_addr, RECEIVE_MTU};
14
15/// A trait for a [`UDPMuxConn`] to communicate with an UDP mux.
16#[async_trait]
17pub trait UDPMuxWriter {
18    /// Registers an address for the given connection.
19    async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr);
20    /// Sends the content of the buffer to the given target.
21    ///
22    /// Returns the number of bytes sent or an error, if any.
23    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error>;
24}
25
26/// Parameters for a [`UDPMuxConn`].
27pub struct UDPMuxConnParams {
28    /// Local socket address.
29    pub local_addr: SocketAddr,
30    /// Static key identifying the connection.
31    pub key: String,
32    /// A `std::sync::Weak` reference to the UDP mux.
33    ///
34    /// NOTE: a non-owning reference should be used to prevent possible cycles.
35    pub udp_mux: Weak<dyn UDPMuxWriter + Send + Sync>,
36}
37
38type ConnResult<T> = Result<T, util::Error>;
39
40/// A UDP mux connection.
41#[derive(Clone)]
42pub struct UDPMuxConn {
43    /// Close Receiver. A copy of this can be obtained via [`close_tx`].
44    closed_watch_rx: watch::Receiver<bool>,
45
46    inner: Arc<UDPMuxConnInner>,
47}
48
49impl UDPMuxConn {
50    /// Creates a new [`UDPMuxConn`].
51    pub fn new(params: UDPMuxConnParams) -> Self {
52        let (closed_watch_tx, closed_watch_rx) = watch::channel(false);
53
54        Self {
55            closed_watch_rx,
56            inner: Arc::new(UDPMuxConnInner {
57                params,
58                closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
59                addresses: Default::default(),
60                buffer: Buffer::new(0, 0),
61            }),
62        }
63    }
64
65    /// Returns a key identifying this connection.
66    pub fn key(&self) -> &str {
67        &self.inner.params.key
68    }
69
70    /// Writes data to the given address. Returns an error if the buffer is too short or there's an
71    /// encoding error.
72    pub async fn write_packet(&self, data: &[u8], addr: SocketAddr) -> ConnResult<()> {
73        // NOTE: Pion/ice uses Sync.Pool to optimise this.
74        let mut buffer = make_buffer();
75        let mut offset = 0;
76
77        if (data.len() + MAX_ADDR_SIZE) > (RECEIVE_MTU + MAX_ADDR_SIZE) {
78            return Err(Error::ErrBufferShort);
79        }
80
81        // Format of buffer: | data len(2) | data bytes(dn) | addr len(2) | addr bytes(an) |
82        // Where the number in parenthesis indicate the number of bytes used
83        // `dn` and `an` are the length in bytes of data and addr respectively.
84
85        // SAFETY: `data.len()` is at most RECEIVE_MTU(8192) - MAX_ADDR_SIZE(27)
86        buffer[0..2].copy_from_slice(&(data.len() as u16).to_le_bytes()[..]);
87        offset += 2;
88
89        buffer[offset..offset + data.len()].copy_from_slice(data);
90        offset += data.len();
91
92        let len = addr.encode(&mut buffer[offset + 2..])?;
93        buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_le_bytes()[..]);
94        offset += 2 + len;
95
96        self.inner.buffer.write(&buffer[..offset]).await?;
97
98        Ok(())
99    }
100
101    /// Returns true if this connection is closed.
102    pub fn is_closed(&self) -> bool {
103        self.inner.is_closed()
104    }
105
106    /// Gets a copy of the close [`tokio::sync::watch::Receiver`] that fires when this
107    /// connection is closed.
108    pub fn close_rx(&self) -> watch::Receiver<bool> {
109        self.closed_watch_rx.clone()
110    }
111
112    /// Closes this connection.
113    pub fn close(&self) {
114        self.inner.close();
115    }
116
117    /// Gets the list of the addresses associated with this connection.
118    pub fn get_addresses(&self) -> Vec<SocketAddr> {
119        self.inner.get_addresses()
120    }
121
122    /// Registers a new address for this connection.
123    pub async fn add_address(&self, addr: SocketAddr) {
124        self.inner.add_address(addr);
125        if let Some(mux) = self.inner.params.udp_mux.upgrade() {
126            mux.register_conn_for_address(self, addr).await;
127        }
128    }
129
130    /// Deregisters an address.
131    pub fn remove_address(&self, addr: &SocketAddr) {
132        self.inner.remove_address(addr)
133    }
134
135    /// Returns true if the given address is associated with this connection.
136    pub fn contains_address(&self, addr: &SocketAddr) -> bool {
137        self.inner.contains_address(addr)
138    }
139}
140
141struct UDPMuxConnInner {
142    params: UDPMuxConnParams,
143
144    /// Close Sender. We'll send a value on this channel when we close
145    closed_watch_tx: Mutex<Option<watch::Sender<bool>>>,
146
147    /// Remote addresses we've seen on this connection.
148    addresses: Mutex<HashSet<SocketAddr>>,
149
150    buffer: Buffer,
151}
152
153impl UDPMuxConnInner {
154    // Sending/Recieving
155    async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> {
156        // NOTE: Pion/ice uses Sync.Pool to optimise this.
157        let mut buffer = make_buffer();
158        let mut offset = 0;
159
160        let len = self.buffer.read(&mut buffer, None).await?;
161        // We always have at least.
162        //
163        // * 2 bytes for data len
164        // * 2 bytes for addr len
165        // * 7 bytes for an Ipv4 addr
166        if len < 11 {
167            return Err(Error::ErrBufferShort);
168        }
169
170        let data_len: usize = buffer[..2]
171            .try_into()
172            .map(u16::from_le_bytes)
173            .map(From::from)
174            .unwrap();
175        offset += 2;
176
177        let total = 2 + data_len + 2 + 7;
178        if data_len > buf.len() || total > len {
179            return Err(Error::ErrBufferShort);
180        }
181
182        buf[..data_len].copy_from_slice(&buffer[offset..offset + data_len]);
183        offset += data_len;
184
185        let address_len: usize = buffer[offset..offset + 2]
186            .try_into()
187            .map(u16::from_le_bytes)
188            .map(From::from)
189            .unwrap();
190        offset += 2;
191
192        let addr = SocketAddr::decode(&buffer[offset..offset + address_len])?;
193
194        Ok((data_len, addr))
195    }
196
197    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> ConnResult<usize> {
198        if let Some(mux) = self.params.udp_mux.upgrade() {
199            mux.send_to(buf, target).await
200        } else {
201            Err(Error::Other(format!(
202                "wanted to send {} bytes to {}, but UDP mux is gone",
203                buf.len(),
204                target
205            )))
206        }
207    }
208
209    fn is_closed(&self) -> bool {
210        self.closed_watch_tx.lock().is_none()
211    }
212
213    fn close(self: &Arc<Self>) {
214        let mut closed_tx = self.closed_watch_tx.lock();
215
216        if let Some(tx) = closed_tx.take() {
217            let _ = tx.send(true);
218            drop(closed_tx);
219
220            let cloned_self = Arc::clone(self);
221
222            {
223                let mut addresses = self.addresses.lock();
224                *addresses = Default::default();
225            }
226
227            // NOTE: Alternatively we could wait on the buffer closing here so that
228            // our caller can wait for things to fully settle down
229            tokio::spawn(async move {
230                cloned_self.buffer.close().await;
231            });
232        }
233    }
234
235    fn local_addr(&self) -> SocketAddr {
236        self.params.local_addr
237    }
238
239    // Address related methods
240    pub(super) fn get_addresses(&self) -> Vec<SocketAddr> {
241        let addresses = self.addresses.lock();
242
243        addresses.iter().copied().collect()
244    }
245
246    pub(super) fn add_address(self: &Arc<Self>, addr: SocketAddr) {
247        {
248            let mut addresses = self.addresses.lock();
249            addresses.insert(addr);
250        }
251    }
252
253    pub(super) fn remove_address(&self, addr: &SocketAddr) {
254        {
255            let mut addresses = self.addresses.lock();
256            addresses.remove(addr);
257        }
258    }
259
260    pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool {
261        let addresses = self.addresses.lock();
262
263        addresses.contains(addr)
264    }
265}
266
267#[async_trait]
268impl Conn for UDPMuxConn {
269    async fn connect(&self, _addr: SocketAddr) -> ConnResult<()> {
270        Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
271    }
272
273    async fn recv(&self, _buf: &mut [u8]) -> ConnResult<usize> {
274        Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
275    }
276
277    async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> {
278        self.inner.recv_from(buf).await
279    }
280
281    async fn send(&self, _buf: &[u8]) -> ConnResult<usize> {
282        Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
283    }
284
285    async fn send_to(&self, buf: &[u8], target: SocketAddr) -> ConnResult<usize> {
286        let normalized_target = normalize_socket_addr(&target, &self.inner.params.local_addr);
287
288        if !self.contains_address(&normalized_target) {
289            self.add_address(normalized_target).await;
290        }
291
292        self.inner.send_to(buf, &normalized_target).await
293    }
294
295    fn local_addr(&self) -> ConnResult<SocketAddr> {
296        Ok(self.inner.local_addr())
297    }
298
299    fn remote_addr(&self) -> Option<SocketAddr> {
300        None
301    }
302    async fn close(&self) -> ConnResult<()> {
303        self.inner.close();
304
305        Ok(())
306    }
307
308    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
309        self
310    }
311}
312
313#[inline(always)]
314/// Create a buffer of appropriate size to fit both a packet with max RECEIVE_MTU and the
315/// additional metadata used for muxing.
316fn make_buffer() -> Vec<u8> {
317    // The 4 extra bytes are used to encode the length of the data and address respectively.
318    // See [`write_packet`] for details.
319    vec![0u8; RECEIVE_MTU + MAX_ADDR_SIZE + 2 + 2]
320}