trust_dns_resolver/name_server/
name_server_pool.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use rand::thread_rng as rng;
23use rand::Rng;
24
25use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
26use crate::error::{ResolveError, ResolveErrorKind};
27#[cfg(feature = "mdns")]
28use crate::name_server;
29use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
30use crate::name_server::name_server::NameServer;
31use crate::name_server::RuntimeProvider;
32#[cfg(test)]
33#[cfg(feature = "tokio-runtime")]
34use crate::name_server::TokioRuntimeProvider;
35
36/// Abstract interface for mocking purpose
37#[derive(Clone)]
38pub struct NameServerPool<P: ConnectionProvider + Send + 'static> {
39    // TODO: switch to FuturesMutex (Mutex will have some undesirable locking)
40    datagram_conns: Arc<[NameServer<P>]>, /* All NameServers must be the same type */
41    stream_conns: Arc<[NameServer<P>]>,   /* All NameServers must be the same type */
42    #[cfg(feature = "mdns")]
43    mdns_conns: NameServer<P>, /* All NameServers must be the same type */
44    options: ResolverOpts,
45}
46
47/// A pool of NameServers
48///
49/// This is not expected to be used directly, see [crate::AsyncResolver].
50pub type GenericNameServerPool<P> = NameServerPool<GenericConnector<P>>;
51
52#[cfg(test)]
53#[cfg(feature = "tokio-runtime")]
54impl GenericNameServerPool<TokioRuntimeProvider> {
55    pub(crate) fn tokio_from_config(
56        config: &ResolverConfig,
57        options: &ResolverOpts,
58        runtime: TokioRuntimeProvider,
59    ) -> Self {
60        Self::from_config_with_provider(config, options, GenericConnector::new(runtime))
61    }
62}
63
64impl<P> NameServerPool<P>
65where
66    P: ConnectionProvider + 'static,
67{
68    pub(crate) fn from_config_with_provider(
69        config: &ResolverConfig,
70        options: &ResolverOpts,
71        conn_provider: P,
72    ) -> Self {
73        let datagram_conns: Vec<NameServer<P>> = config
74            .name_servers()
75            .iter()
76            .filter(|ns_config| ns_config.protocol.is_datagram())
77            .map(|ns_config| {
78                #[cfg(feature = "dns-over-rustls")]
79                let ns_config = {
80                    let mut ns_config = ns_config.clone();
81                    ns_config.tls_config = config.client_config().clone();
82                    ns_config
83                };
84                #[cfg(not(feature = "dns-over-rustls"))]
85                let ns_config = { ns_config.clone() };
86
87                NameServer::new(ns_config, *options, conn_provider.clone())
88            })
89            .collect();
90
91        let stream_conns: Vec<NameServer<P>> = config
92            .name_servers()
93            .iter()
94            .filter(|ns_config| ns_config.protocol.is_stream())
95            .map(|ns_config| {
96                #[cfg(feature = "dns-over-rustls")]
97                let ns_config = {
98                    let mut ns_config = ns_config.clone();
99                    ns_config.tls_config = config.client_config().clone();
100                    ns_config
101                };
102                #[cfg(not(feature = "dns-over-rustls"))]
103                let ns_config = { ns_config.clone() };
104
105                NameServer::new(ns_config, *options, conn_provider.clone())
106            })
107            .collect();
108
109        Self {
110            datagram_conns: Arc::from(datagram_conns),
111            stream_conns: Arc::from(stream_conns),
112            #[cfg(feature = "mdns")]
113            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
114            options: *options,
115        }
116    }
117
118    /// Construct a NameServerPool from a set of name server configs
119    pub fn from_config(
120        name_servers: NameServerConfigGroup,
121        options: &ResolverOpts,
122        conn_provider: P,
123    ) -> Self {
124        let map_config_to_ns =
125            |ns_config| NameServer::new(ns_config, *options, conn_provider.clone());
126
127        let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
128            .into_inner()
129            .into_iter()
130            .partition(|ns| ns.protocol.is_datagram());
131
132        let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
133        let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
134
135        Self {
136            datagram_conns: Arc::from(datagram_conns),
137            stream_conns: Arc::from(stream_conns),
138            #[cfg(feature = "mdns")]
139            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
140            options: *options,
141        }
142    }
143
144    #[doc(hidden)]
145    #[cfg(not(feature = "mdns"))]
146    pub fn from_nameservers(
147        options: &ResolverOpts,
148        datagram_conns: Vec<NameServer<P>>,
149        stream_conns: Vec<NameServer<P>>,
150    ) -> Self {
151        Self {
152            datagram_conns: Arc::from(datagram_conns),
153            stream_conns: Arc::from(stream_conns),
154            options: *options,
155        }
156    }
157
158    #[doc(hidden)]
159    #[cfg(feature = "mdns")]
160    pub fn from_nameservers(
161        options: &ResolverOpts,
162        datagram_conns: Vec<NameServer<P>>,
163        stream_conns: Vec<NameServer<P>>,
164        mdns_conns: NameServer<P>,
165    ) -> Self {
166        GenericNameServerPool {
167            datagram_conns: Arc::from(datagram_conns),
168            stream_conns: Arc::from(stream_conns),
169            mdns_conns,
170            options: *options,
171        }
172    }
173
174    #[cfg(test)]
175    #[cfg(not(feature = "mdns"))]
176    #[allow(dead_code)]
177    fn from_nameservers_test(
178        options: &ResolverOpts,
179        datagram_conns: Arc<[NameServer<P>]>,
180        stream_conns: Arc<[NameServer<P>]>,
181    ) -> Self {
182        Self {
183            datagram_conns,
184            stream_conns,
185            options: *options,
186        }
187    }
188
189    #[cfg(test)]
190    #[cfg(feature = "mdns")]
191    fn from_nameservers_test(
192        options: &ResolverOpts,
193        datagram_conns: Arc<[NameServer<P>]>,
194        stream_conns: Arc<[NameServer<P>]>,
195        mdns_conns: NameServer<P>,
196    ) -> Self {
197        GenericNameServerPool {
198            datagram_conns,
199            stream_conns,
200            mdns_conns,
201            options: *options,
202        }
203    }
204
205    async fn try_send(
206        opts: ResolverOpts,
207        conns: Arc<[NameServer<P>]>,
208        request: DnsRequest,
209    ) -> Result<DnsResponse, ResolveError> {
210        let mut conns: Vec<NameServer<P>> = conns.to_vec();
211
212        match opts.server_ordering_strategy {
213            // select the highest priority connection
214            //   reorder the connections based on current view...
215            //   this reorders the inner set
216            ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
217            ServerOrderingStrategy::UserProvidedOrder => {}
218        }
219        let request_loop = request.clone();
220
221        parallel_conn_loop(conns, request_loop, opts).await
222    }
223}
224
225impl<P> DnsHandle for NameServerPool<P>
226where
227    P: ConnectionProvider + 'static,
228{
229    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
230    type Error = ResolveError;
231
232    fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
233        let opts = self.options;
234        let request = request.into();
235        let datagram_conns = Arc::clone(&self.datagram_conns);
236        let stream_conns = Arc::clone(&self.stream_conns);
237        // TODO: remove this clone, return the Message in the error?
238        let tcp_message = request.clone();
239
240        // if it's a .local. query, then we *only* query mDNS, these should never be sent on to upstream resolvers
241        #[cfg(feature = "mdns")]
242        let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
243
244        // TODO: limited to only when mDNS is enabled, but this should probably always be enforced?
245        #[cfg(not(feature = "mdns"))]
246        let mdns = Local::NotMdns(request);
247
248        // local queries are queried through mDNS
249        if mdns.is_local() {
250            return mdns.take_stream();
251        }
252
253        // TODO: should we allow mDNS to be used for standard lookups as well?
254
255        // it wasn't a local query, continue with standard lookup path
256        let request = mdns.take_request();
257        Box::pin(once(async move {
258            debug!("sending request: {:?}", request.queries());
259
260            // First try the UDP connections
261            let udp_res = match Self::try_send(opts, datagram_conns, request).await {
262                Ok(response) if response.truncated() => {
263                    debug!("truncated response received, retrying over TCP");
264                    Ok(response)
265                }
266                Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
267                    debug!("error from UDP, retrying over TCP: {}", e);
268                    Err(e)
269                }
270                result => return result,
271            };
272
273            if stream_conns.is_empty() {
274                debug!("no TCP connections available");
275                return udp_res;
276            }
277
278            // Try query over TCP, as response to query over UDP was either truncated or was an
279            // error.
280            let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
281
282            let tcp_err = match tcp_res {
283                res @ Ok(..) => return res,
284                Err(e) => e,
285            };
286
287            // Even if the UDP result was truncated, return that
288            let udp_err = match udp_res {
289                Ok(response) => return Ok(response),
290                Err(e) => e,
291            };
292
293            match udp_err.cmp_specificity(&tcp_err) {
294                Ordering::Greater => Err(udp_err),
295                _ => Err(tcp_err),
296            }
297        }))
298    }
299}
300
301// TODO: we should be able to have a self-referential future here with Pin and not require cloned conns
302/// An async function that will loop over all the conns with a max parallel request count of ops.num_concurrent_req
303async fn parallel_conn_loop<P>(
304    mut conns: Vec<NameServer<P>>,
305    request: DnsRequest,
306    opts: ResolverOpts,
307) -> Result<DnsResponse, ResolveError>
308where
309    P: ConnectionProvider + 'static,
310{
311    let mut err = ResolveError::no_connections();
312    // If the name server we're trying is giving us backpressure by returning ProtoErrorKind::Busy,
313    // we will first try the other name servers (as for other error types). However, if the other
314    // servers are also busy, we're going to wait for a little while and then retry each server that
315    // returned Busy in the previous round. If the server is still Busy, this continues, while
316    // the backoff increases exponentially (by a factor of 2), until it hits 300ms, in which case we
317    // give up. The request might still be retried by the caller (likely the DnsRetryHandle).
318    //
319    // TODO: more principled handling of timeouts. Currently, timeouts appear to be handled mostly
320    // close to the connection, which means the top level resolution might take substantially longer
321    // to fire than the timeout configured in `ResolverOpts`.
322    let mut backoff = Duration::from_millis(20);
323    let mut busy = SmallVec::<[NameServer<P>; 2]>::new();
324
325    loop {
326        let request_cont = request.clone();
327
328        // construct the parallel requests, 2 is the default
329        let mut par_conns = SmallVec::<[NameServer<P>; 2]>::new();
330        let count = conns.len().min(opts.num_concurrent_reqs.max(1));
331
332        // Shuffe DNS NameServers to avoid overloads to the first configured ones
333        if opts.shuffle_dns_servers {
334            for _ in 0..count {
335                let idx = rng().gen_range(0..conns.len());
336
337                // UNWRAP: swap_remove has an implicit panicking bounds check. This should
338                // never fail because we check that conns is not empty and generate the idx
339                // to explicitly be in range.
340                par_conns.push(conns.swap_remove(idx));
341            }
342        } else {
343            for conn in conns.drain(..count) {
344                par_conns.push(conn);
345            }
346        }
347
348        if par_conns.is_empty() {
349            if !busy.is_empty() && backoff < Duration::from_millis(300) {
350                <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
351                    backoff,
352                )
353                .await;
354                conns.extend(busy.drain(..));
355                backoff *= 2;
356                continue;
357            }
358            return Err(err);
359        }
360
361        let mut requests = par_conns
362            .into_iter()
363            .map(move |mut conn| {
364                conn.send(request_cont.clone())
365                    .first_answer()
366                    .map(|result| result.map_err(|e| (conn, e)))
367            })
368            .collect::<FuturesUnordered<_>>();
369
370        while let Some(result) = requests.next().await {
371            let (conn, e) = match result {
372                Ok(sent) => return Ok(sent),
373                Err((conn, e)) => (conn, e),
374            };
375
376            match e.kind() {
377                ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
378                    return Err(e);
379                }
380                ResolveErrorKind::Proto(e) if e.is_busy() => {
381                    busy.push(conn);
382                }
383                _ if err.cmp_specificity(&e) == Ordering::Less => {
384                    err = e;
385                }
386                _ => {}
387            }
388        }
389    }
390}
391
392#[cfg(feature = "mdns")]
393mod mdns {
394    use super::*;
395
396    use proto::rr::domain::usage;
397    use proto::DnsHandle;
398
399    /// Returns true
400    pub(crate) fn maybe_local<C, P>(
401        name_server: &mut NameServer<C, P>,
402        request: DnsRequest,
403    ) -> Local
404    where
405        C: DnsHandle<Error = ResolveError> + 'static,
406        P: ConnectionProvider<Conn = C> + 'static,
407        P: ConnectionProvider,
408    {
409        if request
410            .queries()
411            .iter()
412            .any(|query| usage::LOCAL.name().zone_of(query.name()))
413        {
414            Local::ResolveStream(name_server.send(request))
415        } else {
416            Local::NotMdns(request)
417        }
418    }
419}
420
421#[allow(clippy::large_enum_variant)]
422pub(crate) enum Local {
423    #[allow(dead_code)]
424    ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
425    NotMdns(DnsRequest),
426}
427
428impl Local {
429    fn is_local(&self) -> bool {
430        matches!(*self, Self::ResolveStream(..))
431    }
432
433    /// Takes the stream
434    ///
435    /// # Panics
436    ///
437    /// Panics if this is in fact a Local::NotMdns
438    fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
439        match self {
440            Self::ResolveStream(future) => future,
441            _ => panic!("non Local queries have no future, see take_message()"),
442        }
443    }
444
445    /// Takes the message
446    ///
447    /// # Panics
448    ///
449    /// Panics if this is in fact a Local::ResolveStream
450    fn take_request(self) -> DnsRequest {
451        match self {
452            Self::NotMdns(request) => request,
453            _ => panic!("Local queries must be polled, see take_future()"),
454        }
455    }
456}
457
458impl Stream for Local {
459    type Item = Result<DnsResponse, ResolveError>;
460
461    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462        match *self {
463            Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
464            // TODO: making this a panic for now
465            Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), //Local::NotMdns(message) => return Err(ResolveErrorKind::Message("not mDNS")),
466        }
467    }
468}
469
470#[cfg(test)]
471#[cfg(feature = "tokio-runtime")]
472mod tests {
473    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
474    use std::str::FromStr;
475
476    use tokio::runtime::Runtime;
477
478    use proto::op::Query;
479    use proto::rr::{Name, RecordType};
480    use proto::xfer::{DnsHandle, DnsRequestOptions};
481    use trust_dns_proto::rr::RData;
482
483    use super::*;
484    use crate::config::NameServerConfig;
485    use crate::config::Protocol;
486    use crate::name_server::TokioRuntimeProvider;
487    use crate::name_server::{GenericNameServer, TokioConnectionProvider};
488
489    #[ignore]
490    // because of there is a real connection that needs a reasonable timeout
491    #[test]
492    #[allow(clippy::uninlined_format_args)]
493    fn test_failed_then_success_pool() {
494        let config1 = NameServerConfig {
495            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
496            protocol: Protocol::Udp,
497            tls_dns_name: None,
498            trust_negative_responses: false,
499            #[cfg(feature = "dns-over-rustls")]
500            tls_config: None,
501            bind_addr: None,
502        };
503
504        let config2 = NameServerConfig {
505            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
506            protocol: Protocol::Udp,
507            tls_dns_name: None,
508            trust_negative_responses: false,
509            #[cfg(feature = "dns-over-rustls")]
510            tls_config: None,
511            bind_addr: None,
512        };
513
514        let mut resolver_config = ResolverConfig::new();
515        resolver_config.add_name_server(config1);
516        resolver_config.add_name_server(config2);
517
518        let io_loop = Runtime::new().unwrap();
519        let mut pool = GenericNameServerPool::tokio_from_config(
520            &resolver_config,
521            &ResolverOpts::default(),
522            TokioRuntimeProvider::new(),
523        );
524
525        let name = Name::parse("www.example.com.", None).unwrap();
526
527        // TODO: it's not clear why there are two failures before the success
528        for i in 0..2 {
529            assert!(
530                io_loop
531                    .block_on(
532                        pool.lookup(
533                            Query::query(name.clone(), RecordType::A),
534                            DnsRequestOptions::default()
535                        )
536                        .first_answer()
537                    )
538                    .is_err(),
539                "iter: {}",
540                i
541            );
542        }
543
544        for i in 0..10 {
545            assert!(
546                io_loop
547                    .block_on(
548                        pool.lookup(
549                            Query::query(name.clone(), RecordType::A),
550                            DnsRequestOptions::default()
551                        )
552                        .first_answer()
553                    )
554                    .is_ok(),
555                "iter: {}",
556                i
557            );
558        }
559    }
560
561    #[test]
562    fn test_multi_use_conns() {
563        let io_loop = Runtime::new().unwrap();
564        let conn_provider = TokioConnectionProvider::default();
565
566        let tcp = NameServerConfig {
567            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
568            protocol: Protocol::Tcp,
569            tls_dns_name: None,
570            trust_negative_responses: false,
571            #[cfg(feature = "dns-over-rustls")]
572            tls_config: None,
573            bind_addr: None,
574        };
575
576        let opts = ResolverOpts {
577            try_tcp_on_error: true,
578            ..ResolverOpts::default()
579        };
580        let ns_config = { tcp };
581        let name_server = GenericNameServer::new(ns_config, opts, conn_provider);
582        let name_servers: Arc<[_]> = Arc::from([name_server]);
583
584        #[cfg(not(feature = "mdns"))]
585        let mut pool = GenericNameServerPool::from_nameservers_test(
586            &opts,
587            Arc::from([]),
588            Arc::clone(&name_servers),
589        );
590        #[cfg(feature = "mdns")]
591        let mut pool = GenericNameServerPool::from_nameservers_test(
592            &opts,
593            Arc::from([]),
594            Arc::clone(&name_servers),
595            name_server::mdns_nameserver(opts, TokioConnectionProvider::default(), false),
596        );
597
598        let name = Name::from_str("www.example.com.").unwrap();
599
600        // first lookup
601        let response = io_loop
602            .block_on(
603                pool.lookup(
604                    Query::query(name.clone(), RecordType::A),
605                    DnsRequestOptions::default(),
606                )
607                .first_answer(),
608            )
609            .expect("lookup failed");
610
611        assert_eq!(
612            *response.answers()[0]
613                .data()
614                .and_then(RData::as_a)
615                .expect("no a record available"),
616            Ipv4Addr::new(93, 184, 216, 34).into()
617        );
618
619        assert!(
620            name_servers[0].is_connected(),
621            "if this is failing then the NameServers aren't being properly shared."
622        );
623
624        // first lookup
625        let response = io_loop
626            .block_on(
627                pool.lookup(
628                    Query::query(name, RecordType::AAAA),
629                    DnsRequestOptions::default(),
630                )
631                .first_answer(),
632            )
633            .expect("lookup failed");
634
635        assert_eq!(
636            *response.answers()[0]
637                .data()
638                .and_then(RData::as_aaaa)
639                .expect("no aaaa record available"),
640            Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946).into()
641        );
642
643        assert!(
644            name_servers[0].is_connected(),
645            "if this is failing then the NameServers aren't being properly shared."
646        );
647    }
648}