webrtc_ice/udp_mux/
udp_mux_conn.rs1use std::collections::HashSet;
2use std::convert::TryInto;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::{Arc, Weak};
6
7use async_trait::async_trait;
8use tokio::sync::watch;
9use util::sync::Mutex;
10use util::{Buffer, Conn, Error};
11
12use super::socket_addr_ext::{SocketAddrExt, MAX_ADDR_SIZE};
13use super::{normalize_socket_addr, RECEIVE_MTU};
14
15#[async_trait]
17pub trait UDPMuxWriter {
18 async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr);
20 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error>;
24}
25
26pub struct UDPMuxConnParams {
28 pub local_addr: SocketAddr,
30 pub key: String,
32 pub udp_mux: Weak<dyn UDPMuxWriter + Send + Sync>,
36}
37
38type ConnResult<T> = Result<T, util::Error>;
39
40#[derive(Clone)]
42pub struct UDPMuxConn {
43 closed_watch_rx: watch::Receiver<bool>,
45
46 inner: Arc<UDPMuxConnInner>,
47}
48
49impl UDPMuxConn {
50 pub fn new(params: UDPMuxConnParams) -> Self {
52 let (closed_watch_tx, closed_watch_rx) = watch::channel(false);
53
54 Self {
55 closed_watch_rx,
56 inner: Arc::new(UDPMuxConnInner {
57 params,
58 closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
59 addresses: Default::default(),
60 buffer: Buffer::new(0, 0),
61 }),
62 }
63 }
64
65 pub fn key(&self) -> &str {
67 &self.inner.params.key
68 }
69
70 pub async fn write_packet(&self, data: &[u8], addr: SocketAddr) -> ConnResult<()> {
73 let mut buffer = make_buffer();
75 let mut offset = 0;
76
77 if (data.len() + MAX_ADDR_SIZE) > (RECEIVE_MTU + MAX_ADDR_SIZE) {
78 return Err(Error::ErrBufferShort);
79 }
80
81 buffer[0..2].copy_from_slice(&(data.len() as u16).to_le_bytes()[..]);
87 offset += 2;
88
89 buffer[offset..offset + data.len()].copy_from_slice(data);
90 offset += data.len();
91
92 let len = addr.encode(&mut buffer[offset + 2..])?;
93 buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_le_bytes()[..]);
94 offset += 2 + len;
95
96 self.inner.buffer.write(&buffer[..offset]).await?;
97
98 Ok(())
99 }
100
101 pub fn is_closed(&self) -> bool {
103 self.inner.is_closed()
104 }
105
106 pub fn close_rx(&self) -> watch::Receiver<bool> {
109 self.closed_watch_rx.clone()
110 }
111
112 pub fn close(&self) {
114 self.inner.close();
115 }
116
117 pub fn get_addresses(&self) -> Vec<SocketAddr> {
119 self.inner.get_addresses()
120 }
121
122 pub async fn add_address(&self, addr: SocketAddr) {
124 self.inner.add_address(addr);
125 if let Some(mux) = self.inner.params.udp_mux.upgrade() {
126 mux.register_conn_for_address(self, addr).await;
127 }
128 }
129
130 pub fn remove_address(&self, addr: &SocketAddr) {
132 self.inner.remove_address(addr)
133 }
134
135 pub fn contains_address(&self, addr: &SocketAddr) -> bool {
137 self.inner.contains_address(addr)
138 }
139}
140
141struct UDPMuxConnInner {
142 params: UDPMuxConnParams,
143
144 closed_watch_tx: Mutex<Option<watch::Sender<bool>>>,
146
147 addresses: Mutex<HashSet<SocketAddr>>,
149
150 buffer: Buffer,
151}
152
153impl UDPMuxConnInner {
154 async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> {
156 let mut buffer = make_buffer();
158 let mut offset = 0;
159
160 let len = self.buffer.read(&mut buffer, None).await?;
161 if len < 11 {
167 return Err(Error::ErrBufferShort);
168 }
169
170 let data_len: usize = buffer[..2]
171 .try_into()
172 .map(u16::from_le_bytes)
173 .map(From::from)
174 .unwrap();
175 offset += 2;
176
177 let total = 2 + data_len + 2 + 7;
178 if data_len > buf.len() || total > len {
179 return Err(Error::ErrBufferShort);
180 }
181
182 buf[..data_len].copy_from_slice(&buffer[offset..offset + data_len]);
183 offset += data_len;
184
185 let address_len: usize = buffer[offset..offset + 2]
186 .try_into()
187 .map(u16::from_le_bytes)
188 .map(From::from)
189 .unwrap();
190 offset += 2;
191
192 let addr = SocketAddr::decode(&buffer[offset..offset + address_len])?;
193
194 Ok((data_len, addr))
195 }
196
197 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> ConnResult<usize> {
198 if let Some(mux) = self.params.udp_mux.upgrade() {
199 mux.send_to(buf, target).await
200 } else {
201 Err(Error::Other(format!(
202 "wanted to send {} bytes to {}, but UDP mux is gone",
203 buf.len(),
204 target
205 )))
206 }
207 }
208
209 fn is_closed(&self) -> bool {
210 self.closed_watch_tx.lock().is_none()
211 }
212
213 fn close(self: &Arc<Self>) {
214 let mut closed_tx = self.closed_watch_tx.lock();
215
216 if let Some(tx) = closed_tx.take() {
217 let _ = tx.send(true);
218 drop(closed_tx);
219
220 let cloned_self = Arc::clone(self);
221
222 {
223 let mut addresses = self.addresses.lock();
224 *addresses = Default::default();
225 }
226
227 tokio::spawn(async move {
230 cloned_self.buffer.close().await;
231 });
232 }
233 }
234
235 fn local_addr(&self) -> SocketAddr {
236 self.params.local_addr
237 }
238
239 pub(super) fn get_addresses(&self) -> Vec<SocketAddr> {
241 let addresses = self.addresses.lock();
242
243 addresses.iter().copied().collect()
244 }
245
246 pub(super) fn add_address(self: &Arc<Self>, addr: SocketAddr) {
247 {
248 let mut addresses = self.addresses.lock();
249 addresses.insert(addr);
250 }
251 }
252
253 pub(super) fn remove_address(&self, addr: &SocketAddr) {
254 {
255 let mut addresses = self.addresses.lock();
256 addresses.remove(addr);
257 }
258 }
259
260 pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool {
261 let addresses = self.addresses.lock();
262
263 addresses.contains(addr)
264 }
265}
266
267#[async_trait]
268impl Conn for UDPMuxConn {
269 async fn connect(&self, _addr: SocketAddr) -> ConnResult<()> {
270 Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
271 }
272
273 async fn recv(&self, _buf: &mut [u8]) -> ConnResult<usize> {
274 Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
275 }
276
277 async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> {
278 self.inner.recv_from(buf).await
279 }
280
281 async fn send(&self, _buf: &[u8]) -> ConnResult<usize> {
282 Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into())
283 }
284
285 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> ConnResult<usize> {
286 let normalized_target = normalize_socket_addr(&target, &self.inner.params.local_addr);
287
288 if !self.contains_address(&normalized_target) {
289 self.add_address(normalized_target).await;
290 }
291
292 self.inner.send_to(buf, &normalized_target).await
293 }
294
295 fn local_addr(&self) -> ConnResult<SocketAddr> {
296 Ok(self.inner.local_addr())
297 }
298
299 fn remote_addr(&self) -> Option<SocketAddr> {
300 None
301 }
302 async fn close(&self) -> ConnResult<()> {
303 self.inner.close();
304
305 Ok(())
306 }
307
308 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
309 self
310 }
311}
312
313#[inline(always)]
314fn make_buffer() -> Vec<u8> {
317 vec![0u8; RECEIVE_MTU + MAX_ADDR_SIZE + 2 + 2]
320}