1use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use rand::thread_rng as rng;
23use rand::Rng;
24
25use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
26use crate::error::{ResolveError, ResolveErrorKind};
27#[cfg(feature = "mdns")]
28use crate::name_server;
29use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
30use crate::name_server::name_server::NameServer;
31use crate::name_server::RuntimeProvider;
32#[cfg(test)]
33#[cfg(feature = "tokio-runtime")]
34use crate::name_server::TokioRuntimeProvider;
35
36#[derive(Clone)]
38pub struct NameServerPool<P: ConnectionProvider + Send + 'static> {
39 datagram_conns: Arc<[NameServer<P>]>, stream_conns: Arc<[NameServer<P>]>, #[cfg(feature = "mdns")]
43 mdns_conns: NameServer<P>, options: ResolverOpts,
45}
46
47pub type GenericNameServerPool<P> = NameServerPool<GenericConnector<P>>;
51
52#[cfg(test)]
53#[cfg(feature = "tokio-runtime")]
54impl GenericNameServerPool<TokioRuntimeProvider> {
55 pub(crate) fn tokio_from_config(
56 config: &ResolverConfig,
57 options: &ResolverOpts,
58 runtime: TokioRuntimeProvider,
59 ) -> Self {
60 Self::from_config_with_provider(config, options, GenericConnector::new(runtime))
61 }
62}
63
64impl<P> NameServerPool<P>
65where
66 P: ConnectionProvider + 'static,
67{
68 pub(crate) fn from_config_with_provider(
69 config: &ResolverConfig,
70 options: &ResolverOpts,
71 conn_provider: P,
72 ) -> Self {
73 let datagram_conns: Vec<NameServer<P>> = config
74 .name_servers()
75 .iter()
76 .filter(|ns_config| ns_config.protocol.is_datagram())
77 .map(|ns_config| {
78 #[cfg(feature = "dns-over-rustls")]
79 let ns_config = {
80 let mut ns_config = ns_config.clone();
81 ns_config.tls_config = config.client_config().clone();
82 ns_config
83 };
84 #[cfg(not(feature = "dns-over-rustls"))]
85 let ns_config = { ns_config.clone() };
86
87 NameServer::new(ns_config, *options, conn_provider.clone())
88 })
89 .collect();
90
91 let stream_conns: Vec<NameServer<P>> = config
92 .name_servers()
93 .iter()
94 .filter(|ns_config| ns_config.protocol.is_stream())
95 .map(|ns_config| {
96 #[cfg(feature = "dns-over-rustls")]
97 let ns_config = {
98 let mut ns_config = ns_config.clone();
99 ns_config.tls_config = config.client_config().clone();
100 ns_config
101 };
102 #[cfg(not(feature = "dns-over-rustls"))]
103 let ns_config = { ns_config.clone() };
104
105 NameServer::new(ns_config, *options, conn_provider.clone())
106 })
107 .collect();
108
109 Self {
110 datagram_conns: Arc::from(datagram_conns),
111 stream_conns: Arc::from(stream_conns),
112 #[cfg(feature = "mdns")]
113 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
114 options: *options,
115 }
116 }
117
118 pub fn from_config(
120 name_servers: NameServerConfigGroup,
121 options: &ResolverOpts,
122 conn_provider: P,
123 ) -> Self {
124 let map_config_to_ns =
125 |ns_config| NameServer::new(ns_config, *options, conn_provider.clone());
126
127 let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
128 .into_inner()
129 .into_iter()
130 .partition(|ns| ns.protocol.is_datagram());
131
132 let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
133 let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
134
135 Self {
136 datagram_conns: Arc::from(datagram_conns),
137 stream_conns: Arc::from(stream_conns),
138 #[cfg(feature = "mdns")]
139 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
140 options: *options,
141 }
142 }
143
144 #[doc(hidden)]
145 #[cfg(not(feature = "mdns"))]
146 pub fn from_nameservers(
147 options: &ResolverOpts,
148 datagram_conns: Vec<NameServer<P>>,
149 stream_conns: Vec<NameServer<P>>,
150 ) -> Self {
151 Self {
152 datagram_conns: Arc::from(datagram_conns),
153 stream_conns: Arc::from(stream_conns),
154 options: *options,
155 }
156 }
157
158 #[doc(hidden)]
159 #[cfg(feature = "mdns")]
160 pub fn from_nameservers(
161 options: &ResolverOpts,
162 datagram_conns: Vec<NameServer<P>>,
163 stream_conns: Vec<NameServer<P>>,
164 mdns_conns: NameServer<P>,
165 ) -> Self {
166 GenericNameServerPool {
167 datagram_conns: Arc::from(datagram_conns),
168 stream_conns: Arc::from(stream_conns),
169 mdns_conns,
170 options: *options,
171 }
172 }
173
174 #[cfg(test)]
175 #[cfg(not(feature = "mdns"))]
176 #[allow(dead_code)]
177 fn from_nameservers_test(
178 options: &ResolverOpts,
179 datagram_conns: Arc<[NameServer<P>]>,
180 stream_conns: Arc<[NameServer<P>]>,
181 ) -> Self {
182 Self {
183 datagram_conns,
184 stream_conns,
185 options: *options,
186 }
187 }
188
189 #[cfg(test)]
190 #[cfg(feature = "mdns")]
191 fn from_nameservers_test(
192 options: &ResolverOpts,
193 datagram_conns: Arc<[NameServer<P>]>,
194 stream_conns: Arc<[NameServer<P>]>,
195 mdns_conns: NameServer<P>,
196 ) -> Self {
197 GenericNameServerPool {
198 datagram_conns,
199 stream_conns,
200 mdns_conns,
201 options: *options,
202 }
203 }
204
205 async fn try_send(
206 opts: ResolverOpts,
207 conns: Arc<[NameServer<P>]>,
208 request: DnsRequest,
209 ) -> Result<DnsResponse, ResolveError> {
210 let mut conns: Vec<NameServer<P>> = conns.to_vec();
211
212 match opts.server_ordering_strategy {
213 ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
217 ServerOrderingStrategy::UserProvidedOrder => {}
218 }
219 let request_loop = request.clone();
220
221 parallel_conn_loop(conns, request_loop, opts).await
222 }
223}
224
225impl<P> DnsHandle for NameServerPool<P>
226where
227 P: ConnectionProvider + 'static,
228{
229 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
230 type Error = ResolveError;
231
232 fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
233 let opts = self.options;
234 let request = request.into();
235 let datagram_conns = Arc::clone(&self.datagram_conns);
236 let stream_conns = Arc::clone(&self.stream_conns);
237 let tcp_message = request.clone();
239
240 #[cfg(feature = "mdns")]
242 let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
243
244 #[cfg(not(feature = "mdns"))]
246 let mdns = Local::NotMdns(request);
247
248 if mdns.is_local() {
250 return mdns.take_stream();
251 }
252
253 let request = mdns.take_request();
257 Box::pin(once(async move {
258 debug!("sending request: {:?}", request.queries());
259
260 let udp_res = match Self::try_send(opts, datagram_conns, request).await {
262 Ok(response) if response.truncated() => {
263 debug!("truncated response received, retrying over TCP");
264 Ok(response)
265 }
266 Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
267 debug!("error from UDP, retrying over TCP: {}", e);
268 Err(e)
269 }
270 result => return result,
271 };
272
273 if stream_conns.is_empty() {
274 debug!("no TCP connections available");
275 return udp_res;
276 }
277
278 let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
281
282 let tcp_err = match tcp_res {
283 res @ Ok(..) => return res,
284 Err(e) => e,
285 };
286
287 let udp_err = match udp_res {
289 Ok(response) => return Ok(response),
290 Err(e) => e,
291 };
292
293 match udp_err.cmp_specificity(&tcp_err) {
294 Ordering::Greater => Err(udp_err),
295 _ => Err(tcp_err),
296 }
297 }))
298 }
299}
300
301async fn parallel_conn_loop<P>(
304 mut conns: Vec<NameServer<P>>,
305 request: DnsRequest,
306 opts: ResolverOpts,
307) -> Result<DnsResponse, ResolveError>
308where
309 P: ConnectionProvider + 'static,
310{
311 let mut err = ResolveError::no_connections();
312 let mut backoff = Duration::from_millis(20);
323 let mut busy = SmallVec::<[NameServer<P>; 2]>::new();
324
325 loop {
326 let request_cont = request.clone();
327
328 let mut par_conns = SmallVec::<[NameServer<P>; 2]>::new();
330 let count = conns.len().min(opts.num_concurrent_reqs.max(1));
331
332 if opts.shuffle_dns_servers {
334 for _ in 0..count {
335 let idx = rng().gen_range(0..conns.len());
336
337 par_conns.push(conns.swap_remove(idx));
341 }
342 } else {
343 for conn in conns.drain(..count) {
344 par_conns.push(conn);
345 }
346 }
347
348 if par_conns.is_empty() {
349 if !busy.is_empty() && backoff < Duration::from_millis(300) {
350 <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
351 backoff,
352 )
353 .await;
354 conns.extend(busy.drain(..));
355 backoff *= 2;
356 continue;
357 }
358 return Err(err);
359 }
360
361 let mut requests = par_conns
362 .into_iter()
363 .map(move |mut conn| {
364 conn.send(request_cont.clone())
365 .first_answer()
366 .map(|result| result.map_err(|e| (conn, e)))
367 })
368 .collect::<FuturesUnordered<_>>();
369
370 while let Some(result) = requests.next().await {
371 let (conn, e) = match result {
372 Ok(sent) => return Ok(sent),
373 Err((conn, e)) => (conn, e),
374 };
375
376 match e.kind() {
377 ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
378 return Err(e);
379 }
380 ResolveErrorKind::Proto(e) if e.is_busy() => {
381 busy.push(conn);
382 }
383 _ if err.cmp_specificity(&e) == Ordering::Less => {
384 err = e;
385 }
386 _ => {}
387 }
388 }
389 }
390}
391
392#[cfg(feature = "mdns")]
393mod mdns {
394 use super::*;
395
396 use proto::rr::domain::usage;
397 use proto::DnsHandle;
398
399 pub(crate) fn maybe_local<C, P>(
401 name_server: &mut NameServer<C, P>,
402 request: DnsRequest,
403 ) -> Local
404 where
405 C: DnsHandle<Error = ResolveError> + 'static,
406 P: ConnectionProvider<Conn = C> + 'static,
407 P: ConnectionProvider,
408 {
409 if request
410 .queries()
411 .iter()
412 .any(|query| usage::LOCAL.name().zone_of(query.name()))
413 {
414 Local::ResolveStream(name_server.send(request))
415 } else {
416 Local::NotMdns(request)
417 }
418 }
419}
420
421#[allow(clippy::large_enum_variant)]
422pub(crate) enum Local {
423 #[allow(dead_code)]
424 ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
425 NotMdns(DnsRequest),
426}
427
428impl Local {
429 fn is_local(&self) -> bool {
430 matches!(*self, Self::ResolveStream(..))
431 }
432
433 fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
439 match self {
440 Self::ResolveStream(future) => future,
441 _ => panic!("non Local queries have no future, see take_message()"),
442 }
443 }
444
445 fn take_request(self) -> DnsRequest {
451 match self {
452 Self::NotMdns(request) => request,
453 _ => panic!("Local queries must be polled, see take_future()"),
454 }
455 }
456}
457
458impl Stream for Local {
459 type Item = Result<DnsResponse, ResolveError>;
460
461 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462 match *self {
463 Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
464 Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
467 }
468}
469
470#[cfg(test)]
471#[cfg(feature = "tokio-runtime")]
472mod tests {
473 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
474 use std::str::FromStr;
475
476 use tokio::runtime::Runtime;
477
478 use proto::op::Query;
479 use proto::rr::{Name, RecordType};
480 use proto::xfer::{DnsHandle, DnsRequestOptions};
481 use trust_dns_proto::rr::RData;
482
483 use super::*;
484 use crate::config::NameServerConfig;
485 use crate::config::Protocol;
486 use crate::name_server::TokioRuntimeProvider;
487 use crate::name_server::{GenericNameServer, TokioConnectionProvider};
488
489 #[ignore]
490 #[test]
492 #[allow(clippy::uninlined_format_args)]
493 fn test_failed_then_success_pool() {
494 let config1 = NameServerConfig {
495 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
496 protocol: Protocol::Udp,
497 tls_dns_name: None,
498 trust_negative_responses: false,
499 #[cfg(feature = "dns-over-rustls")]
500 tls_config: None,
501 bind_addr: None,
502 };
503
504 let config2 = NameServerConfig {
505 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
506 protocol: Protocol::Udp,
507 tls_dns_name: None,
508 trust_negative_responses: false,
509 #[cfg(feature = "dns-over-rustls")]
510 tls_config: None,
511 bind_addr: None,
512 };
513
514 let mut resolver_config = ResolverConfig::new();
515 resolver_config.add_name_server(config1);
516 resolver_config.add_name_server(config2);
517
518 let io_loop = Runtime::new().unwrap();
519 let mut pool = GenericNameServerPool::tokio_from_config(
520 &resolver_config,
521 &ResolverOpts::default(),
522 TokioRuntimeProvider::new(),
523 );
524
525 let name = Name::parse("www.example.com.", None).unwrap();
526
527 for i in 0..2 {
529 assert!(
530 io_loop
531 .block_on(
532 pool.lookup(
533 Query::query(name.clone(), RecordType::A),
534 DnsRequestOptions::default()
535 )
536 .first_answer()
537 )
538 .is_err(),
539 "iter: {}",
540 i
541 );
542 }
543
544 for i in 0..10 {
545 assert!(
546 io_loop
547 .block_on(
548 pool.lookup(
549 Query::query(name.clone(), RecordType::A),
550 DnsRequestOptions::default()
551 )
552 .first_answer()
553 )
554 .is_ok(),
555 "iter: {}",
556 i
557 );
558 }
559 }
560
561 #[test]
562 fn test_multi_use_conns() {
563 let io_loop = Runtime::new().unwrap();
564 let conn_provider = TokioConnectionProvider::default();
565
566 let tcp = NameServerConfig {
567 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
568 protocol: Protocol::Tcp,
569 tls_dns_name: None,
570 trust_negative_responses: false,
571 #[cfg(feature = "dns-over-rustls")]
572 tls_config: None,
573 bind_addr: None,
574 };
575
576 let opts = ResolverOpts {
577 try_tcp_on_error: true,
578 ..ResolverOpts::default()
579 };
580 let ns_config = { tcp };
581 let name_server = GenericNameServer::new(ns_config, opts, conn_provider);
582 let name_servers: Arc<[_]> = Arc::from([name_server]);
583
584 #[cfg(not(feature = "mdns"))]
585 let mut pool = GenericNameServerPool::from_nameservers_test(
586 &opts,
587 Arc::from([]),
588 Arc::clone(&name_servers),
589 );
590 #[cfg(feature = "mdns")]
591 let mut pool = GenericNameServerPool::from_nameservers_test(
592 &opts,
593 Arc::from([]),
594 Arc::clone(&name_servers),
595 name_server::mdns_nameserver(opts, TokioConnectionProvider::default(), false),
596 );
597
598 let name = Name::from_str("www.example.com.").unwrap();
599
600 let response = io_loop
602 .block_on(
603 pool.lookup(
604 Query::query(name.clone(), RecordType::A),
605 DnsRequestOptions::default(),
606 )
607 .first_answer(),
608 )
609 .expect("lookup failed");
610
611 assert_eq!(
612 *response.answers()[0]
613 .data()
614 .and_then(RData::as_a)
615 .expect("no a record available"),
616 Ipv4Addr::new(93, 184, 216, 34).into()
617 );
618
619 assert!(
620 name_servers[0].is_connected(),
621 "if this is failing then the NameServers aren't being properly shared."
622 );
623
624 let response = io_loop
626 .block_on(
627 pool.lookup(
628 Query::query(name, RecordType::AAAA),
629 DnsRequestOptions::default(),
630 )
631 .first_answer(),
632 )
633 .expect("lookup failed");
634
635 assert_eq!(
636 *response.answers()[0]
637 .data()
638 .and_then(RData::as_aaaa)
639 .expect("no aaaa record available"),
640 Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946).into()
641 );
642
643 assert!(
644 name_servers[0].is_connected(),
645 "if this is failing then the NameServers aren't being properly shared."
646 );
647 }
648}