webrtc_mdns/conn/
mod.rs

1use core::sync::atomic;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use socket2::SockAddr;
7use tokio::net::{ToSocketAddrs, UdpSocket};
8use tokio::sync::{mpsc, Mutex};
9use util::ifaces;
10
11use crate::config::*;
12use crate::error::*;
13use crate::message::header::*;
14use crate::message::name::*;
15use crate::message::parser::*;
16use crate::message::question::*;
17use crate::message::resource::a::*;
18use crate::message::resource::*;
19use crate::message::*;
20
21mod conn_test;
22
23pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
24
25const INBOUND_BUFFER_SIZE: usize = 65535;
26const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
27const MAX_MESSAGE_RECORDS: usize = 3;
28const RESPONSE_TTL: u32 = 120;
29
30// Conn represents a mDNS Server
31pub struct DnsConn {
32    socket: Arc<UdpSocket>,
33    dst_addr: SocketAddr,
34
35    query_interval: Duration,
36    queries: Arc<Mutex<Vec<Query>>>,
37
38    is_server_closed: Arc<atomic::AtomicBool>,
39    close_server: mpsc::Sender<()>,
40}
41
42struct Query {
43    name_with_suffix: String,
44    query_result_chan: mpsc::Sender<QueryResult>,
45}
46
47struct QueryResult {
48    answer: ResourceHeader,
49    addr: SocketAddr,
50}
51
52impl DnsConn {
53    /// server establishes a mDNS connection over an existing connection
54    pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
55        let socket = socket2::Socket::new(
56            socket2::Domain::IPV4,
57            socket2::Type::DGRAM,
58            Some(socket2::Protocol::UDP),
59        )?;
60
61        #[cfg(feature = "reuse_port")]
62        #[cfg(target_family = "unix")]
63        socket.set_reuse_port(true)?;
64
65        socket.set_reuse_address(true)?;
66        socket.set_broadcast(true)?;
67        socket.set_nonblocking(true)?;
68
69        socket.bind(&SockAddr::from(addr))?;
70        {
71            let mut join_error_count = 0;
72            let interfaces = match ifaces::ifaces() {
73                Ok(e) => e,
74                Err(e) => {
75                    log::error!("Error getting interfaces: {:?}", e);
76                    return Err(Error::Other(e.to_string()));
77                }
78            };
79
80            for interface in &interfaces {
81                if let Some(SocketAddr::V4(e)) = interface.addr {
82                    if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
83                    {
84                        log::trace!("Error connecting multicast, error: {:?}", e);
85                        join_error_count += 1;
86                        continue;
87                    }
88
89                    log::trace!("Connected to interface address {:?}", e);
90                }
91            }
92
93            if join_error_count >= interfaces.len() {
94                return Err(Error::ErrJoiningMulticastGroup);
95            }
96        }
97
98        let socket = UdpSocket::from_std(socket.into())?;
99
100        let local_names = config
101            .local_names
102            .iter()
103            .map(|l| l.to_string() + ".")
104            .collect();
105
106        let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
107
108        let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
109
110        let (close_server_send, close_server_rcv) = mpsc::channel(1);
111
112        let c = DnsConn {
113            query_interval: if config.query_interval != Duration::from_secs(0) {
114                config.query_interval
115            } else {
116                DEFAULT_QUERY_INTERVAL
117            },
118
119            queries: Arc::new(Mutex::new(vec![])),
120            socket: Arc::new(socket),
121            dst_addr,
122            is_server_closed: Arc::clone(&is_server_closed),
123            close_server: close_server_send,
124        };
125
126        let queries = c.queries.clone();
127        let socket = Arc::clone(&c.socket);
128
129        tokio::spawn(async move {
130            DnsConn::start(
131                close_server_rcv,
132                is_server_closed,
133                socket,
134                local_names,
135                dst_addr,
136                queries,
137            )
138            .await
139        });
140
141        Ok(c)
142    }
143
144    /// Close closes the mDNS Conn
145    pub async fn close(&self) -> Result<()> {
146        log::info!("Closing connection");
147        if self.is_server_closed.load(atomic::Ordering::SeqCst) {
148            return Err(Error::ErrConnectionClosed);
149        }
150
151        log::trace!("Sending close command to server");
152        match self.close_server.send(()).await {
153            Ok(_) => {
154                log::trace!("Close command sent");
155                Ok(())
156            }
157            Err(e) => {
158                log::warn!("Error sending close command to server: {:?}", e);
159                Err(Error::ErrConnectionClosed)
160            }
161        }
162    }
163
164    /// Query sends mDNS Queries for the following name until
165    /// either there's a close signal or we get a result
166    pub async fn query(
167        &self,
168        name: &str,
169        mut close_query_signal: mpsc::Receiver<()>,
170    ) -> Result<(ResourceHeader, SocketAddr)> {
171        if self.is_server_closed.load(atomic::Ordering::SeqCst) {
172            return Err(Error::ErrConnectionClosed);
173        }
174
175        let name_with_suffix = name.to_owned() + ".";
176
177        let (query_tx, mut query_rx) = mpsc::channel(1);
178        {
179            let mut queries = self.queries.lock().await;
180            queries.push(Query {
181                name_with_suffix: name_with_suffix.clone(),
182                query_result_chan: query_tx,
183            });
184        }
185
186        log::trace!("Sending query");
187        self.send_question(&name_with_suffix).await;
188
189        loop {
190            tokio::select! {
191                _ = tokio::time::sleep(self.query_interval) => {
192                    log::trace!("Sending query");
193                    self.send_question(&name_with_suffix).await
194                },
195
196                _ = close_query_signal.recv() => {
197                    log::info!("Query close signal received.");
198                    return Err(Error::ErrConnectionClosed)
199                },
200
201                res_opt = query_rx.recv() =>{
202                    log::info!("Received query result");
203                    if let Some(res) = res_opt{
204                        return Ok((res.answer, res.addr));
205                    }
206                }
207            }
208        }
209    }
210
211    async fn send_question(&self, name: &str) {
212        let packed_name = match Name::new(name) {
213            Ok(pn) => pn,
214            Err(err) => {
215                log::warn!("Failed to construct mDNS packet: {}", err);
216                return;
217            }
218        };
219
220        let raw_query = {
221            let mut msg = Message {
222                header: Header::default(),
223                questions: vec![Question {
224                    typ: DnsType::A,
225                    class: DNSCLASS_INET,
226                    name: packed_name,
227                }],
228                ..Default::default()
229            };
230
231            match msg.pack() {
232                Ok(v) => v,
233                Err(err) => {
234                    log::error!("Failed to construct mDNS packet {}", err);
235                    return;
236                }
237            }
238        };
239
240        log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
241        if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
242            log::error!("Failed to send mDNS packet {}", err);
243        }
244    }
245
246    async fn start(
247        mut closed_rx: mpsc::Receiver<()>,
248        close_server: Arc<atomic::AtomicBool>,
249        socket: Arc<UdpSocket>,
250        local_names: Vec<String>,
251        dst_addr: SocketAddr,
252        queries: Arc<Mutex<Vec<Query>>>,
253    ) -> Result<()> {
254        log::info!("Looping and listening {:?}", socket.local_addr());
255
256        let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
257        let (mut n, mut src);
258
259        loop {
260            tokio::select! {
261                _ = closed_rx.recv() => {
262                    log::info!("Closing server connection");
263                    close_server.store(true, atomic::Ordering::SeqCst);
264
265                    return Ok(());
266                }
267
268                result = socket.recv_from(&mut b) => {
269                    match result{
270                        Ok((len, addr)) => {
271                            n = len;
272                            src = addr;
273                            log::trace!("Received new connection from {:?}", addr);
274                        },
275
276                        Err(err) => {
277                            log::error!("Error receiving from socket connection: {:?}", err);
278                            continue;
279                        },
280                    }
281                }
282            }
283
284            let mut p = Parser::default();
285            if let Err(err) = p.start(&b[..n]) {
286                log::error!("Failed to parse mDNS packet {}", err);
287                continue;
288            }
289
290            run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
291        }
292    }
293}
294
295async fn run(
296    p: &mut Parser<'_>,
297    socket: &Arc<UdpSocket>,
298    local_names: &[String],
299    src: SocketAddr,
300    dst_addr: SocketAddr,
301    queries: &Arc<Mutex<Vec<Query>>>,
302) {
303    let mut interface_addr = None;
304    for _ in 0..=MAX_MESSAGE_RECORDS {
305        let q = match p.question() {
306            Ok(q) => q,
307            Err(err) => {
308                if Error::ErrSectionDone == err {
309                    log::trace!("Parsing has completed");
310                    break;
311                } else {
312                    log::error!("Failed to parse mDNS packet {}", err);
313                    return;
314                }
315            }
316        };
317
318        for local_name in local_names {
319            if *local_name == q.name.data {
320                let interface_addr = match interface_addr {
321                    Some(addr) => addr,
322                    None => match get_interface_addr_for_ip(src).await {
323                        Ok(addr) => {
324                            interface_addr.replace(addr);
325                            addr
326                        }
327                        Err(e) => {
328                            log::warn!(
329                                "Failed to get local interface to communicate with {}: {:?}",
330                                &src,
331                                e
332                            );
333                            continue;
334                        }
335                    },
336                };
337
338                log::trace!(
339                    "Found local name: {} to send answer, IP {}, interface addr {}",
340                    local_name,
341                    src.ip(),
342                    interface_addr
343                );
344                if let Err(e) =
345                    send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
346                {
347                    log::error!("Error sending answer to client: {:?}", e);
348                    continue;
349                };
350            }
351        }
352    }
353
354    // There might be more than MAX_MESSAGE_RECORDS questions, so skip the rest
355    let _ = p.skip_all_questions();
356
357    for _ in 0..=MAX_MESSAGE_RECORDS {
358        let a = match p.answer_header() {
359            Ok(a) => a,
360            Err(err) => {
361                if Error::ErrSectionDone != err {
362                    log::warn!("Failed to parse mDNS packet {}", err);
363                }
364                return;
365            }
366        };
367
368        if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
369            continue;
370        }
371
372        let mut qs = queries.lock().await;
373        for j in (0..qs.len()).rev() {
374            if qs[j].name_with_suffix == a.name.data {
375                let _ = qs[j]
376                    .query_result_chan
377                    .send(QueryResult {
378                        answer: a.clone(),
379                        addr: src,
380                    })
381                    .await;
382                qs.remove(j);
383            }
384        }
385    }
386}
387
388async fn send_answer(
389    socket: &Arc<UdpSocket>,
390    interface_addr: &SocketAddr,
391    name: &str,
392    dst: IpAddr,
393    dst_addr: SocketAddr,
394) -> Result<()> {
395    let raw_answer = {
396        let mut msg = Message {
397            header: Header {
398                response: true,
399                authoritative: true,
400                ..Default::default()
401            },
402
403            answers: vec![Resource {
404                header: ResourceHeader {
405                    typ: DnsType::A,
406                    class: DNSCLASS_INET,
407                    name: Name::new(name)?,
408                    ttl: RESPONSE_TTL,
409                    ..Default::default()
410                },
411                body: Some(Box::new(AResource {
412                    a: match interface_addr.ip() {
413                        IpAddr::V4(ip) => ip.octets(),
414                        IpAddr::V6(_) => {
415                            return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
416                        }
417                    },
418                })),
419            }],
420            ..Default::default()
421        };
422
423        msg.pack()?
424    };
425
426    socket.send_to(&raw_answer, dst_addr).await?;
427    log::trace!("Sent answer to IP {}", dst);
428
429    Ok(())
430}
431
432async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
433    let socket = UdpSocket::bind("0.0.0.0:0").await?;
434    socket.connect(addr).await?;
435    socket.local_addr()
436}