hickory_proto/udp/
udp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::boxed::Box;
9use alloc::sync::Arc;
10use alloc::vec::Vec;
11use core::fmt::{self, Display};
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::collections::HashSet;
16use std::net::SocketAddr;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19use futures_util::{future::Future, stream::Stream};
20use tracing::{debug, trace, warn};
21
22use crate::error::{ProtoError, ProtoErrorKind};
23use crate::op::{Message, MessageFinalizer, MessageVerifier, Query};
24use crate::runtime::{RuntimeProvider, Time};
25use crate::udp::udp_stream::NextRandomUdpSocket;
26use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
27use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
28
29/// A builder to create a UDP client stream.
30///
31/// This is created by [`UdpClientStream::builder`].
32pub struct UdpClientStreamBuilder<P> {
33    name_server: SocketAddr,
34    timeout: Option<Duration>,
35    signer: Option<Arc<dyn MessageFinalizer>>,
36    bind_addr: Option<SocketAddr>,
37    avoid_local_ports: Arc<HashSet<u16>>,
38    os_port_selection: bool,
39    provider: P,
40}
41
42impl<P> UdpClientStreamBuilder<P> {
43    /// Sets the connection timeout.
44    pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
45        self.timeout = timeout;
46        self
47    }
48
49    /// Sets the message finalizer to be applied to queries.
50    pub fn with_signer(self, signer: Option<Arc<dyn MessageFinalizer>>) -> Self {
51        Self {
52            name_server: self.name_server,
53            timeout: self.timeout,
54            signer,
55            bind_addr: self.bind_addr,
56            avoid_local_ports: self.avoid_local_ports,
57            os_port_selection: self.os_port_selection,
58            provider: self.provider,
59        }
60    }
61
62    /// Sets the local socket address to connect from.
63    ///
64    /// If the port number is 0, a random port number will be chosen to defend against spoofing
65    /// attacks. If the port number is nonzero, it will be used instead.
66    pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
67        self.bind_addr = bind_addr;
68        self
69    }
70
71    /// Configures a list of local UDP ports that should not be used when making outgoing
72    /// connections.
73    pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
74        self.avoid_local_ports = avoid_local_ports;
75        self
76    }
77
78    /// Configures that OS should provide the ephemeral port, not the Hickory DNS
79    pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
80        self.os_port_selection = os_port_selection;
81        self
82    }
83
84    /// Construct a new UDP client stream.
85    ///
86    /// Returns a future that outputs the client stream.
87    pub fn build(self) -> UdpClientConnect<P> {
88        UdpClientConnect {
89            name_server: self.name_server,
90            timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
91            signer: self.signer,
92            bind_addr: self.bind_addr,
93            avoid_local_ports: self.avoid_local_ports.clone(),
94            os_port_selection: self.os_port_selection,
95            provider: self.provider,
96        }
97    }
98}
99
100/// A UDP client stream of DNS binary packets.
101///
102/// It is expected that the resolver wrapper will be responsible for creating and managing a new UDP
103/// client stream such that each request would have a random port. This is to avoid potential cache
104/// poisoning due to UDP spoofing attacks.
105#[must_use = "futures do nothing unless polled"]
106pub struct UdpClientStream<P> {
107    name_server: SocketAddr,
108    timeout: Duration,
109    is_shutdown: bool,
110    signer: Option<Arc<dyn MessageFinalizer>>,
111    bind_addr: Option<SocketAddr>,
112    avoid_local_ports: Arc<HashSet<u16>>,
113    os_port_selection: bool,
114    provider: P,
115}
116
117impl<P: RuntimeProvider> UdpClientStream<P> {
118    /// Construct a new [`UdpClientStream`] via a [`UdpClientStreamBuilder`].
119    pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
120        UdpClientStreamBuilder {
121            name_server,
122            timeout: None,
123            signer: None,
124            bind_addr: None,
125            avoid_local_ports: Arc::default(),
126            os_port_selection: false,
127            provider,
128        }
129    }
130}
131
132impl<P> Display for UdpClientStream<P> {
133    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
134        write!(formatter, "UDP({})", self.name_server)
135    }
136}
137
138/// creates random query_id, each socket is unique, no need for global uniqueness
139fn random_query_id() -> u16 {
140    rand::random()
141}
142
143impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
144    fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
145        if self.is_shutdown {
146            panic!("can not send messages after stream is shutdown")
147        }
148
149        let case_randomization = request.options().case_randomization;
150
151        // associated the ID for this request, b/c this connection is unique to socket port, the ID
152        //   does not need to be globally unique
153        request.set_id(random_query_id());
154
155        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
156            Ok(now) => now.as_secs(),
157            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
158        };
159
160        // TODO: truncates u64 to u32, error on overflow?
161        let now = now as u32;
162
163        let mut verifier = None;
164        if let Some(signer) = &self.signer {
165            if signer.should_finalize_message(&request) {
166                match request.finalize(&**signer, now) {
167                    Ok(answer_verifier) => verifier = answer_verifier,
168                    Err(e) => {
169                        debug!("could not sign message: {}", e);
170                        return e.into();
171                    }
172                }
173            }
174        }
175
176        // Get an appropriate read buffer size.
177        let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize);
178
179        let bytes = match request.to_vec() {
180            Ok(bytes) => bytes,
181            Err(err) => {
182                return err.into();
183            }
184        };
185
186        let message_id = request.id();
187        let message = SerialMessage::new(bytes, self.name_server);
188
189        debug!(
190            "final message: {}",
191            message
192                .to_message()
193                .expect("bizarre we just made this message")
194        );
195        let provider = self.provider.clone();
196        let addr = message.addr();
197        let bind_addr = self.bind_addr;
198        let avoid_local_ports = self.avoid_local_ports.clone();
199        let os_port_selection = self.os_port_selection;
200
201        P::Timer::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
202            self.timeout,
203            Box::pin(async move {
204                let socket = NextRandomUdpSocket::new(
205                    addr,
206                    bind_addr,
207                    avoid_local_ports,
208                    os_port_selection,
209                    provider,
210                )
211                .await?;
212                send_serial_message_inner(
213                    message,
214                    message_id,
215                    verifier,
216                    socket,
217                    recv_buf_size,
218                    case_randomization,
219                    request.original_query(),
220                )
221                .await
222            }),
223        )
224        .into()
225    }
226
227    fn shutdown(&mut self) {
228        self.is_shutdown = true;
229    }
230
231    fn is_shutdown(&self) -> bool {
232        self.is_shutdown
233    }
234}
235
236// TODO: is this impl necessary? there's nothing being driven here...
237impl<P> Stream for UdpClientStream<P> {
238    type Item = Result<(), ProtoError>;
239
240    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241        // Technically the Stream doesn't actually do anything.
242        if self.is_shutdown {
243            Poll::Ready(None)
244        } else {
245            Poll::Ready(Some(Ok(())))
246        }
247    }
248}
249
250/// A future that resolves to an UdpClientStream
251pub struct UdpClientConnect<P> {
252    name_server: SocketAddr,
253    timeout: Duration,
254    signer: Option<Arc<dyn MessageFinalizer>>,
255    bind_addr: Option<SocketAddr>,
256    avoid_local_ports: Arc<HashSet<u16>>,
257    os_port_selection: bool,
258    provider: P,
259}
260
261impl<P: RuntimeProvider> Future for UdpClientConnect<P> {
262    type Output = Result<UdpClientStream<P>, ProtoError>;
263
264    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
265        // TODO: this doesn't need to be a future?
266        Poll::Ready(Ok(UdpClientStream {
267            name_server: self.name_server,
268            is_shutdown: false,
269            timeout: self.timeout,
270            signer: self.signer.take(),
271            bind_addr: self.bind_addr,
272            avoid_local_ports: self.avoid_local_ports.clone(),
273            os_port_selection: self.os_port_selection,
274            provider: self.provider.clone(),
275        }))
276    }
277}
278
279async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
280    msg: SerialMessage,
281    msg_id: u16,
282    verifier: Option<MessageVerifier>,
283    socket: S,
284    recv_buf_size: usize,
285    case_randomization: bool,
286    original_query: Option<&Query>,
287) -> Result<DnsResponse, ProtoError> {
288    let bytes = msg.bytes();
289    let addr = msg.addr();
290    let len_sent: usize = socket.send_to(bytes, addr).await?;
291
292    if bytes.len() != len_sent {
293        return Err(ProtoError::from(format!(
294            "Not all bytes of message sent, {} of {}",
295            len_sent,
296            bytes.len()
297        )));
298    }
299
300    // Create the receive buffer.
301    trace!("creating UDP receive buffer with size {recv_buf_size}");
302    let mut recv_buf = vec![0; recv_buf_size];
303
304    // TODO: limit the max number of attempted messages? this relies on a timeout to die...
305    loop {
306        let (len, src) = socket.recv_from(&mut recv_buf).await?;
307
308        // Copy the slice of read bytes.
309        let response_bytes = &recv_buf[0..len];
310        let response_buffer = Vec::from(response_bytes);
311
312        // compare expected src to received packet
313        let request_target = msg.addr();
314
315        // Comparing the IP and Port directly as internal information about the link is stored with the IpAddr, see https://github.com/hickory-dns/hickory-dns/issues/2081
316        if src.ip() != request_target.ip() || src.port() != request_target.port() {
317            warn!(
318                "ignoring response from {} because it does not match name_server: {}.",
319                src, request_target,
320            );
321
322            // await an answer from the correct NameServer
323            continue;
324        }
325
326        let mut response = match DnsResponse::from_buffer(response_buffer) {
327            Ok(response) => response,
328            Err(e) => {
329                // on errors deserializing, continue
330                warn!("dropped malformed message waiting for id: {msg_id} err: {e}");
331                continue;
332            }
333        };
334
335        // Validate the message id in the response matches the value chosen for the query.
336        if msg_id != response.id() {
337            // on wrong id, attempted poison?
338            warn!(
339                "expected message id: {} got: {}, dropped",
340                msg_id,
341                response.id()
342            );
343
344            continue;
345        }
346
347        // Validate the returned query name.
348        //
349        // This currently checks that each response query name was present in the original query, but not that
350        // every original question is present.
351        //
352        // References:
353        //
354        // RFC 1035 7.3:
355        //
356        // The next step is to match the response to a current resolver request.
357        // The recommended strategy is to do a preliminary matching using the ID
358        // field in the domain header, and then to verify that the question section
359        // corresponds to the information currently desired.
360        //
361        // RFC 1035 7.4:
362        //
363        // In general, we expect a resolver to cache all data which it receives in
364        // responses since it may be useful in answering future client requests.
365        // However, there are several types of data which should not be cached:
366        //
367        // ...
368        //
369        //  - RR data in responses of dubious reliability.  When a resolver
370        // receives unsolicited responses or RR data other than that
371        // requested, it should discard it without caching it.
372        let request_message = Message::from_vec(msg.bytes())?;
373        let request_queries = request_message.queries();
374        let response_queries = response.queries_mut();
375
376        let question_matches = response_queries
377            .iter()
378            .all(|elem| request_queries.contains(elem));
379        if case_randomization
380            && question_matches
381            && !response_queries.iter().all(|elem| {
382                request_queries
383                    .iter()
384                    .any(|req_q| req_q == elem && req_q.name().eq_case(elem.name()))
385            })
386        {
387            warn!(
388                "case of question section did not match: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
389            );
390            return Err(ProtoErrorKind::QueryCaseMismatch.into());
391        }
392        if !question_matches {
393            warn!(
394                "detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
395            );
396            continue;
397        }
398
399        // overwrite the query with the original query if case randomization may have been used
400        if case_randomization {
401            if let Some(original_query) = original_query {
402                for response_query in response_queries.iter_mut() {
403                    if response_query == original_query {
404                        *response_query = original_query.clone();
405                    }
406                }
407            }
408        }
409
410        debug!("received message id: {}", response.id());
411        if let Some(mut verifier) = verifier {
412            return verifier(response_bytes);
413        } else {
414            return Ok(response);
415        }
416    }
417}
418
419#[cfg(test)]
420#[cfg(feature = "tokio")]
421mod tests {
422    #![allow(clippy::dbg_macro, clippy::print_stdout)]
423    use crate::{runtime::TokioRuntimeProvider, tests::udp_client_stream_test};
424    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
425    use test_support::subscribe;
426
427    #[tokio::test]
428    async fn test_udp_client_stream_ipv4() {
429        subscribe();
430        let provider = TokioRuntimeProvider::new();
431        udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
432    }
433
434    #[tokio::test]
435    async fn test_udp_client_stream_ipv6() {
436        subscribe();
437        let provider = TokioRuntimeProvider::new();
438        udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
439    }
440}