hickory_proto/udp/
udp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::borrow::Borrow;
9use std::fmt::{self, Display};
10use std::marker::PhantomData;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use futures_util::{future::Future, stream::Stream};
18use tracing::{debug, trace, warn};
19
20use crate::error::ProtoError;
21use crate::op::message::NoopMessageFinalizer;
22use crate::op::{Message, MessageFinalizer, MessageVerifier};
23use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket};
24use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
25use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
26use crate::Time;
27
28/// A UDP client stream of DNS binary packets
29///
30/// This stream will create a new UDP socket for every request. This is to avoid potential cache
31///   poisoning during use by UDP based attacks.
32#[must_use = "futures do nothing unless polled"]
33pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
34where
35    S: Send,
36    MF: MessageFinalizer,
37{
38    name_server: SocketAddr,
39    timeout: Duration,
40    is_shutdown: bool,
41    signer: Option<Arc<MF>>,
42    creator: UdpCreator<S>,
43    marker: PhantomData<S>,
44}
45
46impl<S: UdpSocket + Send + 'static> UdpClientStream<S, NoopMessageFinalizer> {
47    /// it is expected that the resolver wrapper will be responsible for creating and managing
48    ///  new UdpClients such that each new client would have a random port (reduce chance of cache
49    ///  poisoning)
50    ///
51    /// # Return
52    ///
53    /// a tuple of a Future Stream which will handle sending and receiving messages, and a
54    ///  handle which can be used to send messages into the stream.
55    #[allow(clippy::new_ret_no_self)]
56    pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
57        Self::with_timeout(name_server, Duration::from_secs(5))
58    }
59
60    /// Constructs a new UdpStream for a client to the specified SocketAddr.
61    ///
62    /// # Arguments
63    ///
64    /// * `name_server` - the IP and Port of the DNS server to connect to
65    /// * `timeout` - connection timeout
66    pub fn with_timeout(
67        name_server: SocketAddr,
68        timeout: Duration,
69    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
70        Self::with_bind_addr_and_timeout(name_server, None, timeout)
71    }
72
73    /// Constructs a new UdpStream for a client to the specified SocketAddr.
74    ///
75    /// # Arguments
76    ///
77    /// * `name_server` - the IP and Port of the DNS server to connect to
78    /// * `bind_addr` - the IP and port to connect from
79    /// * `timeout` - connection timeout
80    pub fn with_bind_addr_and_timeout(
81        name_server: SocketAddr,
82        bind_addr: Option<SocketAddr>,
83        timeout: Duration,
84    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
85        Self::with_timeout_and_signer_and_bind_addr(name_server, timeout, None, bind_addr)
86    }
87}
88
89impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> UdpClientStream<S, MF> {
90    /// Constructs a new UdpStream for a client to the specified SocketAddr.
91    ///
92    /// # Arguments
93    ///
94    /// * `name_server` - the IP and Port of the DNS server to connect to
95    /// * `timeout` - connection timeout
96    pub fn with_timeout_and_signer(
97        name_server: SocketAddr,
98        timeout: Duration,
99        signer: Option<Arc<MF>>,
100    ) -> UdpClientConnect<S, MF> {
101        UdpClientConnect {
102            name_server,
103            timeout,
104            signer,
105            creator: Arc::new(|local_addr: _, server_addr: _| {
106                Box::pin(NextRandomUdpSocket::<S>::new(
107                    &server_addr,
108                    &Some(local_addr),
109                ))
110            }),
111            marker: PhantomData::<S>,
112        }
113    }
114
115    /// Constructs a new UdpStream for a client to the specified SocketAddr.
116    ///
117    /// # Arguments
118    ///
119    /// * `name_server` - the IP and Port of the DNS server to connect to
120    /// * `timeout` - connection timeout
121    /// * `bind_addr` - the IP address and port to connect from
122    pub fn with_timeout_and_signer_and_bind_addr(
123        name_server: SocketAddr,
124        timeout: Duration,
125        signer: Option<Arc<MF>>,
126        bind_addr: Option<SocketAddr>,
127    ) -> UdpClientConnect<S, MF> {
128        UdpClientConnect {
129            name_server,
130            timeout,
131            signer,
132            creator: Arc::new(move |local_addr: _, server_addr: _| {
133                Box::pin(NextRandomUdpSocket::<S>::new(
134                    &server_addr,
135                    &Some(bind_addr.unwrap_or(local_addr)),
136                ))
137            }),
138            marker: PhantomData::<S>,
139        }
140    }
141}
142
143impl<S: DnsUdpSocket + Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
144    /// Constructs a new UdpStream for a client to the specified SocketAddr.
145    ///
146    /// # Arguments
147    ///
148    /// * `name_server` - the IP and Port of the DNS server to connect to
149    /// * `signer` - optional final amendment
150    /// * `timeout` - connection timeout
151    /// * `creator` - function that binds a local address to a newly created UDP socket
152    pub fn with_creator(
153        name_server: SocketAddr,
154        signer: Option<Arc<MF>>,
155        timeout: Duration,
156        creator: UdpCreator<S>,
157    ) -> UdpClientConnect<S, MF> {
158        UdpClientConnect {
159            name_server,
160            timeout,
161            signer,
162            creator,
163            marker: PhantomData::<S>,
164        }
165    }
166}
167
168impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
169    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
170        write!(formatter, "UDP({})", self.name_server)
171    }
172}
173
174/// creates random query_id, each socket is unique, no need for global uniqueness
175fn random_query_id() -> u16 {
176    use rand::distributions::{Distribution, Standard};
177    let mut rand = rand::thread_rng();
178
179    Standard.sample(&mut rand)
180}
181
182impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
183    for UdpClientStream<S, MF>
184{
185    fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
186        if self.is_shutdown {
187            panic!("can not send messages after stream is shutdown")
188        }
189
190        // associated the ID for this request, b/c this connection is unique to socket port, the ID
191        //   does not need to be globally unique
192        message.set_id(random_query_id());
193
194        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
195            Ok(now) => now.as_secs(),
196            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
197        };
198
199        // TODO: truncates u64 to u32, error on overflow?
200        let now = now as u32;
201
202        let mut verifier = None;
203        if let Some(ref signer) = self.signer {
204            if signer.should_finalize_message(&message) {
205                match message.finalize::<MF>(signer.borrow(), now) {
206                    Ok(answer_verifier) => verifier = answer_verifier,
207                    Err(e) => {
208                        debug!("could not sign message: {}", e);
209                        return e.into();
210                    }
211                }
212            }
213        }
214
215        // Get an appropriate read buffer size.
216        let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize);
217
218        let bytes = match message.to_vec() {
219            Ok(bytes) => bytes,
220            Err(err) => {
221                return err.into();
222            }
223        };
224
225        let message_id = message.id();
226        let message = SerialMessage::new(bytes, self.name_server);
227
228        debug!(
229            "final message: {}",
230            message
231                .to_message()
232                .expect("bizarre we just made this message")
233        );
234        let creator = self.creator.clone();
235        let addr = message.addr();
236
237        S::Time::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
238            self.timeout,
239            Box::pin(async move {
240                let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?;
241                send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
242                    .await
243            }),
244        )
245        .into()
246    }
247
248    fn shutdown(&mut self) {
249        self.is_shutdown = true;
250    }
251
252    fn is_shutdown(&self) -> bool {
253        self.is_shutdown
254    }
255}
256
257// TODO: is this impl necessary? there's nothing being driven here...
258impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
259    type Item = Result<(), ProtoError>;
260
261    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
262        // Technically the Stream doesn't actually do anything.
263        if self.is_shutdown {
264            Poll::Ready(None)
265        } else {
266            Poll::Ready(Some(Ok(())))
267        }
268    }
269}
270
271/// A future that resolves to an UdpClientStream
272pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
273where
274    S: Send,
275    MF: MessageFinalizer,
276{
277    name_server: SocketAddr,
278    timeout: Duration,
279    signer: Option<Arc<MF>>,
280    creator: UdpCreator<S>,
281    marker: PhantomData<S>,
282}
283
284impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
285    type Output = Result<UdpClientStream<S, MF>, ProtoError>;
286
287    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
288        // TODO: this doesn't need to be a future?
289        Poll::Ready(Ok(UdpClientStream::<S, MF> {
290            name_server: self.name_server,
291            is_shutdown: false,
292            timeout: self.timeout,
293            signer: self.signer.take(),
294            creator: self.creator.clone(),
295            marker: PhantomData,
296        }))
297    }
298}
299
300async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
301    msg: SerialMessage,
302    msg_id: u16,
303    verifier: Option<MessageVerifier>,
304    socket: S,
305    recv_buf_size: usize,
306) -> Result<DnsResponse, ProtoError> {
307    let bytes = msg.bytes();
308    let addr = msg.addr();
309    let len_sent: usize = socket.send_to(bytes, addr).await?;
310
311    if bytes.len() != len_sent {
312        return Err(ProtoError::from(format!(
313            "Not all bytes of message sent, {} of {}",
314            len_sent,
315            bytes.len()
316        )));
317    }
318
319    // Create the receive buffer.
320    trace!("creating UDP receive buffer with size {recv_buf_size}");
321    let mut recv_buf = vec![0; recv_buf_size];
322
323    // TODO: limit the max number of attempted messages? this relies on a timeout to die...
324    loop {
325        let (len, src) = socket.recv_from(&mut recv_buf).await?;
326
327        // Copy the slice of read bytes.
328        let buffer: Vec<_> = Vec::from(&recv_buf[0..len]);
329
330        // compare expected src to received packet
331        let request_target = msg.addr();
332
333        if src != request_target {
334            warn!(
335                "ignoring response from {} because it does not match name_server: {}.",
336                src, request_target,
337            );
338
339            // await an answer from the correct NameServer
340            continue;
341        }
342
343        // TODO: match query strings from request and response?
344
345        match Message::from_vec(&buffer) {
346            Ok(message) => {
347                if msg_id == message.id() {
348                    debug!("received message id: {}", message.id());
349                    if let Some(mut verifier) = verifier {
350                        return verifier(&buffer);
351                    } else {
352                        return Ok(DnsResponse::new(message, buffer));
353                    }
354                } else {
355                    // on wrong id, attempted poison?
356                    warn!(
357                        "expected message id: {} got: {}, dropped",
358                        msg_id,
359                        message.id()
360                    );
361
362                    continue;
363                }
364            }
365            Err(e) => {
366                // on errors deserializing, continue
367                warn!(
368                    "dropped malformed message waiting for id: {} err: {}",
369                    msg_id, e
370                );
371
372                continue;
373            }
374        }
375    }
376}
377
378#[cfg(test)]
379#[cfg(feature = "tokio-runtime")]
380mod tests {
381    #![allow(clippy::dbg_macro, clippy::print_stdout)]
382    use crate::tests::udp_client_stream_test;
383    #[cfg(not(target_os = "linux"))]
384    use std::net::Ipv6Addr;
385    use std::net::{IpAddr, Ipv4Addr};
386    use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
387
388    #[test]
389    fn test_udp_client_stream_ipv4() {
390        let io_loop = Runtime::new().expect("failed to create tokio runtime");
391        udp_client_stream_test::<TokioUdpSocket, Runtime>(
392            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
393            io_loop,
394        )
395    }
396
397    #[test]
398    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
399    fn test_udp_client_stream_ipv6() {
400        let io_loop = Runtime::new().expect("failed to create tokio runtime");
401        udp_client_stream_test::<TokioUdpSocket, Runtime>(
402            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
403            io_loop,
404        )
405    }
406}