1use std;
9use std::io;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14
15use futures_util::stream::{Stream, StreamExt};
16use futures_util::{future, future::Future, ready, FutureExt, TryFutureExt};
17use once_cell::sync::Lazy;
18use rand;
19use rand::distributions::{uniform::Uniform, Distribution};
20use socket2::{self, Socket};
21use tokio::net::UdpSocket;
22use tracing::{debug, trace};
23
24use crate::multicast::MdnsQueryType;
25use crate::udp::UdpStream;
26use crate::xfer::SerialMessage;
27use crate::BufDnsStreamHandle;
28
29pub(crate) const MDNS_PORT: u16 = 5353;
30pub static MDNS_IPV4: Lazy<SocketAddr> =
32 Lazy::new(|| SocketAddr::new(Ipv4Addr::new(224, 0, 0, 251).into(), MDNS_PORT));
33pub static MDNS_IPV6: Lazy<SocketAddr> = Lazy::new(|| {
35 SocketAddr::new(
36 Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x00FB).into(),
37 MDNS_PORT,
38 )
39});
40
41#[must_use = "futures do nothing unless polled"]
43pub struct MdnsStream {
44 multicast_addr: SocketAddr,
46 datagram: Option<UdpStream<UdpSocket>>,
48 multicast: Option<Arc<UdpSocket>>,
51 rcving_mcast: Option<Pin<Box<dyn Future<Output = io::Result<SerialMessage>> + Send>>>,
53}
54
55impl MdnsStream {
56 pub fn new_ipv4(
58 mdns_query_type: MdnsQueryType,
59 packet_ttl: Option<u32>,
60 ipv4_if: Option<Ipv4Addr>,
61 ) -> (
62 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
63 BufDnsStreamHandle,
64 ) {
65 Self::new(*MDNS_IPV4, mdns_query_type, packet_ttl, ipv4_if, None)
66 }
67
68 pub fn new_ipv6(
70 mdns_query_type: MdnsQueryType,
71 packet_ttl: Option<u32>,
72 ipv6_if: Option<u32>,
73 ) -> (
74 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
75 BufDnsStreamHandle,
76 ) {
77 Self::new(*MDNS_IPV6, mdns_query_type, packet_ttl, None, ipv6_if)
78 }
79
80 pub fn multicast_addr(&self) -> SocketAddr {
82 self.multicast_addr
83 }
84
85 pub fn new(
107 multicast_addr: SocketAddr,
108 mdns_query_type: MdnsQueryType,
109 packet_ttl: Option<u32>,
110 ipv4_if: Option<Ipv4Addr>,
111 ipv6_if: Option<u32>,
112 ) -> (
113 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
114 BufDnsStreamHandle,
115 ) {
116 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(multicast_addr);
117 let multicast_socket = match Self::join_multicast(&multicast_addr, mdns_query_type) {
118 Ok(socket) => socket,
119 Err(err) => return (Box::new(future::err(err)), message_sender),
120 };
121
122 let next_socket = Self::next_bound_local_address(
125 &multicast_addr,
126 mdns_query_type,
127 packet_ttl,
128 ipv4_if,
129 ipv6_if,
130 );
131
132 if let Some(ttl) = packet_ttl {
135 assert!(ttl > 0, "TTL must be greater than 0");
136 }
137
138 let stream = {
141 Box::new(
142 next_socket
143 .map(move |socket| match socket {
144 Ok(Some(socket)) => Ok(Some(UdpSocket::from_std(socket)?)),
145 Ok(None) => Ok(None),
146 Err(err) => Err(err),
147 })
148 .map_ok(move |socket: Option<_>| {
149 let datagram: Option<_> =
150 socket.map(|socket| UdpStream::from_parts(socket, outbound_messages));
151 let multicast: Option<_> = multicast_socket.map(|multicast_socket| {
152 Arc::new(UdpSocket::from_std(multicast_socket).expect("bad handle?"))
153 });
154
155 Self {
156 multicast_addr,
157 datagram,
158 multicast,
159 rcving_mcast: None,
160 }
161 }),
162 )
163 };
164
165 (stream, message_sender)
166 }
167
168 #[cfg(windows)]
172 #[cfg_attr(docsrs, doc(cfg(windows)))]
173 fn bind_multicast(socket: &Socket, multicast_addr: &SocketAddr) -> io::Result<()> {
174 let multicast_addr = match *multicast_addr {
175 SocketAddr::V4(addr) => SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), addr.port()),
176 SocketAddr::V6(addr) => {
177 SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), addr.port())
178 }
179 };
180 socket.bind(&socket2::SockAddr::from(multicast_addr))
181 }
182
183 #[cfg(unix)]
185 #[cfg_attr(docsrs, doc(cfg(unix)))]
186 fn bind_multicast(socket: &Socket, multicast_addr: &SocketAddr) -> io::Result<()> {
187 socket.bind(&socket2::SockAddr::from(*multicast_addr))
188 }
189
190 fn join_multicast(
192 multicast_addr: &SocketAddr,
193 mdns_query_type: MdnsQueryType,
194 ) -> Result<Option<std::net::UdpSocket>, io::Error> {
195 if !mdns_query_type.join_multicast() {
196 return Ok(None);
197 }
198
199 let ip_addr = multicast_addr.ip();
200 if !ip_addr.is_multicast() {
202 return Err(io::Error::new(
203 io::ErrorKind::Other,
204 format!("expected multicast address for binding: {ip_addr}"),
205 ));
206 }
207
208 let socket = match ip_addr {
212 IpAddr::V4(ref mdns_v4) => {
213 let socket = Socket::new(
214 socket2::Domain::IPV4,
215 socket2::Type::DGRAM,
216 Some(socket2::Protocol::UDP),
217 )?;
218 socket.join_multicast_v4(mdns_v4, &Ipv4Addr::new(0, 0, 0, 0))?;
219 socket
220 }
221 IpAddr::V6(ref mdns_v6) => {
222 let socket = Socket::new(
223 socket2::Domain::IPV6,
224 socket2::Type::DGRAM,
225 Some(socket2::Protocol::UDP),
226 )?;
227
228 socket.set_only_v6(true)?;
229 socket.join_multicast_v6(mdns_v6, 0)?;
230 socket
231 }
232 };
233
234 socket.set_nonblocking(true)?;
235 socket.set_reuse_address(true)?;
236 #[cfg(unix)] socket.set_reuse_port(true)?;
238 Self::bind_multicast(&socket, multicast_addr)?;
239
240 debug!("joined {multicast_addr}");
241 Ok(Some(std::net::UdpSocket::from(socket)))
242 }
243
244 fn next_bound_local_address(
246 multicast_addr: &SocketAddr,
247 mdns_query_type: MdnsQueryType,
248 packet_ttl: Option<u32>,
249 ipv4_if: Option<Ipv4Addr>,
250 ipv6_if: Option<u32>,
251 ) -> NextRandomUdpSocket {
252 let bind_address: IpAddr = match *multicast_addr {
253 SocketAddr::V4(..) => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
254 SocketAddr::V6(..) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
255 };
256
257 NextRandomUdpSocket {
258 bind_address,
259 mdns_query_type,
260 packet_ttl,
261 ipv4_if,
262 ipv6_if,
263 }
264 }
265}
266
267impl Stream for MdnsStream {
268 type Item = io::Result<SerialMessage>;
269
270 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271 assert!(self.datagram.is_some() || self.multicast.is_some());
272
273 if let Some(ref mut datagram) = self.as_mut().datagram {
275 match datagram.poll_next_unpin(cx) {
276 Poll::Ready(ready) => return Poll::Ready(ready),
277 Poll::Pending => (), }
279 }
280
281 loop {
282 let msg = if let Some(ref mut receiving) = self.rcving_mcast {
283 let msg = ready!(receiving.as_mut().poll_unpin(cx))?;
285
286 Some(Poll::Ready(Some(Ok(msg))))
287 } else {
288 None
289 };
290
291 self.rcving_mcast = None;
292
293 if let Some(msg) = msg {
294 return msg;
295 }
296
297 if let Some(ref socket) = self.multicast {
299 let socket = Arc::clone(socket);
300 let receive_future = async {
301 let socket = socket;
302 let mut buf = [0u8; 2048];
303 let (len, src) = socket.recv_from(&mut buf).await?;
304
305 Ok(SerialMessage::new(
306 buf.iter().take(len).cloned().collect(),
307 src,
308 ))
309 };
310
311 self.rcving_mcast = Some(Box::pin(receive_future.boxed()));
312 }
313 }
314 }
315}
316
317#[must_use = "futures do nothing unless polled"]
318struct NextRandomUdpSocket {
319 bind_address: IpAddr,
320 mdns_query_type: MdnsQueryType,
321 packet_ttl: Option<u32>,
322 ipv4_if: Option<Ipv4Addr>,
323 ipv6_if: Option<u32>,
324}
325
326impl NextRandomUdpSocket {
327 fn prepare_sender(&self, socket: std::net::UdpSocket) -> io::Result<std::net::UdpSocket> {
328 let addr = socket.local_addr()?;
329 debug!("preparing sender on: {addr}");
330
331 let socket = Socket::from(socket);
332
333 match addr {
335 SocketAddr::V4(..) => {
336 socket.set_multicast_loop_v4(true)?;
337 socket.set_multicast_if_v4(
338 &self.ipv4_if.unwrap_or_else(|| Ipv4Addr::new(0, 0, 0, 0)),
339 )?;
340 if let Some(ttl) = self.packet_ttl {
341 socket.set_ttl(ttl)?;
342 socket.set_multicast_ttl_v4(ttl)?;
343 }
344 }
345 SocketAddr::V6(..) => {
346 let ipv6_if = self.ipv6_if.unwrap_or_else(|| {
347 panic!("for ipv6 multicasting the interface must be specified")
348 });
349
350 socket.set_multicast_loop_v6(true)?;
351 socket.set_multicast_if_v6(ipv6_if)?;
352 if let Some(ttl) = self.packet_ttl {
353 socket.set_unicast_hops_v6(ttl)?;
354 socket.set_multicast_hops_v6(ttl)?;
355 }
356 }
357 }
358
359 Ok(std::net::UdpSocket::from(socket))
360 }
361}
362
363impl Future for NextRandomUdpSocket {
364 type Output = io::Result<Option<std::net::UdpSocket>>;
366
367 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
371 if !self.mdns_query_type.sender() {
373 debug!("skipping sending stream");
374 Poll::Ready(Ok(None))
375 } else if self.mdns_query_type.bind_on_5353() {
376 let addr = SocketAddr::new(self.bind_address, MDNS_PORT);
377 debug!("binding sending stream to {}", addr);
378 let socket = std::net::UdpSocket::bind(addr)?;
379 let socket = self.prepare_sender(socket)?;
380
381 Poll::Ready(Ok(Some(socket)))
382 } else {
383 let rand_port_range = Uniform::new_inclusive(49152_u16, u16::MAX);
391 let mut rand = rand::thread_rng();
392
393 for attempt in 0..10 {
394 let port = rand_port_range.sample(&mut rand);
395
396 if port == MDNS_PORT {
400 trace!("unlucky, got MDNS_PORT");
401 continue;
402 }
403
404 let addr = SocketAddr::new(self.bind_address, port);
405 debug!("binding sending stream to {}", addr);
406
407 match std::net::UdpSocket::bind(addr) {
408 Ok(socket) => {
409 let socket = self.prepare_sender(socket)?;
410 return Poll::Ready(Ok(Some(socket)));
411 }
412 Err(err) => debug!("unable to bind port, attempt: {}: {}", attempt, err),
413 }
414 }
415
416 debug!("could not get next random port, delaying");
417
418 cx.waker().wake_by_ref();
420 Poll::Pending
421 }
422 }
423}
424
425#[cfg(test)]
426pub(crate) mod tests {
427 #![allow(clippy::dbg_macro, clippy::print_stdout)]
428
429 use super::*;
430 use crate::xfer::dns_handle::DnsStreamHandle;
431 use futures_util::future::Either;
432 use tokio::runtime;
433
434 const BASE_TEST_PORT: u16 = 5379;
436
437 static TEST_MDNS_IPV4: Lazy<IpAddr> = Lazy::new(|| Ipv4Addr::new(224, 0, 0, 250).into());
439 static TEST_MDNS_IPV6: Lazy<IpAddr> =
441 Lazy::new(|| Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x00FA).into());
442
443 #[test]
445 fn test_next_random_socket() {
446 let io_loop = runtime::Runtime::new().unwrap();
450 let (stream, _) = MdnsStream::new(
451 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT),
452 MdnsQueryType::OneShot,
453 Some(1),
454 None,
455 None,
456 );
457 let result = io_loop.block_on(stream);
458
459 if let Err(error) = result {
460 println!("Random address error: {error:#?}");
461 panic!("failed to get next random address");
462 }
463 }
464
465 #[ignore]
467 #[test]
468 fn test_one_shot_mdns_ipv4() {
469 one_shot_mdns_test(SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 1));
470 }
471
472 #[test]
473 #[ignore]
474 fn test_one_shot_mdns_ipv6() {
475 one_shot_mdns_test(SocketAddr::new(*TEST_MDNS_IPV6, BASE_TEST_PORT + 2));
476 }
477
478 fn one_shot_mdns_test(mdns_addr: SocketAddr) {
480 use std::time::Duration;
481
482 let client_done = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
483
484 let test_bytes: &'static [u8; 8] = b"DEADBEEF";
485 let send_recv_times = 10;
486 let client_done_clone = client_done.clone();
487
488 let server_handle = std::thread::Builder::new()
490 .name("test_one_shot_mdns:server".to_string())
491 .spawn(move || {
492 let server_loop = runtime::Runtime::new().unwrap();
493 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
494 .flatten()
495 .boxed();
496
497 let (server_stream_future, mut server_sender) = MdnsStream::new(
500 mdns_addr,
501 MdnsQueryType::OneShotJoin,
502 Some(1),
503 None,
504 Some(5),
505 );
506
507 let mut server_stream = server_loop
509 .block_on(server_stream_future)
510 .expect("could not create mDNS listener")
511 .into_future();
512
513 for _ in 0..=send_recv_times {
514 if client_done_clone.load(std::sync::atomic::Ordering::Relaxed) {
515 return;
516 }
517 match server_loop.block_on(
519 future::lazy(|_| future::select(server_stream, timeout)).flatten(),
520 ) {
521 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
522 let (buffer_and_addr, stream_tmp): (
523 Option<Result<SerialMessage, io::Error>>,
524 MdnsStream,
525 ) = buffer_and_addr_stream_tmp;
526
527 server_stream = stream_tmp.into_future();
528 timeout = timeout_tmp;
529 let (buffer, addr) = buffer_and_addr
530 .expect("no msg received")
531 .expect("error receiving msg")
532 .into_parts();
533
534 assert_eq!(&buffer, test_bytes);
535 server_sender
539 .send(SerialMessage::new(test_bytes.to_vec(), addr))
540 .expect("could not send to client");
541 }
542 Either::Right(((), buffer_and_addr_stream_tmp)) => {
543 server_stream = buffer_and_addr_stream_tmp;
544 timeout =
545 future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
546 .flatten()
547 .boxed();
548 }
549 }
550
551 server_loop.block_on(tokio::time::sleep(Duration::from_millis(100)));
553 }
554 })
555 .unwrap();
556
557 let io_loop = runtime::Runtime::new().unwrap();
559
560 let (stream, mut sender) =
562 MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
563 let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
564 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
565 .flatten()
566 .boxed();
567 let mut successes = 0;
568
569 for _ in 0..send_recv_times {
570 sender
572 .send(SerialMessage::new(test_bytes.to_vec(), mdns_addr))
573 .unwrap();
574
575 println!("client sending data!");
576
577 match io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten()) {
579 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
580 let (buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp;
581 stream = stream_tmp.into_future();
582 timeout = timeout_tmp;
583
584 let (buffer, _addr) = buffer_and_addr
585 .expect("no msg received")
586 .expect("error receiving msg")
587 .into_parts();
588 println!("client got data!");
589
590 assert_eq!(&buffer, test_bytes);
591 successes += 1;
592 }
593 Either::Right(((), buffer_and_addr_stream_tmp)) => {
594 stream = buffer_and_addr_stream_tmp;
595 timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
596 .flatten()
597 .boxed();
598 }
599 }
600 }
601
602 client_done.store(true, std::sync::atomic::Ordering::Relaxed);
603 println!("successes: {successes}");
604 assert!(successes >= 1);
605 server_handle.join().expect("server thread failed");
606 }
607
608 #[ignore]
610 #[test]
611 fn test_passive_mdns() {
612 passive_mdns_test(
613 MdnsQueryType::Passive,
614 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 3),
615 )
616 }
617
618 #[ignore]
620 #[test]
621 fn test_oneshot_join_mdns() {
622 passive_mdns_test(
623 MdnsQueryType::OneShotJoin,
624 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 4),
625 )
626 }
627
628 fn passive_mdns_test(mdns_query_type: MdnsQueryType, mdns_addr: SocketAddr) {
630 use std::time::Duration;
631
632 let server_got_packet = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
633
634 let test_bytes: &'static [u8; 8] = b"DEADBEEF";
635 let send_recv_times = 10;
636 let server_got_packet_clone = server_got_packet.clone();
637
638 let _server_handle = std::thread::Builder::new()
640 .name("test_one_shot_mdns:server".to_string())
641 .spawn(move || {
642 let io_loop = runtime::Runtime::new().unwrap();
643 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
644 .flatten()
645 .boxed();
646
647 let (server_stream_future, _server_sender) =
650 MdnsStream::new(mdns_addr, mdns_query_type, Some(1), None, Some(5));
651
652 let mut server_stream = io_loop
654 .block_on(server_stream_future)
655 .expect("could not create mDNS listener")
656 .into_future();
657
658 for _ in 0..=send_recv_times {
659 match io_loop.block_on(
661 future::lazy(|_| future::select(server_stream, timeout)).flatten(),
662 ) {
663 Either::Left((_buffer_and_addr_stream_tmp, _timeout_tmp)) => {
664 server_got_packet_clone
674 .store(true, std::sync::atomic::Ordering::Relaxed);
675 return;
676 }
677 Either::Right(((), buffer_and_addr_stream_tmp)) => {
678 server_stream = buffer_and_addr_stream_tmp;
679 timeout =
680 future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
681 .flatten()
682 .boxed();
683 }
684 }
685
686 io_loop.block_on(tokio::time::sleep(Duration::from_millis(100)));
688 }
689 })
690 .unwrap();
691
692 let io_loop = runtime::Runtime::new().unwrap();
694 let (stream, mut sender) =
696 MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
697 let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
698 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
699 .flatten()
700 .boxed();
701
702 for _ in 0..send_recv_times {
703 sender
705 .send(SerialMessage::new(test_bytes.to_vec(), mdns_addr))
706 .unwrap();
707
708 println!("client sending data!");
709
710 let run_result =
712 io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten());
713
714 if server_got_packet.load(std::sync::atomic::Ordering::Relaxed) {
715 return;
716 }
717
718 match run_result {
719 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
720 let (_buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp;
721 stream = stream_tmp.into_future();
722 timeout = timeout_tmp;
723 }
724 Either::Right(((), buffer_and_addr_stream_tmp)) => {
725 stream = buffer_and_addr_stream_tmp;
726 timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
727 .flatten()
728 .boxed();
729 }
730 }
731 }
732
733 panic!("server never got packet.");
734 }
735}