1use alloc::boxed::Box;
9use alloc::sync::Arc;
10use core::future::poll_fn;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use std::collections::HashSet;
14use std::io;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16
17use async_trait::async_trait;
18use futures_util::stream::Stream;
19use futures_util::{TryFutureExt, future::Future, ready};
20use tracing::{debug, trace, warn};
21
22use crate::runtime::{RuntimeProvider, Time};
23use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
24use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
25
26#[async_trait]
28pub trait DnsUdpSocket
29where
30 Self: Send + Sync + Sized + Unpin,
31{
32 type Time: Time;
34
35 fn poll_recv_from(
38 &self,
39 cx: &mut Context<'_>,
40 buf: &mut [u8],
41 ) -> Poll<io::Result<(usize, SocketAddr)>>;
42
43 async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
46 poll_fn(|cx| self.poll_recv_from(cx, buf)).await
47 }
48
49 fn poll_send_to(
51 &self,
52 cx: &mut Context<'_>,
53 buf: &[u8],
54 target: SocketAddr,
55 ) -> Poll<io::Result<usize>>;
56
57 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
59 poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
60 }
61}
62
63#[async_trait]
65pub trait UdpSocket: DnsUdpSocket {
66 async fn connect(addr: SocketAddr) -> io::Result<Self>;
68
69 async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
71
72 async fn bind(addr: SocketAddr) -> io::Result<Self>;
74}
75
76#[must_use = "futures do nothing unless polled"]
78pub struct UdpStream<P: RuntimeProvider> {
79 socket: P::Udp,
80 outbound_messages: StreamReceiver,
81}
82
83impl<P: RuntimeProvider> UdpStream<P> {
84 #[allow(clippy::type_complexity)]
112 pub fn new(
113 remote_addr: SocketAddr,
114 bind_addr: Option<SocketAddr>,
115 avoid_local_ports: Option<Arc<HashSet<u16>>>,
116 os_port_selection: bool,
117 provider: P,
118 ) -> (
119 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
120 BufDnsStreamHandle,
121 ) {
122 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
123
124 let next_socket = NextRandomUdpSocket::new(
126 remote_addr,
127 bind_addr,
128 avoid_local_ports.unwrap_or_default(),
129 os_port_selection,
130 provider,
131 );
132
133 let stream = Box::new(next_socket.map_ok(move |socket| Self {
136 socket,
137 outbound_messages,
138 }));
139
140 (stream, message_sender)
141 }
142}
143
144impl<P: RuntimeProvider> UdpStream<P> {
145 pub fn with_bound(socket: P::Udp, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
160 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
161 let stream = Self {
162 socket,
163 outbound_messages,
164 };
165
166 (stream, message_sender)
167 }
168
169 #[allow(unused)]
170 pub(crate) fn from_parts(socket: P::Udp, outbound_messages: StreamReceiver) -> Self {
171 Self {
172 socket,
173 outbound_messages,
174 }
175 }
176}
177
178impl<P: RuntimeProvider> UdpStream<P> {
179 #[allow(clippy::type_complexity)]
180 fn pollable_split(&mut self) -> (&mut P::Udp, &mut StreamReceiver) {
181 (&mut self.socket, &mut self.outbound_messages)
182 }
183}
184
185impl<P: RuntimeProvider> Stream for UdpStream<P> {
186 type Item = Result<SerialMessage, io::Error>;
187
188 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189 let (socket, outbound_messages) = self.pollable_split();
190 let socket = Pin::new(socket);
191 let mut outbound_messages = Pin::new(outbound_messages);
192
193 while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
196 let addr = message.addr();
198
199 if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
204 warn!(
206 "error sending message to {} on udp_socket, dropping response: {}",
207 addr, e
208 );
209 }
210
211 assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
213 }
214
215 let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
220 let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
221
222 let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
223 Poll::Ready(Some(Ok(serial_message)))
224 }
225}
226
227#[must_use = "futures do nothing unless polled"]
228pub(crate) struct NextRandomUdpSocket<P: RuntimeProvider> {
229 name_server: SocketAddr,
230 bind_address: SocketAddr,
231 provider: P,
232 attempted: usize,
234 #[allow(clippy::type_complexity)]
235 future: Option<Pin<Box<dyn Send + Future<Output = io::Result<P::Udp>>>>>,
236 avoid_local_ports: Arc<HashSet<u16>>,
237 os_port_selection: bool,
238}
239
240impl<P: RuntimeProvider> NextRandomUdpSocket<P> {
241 pub(crate) fn new(
246 name_server: SocketAddr,
247 bind_addr: Option<SocketAddr>,
248 avoid_local_ports: Arc<HashSet<u16>>,
249 os_port_selection: bool,
250 provider: P,
251 ) -> Self {
252 let bind_address = match bind_addr {
253 Some(ba) => ba,
254 None => match name_server {
255 SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
256 SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
257 },
258 };
259
260 Self {
261 name_server,
262 bind_address,
263 provider,
264 attempted: 0,
265 future: None,
266 avoid_local_ports,
267 os_port_selection,
268 }
269 }
270}
271
272impl<P: RuntimeProvider> Future for NextRandomUdpSocket<P> {
273 type Output = Result<P::Udp, io::Error>;
274
275 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278 let this = self.get_mut();
279 loop {
280 this.future = match this.future.take() {
281 Some(mut future) => match future.as_mut().poll(cx) {
282 Poll::Ready(Ok(socket)) => {
283 debug!("created socket successfully");
284 return Poll::Ready(Ok(socket));
285 }
286 Poll::Ready(Err(err)) => match err.kind() {
287 io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
288 if this.attempted < ATTEMPT_RANDOM + 1 =>
289 {
290 debug!("unable to bind port, attempt: {}: {err}", this.attempted);
291 this.attempted += 1;
292 None
293 }
294 _ => {
295 debug!("failed to bind port: {}", err);
296 return Poll::Ready(Err(err));
297 }
298 },
299 Poll::Pending => {
300 debug!("unable to bind port, attempt: {}", this.attempted);
301 this.future = Some(future);
302 return Poll::Pending;
303 }
304 },
305 None => {
306 let mut bind_addr = this.bind_address;
307
308 if !this.os_port_selection && bind_addr.port() == 0 {
309 while this.attempted < ATTEMPT_RANDOM {
310 let port = rand::random_range(1024..=u16::MAX);
316 if this.avoid_local_ports.contains(&port) {
317 this.attempted += 1;
323 continue;
324 } else {
325 bind_addr = SocketAddr::new(bind_addr.ip(), port);
326 break;
327 }
328 }
329 }
330
331 trace!(port = bind_addr.port(), "binding UDP socket");
332 Some(Box::pin(
333 this.provider.bind_udp(bind_addr, this.name_server),
334 ))
335 }
336 }
337 }
338 }
339}
340
341const ATTEMPT_RANDOM: usize = 10;
342
343#[cfg(feature = "tokio")]
344#[async_trait]
345impl UdpSocket for tokio::net::UdpSocket {
346 async fn connect(addr: SocketAddr) -> io::Result<Self> {
350 let bind_addr: SocketAddr = match addr {
351 SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
352 SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
353 };
354
355 Self::connect_with_bind(addr, bind_addr).await
356 }
357
358 async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
360 let socket = Self::bind(bind_addr).await?;
361
362 Ok(socket)
366 }
367
368 async fn bind(addr: SocketAddr) -> io::Result<Self> {
369 Self::bind(addr).await
370 }
371}
372
373#[cfg(feature = "tokio")]
374#[async_trait]
375impl DnsUdpSocket for tokio::net::UdpSocket {
376 type Time = crate::runtime::TokioTime;
377
378 fn poll_recv_from(
379 &self,
380 cx: &mut Context<'_>,
381 buf: &mut [u8],
382 ) -> Poll<io::Result<(usize, SocketAddr)>> {
383 let mut buf = tokio::io::ReadBuf::new(buf);
384 let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
385 let len = buf.filled().len();
386
387 Poll::Ready(Ok((len, addr)))
388 }
389
390 fn poll_send_to(
391 &self,
392 cx: &mut Context<'_>,
393 buf: &[u8],
394 target: SocketAddr,
395 ) -> Poll<io::Result<usize>> {
396 Self::poll_send_to(self, cx, buf, target)
397 }
398}
399
400#[cfg(test)]
401#[cfg(feature = "tokio")]
402mod tests {
403 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
404
405 use test_support::subscribe;
406
407 use crate::{
408 runtime::TokioRuntimeProvider,
409 tests::{next_random_socket_test, udp_stream_test},
410 };
411
412 #[tokio::test]
413 async fn test_next_random_socket() {
414 subscribe();
415 let provider = TokioRuntimeProvider::new();
416 next_random_socket_test(provider).await;
417 }
418
419 #[tokio::test]
420 async fn test_udp_stream_ipv4() {
421 subscribe();
422 let provider = TokioRuntimeProvider::new();
423 udp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
424 }
425
426 #[tokio::test]
427 async fn test_udp_stream_ipv6() {
428 subscribe();
429 let provider = TokioRuntimeProvider::new();
430 udp_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
431 }
432}