hickory_proto/xfer/
dns_multiplexer.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `DnsMultiplexer` and associated types implement the state machines for sending DNS messages while using the underlying streams.
9
10use alloc::{boxed::Box, sync::Arc};
11use core::{
12    borrow::Borrow,
13    fmt::{self, Display},
14    marker::Unpin,
15    pin::Pin,
16    task::{Context, Poll},
17    time::Duration,
18};
19use std::{
20    collections::{HashMap, hash_map::Entry},
21    time::{SystemTime, UNIX_EPOCH},
22};
23
24use futures_channel::mpsc;
25use futures_util::{
26    FutureExt,
27    future::Future,
28    ready,
29    stream::{Stream, StreamExt},
30};
31use rand::Rng;
32use tracing::debug;
33
34use crate::{
35    DnsStreamHandle,
36    error::{ProtoError, ProtoErrorKind},
37    op::{MessageFinalizer, MessageVerifier},
38    runtime::Time,
39    xfer::{
40        BufDnsStreamHandle, CHANNEL_BUFFER_SIZE, DnsClientStream, DnsRequest, DnsRequestSender,
41        DnsResponse, DnsResponseStream, SerialMessage, ignore_send,
42    },
43};
44
45const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket
46
47struct ActiveRequest {
48    // the completion is the channel for a response to the original request
49    completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
50    request_id: u16,
51    timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
52    verifier: Option<MessageVerifier>,
53}
54
55impl ActiveRequest {
56    fn new(
57        completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
58        request_id: u16,
59        timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
60        verifier: Option<MessageVerifier>,
61    ) -> Self {
62        Self {
63            completion,
64            request_id,
65            // request,
66            timeout,
67            verifier,
68        }
69    }
70
71    /// polls the timeout and converts the error
72    fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
73        self.timeout.poll_unpin(cx)
74    }
75
76    /// Returns true of the other side canceled the request
77    fn is_canceled(&self) -> bool {
78        self.completion.is_closed()
79    }
80
81    /// the request id of the message that was sent
82    fn request_id(&self) -> u16 {
83        self.request_id
84    }
85
86    /// Sends an error
87    fn complete_with_error(mut self, error: ProtoError) {
88        ignore_send(self.completion.try_send(Err(error)));
89    }
90}
91
92/// A DNS Client implemented over futures-rs.
93///
94/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
95///  implementations. This should be used for underlying protocols that do not natively support
96///  multiplexed sessions.
97#[must_use = "futures do nothing unless polled"]
98pub struct DnsMultiplexer<S>
99where
100    S: DnsClientStream + 'static,
101{
102    stream: S,
103    timeout_duration: Duration,
104    stream_handle: BufDnsStreamHandle,
105    active_requests: HashMap<u16, ActiveRequest>,
106    signer: Option<Arc<dyn MessageFinalizer>>,
107    is_shutdown: bool,
108}
109
110impl<S> DnsMultiplexer<S>
111where
112    S: DnsClientStream + Unpin + 'static,
113{
114    /// Spawns a new DnsMultiplexer Stream. This uses a default timeout of 5 seconds for all requests.
115    ///
116    /// # Arguments
117    ///
118    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
119    ///              (see TcpClientStream or UdpClientStream)
120    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
121    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
122    #[allow(clippy::new_ret_no_self)]
123    pub fn new<F>(
124        stream: F,
125        stream_handle: BufDnsStreamHandle,
126        signer: Option<Arc<dyn MessageFinalizer>>,
127    ) -> DnsMultiplexerConnect<F, S>
128    where
129        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
130    {
131        Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
132    }
133
134    /// Spawns a new DnsMultiplexer Stream.
135    ///
136    /// # Arguments
137    ///
138    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
139    ///              (see TcpClientStream or UdpClientStream)
140    /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to
141    ///                        wait for a response before canceling the request.
142    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
143    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
144    pub fn with_timeout<F>(
145        stream: F,
146        stream_handle: BufDnsStreamHandle,
147        timeout_duration: Duration,
148        signer: Option<Arc<dyn MessageFinalizer>>,
149    ) -> DnsMultiplexerConnect<F, S>
150    where
151        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
152    {
153        DnsMultiplexerConnect {
154            stream,
155            stream_handle: Some(stream_handle),
156            timeout_duration,
157            signer,
158        }
159    }
160
161    /// loop over active_requests and remove cancelled requests
162    ///  this should free up space if we already had 4096 active requests
163    fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
164        let mut canceled = HashMap::<u16, ProtoError>::new();
165        for (&id, active_req) in &mut self.active_requests {
166            if active_req.is_canceled() {
167                canceled.insert(id, ProtoError::from("requestor canceled"));
168            }
169
170            // check for timeouts...
171            match active_req.poll_timeout(cx) {
172                Poll::Ready(()) => {
173                    debug!("request timed out: {}", id);
174                    canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
175                }
176                Poll::Pending => (),
177            }
178        }
179
180        // drop all the canceled requests
181        for (id, error) in canceled {
182            if let Some(active_request) = self.active_requests.remove(&id) {
183                // complete the request, it's failed...
184                active_request.complete_with_error(error);
185            }
186        }
187    }
188
189    /// creates random query_id, validates against all active queries
190    fn next_random_query_id(&self) -> Result<u16, ProtoError> {
191        let mut rand = rand::rng();
192
193        for _ in 0..100 {
194            let id: u16 = rand.random(); // the range is [0 ... u16::max]
195
196            if !self.active_requests.contains_key(&id) {
197                return Ok(id);
198            }
199        }
200
201        Err(ProtoError::from(
202            "id space exhausted, consider filing an issue",
203        ))
204    }
205
206    /// Closes all outstanding completes with a closed stream error
207    fn stream_closed_close_all(&mut self, error: ProtoError) {
208        debug!(error = error.as_dyn(), stream = %self.stream);
209
210        for (_, active_request) in self.active_requests.drain() {
211            // complete the request, it's failed...
212            active_request.complete_with_error(error.clone());
213        }
214    }
215}
216
217/// A wrapper for a future DnsExchange connection
218#[must_use = "futures do nothing unless polled"]
219pub struct DnsMultiplexerConnect<F, S>
220where
221    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
222    S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
223{
224    stream: F,
225    stream_handle: Option<BufDnsStreamHandle>,
226    timeout_duration: Duration,
227    signer: Option<Arc<dyn MessageFinalizer>>,
228}
229
230impl<F, S> Future for DnsMultiplexerConnect<F, S>
231where
232    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
233    S: DnsClientStream + Unpin + 'static,
234{
235    type Output = Result<DnsMultiplexer<S>, ProtoError>;
236
237    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238        let stream: S = ready!(self.stream.poll_unpin(cx))?;
239
240        Poll::Ready(Ok(DnsMultiplexer {
241            stream,
242            timeout_duration: self.timeout_duration,
243            stream_handle: self
244                .stream_handle
245                .take()
246                .expect("must not poll after complete"),
247            active_requests: HashMap::new(),
248            signer: self.signer.clone(),
249            is_shutdown: false,
250        }))
251    }
252}
253
254impl<S> Display for DnsMultiplexer<S>
255where
256    S: DnsClientStream + 'static,
257{
258    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
259        write!(formatter, "{}", self.stream)
260    }
261}
262
263impl<S> DnsRequestSender for DnsMultiplexer<S>
264where
265    S: DnsClientStream + Unpin + 'static,
266{
267    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
268        if self.is_shutdown {
269            panic!("can not send messages after stream is shutdown")
270        }
271
272        if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
273            return ProtoError::from(ProtoErrorKind::Busy).into();
274        }
275
276        let query_id = match self.next_random_query_id() {
277            Ok(id) => id,
278            Err(e) => return e.into(),
279        };
280
281        let (mut request, _) = request.into_parts();
282        request.set_id(query_id);
283
284        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
285            Ok(now) => now.as_secs(),
286            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
287        };
288
289        // TODO: truncates u64 to u32, error on overflow?
290        let now = now as u32;
291
292        let mut verifier = None;
293        if let Some(signer) = &self.signer {
294            if signer.should_finalize_message(&request) {
295                match request.finalize(signer.borrow(), now) {
296                    Ok(answer_verifier) => verifier = answer_verifier,
297                    Err(e) => {
298                        debug!("could not sign message: {}", e);
299                        return e.into();
300                    }
301                }
302            }
303        }
304
305        // store a Timeout for this message before sending
306        let timeout = S::Time::delay_for(self.timeout_duration);
307
308        let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
309
310        // send the message
311        let active_request =
312            ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
313
314        match request.to_vec() {
315            Ok(buffer) => {
316                debug!(id = %active_request.request_id(), "sending message");
317                let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
318
319                debug!(
320                    "final message: {}",
321                    serial_message
322                        .to_message()
323                        .expect("bizarre we just made this message")
324                );
325
326                // add to the map -after- the client send b/c we don't want to put it in the map if
327                //  we ended up returning an error from the send.
328                match self.stream_handle.send(serial_message) {
329                    Ok(()) => self
330                        .active_requests
331                        .insert(active_request.request_id(), active_request),
332                    Err(err) => return err.into(),
333                };
334            }
335            Err(e) => {
336                debug!(
337                    id = %active_request.request_id(),
338                    error = e.as_dyn(),
339                    "error message"
340                );
341                // complete with the error, don't add to the map of active requests
342                return e.into();
343            }
344        }
345
346        receiver.into()
347    }
348
349    fn shutdown(&mut self) {
350        self.is_shutdown = true;
351    }
352
353    fn is_shutdown(&self) -> bool {
354        self.is_shutdown
355    }
356}
357
358impl<S> Stream for DnsMultiplexer<S>
359where
360    S: DnsClientStream + Unpin + 'static,
361{
362    type Item = Result<(), ProtoError>;
363
364    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
365        // Always drop the cancelled queries first
366        self.drop_cancelled(cx);
367
368        if self.is_shutdown && self.active_requests.is_empty() {
369            debug!("stream is done: {}", self);
370            return Poll::Ready(None);
371        }
372
373        // Collect all inbound requests, max 100 at a time for QoS
374        //   by having a max we will guarantee that the client can't be DOSed in this loop
375        // TODO: make the QoS configurable
376        let mut messages_received = 0;
377        for i in 0..QOS_MAX_RECEIVE_MSGS {
378            match self.stream.poll_next_unpin(cx) {
379                Poll::Ready(Some(Ok(buffer))) => {
380                    messages_received = i;
381
382                    //   deserialize or log decode_error
383                    match DnsResponse::from_buffer(buffer.into_parts().0) {
384                        Ok(response) => match self.active_requests.entry(response.id()) {
385                            Entry::Occupied(mut request_entry) => {
386                                // send the response, complete the request...
387                                let active_request = request_entry.get_mut();
388                                if let Some(verifier) = &mut active_request.verifier {
389                                    ignore_send(
390                                        active_request
391                                            .completion
392                                            .try_send(verifier(response.as_buffer())),
393                                    );
394                                } else {
395                                    ignore_send(active_request.completion.try_send(Ok(response)));
396                                }
397                            }
398                            Entry::Vacant(..) => debug!("unexpected request_id: {}", response.id()),
399                        },
400                        // TODO: return src address for diagnostics
401                        Err(error) => debug!(error = error.as_dyn(), "error decoding message"),
402                    }
403                }
404                Poll::Ready(err) => {
405                    let err = match err {
406                        Some(Err(e)) => e,
407                        None => ProtoError::from("stream closed"),
408                        _ => unreachable!(),
409                    };
410
411                    self.stream_closed_close_all(err);
412                    self.is_shutdown = true;
413                    return Poll::Ready(None);
414                }
415                Poll::Pending => break,
416            }
417        }
418
419        // If still active, then if the qos (for _ in 0..100 loop) limit
420        // was hit then "yield". This'll make sure that the future is
421        // woken up immediately on the next turn of the event loop.
422        if messages_received == QOS_MAX_RECEIVE_MSGS {
423            // FIXME: this was a task::current().notify(); is this right?
424            cx.waker().wake_by_ref();
425        }
426
427        // Finally, return not ready to keep the 'driver task' alive.
428        Poll::Pending
429    }
430}
431
432#[cfg(test)]
433mod test {
434    use alloc::vec::Vec;
435    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
436
437    use futures_util::future;
438    use futures_util::stream::TryStreamExt;
439    use test_support::subscribe;
440
441    use super::*;
442    use crate::op::op_code::OpCode;
443    use crate::op::{Message, MessageType, Query};
444    use crate::rr::record_type::RecordType;
445    use crate::rr::{DNSClass, Name, RData, Record};
446    use crate::serialize::binary::BinEncodable;
447    use crate::xfer::StreamReceiver;
448    use crate::xfer::{DnsClientStream, DnsRequestOptions};
449
450    struct MockClientStream {
451        messages: Vec<Message>,
452        addr: SocketAddr,
453        id: Option<u16>,
454        receiver: Option<StreamReceiver>,
455    }
456
457    impl MockClientStream {
458        fn new(
459            mut messages: Vec<Message>,
460            addr: SocketAddr,
461        ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
462            messages.reverse(); // so we can pop() and get messages in order
463            Box::pin(future::ok(Self {
464                messages,
465                addr,
466                id: None,
467                receiver: None,
468            }))
469        }
470    }
471
472    impl fmt::Display for MockClientStream {
473        fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
474            write!(formatter, "TestClientStream")
475        }
476    }
477
478    impl Stream for MockClientStream {
479        type Item = Result<SerialMessage, ProtoError>;
480
481        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
482            let id = if let Some(id) = self.id {
483                id
484            } else {
485                let serial = ready!(
486                    self.receiver
487                        .as_mut()
488                        .expect("should only be polled after receiver has been set")
489                        .poll_next_unpin(cx)
490                );
491                let message = serial.unwrap().to_message().unwrap();
492                self.id = Some(message.id());
493                message.id()
494            };
495
496            if let Some(mut message) = self.messages.pop() {
497                message.set_id(id);
498                Poll::Ready(Some(Ok(SerialMessage::new(
499                    message.to_bytes().unwrap(),
500                    self.addr,
501                ))))
502            } else {
503                Poll::Pending
504            }
505        }
506    }
507
508    impl DnsClientStream for MockClientStream {
509        type Time = crate::runtime::TokioTime;
510
511        fn name_server_addr(&self) -> SocketAddr {
512            self.addr
513        }
514    }
515
516    async fn get_mocked_multiplexer(
517        mock_response: Vec<Message>,
518    ) -> DnsMultiplexer<MockClientStream> {
519        let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
520        let mock_response = MockClientStream::new(mock_response, addr);
521        let (handler, receiver) = BufDnsStreamHandle::new(addr);
522        let mut multiplexer =
523            DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
524                .await
525                .unwrap();
526
527        multiplexer.stream.receiver = Some(receiver); // so it can get the correct request id
528
529        multiplexer
530    }
531
532    fn a_query_answer() -> (DnsRequest, Vec<Message>) {
533        let name = Name::from_ascii("www.example.com.").unwrap();
534
535        let mut msg = Message::new();
536        msg.add_query({
537            let mut query = Query::query(name.clone(), RecordType::A);
538            query.set_query_class(DNSClass::IN);
539            query
540        })
541        .set_message_type(MessageType::Query)
542        .set_op_code(OpCode::Query)
543        .set_recursion_desired(true);
544
545        let query = msg.clone();
546        msg.set_message_type(MessageType::Response).add_answer(
547            Record::from_rdata(
548                name,
549                86400,
550                RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
551            )
552            .set_dns_class(DNSClass::IN)
553            .clone(),
554        );
555        (
556            DnsRequest::new(query, DnsRequestOptions::default()),
557            vec![msg],
558        )
559    }
560
561    fn axfr_query() -> Message {
562        let name = Name::from_ascii("example.com.").unwrap();
563
564        let mut msg = Message::new();
565        msg.add_query({
566            let mut query = Query::query(name, RecordType::AXFR);
567            query.set_query_class(DNSClass::IN);
568            query
569        })
570        .set_message_type(MessageType::Query)
571        .set_op_code(OpCode::Query)
572        .set_recursion_desired(true);
573        msg
574    }
575
576    fn axfr_response() -> Vec<Record> {
577        use crate::rr::rdata::*;
578        let origin = Name::from_ascii("example.com.").unwrap();
579        let soa = Record::from_rdata(
580            origin.clone(),
581            3600,
582            RData::SOA(SOA::new(
583                Name::parse("sns.dns.icann.org.", None).unwrap(),
584                Name::parse("noc.dns.icann.org.", None).unwrap(),
585                2015082403,
586                7200,
587                3600,
588                1209600,
589                3600,
590            )),
591        )
592        .set_dns_class(DNSClass::IN)
593        .clone();
594
595        vec![
596            soa.clone(),
597            Record::from_rdata(
598                origin.clone(),
599                86400,
600                RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
601            )
602            .set_dns_class(DNSClass::IN)
603            .clone(),
604            Record::from_rdata(
605                origin.clone(),
606                86400,
607                RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
608            )
609            .set_dns_class(DNSClass::IN)
610            .clone(),
611            Record::from_rdata(
612                origin.clone(),
613                86400,
614                RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
615            )
616            .set_dns_class(DNSClass::IN)
617            .clone(),
618            Record::from_rdata(
619                origin,
620                86400,
621                RData::AAAA(
622                    Ipv6Addr::new(
623                        0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
624                    )
625                    .into(),
626                ),
627            )
628            .set_dns_class(DNSClass::IN)
629            .clone(),
630            soa,
631        ]
632    }
633
634    fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
635        let mut msg = axfr_query();
636
637        let query = msg.clone();
638        msg.set_message_type(MessageType::Response)
639            .insert_answers(axfr_response());
640        (
641            DnsRequest::new(query, DnsRequestOptions::default()),
642            vec![msg],
643        )
644    }
645
646    fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
647        let base = axfr_query();
648
649        let query = base.clone();
650        let mut rr = axfr_response();
651        let rr2 = rr.split_off(3);
652        let mut msg1 = base.clone();
653        msg1.set_message_type(MessageType::Response)
654            .insert_answers(rr);
655        let mut msg2 = base;
656        msg2.set_message_type(MessageType::Response)
657            .insert_answers(rr2);
658        (
659            DnsRequest::new(query, DnsRequestOptions::default()),
660            vec![msg1, msg2],
661        )
662    }
663
664    #[tokio::test]
665    async fn test_multiplexer_a() {
666        subscribe();
667        let (query, answer) = a_query_answer();
668        let mut multiplexer = get_mocked_multiplexer(answer).await;
669        let response = multiplexer.send_message(query);
670        let response = tokio::select! {
671            _ = multiplexer.next() => {
672                // polling multiplexer to make it run
673                panic!("should never end")
674            },
675            r = response.try_collect::<Vec<_>>() => r.unwrap(),
676        };
677        assert_eq!(response.len(), 1);
678    }
679
680    #[tokio::test]
681    async fn test_multiplexer_axfr() {
682        subscribe();
683        let (query, answer) = axfr_query_answer();
684        let mut multiplexer = get_mocked_multiplexer(answer).await;
685        let response = multiplexer.send_message(query);
686        let response = tokio::select! {
687            _ = multiplexer.next() => {
688                // polling multiplexer to make it run
689                panic!("should never end")
690            },
691            r = response.try_collect::<Vec<_>>() => r.unwrap(),
692        };
693        assert_eq!(response.len(), 1);
694        assert_eq!(response[0].answers().len(), axfr_response().len());
695    }
696
697    #[tokio::test]
698    async fn test_multiplexer_axfr_multi() {
699        subscribe();
700        let (query, answer) = axfr_query_answer_multi();
701        let mut multiplexer = get_mocked_multiplexer(answer).await;
702        let response = multiplexer.send_message(query);
703        let response = tokio::select! {
704            _ = multiplexer.next() => {
705                // polling multiplexer to make it run
706                panic!("should never end")
707            },
708            r = response.try_collect::<Vec<_>>() => r.unwrap(),
709        };
710        assert_eq!(response.len(), 2);
711        assert_eq!(
712            response.iter().map(|m| m.answers().len()).sum::<usize>(),
713            axfr_response().len()
714        );
715    }
716}