hickory_proto/udp/
udp_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::boxed::Box;
9use alloc::sync::Arc;
10use core::future::poll_fn;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use std::collections::HashSet;
14use std::io;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16
17use async_trait::async_trait;
18use futures_util::stream::Stream;
19use futures_util::{TryFutureExt, future::Future, ready};
20use tracing::{debug, trace, warn};
21
22use crate::runtime::{RuntimeProvider, Time};
23use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
24use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
25
26/// Trait for DnsUdpSocket
27#[async_trait]
28pub trait DnsUdpSocket
29where
30    Self: Send + Sync + Sized + Unpin,
31{
32    /// Time implementation used for this type
33    type Time: Time;
34
35    /// Poll once Receive data from the socket and returns the number of bytes read and the address from
36    /// where the data came on success.
37    fn poll_recv_from(
38        &self,
39        cx: &mut Context<'_>,
40        buf: &mut [u8],
41    ) -> Poll<io::Result<(usize, SocketAddr)>>;
42
43    /// Receive data from the socket and returns the number of bytes read and the address from
44    /// where the data came on success.
45    async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
46        poll_fn(|cx| self.poll_recv_from(cx, buf)).await
47    }
48
49    /// Poll once to send data to the given address.
50    fn poll_send_to(
51        &self,
52        cx: &mut Context<'_>,
53        buf: &[u8],
54        target: SocketAddr,
55    ) -> Poll<io::Result<usize>>;
56
57    /// Send data to the given address.
58    async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
59        poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
60    }
61}
62
63/// Trait for UdpSocket
64#[async_trait]
65pub trait UdpSocket: DnsUdpSocket {
66    /// setups up a "client" udp connection that will only receive packets from the associated address
67    async fn connect(addr: SocketAddr) -> io::Result<Self>;
68
69    /// same as connect, but binds to the specified local address for sending address
70    async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
71
72    /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
73    async fn bind(addr: SocketAddr) -> io::Result<Self>;
74}
75
76/// A UDP stream of DNS binary packets
77#[must_use = "futures do nothing unless polled"]
78pub struct UdpStream<P: RuntimeProvider> {
79    socket: P::Udp,
80    outbound_messages: StreamReceiver,
81}
82
83impl<P: RuntimeProvider> UdpStream<P> {
84    /// This method is intended for client connections, see [`Self::with_bound`] for a method better
85    ///  for straight listening. It is expected that the resolver wrapper will be responsible for
86    ///  creating and managing new UdpStreams such that each new client would have a random port
87    ///  (reduce chance of cache poisoning). This will return a randomly assigned local port, unless
88    ///  a nonzero port number is specified in `bind_addr`.
89    ///
90    /// # Arguments
91    ///
92    /// * `remote_addr` - socket address for the remote connection (used to determine IPv4 or IPv6)
93    /// * `bind_addr` - optional local socket address to connect from (if a nonzero port number is
94    ///                 specified, it will be used instead of randomly selecting a port)
95    /// * `os_port_selection` - Boolean parameter to specify whether to use the operating system's
96    ///                         standard UDP port selection logic instead of Hickory's logic to
97    ///                         securely select a random source port. We do not recommend using
98    ///                         this option unless absolutely necessary, as the operating system
99    ///                         may select ephemeral ports from a smaller range than Hickory, which
100    ///                         can make response poisoning attacks easier to conduct. Some
101    ///                         operating systems (notably, Windows) might display a user-prompt to
102    ///                         allow a Hickory-specified port to be used, and setting this option
103    ///                         will prevent those prompts from being displayed. If os_port_selection
104    ///                         is true, avoid_local_udp_ports will be ignored.
105    /// * `provider` - async runtime provider, for I/O and timers
106    ///
107    /// # Return
108    ///
109    /// A tuple of a Future of a Stream which will handle sending and receiving messages, and a
110    ///  handle which can be used to send messages into the stream.
111    #[allow(clippy::type_complexity)]
112    pub fn new(
113        remote_addr: SocketAddr,
114        bind_addr: Option<SocketAddr>,
115        avoid_local_ports: Option<Arc<HashSet<u16>>>,
116        os_port_selection: bool,
117        provider: P,
118    ) -> (
119        Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
120        BufDnsStreamHandle,
121    ) {
122        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
123
124        // constructs a future for getting the next randomly bound port to a UdpSocket
125        let next_socket = NextRandomUdpSocket::new(
126            remote_addr,
127            bind_addr,
128            avoid_local_ports.unwrap_or_default(),
129            os_port_selection,
130            provider,
131        );
132
133        // This set of futures collapses the next udp socket into a stream which can be used for
134        //  sending and receiving udp packets.
135        let stream = Box::new(next_socket.map_ok(move |socket| Self {
136            socket,
137            outbound_messages,
138        }));
139
140        (stream, message_sender)
141    }
142}
143
144impl<P: RuntimeProvider> UdpStream<P> {
145    /// Initialize the Stream with an already bound socket. Generally this should be only used for
146    ///  server listening sockets. See [`Self::new`] for a client oriented socket. Specifically,
147    ///  this requires there is already a bound socket, whereas `new` makes sure to randomize ports
148    ///  for additional cache poison prevention.
149    ///
150    /// # Arguments
151    ///
152    /// * `socket` - an already bound UDP socket
153    /// * `remote_addr` - remote side of this connection
154    ///
155    /// # Return
156    ///
157    /// A tuple of a Stream which will handle sending and receiving messages, and a handle which can
158    ///  be used to send messages into the stream.
159    pub fn with_bound(socket: P::Udp, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
160        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
161        let stream = Self {
162            socket,
163            outbound_messages,
164        };
165
166        (stream, message_sender)
167    }
168
169    #[allow(unused)]
170    pub(crate) fn from_parts(socket: P::Udp, outbound_messages: StreamReceiver) -> Self {
171        Self {
172            socket,
173            outbound_messages,
174        }
175    }
176}
177
178impl<P: RuntimeProvider> UdpStream<P> {
179    #[allow(clippy::type_complexity)]
180    fn pollable_split(&mut self) -> (&mut P::Udp, &mut StreamReceiver) {
181        (&mut self.socket, &mut self.outbound_messages)
182    }
183}
184
185impl<P: RuntimeProvider> Stream for UdpStream<P> {
186    type Item = Result<SerialMessage, io::Error>;
187
188    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189        let (socket, outbound_messages) = self.pollable_split();
190        let socket = Pin::new(socket);
191        let mut outbound_messages = Pin::new(outbound_messages);
192
193        // this will not accept incoming data while there is data to send
194        //  makes this self throttling.
195        while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
196            // first try to send
197            let addr = message.addr();
198
199            // this will return if not ready,
200            //   meaning that sending will be preferred over receiving...
201
202            // TODO: shouldn't this return the error to send to the sender?
203            if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
204                // Drop the UDP packet and continue
205                warn!(
206                    "error sending message to {} on udp_socket, dropping response: {}",
207                    addr, e
208                );
209            }
210
211            // message sent, need to pop the message
212            assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
213        }
214
215        // For QoS, this will only accept one message and output that
216        // receive all inbound messages
217
218        // TODO: this should match edns settings
219        let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
220        let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
221
222        let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
223        Poll::Ready(Some(Ok(serial_message)))
224    }
225}
226
227#[must_use = "futures do nothing unless polled"]
228pub(crate) struct NextRandomUdpSocket<P: RuntimeProvider> {
229    name_server: SocketAddr,
230    bind_address: SocketAddr,
231    provider: P,
232    /// Number of unsuccessful attempts to pick a port.
233    attempted: usize,
234    #[allow(clippy::type_complexity)]
235    future: Option<Pin<Box<dyn Send + Future<Output = io::Result<P::Udp>>>>>,
236    avoid_local_ports: Arc<HashSet<u16>>,
237    os_port_selection: bool,
238}
239
240impl<P: RuntimeProvider> NextRandomUdpSocket<P> {
241    /// Creates a future for randomly binding to a local socket address for client connections,
242    /// if no port is specified.
243    ///
244    /// If a port is specified in the bind address it is used.
245    pub(crate) fn new(
246        name_server: SocketAddr,
247        bind_addr: Option<SocketAddr>,
248        avoid_local_ports: Arc<HashSet<u16>>,
249        os_port_selection: bool,
250        provider: P,
251    ) -> Self {
252        let bind_address = match bind_addr {
253            Some(ba) => ba,
254            None => match name_server {
255                SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
256                SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
257            },
258        };
259
260        Self {
261            name_server,
262            bind_address,
263            provider,
264            attempted: 0,
265            future: None,
266            avoid_local_ports,
267            os_port_selection,
268        }
269    }
270}
271
272impl<P: RuntimeProvider> Future for NextRandomUdpSocket<P> {
273    type Output = Result<P::Udp, io::Error>;
274
275    /// polls until there is an available next random UDP port,
276    /// if no port has been specified in bind_addr.
277    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278        let this = self.get_mut();
279        loop {
280            this.future = match this.future.take() {
281                Some(mut future) => match future.as_mut().poll(cx) {
282                    Poll::Ready(Ok(socket)) => {
283                        debug!("created socket successfully");
284                        return Poll::Ready(Ok(socket));
285                    }
286                    Poll::Ready(Err(err)) => match err.kind() {
287                        io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
288                            if this.attempted < ATTEMPT_RANDOM + 1 =>
289                        {
290                            debug!("unable to bind port, attempt: {}: {err}", this.attempted);
291                            this.attempted += 1;
292                            None
293                        }
294                        _ => {
295                            debug!("failed to bind port: {}", err);
296                            return Poll::Ready(Err(err));
297                        }
298                    },
299                    Poll::Pending => {
300                        debug!("unable to bind port, attempt: {}", this.attempted);
301                        this.future = Some(future);
302                        return Poll::Pending;
303                    }
304                },
305                None => {
306                    let mut bind_addr = this.bind_address;
307
308                    if !this.os_port_selection && bind_addr.port() == 0 {
309                        while this.attempted < ATTEMPT_RANDOM {
310                            // Per RFC 6056 Section 3.2:
311                            //
312                            // As mentioned in Section 2.1, the dynamic ports consist of the range
313                            // 49152-65535.  However, ephemeral port selection algorithms should use
314                            // the whole range 1024-65535.
315                            let port = rand::random_range(1024..=u16::MAX);
316                            if this.avoid_local_ports.contains(&port) {
317                                // Count this against the total number of attempts to pick a port.
318                                // RFC 6056 Section 3.3.2 notes that this algorithm should find a
319                                // suitable port in one or two attempts with high probability in
320                                // common scenarios. If `avoid_local_ports` is pathologically large,
321                                // then incrementing the counter here will prevent an infinite loop.
322                                this.attempted += 1;
323                                continue;
324                            } else {
325                                bind_addr = SocketAddr::new(bind_addr.ip(), port);
326                                break;
327                            }
328                        }
329                    }
330
331                    trace!(port = bind_addr.port(), "binding UDP socket");
332                    Some(Box::pin(
333                        this.provider.bind_udp(bind_addr, this.name_server),
334                    ))
335                }
336            }
337        }
338    }
339}
340
341const ATTEMPT_RANDOM: usize = 10;
342
343#[cfg(feature = "tokio")]
344#[async_trait]
345impl UdpSocket for tokio::net::UdpSocket {
346    /// sets up up a "client" udp connection that will only receive packets from the associated address
347    ///
348    /// if the addr is ipv4 then it will bind local addr to 0.0.0.0:0, ipv6 \[::\]0
349    async fn connect(addr: SocketAddr) -> io::Result<Self> {
350        let bind_addr: SocketAddr = match addr {
351            SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
352            SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
353        };
354
355        Self::connect_with_bind(addr, bind_addr).await
356    }
357
358    /// same as connect, but binds to the specified local address for sending address
359    async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
360        let socket = Self::bind(bind_addr).await?;
361
362        // TODO: research connect more, it appears to break UDP receiving tests, etc...
363        // socket.connect(addr).await?;
364
365        Ok(socket)
366    }
367
368    async fn bind(addr: SocketAddr) -> io::Result<Self> {
369        Self::bind(addr).await
370    }
371}
372
373#[cfg(feature = "tokio")]
374#[async_trait]
375impl DnsUdpSocket for tokio::net::UdpSocket {
376    type Time = crate::runtime::TokioTime;
377
378    fn poll_recv_from(
379        &self,
380        cx: &mut Context<'_>,
381        buf: &mut [u8],
382    ) -> Poll<io::Result<(usize, SocketAddr)>> {
383        let mut buf = tokio::io::ReadBuf::new(buf);
384        let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
385        let len = buf.filled().len();
386
387        Poll::Ready(Ok((len, addr)))
388    }
389
390    fn poll_send_to(
391        &self,
392        cx: &mut Context<'_>,
393        buf: &[u8],
394        target: SocketAddr,
395    ) -> Poll<io::Result<usize>> {
396        Self::poll_send_to(self, cx, buf, target)
397    }
398}
399
400#[cfg(test)]
401#[cfg(feature = "tokio")]
402mod tests {
403    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
404
405    use test_support::subscribe;
406
407    use crate::{
408        runtime::TokioRuntimeProvider,
409        tests::{next_random_socket_test, udp_stream_test},
410    };
411
412    #[tokio::test]
413    async fn test_next_random_socket() {
414        subscribe();
415        let provider = TokioRuntimeProvider::new();
416        next_random_socket_test(provider).await;
417    }
418
419    #[tokio::test]
420    async fn test_udp_stream_ipv4() {
421        subscribe();
422        let provider = TokioRuntimeProvider::new();
423        udp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
424    }
425
426    #[tokio::test]
427    async fn test_udp_stream_ipv6() {
428        subscribe();
429        let provider = TokioRuntimeProvider::new();
430        udp_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
431    }
432}