1use alloc::boxed::Box;
9use alloc::sync::Arc;
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use std::io;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
15
16use futures_util::stream::{Stream, StreamExt};
17use futures_util::{FutureExt, TryFutureExt, future, ready};
18use once_cell::sync::Lazy;
19use socket2::{self, Socket};
20use tokio::net::UdpSocket;
21use tracing::{debug, trace};
22
23use crate::BufDnsStreamHandle;
24use crate::multicast::MdnsQueryType;
25use crate::runtime::TokioRuntimeProvider;
26use crate::udp::UdpStream;
27use crate::xfer::SerialMessage;
28
29pub(crate) const MDNS_PORT: u16 = 5353;
30pub static MDNS_IPV4: Lazy<SocketAddr> =
32 Lazy::new(|| SocketAddr::new(Ipv4Addr::new(224, 0, 0, 251).into(), MDNS_PORT));
33pub static MDNS_IPV6: Lazy<SocketAddr> = Lazy::new(|| {
35 SocketAddr::new(
36 Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x00FB).into(),
37 MDNS_PORT,
38 )
39});
40
41#[must_use = "futures do nothing unless polled"]
43pub struct MdnsStream {
44 multicast_addr: SocketAddr,
46 datagram: Option<UdpStream<TokioRuntimeProvider>>,
48 multicast: Option<Arc<UdpSocket>>,
51 rcving_mcast: Option<Pin<Box<dyn Future<Output = io::Result<SerialMessage>> + Send>>>,
53}
54
55impl MdnsStream {
56 pub fn new_ipv4(
58 mdns_query_type: MdnsQueryType,
59 packet_ttl: Option<u32>,
60 ipv4_if: Option<Ipv4Addr>,
61 ) -> (
62 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
63 BufDnsStreamHandle,
64 ) {
65 Self::new(*MDNS_IPV4, mdns_query_type, packet_ttl, ipv4_if, None)
66 }
67
68 pub fn new_ipv6(
70 mdns_query_type: MdnsQueryType,
71 packet_ttl: Option<u32>,
72 ipv6_if: Option<u32>,
73 ) -> (
74 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
75 BufDnsStreamHandle,
76 ) {
77 Self::new(*MDNS_IPV6, mdns_query_type, packet_ttl, None, ipv6_if)
78 }
79
80 pub fn multicast_addr(&self) -> SocketAddr {
82 self.multicast_addr
83 }
84
85 pub fn new(
107 multicast_addr: SocketAddr,
108 mdns_query_type: MdnsQueryType,
109 packet_ttl: Option<u32>,
110 ipv4_if: Option<Ipv4Addr>,
111 ipv6_if: Option<u32>,
112 ) -> (
113 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
114 BufDnsStreamHandle,
115 ) {
116 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(multicast_addr);
117 let multicast_socket = match Self::join_multicast(&multicast_addr, mdns_query_type) {
118 Ok(socket) => socket,
119 Err(err) => return (Box::new(future::err(err)), message_sender),
120 };
121
122 let next_socket = Self::next_bound_local_address(
125 &multicast_addr,
126 mdns_query_type,
127 packet_ttl,
128 ipv4_if,
129 ipv6_if,
130 );
131
132 if let Some(ttl) = packet_ttl {
135 assert!(ttl > 0, "TTL must be greater than 0");
136 }
137
138 let stream = {
141 Box::new(
142 next_socket
143 .map(move |socket| match socket {
144 Ok(Some(socket)) => {
145 socket.set_nonblocking(true)?;
146 Ok(Some(UdpSocket::from_std(socket)?))
147 }
148 Ok(None) => Ok(None),
149 Err(err) => Err(err),
150 })
151 .map_ok(move |socket: Option<_>| {
152 let datagram: Option<_> =
153 socket.map(|socket| UdpStream::from_parts(socket, outbound_messages));
154 let multicast: Option<_> = multicast_socket.map(|multicast_socket| {
155 Arc::new(UdpSocket::from_std(multicast_socket).expect("bad handle?"))
156 });
157
158 Self {
159 multicast_addr,
160 datagram,
161 multicast,
162 rcving_mcast: None,
163 }
164 }),
165 )
166 };
167
168 (stream, message_sender)
169 }
170
171 #[cfg(windows)]
175 fn bind_multicast(socket: &Socket, multicast_addr: &SocketAddr) -> io::Result<()> {
176 let multicast_addr = match multicast_addr {
177 SocketAddr::V4(addr) => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), addr.port()),
178 SocketAddr::V6(addr) => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), addr.port()),
179 };
180 socket.bind(&socket2::SockAddr::from(multicast_addr))
181 }
182
183 #[cfg(unix)]
185 fn bind_multicast(socket: &Socket, multicast_addr: &SocketAddr) -> io::Result<()> {
186 socket.bind(&socket2::SockAddr::from(*multicast_addr))
187 }
188
189 fn join_multicast(
191 multicast_addr: &SocketAddr,
192 mdns_query_type: MdnsQueryType,
193 ) -> Result<Option<std::net::UdpSocket>, io::Error> {
194 if !mdns_query_type.join_multicast() {
195 return Ok(None);
196 }
197
198 let ip_addr = multicast_addr.ip();
199 if !ip_addr.is_multicast() {
201 return Err(io::Error::new(
202 io::ErrorKind::Other,
203 format!("expected multicast address for binding: {ip_addr}"),
204 ));
205 }
206
207 let socket = match &ip_addr {
211 IpAddr::V4(mdns_v4) => {
212 let socket = Socket::new(
213 socket2::Domain::IPV4,
214 socket2::Type::DGRAM,
215 Some(socket2::Protocol::UDP),
216 )?;
217 socket.join_multicast_v4(mdns_v4, &Ipv4Addr::UNSPECIFIED)?;
218 socket
219 }
220 IpAddr::V6(mdns_v6) => {
221 let socket = Socket::new(
222 socket2::Domain::IPV6,
223 socket2::Type::DGRAM,
224 Some(socket2::Protocol::UDP),
225 )?;
226
227 socket.set_only_v6(true)?;
228 socket.join_multicast_v6(mdns_v6, 0)?;
229 socket
230 }
231 };
232
233 socket.set_nonblocking(true)?;
234 socket.set_reuse_address(true)?;
235 #[cfg(unix)] socket.set_reuse_port(true)?;
237 Self::bind_multicast(&socket, multicast_addr)?;
238
239 debug!("joined {multicast_addr}");
240 Ok(Some(std::net::UdpSocket::from(socket)))
241 }
242
243 fn next_bound_local_address(
245 multicast_addr: &SocketAddr,
246 mdns_query_type: MdnsQueryType,
247 packet_ttl: Option<u32>,
248 ipv4_if: Option<Ipv4Addr>,
249 ipv6_if: Option<u32>,
250 ) -> NextRandomUdpSocket {
251 let bind_address: IpAddr = match multicast_addr {
252 SocketAddr::V4(..) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
253 SocketAddr::V6(..) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
254 };
255
256 NextRandomUdpSocket {
257 bind_address,
258 mdns_query_type,
259 packet_ttl,
260 ipv4_if,
261 ipv6_if,
262 }
263 }
264}
265
266impl Stream for MdnsStream {
267 type Item = io::Result<SerialMessage>;
268
269 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270 assert!(self.datagram.is_some() || self.multicast.is_some());
271
272 if let Some(datagram) = self.as_mut().datagram.as_mut() {
274 match datagram.poll_next_unpin(cx) {
275 Poll::Ready(ready) => return Poll::Ready(ready),
276 Poll::Pending => (), }
278 }
279
280 loop {
281 let msg = if let Some(receiving) = self.rcving_mcast.as_mut() {
282 let msg = ready!(receiving.as_mut().poll_unpin(cx))?;
284
285 Some(Poll::Ready(Some(Ok(msg))))
286 } else {
287 None
288 };
289
290 self.rcving_mcast = None;
291
292 if let Some(msg) = msg {
293 return msg;
294 }
295
296 if let Some(socket) = &self.multicast {
298 let socket = Arc::clone(socket);
299 let receive_future = async {
300 let socket = socket;
301 let mut buf = [0u8; 2_048];
302 let (len, src) = socket.recv_from(&mut buf).await?;
303
304 Ok(SerialMessage::new(
305 buf.iter().take(len).cloned().collect(),
306 src,
307 ))
308 };
309
310 self.rcving_mcast = Some(Box::pin(receive_future.boxed()));
311 }
312 }
313 }
314}
315
316#[must_use = "futures do nothing unless polled"]
317struct NextRandomUdpSocket {
318 bind_address: IpAddr,
319 mdns_query_type: MdnsQueryType,
320 packet_ttl: Option<u32>,
321 ipv4_if: Option<Ipv4Addr>,
322 ipv6_if: Option<u32>,
323}
324
325impl NextRandomUdpSocket {
326 fn prepare_sender(&self, socket: std::net::UdpSocket) -> io::Result<std::net::UdpSocket> {
327 let addr = socket.local_addr()?;
328 debug!("preparing sender on: {addr}");
329
330 let socket = Socket::from(socket);
331
332 match addr {
334 SocketAddr::V4(..) => {
335 socket.set_multicast_loop_v4(true)?;
336 socket.set_multicast_if_v4(&self.ipv4_if.unwrap_or(Ipv4Addr::UNSPECIFIED))?;
337 if let Some(ttl) = self.packet_ttl {
338 socket.set_ttl(ttl)?;
339 socket.set_multicast_ttl_v4(ttl)?;
340 }
341 }
342 SocketAddr::V6(..) => {
343 let ipv6_if = self.ipv6_if.unwrap_or_else(|| {
344 panic!("for ipv6 multicasting the interface must be specified")
345 });
346
347 socket.set_multicast_loop_v6(true)?;
348 socket.set_multicast_if_v6(ipv6_if)?;
349 if let Some(ttl) = self.packet_ttl {
350 socket.set_unicast_hops_v6(ttl)?;
351 socket.set_multicast_hops_v6(ttl)?;
352 }
353 }
354 }
355
356 Ok(std::net::UdpSocket::from(socket))
357 }
358}
359
360impl Future for NextRandomUdpSocket {
361 type Output = io::Result<Option<std::net::UdpSocket>>;
363
364 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368 if !self.mdns_query_type.sender() {
370 debug!("skipping sending stream");
371 Poll::Ready(Ok(None))
372 } else if self.mdns_query_type.bind_on_5353() {
373 let addr = SocketAddr::new(self.bind_address, MDNS_PORT);
374 debug!("binding sending stream to {}", addr);
375 let socket = std::net::UdpSocket::bind(addr)?;
376 let socket = self.prepare_sender(socket)?;
377
378 Poll::Ready(Ok(Some(socket)))
379 } else {
380 for attempt in 0..10 {
384 let port = rand::random_range(49152_u16..=u16::MAX);
389
390 if port == MDNS_PORT {
394 trace!("unlucky, got MDNS_PORT");
395 continue;
396 }
397
398 let addr = SocketAddr::new(self.bind_address, port);
399 debug!("binding sending stream to {}", addr);
400
401 match std::net::UdpSocket::bind(addr) {
402 Ok(socket) => {
403 let socket = self.prepare_sender(socket)?;
404 return Poll::Ready(Ok(Some(socket)));
405 }
406 Err(err) => debug!("unable to bind port, attempt: {}: {}", attempt, err),
407 }
408 }
409
410 debug!("could not get next random port, delaying");
411
412 cx.waker().wake_by_ref();
414 Poll::Pending
415 }
416 }
417}
418
419#[cfg(test)]
420pub(crate) mod tests {
421 #![allow(clippy::dbg_macro, clippy::print_stdout)]
422
423 use alloc::string::ToString;
424 use std::println;
425
426 use futures_util::future::Either;
427 use test_support::subscribe;
428 use tokio::runtime;
429
430 use super::*;
431 use crate::xfer::dns_handle::DnsStreamHandle;
432
433 const BASE_TEST_PORT: u16 = 5379;
435
436 static TEST_MDNS_IPV4: Lazy<IpAddr> = Lazy::new(|| Ipv4Addr::new(224, 0, 0, 250).into());
438 static TEST_MDNS_IPV6: Lazy<IpAddr> =
440 Lazy::new(|| Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x00FA).into());
441
442 #[tokio::test]
444 async fn test_next_random_socket() {
445 subscribe();
446
447 let (stream, _) = MdnsStream::new(
448 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT),
449 MdnsQueryType::OneShot,
450 Some(1),
451 None,
452 None,
453 );
454 let result = stream.await;
455
456 if let Err(error) = result {
457 println!("Random address error: {error:#?}");
458 panic!("failed to get next random address");
459 }
460 }
461
462 #[ignore]
464 #[test]
465 fn test_one_shot_mdns_ipv4() {
466 subscribe();
467 one_shot_mdns_test(SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 1));
468 }
469
470 #[test]
471 #[ignore]
472 fn test_one_shot_mdns_ipv6() {
473 subscribe();
474 one_shot_mdns_test(SocketAddr::new(*TEST_MDNS_IPV6, BASE_TEST_PORT + 2));
475 }
476
477 fn one_shot_mdns_test(mdns_addr: SocketAddr) {
479 use core::time::Duration;
480
481 let client_done = alloc::sync::Arc::new(core::sync::atomic::AtomicBool::new(false));
482
483 let test_bytes: &'static [u8; 8] = b"DEADBEEF";
484 let send_recv_times = 10;
485 let client_done_clone = client_done.clone();
486
487 let server_handle = std::thread::Builder::new()
489 .name("test_one_shot_mdns:server".to_string())
490 .spawn(move || {
491 let server_loop = runtime::Runtime::new().unwrap();
492 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
493 .flatten()
494 .boxed();
495
496 let (server_stream_future, mut server_sender) = MdnsStream::new(
499 mdns_addr,
500 MdnsQueryType::OneShotJoin,
501 Some(1),
502 None,
503 Some(5),
504 );
505
506 let mut server_stream = server_loop
508 .block_on(server_stream_future)
509 .expect("could not create mDNS listener")
510 .into_future();
511
512 for _ in 0..=send_recv_times {
513 if client_done_clone.load(core::sync::atomic::Ordering::Relaxed) {
514 return;
515 }
516 match server_loop.block_on(
518 future::lazy(|_| future::select(server_stream, timeout)).flatten(),
519 ) {
520 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
521 let (buffer_and_addr, stream_tmp): (
522 Option<Result<SerialMessage, io::Error>>,
523 MdnsStream,
524 ) = buffer_and_addr_stream_tmp;
525
526 server_stream = stream_tmp.into_future();
527 timeout = timeout_tmp;
528 let (buffer, addr) = buffer_and_addr
529 .expect("no msg received")
530 .expect("error receiving msg")
531 .into_parts();
532
533 assert_eq!(&buffer, test_bytes);
534 server_sender
538 .send(SerialMessage::new(test_bytes.to_vec(), addr))
539 .expect("could not send to client");
540 }
541 Either::Right(((), buffer_and_addr_stream_tmp)) => {
542 server_stream = buffer_and_addr_stream_tmp;
543 timeout =
544 future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
545 .flatten()
546 .boxed();
547 }
548 }
549
550 server_loop.block_on(tokio::time::sleep(Duration::from_millis(100)));
552 }
553 })
554 .unwrap();
555
556 let io_loop = runtime::Runtime::new().unwrap();
558
559 let (stream, mut sender) =
561 MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
562 let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
563 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
564 .flatten()
565 .boxed();
566 let mut successes = 0;
567
568 for _ in 0..send_recv_times {
569 sender
571 .send(SerialMessage::new(test_bytes.to_vec(), mdns_addr))
572 .unwrap();
573
574 println!("client sending data!");
575
576 match io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten()) {
578 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
579 let (buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp;
580 stream = stream_tmp.into_future();
581 timeout = timeout_tmp;
582
583 let (buffer, _addr) = buffer_and_addr
584 .expect("no msg received")
585 .expect("error receiving msg")
586 .into_parts();
587 println!("client got data!");
588
589 assert_eq!(&buffer, test_bytes);
590 successes += 1;
591 }
592 Either::Right(((), buffer_and_addr_stream_tmp)) => {
593 stream = buffer_and_addr_stream_tmp;
594 timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
595 .flatten()
596 .boxed();
597 }
598 }
599 }
600
601 client_done.store(true, core::sync::atomic::Ordering::Relaxed);
602 println!("successes: {successes}");
603 assert!(successes >= 1);
604 server_handle.join().expect("server thread failed");
605 }
606
607 #[ignore]
609 #[test]
610 fn test_passive_mdns() {
611 subscribe();
612 passive_mdns_test(
613 MdnsQueryType::Passive,
614 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 3),
615 )
616 }
617
618 #[ignore]
620 #[test]
621 fn test_oneshot_join_mdns() {
622 subscribe();
623 passive_mdns_test(
624 MdnsQueryType::OneShotJoin,
625 SocketAddr::new(*TEST_MDNS_IPV4, BASE_TEST_PORT + 4),
626 )
627 }
628
629 fn passive_mdns_test(mdns_query_type: MdnsQueryType, mdns_addr: SocketAddr) {
631 use core::time::Duration;
632
633 let server_got_packet = alloc::sync::Arc::new(core::sync::atomic::AtomicBool::new(false));
634
635 let test_bytes: &'static [u8; 8] = b"DEADBEEF";
636 let send_recv_times = 10;
637 let server_got_packet_clone = server_got_packet.clone();
638
639 let _server_handle = std::thread::Builder::new()
641 .name("test_one_shot_mdns:server".to_string())
642 .spawn(move || {
643 let io_loop = runtime::Runtime::new().unwrap();
644 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
645 .flatten()
646 .boxed();
647
648 let (server_stream_future, _server_sender) =
651 MdnsStream::new(mdns_addr, mdns_query_type, Some(1), None, Some(5));
652
653 let mut server_stream = io_loop
655 .block_on(server_stream_future)
656 .expect("could not create mDNS listener")
657 .into_future();
658
659 for _ in 0..=send_recv_times {
660 match io_loop.block_on(
662 future::lazy(|_| future::select(server_stream, timeout)).flatten(),
663 ) {
664 Either::Left((_buffer_and_addr_stream_tmp, _timeout_tmp)) => {
665 server_got_packet_clone
675 .store(true, core::sync::atomic::Ordering::Relaxed);
676 return;
677 }
678 Either::Right(((), buffer_and_addr_stream_tmp)) => {
679 server_stream = buffer_and_addr_stream_tmp;
680 timeout =
681 future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
682 .flatten()
683 .boxed();
684 }
685 }
686
687 io_loop.block_on(tokio::time::sleep(Duration::from_millis(100)));
689 }
690 })
691 .unwrap();
692
693 let io_loop = runtime::Runtime::new().unwrap();
695 let (stream, mut sender) =
697 MdnsStream::new(mdns_addr, MdnsQueryType::OneShot, Some(1), None, Some(5));
698 let mut stream = io_loop.block_on(stream).ok().unwrap().into_future();
699 let mut timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
700 .flatten()
701 .boxed();
702
703 for _ in 0..send_recv_times {
704 sender
706 .send(SerialMessage::new(test_bytes.to_vec(), mdns_addr))
707 .unwrap();
708
709 println!("client sending data!");
710
711 let run_result =
713 io_loop.block_on(future::lazy(|_| future::select(stream, timeout)).flatten());
714
715 if server_got_packet.load(core::sync::atomic::Ordering::Relaxed) {
716 return;
717 }
718
719 match run_result {
720 Either::Left((buffer_and_addr_stream_tmp, timeout_tmp)) => {
721 let (_buffer_and_addr, stream_tmp) = buffer_and_addr_stream_tmp;
722 stream = stream_tmp.into_future();
723 timeout = timeout_tmp;
724 }
725 Either::Right(((), buffer_and_addr_stream_tmp)) => {
726 stream = buffer_and_addr_stream_tmp;
727 timeout = future::lazy(|_| tokio::time::sleep(Duration::from_millis(100)))
728 .flatten()
729 .boxed();
730 }
731 }
732 }
733
734 panic!("server never got packet.");
735 }
736}