1use core::sync::atomic;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use socket2::SockAddr;
7use tokio::net::{ToSocketAddrs, UdpSocket};
8use tokio::sync::{mpsc, Mutex};
9use util::ifaces;
10
11use crate::config::*;
12use crate::error::*;
13use crate::message::header::*;
14use crate::message::name::*;
15use crate::message::parser::*;
16use crate::message::question::*;
17use crate::message::resource::a::*;
18use crate::message::resource::*;
19use crate::message::*;
20
21mod conn_test;
22
23pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
24
25const INBOUND_BUFFER_SIZE: usize = 65535;
26const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
27const MAX_MESSAGE_RECORDS: usize = 3;
28const RESPONSE_TTL: u32 = 120;
29
30pub struct DnsConn {
32 socket: Arc<UdpSocket>,
33 dst_addr: SocketAddr,
34
35 query_interval: Duration,
36 queries: Arc<Mutex<Vec<Query>>>,
37
38 is_server_closed: Arc<atomic::AtomicBool>,
39 close_server: mpsc::Sender<()>,
40}
41
42struct Query {
43 name_with_suffix: String,
44 query_result_chan: mpsc::Sender<QueryResult>,
45}
46
47struct QueryResult {
48 answer: ResourceHeader,
49 addr: SocketAddr,
50}
51
52impl DnsConn {
53 pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
55 let socket = socket2::Socket::new(
56 socket2::Domain::IPV4,
57 socket2::Type::DGRAM,
58 Some(socket2::Protocol::UDP),
59 )?;
60
61 #[cfg(feature = "reuse_port")]
62 #[cfg(target_family = "unix")]
63 socket.set_reuse_port(true)?;
64
65 socket.set_reuse_address(true)?;
66 socket.set_broadcast(true)?;
67 socket.set_nonblocking(true)?;
68
69 socket.bind(&SockAddr::from(addr))?;
70 {
71 let mut join_error_count = 0;
72 let interfaces = match ifaces::ifaces() {
73 Ok(e) => e,
74 Err(e) => {
75 log::error!("Error getting interfaces: {:?}", e);
76 return Err(Error::Other(e.to_string()));
77 }
78 };
79
80 for interface in &interfaces {
81 if let Some(SocketAddr::V4(e)) = interface.addr {
82 if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
83 {
84 log::trace!("Error connecting multicast, error: {:?}", e);
85 join_error_count += 1;
86 continue;
87 }
88
89 log::trace!("Connected to interface address {:?}", e);
90 }
91 }
92
93 if join_error_count >= interfaces.len() {
94 return Err(Error::ErrJoiningMulticastGroup);
95 }
96 }
97
98 let socket = UdpSocket::from_std(socket.into())?;
99
100 let local_names = config
101 .local_names
102 .iter()
103 .map(|l| l.to_string() + ".")
104 .collect();
105
106 let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
107
108 let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
109
110 let (close_server_send, close_server_rcv) = mpsc::channel(1);
111
112 let c = DnsConn {
113 query_interval: if config.query_interval != Duration::from_secs(0) {
114 config.query_interval
115 } else {
116 DEFAULT_QUERY_INTERVAL
117 },
118
119 queries: Arc::new(Mutex::new(vec![])),
120 socket: Arc::new(socket),
121 dst_addr,
122 is_server_closed: Arc::clone(&is_server_closed),
123 close_server: close_server_send,
124 };
125
126 let queries = c.queries.clone();
127 let socket = Arc::clone(&c.socket);
128
129 tokio::spawn(async move {
130 DnsConn::start(
131 close_server_rcv,
132 is_server_closed,
133 socket,
134 local_names,
135 dst_addr,
136 queries,
137 )
138 .await
139 });
140
141 Ok(c)
142 }
143
144 pub async fn close(&self) -> Result<()> {
146 log::info!("Closing connection");
147 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
148 return Err(Error::ErrConnectionClosed);
149 }
150
151 log::trace!("Sending close command to server");
152 match self.close_server.send(()).await {
153 Ok(_) => {
154 log::trace!("Close command sent");
155 Ok(())
156 }
157 Err(e) => {
158 log::warn!("Error sending close command to server: {:?}", e);
159 Err(Error::ErrConnectionClosed)
160 }
161 }
162 }
163
164 pub async fn query(
167 &self,
168 name: &str,
169 mut close_query_signal: mpsc::Receiver<()>,
170 ) -> Result<(ResourceHeader, SocketAddr)> {
171 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
172 return Err(Error::ErrConnectionClosed);
173 }
174
175 let name_with_suffix = name.to_owned() + ".";
176
177 let (query_tx, mut query_rx) = mpsc::channel(1);
178 {
179 let mut queries = self.queries.lock().await;
180 queries.push(Query {
181 name_with_suffix: name_with_suffix.clone(),
182 query_result_chan: query_tx,
183 });
184 }
185
186 log::trace!("Sending query");
187 self.send_question(&name_with_suffix).await;
188
189 loop {
190 tokio::select! {
191 _ = tokio::time::sleep(self.query_interval) => {
192 log::trace!("Sending query");
193 self.send_question(&name_with_suffix).await
194 },
195
196 _ = close_query_signal.recv() => {
197 log::info!("Query close signal received.");
198 return Err(Error::ErrConnectionClosed)
199 },
200
201 res_opt = query_rx.recv() =>{
202 log::info!("Received query result");
203 if let Some(res) = res_opt{
204 return Ok((res.answer, res.addr));
205 }
206 }
207 }
208 }
209 }
210
211 async fn send_question(&self, name: &str) {
212 let packed_name = match Name::new(name) {
213 Ok(pn) => pn,
214 Err(err) => {
215 log::warn!("Failed to construct mDNS packet: {}", err);
216 return;
217 }
218 };
219
220 let raw_query = {
221 let mut msg = Message {
222 header: Header::default(),
223 questions: vec![Question {
224 typ: DnsType::A,
225 class: DNSCLASS_INET,
226 name: packed_name,
227 }],
228 ..Default::default()
229 };
230
231 match msg.pack() {
232 Ok(v) => v,
233 Err(err) => {
234 log::error!("Failed to construct mDNS packet {}", err);
235 return;
236 }
237 }
238 };
239
240 log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
241 if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
242 log::error!("Failed to send mDNS packet {}", err);
243 }
244 }
245
246 async fn start(
247 mut closed_rx: mpsc::Receiver<()>,
248 close_server: Arc<atomic::AtomicBool>,
249 socket: Arc<UdpSocket>,
250 local_names: Vec<String>,
251 dst_addr: SocketAddr,
252 queries: Arc<Mutex<Vec<Query>>>,
253 ) -> Result<()> {
254 log::info!("Looping and listening {:?}", socket.local_addr());
255
256 let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
257 let (mut n, mut src);
258
259 loop {
260 tokio::select! {
261 _ = closed_rx.recv() => {
262 log::info!("Closing server connection");
263 close_server.store(true, atomic::Ordering::SeqCst);
264
265 return Ok(());
266 }
267
268 result = socket.recv_from(&mut b) => {
269 match result{
270 Ok((len, addr)) => {
271 n = len;
272 src = addr;
273 log::trace!("Received new connection from {:?}", addr);
274 },
275
276 Err(err) => {
277 log::error!("Error receiving from socket connection: {:?}", err);
278 continue;
279 },
280 }
281 }
282 }
283
284 let mut p = Parser::default();
285 if let Err(err) = p.start(&b[..n]) {
286 log::error!("Failed to parse mDNS packet {}", err);
287 continue;
288 }
289
290 run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
291 }
292 }
293}
294
295async fn run(
296 p: &mut Parser<'_>,
297 socket: &Arc<UdpSocket>,
298 local_names: &[String],
299 src: SocketAddr,
300 dst_addr: SocketAddr,
301 queries: &Arc<Mutex<Vec<Query>>>,
302) {
303 let mut interface_addr = None;
304 for _ in 0..=MAX_MESSAGE_RECORDS {
305 let q = match p.question() {
306 Ok(q) => q,
307 Err(err) => {
308 if Error::ErrSectionDone == err {
309 log::trace!("Parsing has completed");
310 break;
311 } else {
312 log::error!("Failed to parse mDNS packet {}", err);
313 return;
314 }
315 }
316 };
317
318 for local_name in local_names {
319 if *local_name == q.name.data {
320 let interface_addr = match interface_addr {
321 Some(addr) => addr,
322 None => match get_interface_addr_for_ip(src).await {
323 Ok(addr) => {
324 interface_addr.replace(addr);
325 addr
326 }
327 Err(e) => {
328 log::warn!(
329 "Failed to get local interface to communicate with {}: {:?}",
330 &src,
331 e
332 );
333 continue;
334 }
335 },
336 };
337
338 log::trace!(
339 "Found local name: {} to send answer, IP {}, interface addr {}",
340 local_name,
341 src.ip(),
342 interface_addr
343 );
344 if let Err(e) =
345 send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
346 {
347 log::error!("Error sending answer to client: {:?}", e);
348 continue;
349 };
350 }
351 }
352 }
353
354 let _ = p.skip_all_questions();
356
357 for _ in 0..=MAX_MESSAGE_RECORDS {
358 let a = match p.answer_header() {
359 Ok(a) => a,
360 Err(err) => {
361 if Error::ErrSectionDone != err {
362 log::warn!("Failed to parse mDNS packet {}", err);
363 }
364 return;
365 }
366 };
367
368 if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
369 continue;
370 }
371
372 let mut qs = queries.lock().await;
373 for j in (0..qs.len()).rev() {
374 if qs[j].name_with_suffix == a.name.data {
375 let _ = qs[j]
376 .query_result_chan
377 .send(QueryResult {
378 answer: a.clone(),
379 addr: src,
380 })
381 .await;
382 qs.remove(j);
383 }
384 }
385 }
386}
387
388async fn send_answer(
389 socket: &Arc<UdpSocket>,
390 interface_addr: &SocketAddr,
391 name: &str,
392 dst: IpAddr,
393 dst_addr: SocketAddr,
394) -> Result<()> {
395 let raw_answer = {
396 let mut msg = Message {
397 header: Header {
398 response: true,
399 authoritative: true,
400 ..Default::default()
401 },
402
403 answers: vec![Resource {
404 header: ResourceHeader {
405 typ: DnsType::A,
406 class: DNSCLASS_INET,
407 name: Name::new(name)?,
408 ttl: RESPONSE_TTL,
409 ..Default::default()
410 },
411 body: Some(Box::new(AResource {
412 a: match interface_addr.ip() {
413 IpAddr::V4(ip) => ip.octets(),
414 IpAddr::V6(_) => {
415 return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
416 }
417 },
418 })),
419 }],
420 ..Default::default()
421 };
422
423 msg.pack()?
424 };
425
426 socket.send_to(&raw_answer, dst_addr).await?;
427 log::trace!("Sent answer to IP {}", dst);
428
429 Ok(())
430}
431
432async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
433 let socket = UdpSocket::bind("0.0.0.0:0").await?;
434 socket.connect(addr).await?;
435 socket.local_addr()
436}