hickory_proto/udp/
udp_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::marker::PhantomData;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14
15use async_trait::async_trait;
16use futures_util::stream::Stream;
17use futures_util::{future::Future, ready, TryFutureExt};
18use rand;
19use rand::distributions::{uniform::Uniform, Distribution};
20use tracing::{debug, warn};
21
22use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
23use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
24use crate::Time;
25
26pub(crate) type UdpCreator<S> = Arc<
27    dyn Send
28        + Sync
29        + (Fn(
30            SocketAddr, // local addr
31            SocketAddr, // server addr
32        ) -> Pin<Box<dyn Send + (Future<Output = Result<S, std::io::Error>>)>>),
33>;
34
35/// Trait for DnsUdpSocket
36#[async_trait]
37pub trait DnsUdpSocket
38where
39    Self: Send + Sync + Sized + Unpin,
40{
41    /// Time implementation used for this type
42    type Time: Time;
43
44    /// Poll once Receive data from the socket and returns the number of bytes read and the address from
45    /// where the data came on success.
46    fn poll_recv_from(
47        &self,
48        cx: &mut Context<'_>,
49        buf: &mut [u8],
50    ) -> Poll<io::Result<(usize, SocketAddr)>>;
51
52    /// Receive data from the socket and returns the number of bytes read and the address from
53    /// where the data came on success.
54    async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
55        futures_util::future::poll_fn(|cx| self.poll_recv_from(cx, buf)).await
56    }
57
58    /// Poll once to send data to the given address.
59    fn poll_send_to(
60        &self,
61        cx: &mut Context<'_>,
62        buf: &[u8],
63        target: SocketAddr,
64    ) -> Poll<io::Result<usize>>;
65
66    /// Send data to the given address.
67    async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
68        futures_util::future::poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
69    }
70}
71
72/// Trait for UdpSocket
73#[async_trait]
74pub trait UdpSocket: DnsUdpSocket {
75    /// setups up a "client" udp connection that will only receive packets from the associated address
76    async fn connect(addr: SocketAddr) -> io::Result<Self>;
77
78    /// same as connect, but binds to the specified local address for sending address
79    async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
80
81    /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
82    async fn bind(addr: SocketAddr) -> io::Result<Self>;
83}
84
85/// A UDP stream of DNS binary packets
86#[must_use = "futures do nothing unless polled"]
87pub struct UdpStream<S: Send> {
88    socket: S,
89    outbound_messages: StreamReceiver,
90}
91
92/// To implement quinn::AsyncUdpSocket, we need our custom socket capable of getting local address.
93pub trait QuicLocalAddr {
94    /// Get local address
95    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr>;
96}
97
98#[cfg(feature = "tokio-runtime")]
99use tokio::net::UdpSocket as TokioUdpSocket;
100
101#[cfg(feature = "tokio-runtime")]
102#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
103#[allow(unreachable_pub)]
104impl QuicLocalAddr for TokioUdpSocket {
105    fn local_addr(&self) -> std::io::Result<SocketAddr> {
106        self.local_addr()
107    }
108}
109
110impl<S: UdpSocket + Send + 'static> UdpStream<S> {
111    /// This method is intended for client connections, see `with_bound` for a method better for
112    ///  straight listening. It is expected that the resolver wrapper will be responsible for
113    ///  creating and managing new UdpStreams such that each new client would have a random port
114    ///  (reduce chance of cache poisoning). This will return a randomly assigned local port.
115    ///
116    /// # Arguments
117    ///
118    /// * `remote_addr` - socket address for the remote connection (used to determine IPv4 or IPv6)
119    ///
120    /// # Return
121    ///
122    /// a tuple of a Future Stream which will handle sending and receiving messages, and a
123    ///  handle which can be used to send messages into the stream.
124    #[allow(clippy::type_complexity)]
125    pub fn new(
126        remote_addr: SocketAddr,
127        bind_addr: Option<SocketAddr>,
128    ) -> (
129        Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
130        BufDnsStreamHandle,
131    ) {
132        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
133
134        // TODO: allow the bind address to be specified...
135        // constructs a future for getting the next randomly bound port to a UdpSocket
136        let next_socket = NextRandomUdpSocket::new(&remote_addr, &bind_addr);
137
138        // This set of futures collapses the next udp socket into a stream which can be used for
139        //  sending and receiving udp packets.
140        let stream = Box::new(next_socket.map_ok(move |socket| Self {
141            socket,
142            outbound_messages,
143        }));
144
145        (stream, message_sender)
146    }
147}
148
149impl<S: DnsUdpSocket + Send + 'static> UdpStream<S> {
150    /// Initialize the Stream with an already bound socket. Generally this should be only used for
151    ///  server listening sockets. See `new` for a client oriented socket. Specifically, this there
152    ///  is already a bound socket in this context, whereas `new` makes sure to randomize ports
153    ///  for additional cache poison prevention.
154    ///
155    /// # Arguments
156    ///
157    /// * `socket` - an already bound UDP socket
158    /// * `remote_addr` - remote side of this connection
159    ///
160    /// # Return
161    ///
162    /// a tuple of a Future Stream which will handle sending and receiving messages, and a
163    ///  handle which can be used to send messages into the stream.
164    pub fn with_bound(socket: S, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
165        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
166        let stream = Self {
167            socket,
168            outbound_messages,
169        };
170
171        (stream, message_sender)
172    }
173
174    #[allow(unused)]
175    pub(crate) fn from_parts(socket: S, outbound_messages: StreamReceiver) -> Self {
176        Self {
177            socket,
178            outbound_messages,
179        }
180    }
181}
182
183impl<S: Send> UdpStream<S> {
184    #[allow(clippy::type_complexity)]
185    fn pollable_split(&mut self) -> (&mut S, &mut StreamReceiver) {
186        (&mut self.socket, &mut self.outbound_messages)
187    }
188}
189
190impl<S: DnsUdpSocket + Send + 'static> Stream for UdpStream<S> {
191    type Item = Result<SerialMessage, io::Error>;
192
193    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194        let (socket, outbound_messages) = self.pollable_split();
195        let socket = Pin::new(socket);
196        let mut outbound_messages = Pin::new(outbound_messages);
197
198        // this will not accept incoming data while there is data to send
199        //  makes this self throttling.
200        while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
201            // first try to send
202            let addr = message.addr();
203
204            // this will return if not ready,
205            //   meaning that sending will be preferred over receiving...
206
207            // TODO: shouldn't this return the error to send to the sender?
208            if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
209                // Drop the UDP packet and continue
210                warn!(
211                    "error sending message to {} on udp_socket, dropping response: {}",
212                    addr, e
213                );
214            }
215
216            // message sent, need to pop the message
217            assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
218        }
219
220        // For QoS, this will only accept one message and output that
221        // receive all inbound messages
222
223        // TODO: this should match edns settings
224        let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
225        let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
226
227        let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
228        Poll::Ready(Some(Ok(serial_message)))
229    }
230}
231
232#[must_use = "futures do nothing unless polled"]
233pub(crate) struct NextRandomUdpSocket<S> {
234    name_server: SocketAddr,
235    bind_address: SocketAddr,
236    closure: UdpCreator<S>,
237    marker: PhantomData<S>,
238}
239
240impl<S: UdpSocket + 'static> NextRandomUdpSocket<S> {
241    /// Creates a future for randomly binding to a local socket address for client connections,
242    /// if no port is specified.
243    ///
244    /// If a port is specified in the bind address it is used.
245    pub(crate) fn new(name_server: &SocketAddr, bind_addr: &Option<SocketAddr>) -> Self {
246        let bind_address = match bind_addr {
247            Some(ba) => *ba,
248            None => match *name_server {
249                SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
250                SocketAddr::V6(..) => {
251                    SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
252                }
253            },
254        };
255
256        Self {
257            name_server: *name_server,
258            bind_address,
259            closure: Arc::new(|local_addr: _, _server_addr: _| S::bind(local_addr)),
260            marker: PhantomData,
261        }
262    }
263}
264
265impl<S: DnsUdpSocket> NextRandomUdpSocket<S> {
266    /// Create a future with generator
267    pub(crate) fn new_with_closure(name_server: &SocketAddr, func: UdpCreator<S>) -> Self {
268        let bind_address = match *name_server {
269            SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
270            SocketAddr::V6(..) => {
271                SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
272            }
273        };
274        Self {
275            name_server: *name_server,
276            bind_address,
277            closure: func,
278            marker: PhantomData,
279        }
280    }
281}
282
283impl<S: DnsUdpSocket + Send> Future for NextRandomUdpSocket<S> {
284    type Output = Result<S, io::Error>;
285
286    /// polls until there is an available next random UDP port,
287    /// if no port has been specified in bind_addr.
288    ///
289    /// if there is no port available after 10 attempts, returns NotReady
290    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291        if self.bind_address.port() == 0 {
292            // Per RFC 6056 Section 2.1:
293            //
294            //    The dynamic port range defined by IANA consists of the 49152-65535
295            //    range, and is meant for the selection of ephemeral ports.
296            let rand_port_range = Uniform::new_inclusive(49152_u16, u16::MAX);
297            let mut rand = rand::thread_rng();
298
299            for attempt in 0..10 {
300                let port = rand_port_range.sample(&mut rand);
301                let bind_addr = SocketAddr::new(self.bind_address.ip(), port);
302
303                // TODO: allow TTL to be adjusted...
304                // TODO: this immediate poll might be wrong in some cases...
305                match (*self.closure)(bind_addr, self.name_server)
306                    .as_mut()
307                    .poll(cx)
308                {
309                    Poll::Ready(Ok(socket)) => {
310                        debug!("created socket successfully");
311                        return Poll::Ready(Ok(socket));
312                    }
313                    Poll::Ready(Err(err)) => match err.kind() {
314                        io::ErrorKind::AddrInUse => {
315                            debug!("unable to bind port, attempt: {}: {}", attempt, err);
316                        }
317                        _ => {
318                            debug!("failed to bind port: {}", err);
319                            return Poll::Ready(Err(err));
320                        }
321                    },
322                    Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
323                }
324            }
325
326            debug!("could not get next random port, delaying");
327
328            // TODO: because no interest is registered anywhere, we must awake.
329            cx.waker().wake_by_ref();
330
331            // returning NotReady here, perhaps the next poll there will be some more socket available.
332            Poll::Pending
333        } else {
334            // Use port that was specified in bind address.
335            (*self.closure)(self.bind_address, self.name_server)
336                .as_mut()
337                .poll(cx)
338        }
339    }
340}
341
342#[cfg(feature = "tokio-runtime")]
343#[async_trait]
344impl UdpSocket for tokio::net::UdpSocket {
345    /// setups up a "client" udp connection that will only receive packets from the associated address
346    ///
347    /// if the addr is ipv4 then it will bind local addr to 0.0.0.0:0, ipv6 \[::\]0
348    async fn connect(addr: SocketAddr) -> io::Result<Self> {
349        let bind_addr: SocketAddr = match addr {
350            SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
351            SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
352        };
353
354        Self::connect_with_bind(addr, bind_addr).await
355    }
356
357    /// same as connect, but binds to the specified local address for sending address
358    async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
359        let socket = Self::bind(bind_addr).await?;
360
361        // TODO: research connect more, it appears to break UDP receiving tests, etc...
362        // socket.connect(addr).await?;
363
364        Ok(socket)
365    }
366
367    async fn bind(addr: SocketAddr) -> io::Result<Self> {
368        Self::bind(addr).await
369    }
370}
371
372#[cfg(feature = "tokio-runtime")]
373#[async_trait]
374impl DnsUdpSocket for tokio::net::UdpSocket {
375    type Time = crate::TokioTime;
376
377    fn poll_recv_from(
378        &self,
379        cx: &mut Context<'_>,
380        buf: &mut [u8],
381    ) -> Poll<io::Result<(usize, SocketAddr)>> {
382        let mut buf = tokio::io::ReadBuf::new(buf);
383        let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
384        let len = buf.filled().len();
385
386        Poll::Ready(Ok((len, addr)))
387    }
388
389    fn poll_send_to(
390        &self,
391        cx: &mut Context<'_>,
392        buf: &[u8],
393        target: SocketAddr,
394    ) -> Poll<io::Result<usize>> {
395        Self::poll_send_to(self, cx, buf, target)
396    }
397}
398
399#[cfg(test)]
400#[cfg(feature = "tokio-runtime")]
401mod tests {
402    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
403    use std::net::Ipv6Addr;
404    use std::net::{IpAddr, Ipv4Addr};
405    use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
406
407    #[test]
408    fn test_next_random_socket() {
409        use crate::tests::next_random_socket_test;
410        let io_loop = Runtime::new().expect("failed to create tokio runtime");
411        next_random_socket_test::<TokioUdpSocket, Runtime>(io_loop)
412    }
413
414    #[test]
415    fn test_udp_stream_ipv4() {
416        use crate::tests::udp_stream_test;
417        let io_loop = Runtime::new().expect("failed to create tokio runtime");
418        io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V4(
419            Ipv4Addr::new(127, 0, 0, 1),
420        )));
421    }
422
423    #[test]
424    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
425    fn test_udp_stream_ipv6() {
426        use crate::tests::udp_stream_test;
427        let io_loop = Runtime::new().expect("failed to create tokio runtime");
428        io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V6(
429            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
430        )));
431    }
432}