1use std::error::Error;
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Instant;
18
19use futures_util::{future, future::Either, future::Future, FutureExt};
20
21use proto::error::ProtoError;
22use proto::op::Query;
23use proto::rr::{Name, RData, Record, RecordType};
24use proto::xfer::{DnsHandle, DnsRequestOptions};
25use tracing::debug;
26
27use crate::caching_client::CachingClient;
28use crate::config::LookupIpStrategy;
29use crate::dns_lru::MAX_TTL;
30use crate::error::*;
31use crate::hosts::Hosts;
32use crate::lookup::{Lookup, LookupIntoIter, LookupIter};
33
34#[derive(Debug, Clone)]
38pub struct LookupIp(Lookup);
39
40impl LookupIp {
41 pub fn iter(&self) -> LookupIpIter<'_> {
43 LookupIpIter(self.0.iter())
44 }
45
46 pub fn query(&self) -> &Query {
48 self.0.query()
49 }
50
51 pub fn valid_until(&self) -> Instant {
53 self.0.valid_until()
54 }
55
56 pub fn as_lookup(&self) -> &Lookup {
60 &self.0
61 }
62}
63
64impl From<Lookup> for LookupIp {
65 fn from(lookup: Lookup) -> Self {
66 Self(lookup)
67 }
68}
69
70impl From<LookupIp> for Lookup {
71 fn from(lookup: LookupIp) -> Self {
72 lookup.0
73 }
74}
75
76pub struct LookupIpIter<'i>(pub(crate) LookupIter<'i>);
78
79impl<'i> Iterator for LookupIpIter<'i> {
80 type Item = IpAddr;
81
82 fn next(&mut self) -> Option<Self::Item> {
83 let iter: &mut _ = &mut self.0;
84 iter.filter_map(|rdata| match *rdata {
85 RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(ip))),
86 RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(ip))),
87 _ => None,
88 })
89 .next()
90 }
91}
92
93impl IntoIterator for LookupIp {
94 type Item = IpAddr;
95 type IntoIter = LookupIpIntoIter;
96
97 fn into_iter(self) -> Self::IntoIter {
100 LookupIpIntoIter(self.0.into_iter())
101 }
102}
103
104pub struct LookupIpIntoIter(LookupIntoIter);
106
107impl Iterator for LookupIpIntoIter {
108 type Item = IpAddr;
109
110 fn next(&mut self) -> Option<Self::Item> {
111 let iter: &mut _ = &mut self.0;
112 iter.filter_map(|rdata| match rdata {
113 RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(ip))),
114 RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(ip))),
115 _ => None,
116 })
117 .next()
118 }
119}
120
121pub struct LookupIpFuture<C, E>
125where
126 C: DnsHandle<Error = E> + 'static,
127 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
128{
129 client_cache: CachingClient<C, E>,
130 names: Vec<Name>,
131 strategy: LookupIpStrategy,
132 options: DnsRequestOptions,
133 query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
134 hosts: Option<Arc<Hosts>>,
135 finally_ip_addr: Option<RData>,
136}
137
138impl<C, E> Future for LookupIpFuture<C, E>
139where
140 C: DnsHandle<Error = E> + 'static,
141 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
142{
143 type Output = Result<LookupIp, ResolveError>;
144
145 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146 loop {
147 let query = self.query.as_mut().poll(cx);
149
150 let should_retry = match query {
152 Poll::Pending => return Poll::Pending,
154 Poll::Ready(Ok(ref lookup)) => lookup.is_empty(),
158 Poll::Ready(Err(_)) => true,
160 };
161
162 if should_retry {
163 if let Some(name) = self.names.pop() {
164 self.query = strategic_lookup(
167 name,
168 self.strategy,
169 self.client_cache.clone(),
170 self.options,
171 self.hosts.clone(),
172 )
173 .boxed();
174 continue;
177 } else if let Some(ip_addr) = self.finally_ip_addr.take() {
178 let record = Record::from_rdata(Name::new(), MAX_TTL, ip_addr);
181 let lookup = Lookup::new_with_max_ttl(Query::new(), Arc::from([record]));
182 return Poll::Ready(Ok(lookup.into()));
183 }
184 };
185
186 return query.map(|f| f.map(LookupIp::from));
190 }
195 }
196}
197
198impl<C, E> LookupIpFuture<C, E>
199where
200 C: DnsHandle<Error = E> + 'static,
201 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
202{
203 pub fn lookup(
211 names: Vec<Name>,
212 strategy: LookupIpStrategy,
213 client_cache: CachingClient<C, E>,
214 options: DnsRequestOptions,
215 hosts: Option<Arc<Hosts>>,
216 finally_ip_addr: Option<RData>,
217 ) -> Self {
218 let empty =
219 ResolveError::from(ResolveErrorKind::Message("can not lookup IPs for no names"));
220 Self {
221 names,
222 strategy,
223 client_cache,
224 query: future::err(empty).boxed(),
227 options,
228 hosts,
229 finally_ip_addr,
230 }
231 }
232}
233
234async fn strategic_lookup<C, E>(
236 name: Name,
237 strategy: LookupIpStrategy,
238 client: CachingClient<C, E>,
239 options: DnsRequestOptions,
240 hosts: Option<Arc<Hosts>>,
241) -> Result<Lookup, ResolveError>
242where
243 C: DnsHandle<Error = E> + 'static,
244 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
245{
246 match strategy {
247 LookupIpStrategy::Ipv4Only => ipv4_only(name, client, options, hosts).await,
248 LookupIpStrategy::Ipv6Only => ipv6_only(name, client, options, hosts).await,
249 LookupIpStrategy::Ipv4AndIpv6 => ipv4_and_ipv6(name, client, options, hosts).await,
250 LookupIpStrategy::Ipv6thenIpv4 => ipv6_then_ipv4(name, client, options, hosts).await,
251 LookupIpStrategy::Ipv4thenIpv6 => ipv4_then_ipv6(name, client, options, hosts).await,
252 }
253}
254
255async fn hosts_lookup<C, E>(
257 query: Query,
258 mut client: CachingClient<C, E>,
259 options: DnsRequestOptions,
260 hosts: Option<Arc<Hosts>>,
261) -> Result<Lookup, ResolveError>
262where
263 C: DnsHandle<Error = E> + 'static,
264 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
265{
266 if let Some(hosts) = hosts {
267 if let Some(lookup) = hosts.lookup_static_host(&query) {
268 return Ok(lookup);
269 };
270 }
271
272 client.lookup(query, options).await
273}
274
275async fn ipv4_only<C, E>(
277 name: Name,
278 client: CachingClient<C, E>,
279 options: DnsRequestOptions,
280 hosts: Option<Arc<Hosts>>,
281) -> Result<Lookup, ResolveError>
282where
283 C: DnsHandle<Error = E> + 'static,
284 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
285{
286 hosts_lookup(Query::query(name, RecordType::A), client, options, hosts).await
287}
288
289async fn ipv6_only<C, E>(
291 name: Name,
292 client: CachingClient<C, E>,
293 options: DnsRequestOptions,
294 hosts: Option<Arc<Hosts>>,
295) -> Result<Lookup, ResolveError>
296where
297 C: DnsHandle<Error = E> + 'static,
298 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
299{
300 hosts_lookup(Query::query(name, RecordType::AAAA), client, options, hosts).await
301}
302
303async fn ipv4_and_ipv6<C, E>(
306 name: Name,
307 client: CachingClient<C, E>,
308 options: DnsRequestOptions,
309 hosts: Option<Arc<Hosts>>,
310) -> Result<Lookup, ResolveError>
311where
312 C: DnsHandle<Error = E> + 'static,
313 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
314{
315 let sel_res = future::select(
316 hosts_lookup(
317 Query::query(name.clone(), RecordType::A),
318 client.clone(),
319 options,
320 hosts.clone(),
321 )
322 .boxed(),
323 hosts_lookup(Query::query(name, RecordType::AAAA), client, options, hosts).boxed(),
324 )
325 .await;
326
327 let (ips, remaining_query) = match sel_res {
328 Either::Left(ips_and_remaining) => ips_and_remaining,
329 Either::Right(ips_and_remaining) => ips_and_remaining,
330 };
331
332 let next_ips = remaining_query.await;
333
334 match (ips, next_ips) {
335 (Ok(ips), Ok(next_ips)) => {
336 let ips = ips.append(next_ips);
338 Ok(ips)
339 }
340 (Ok(ips), Err(e)) | (Err(e), Ok(ips)) => {
341 debug!(
342 "one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}",
343 e
344 );
345 Ok(ips)
346 }
347 (Err(e1), Err(e2)) => {
348 debug!(
349 "both of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy e1: {}, e2: {}",
350 e1, e2
351 );
352 Err(e1)
353 }
354 }
355}
356
357async fn ipv6_then_ipv4<C, E>(
359 name: Name,
360 client: CachingClient<C, E>,
361 options: DnsRequestOptions,
362 hosts: Option<Arc<Hosts>>,
363) -> Result<Lookup, ResolveError>
364where
365 C: DnsHandle<Error = E> + 'static,
366 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
367{
368 rt_then_swap(
369 name,
370 client,
371 RecordType::AAAA,
372 RecordType::A,
373 options,
374 hosts,
375 )
376 .await
377}
378
379async fn ipv4_then_ipv6<C, E>(
381 name: Name,
382 client: CachingClient<C, E>,
383 options: DnsRequestOptions,
384 hosts: Option<Arc<Hosts>>,
385) -> Result<Lookup, ResolveError>
386where
387 C: DnsHandle<Error = E> + 'static,
388 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
389{
390 rt_then_swap(
391 name,
392 client,
393 RecordType::A,
394 RecordType::AAAA,
395 options,
396 hosts,
397 )
398 .await
399}
400
401async fn rt_then_swap<C, E>(
403 name: Name,
404 client: CachingClient<C, E>,
405 first_type: RecordType,
406 second_type: RecordType,
407 options: DnsRequestOptions,
408 hosts: Option<Arc<Hosts>>,
409) -> Result<Lookup, ResolveError>
410where
411 C: DnsHandle<Error = E> + 'static,
412 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
413{
414 let or_client = client.clone();
415 let res = hosts_lookup(
416 Query::query(name.clone(), first_type),
417 client,
418 options,
419 hosts.clone(),
420 )
421 .await;
422
423 match res {
424 Ok(ips) => {
425 if ips.is_empty() {
426 hosts_lookup(
428 Query::query(name.clone(), second_type),
429 or_client,
430 options,
431 hosts,
432 )
433 .await
434 } else {
435 Ok(ips)
436 }
437 }
438 Err(_) => {
439 hosts_lookup(
440 Query::query(name.clone(), second_type),
441 or_client,
442 options,
443 hosts,
444 )
445 .await
446 }
447 }
448}
449
450#[cfg(test)]
451pub mod tests {
452 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
453 use std::sync::{Arc, Mutex};
454
455 use futures_executor::block_on;
456 use futures_util::future;
457
458 use proto::op::Message;
459 use proto::rr::{Name, RData, Record};
460 use proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
461
462 use futures_util::stream::{once, Stream};
463
464 use super::*;
465 use crate::error::ResolveError;
466
467 #[derive(Clone)]
468 pub struct MockDnsHandle {
469 messages: Arc<Mutex<Vec<Result<DnsResponse, ResolveError>>>>,
470 }
471
472 impl DnsHandle for MockDnsHandle {
473 type Response =
474 Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send + Unpin>>;
475 type Error = ResolveError;
476
477 fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
478 Box::pin(once(future::ready(
479 self.messages.lock().unwrap().pop().unwrap_or_else(empty),
480 )))
481 }
482 }
483
484 pub fn v4_message() -> Result<DnsResponse, ResolveError> {
485 let mut message = Message::new();
486 message.add_query(Query::query(Name::root(), RecordType::A));
487 message.insert_answers(vec![Record::from_rdata(
488 Name::root(),
489 86400,
490 RData::A(Ipv4Addr::new(127, 0, 0, 1).into()),
491 )]);
492
493 let resp = DnsResponse::from_message(message).unwrap();
494 assert!(resp.contains_answer());
495 Ok(resp)
496 }
497
498 pub fn v6_message() -> Result<DnsResponse, ResolveError> {
499 let mut message = Message::new();
500 message.add_query(Query::query(Name::root(), RecordType::AAAA));
501 message.insert_answers(vec![Record::from_rdata(
502 Name::root(),
503 86400,
504 RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()),
505 )]);
506
507 let resp = DnsResponse::from_message(message).unwrap();
508 assert!(resp.contains_answer());
509 Ok(resp)
510 }
511
512 pub fn empty() -> Result<DnsResponse, ResolveError> {
513 Ok(DnsResponse::from_message(Message::new()).unwrap())
514 }
515
516 pub fn error() -> Result<DnsResponse, ResolveError> {
517 Err(ResolveError::from("forced test failure"))
518 }
519
520 pub fn mock(messages: Vec<Result<DnsResponse, ResolveError>>) -> MockDnsHandle {
521 MockDnsHandle {
522 messages: Arc::new(Mutex::new(messages)),
523 }
524 }
525
526 #[test]
527 fn test_ipv4_only_strategy() {
528 assert_eq!(
529 block_on(ipv4_only(
530 Name::root(),
531 CachingClient::new(0, mock(vec![v4_message()]), false),
532 DnsRequestOptions::default(),
533 None,
534 ))
535 .unwrap()
536 .iter()
537 .map(|r| r.ip_addr().unwrap())
538 .collect::<Vec<IpAddr>>(),
539 vec![Ipv4Addr::new(127, 0, 0, 1)]
540 );
541 }
542
543 #[test]
544 fn test_ipv6_only_strategy() {
545 assert_eq!(
546 block_on(ipv6_only(
547 Name::root(),
548 CachingClient::new(0, mock(vec![v6_message()]), false),
549 DnsRequestOptions::default(),
550 None,
551 ))
552 .unwrap()
553 .iter()
554 .map(|r| r.ip_addr().unwrap())
555 .collect::<Vec<IpAddr>>(),
556 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
557 );
558 }
559
560 #[test]
561 fn test_ipv4_and_ipv6_strategy() {
562 assert_eq!(
565 block_on(ipv4_and_ipv6(
566 Name::root(),
567 CachingClient::new(0, mock(vec![v6_message(), v4_message()]), false),
568 DnsRequestOptions::default(),
569 None,
570 ))
571 .unwrap()
572 .iter()
573 .map(|r| r.ip_addr().unwrap())
574 .collect::<Vec<IpAddr>>(),
575 vec![
576 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
577 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
578 ]
579 );
580
581 assert_eq!(
583 block_on(ipv4_and_ipv6(
584 Name::root(),
585 CachingClient::new(0, mock(vec![empty(), v4_message()]), false),
586 DnsRequestOptions::default(),
587 None,
588 ))
589 .unwrap()
590 .iter()
591 .map(|r| r.ip_addr().unwrap())
592 .collect::<Vec<IpAddr>>(),
593 vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]
594 );
595
596 assert_eq!(
598 block_on(ipv4_and_ipv6(
599 Name::root(),
600 CachingClient::new(0, mock(vec![error(), v4_message()]), false),
601 DnsRequestOptions::default(),
602 None,
603 ))
604 .unwrap()
605 .iter()
606 .map(|r| r.ip_addr().unwrap())
607 .collect::<Vec<IpAddr>>(),
608 vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]
609 );
610
611 assert_eq!(
613 block_on(ipv4_and_ipv6(
614 Name::root(),
615 CachingClient::new(0, mock(vec![v6_message(), empty()]), false),
616 DnsRequestOptions::default(),
617 None,
618 ))
619 .unwrap()
620 .iter()
621 .map(|r| r.ip_addr().unwrap())
622 .collect::<Vec<IpAddr>>(),
623 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
624 );
625
626 assert_eq!(
628 block_on(ipv4_and_ipv6(
629 Name::root(),
630 CachingClient::new(0, mock(vec![v6_message(), error()]), false),
631 DnsRequestOptions::default(),
632 None,
633 ))
634 .unwrap()
635 .iter()
636 .map(|r| r.ip_addr().unwrap())
637 .collect::<Vec<IpAddr>>(),
638 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
639 );
640 }
641
642 #[test]
643 fn test_ipv6_then_ipv4_strategy() {
644 assert_eq!(
646 block_on(ipv6_then_ipv4(
647 Name::root(),
648 CachingClient::new(0, mock(vec![v6_message()]), false),
649 DnsRequestOptions::default(),
650 None,
651 ))
652 .unwrap()
653 .iter()
654 .map(|r| r.ip_addr().unwrap())
655 .collect::<Vec<IpAddr>>(),
656 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
657 );
658
659 assert_eq!(
661 block_on(ipv6_then_ipv4(
662 Name::root(),
663 CachingClient::new(0, mock(vec![v4_message(), empty()]), false),
664 DnsRequestOptions::default(),
665 None,
666 ))
667 .unwrap()
668 .iter()
669 .map(|r| r.ip_addr().unwrap())
670 .collect::<Vec<IpAddr>>(),
671 vec![Ipv4Addr::new(127, 0, 0, 1)]
672 );
673
674 assert_eq!(
676 block_on(ipv6_then_ipv4(
677 Name::root(),
678 CachingClient::new(0, mock(vec![v4_message(), error()]), false),
679 DnsRequestOptions::default(),
680 None,
681 ))
682 .unwrap()
683 .iter()
684 .map(|r| r.ip_addr().unwrap())
685 .collect::<Vec<IpAddr>>(),
686 vec![Ipv4Addr::new(127, 0, 0, 1)]
687 );
688 }
689
690 #[test]
691 fn test_ipv4_then_ipv6_strategy() {
692 assert_eq!(
694 block_on(ipv4_then_ipv6(
695 Name::root(),
696 CachingClient::new(0, mock(vec![v4_message()]), false),
697 DnsRequestOptions::default(),
698 None,
699 ))
700 .unwrap()
701 .iter()
702 .map(|r| r.ip_addr().unwrap())
703 .collect::<Vec<IpAddr>>(),
704 vec![Ipv4Addr::new(127, 0, 0, 1)]
705 );
706
707 assert_eq!(
709 block_on(ipv4_then_ipv6(
710 Name::root(),
711 CachingClient::new(0, mock(vec![v6_message(), empty()]), false),
712 DnsRequestOptions::default(),
713 None,
714 ))
715 .unwrap()
716 .iter()
717 .map(|r| r.ip_addr().unwrap())
718 .collect::<Vec<IpAddr>>(),
719 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
720 );
721
722 assert_eq!(
724 block_on(ipv4_then_ipv6(
725 Name::root(),
726 CachingClient::new(0, mock(vec![v6_message(), error()]), false),
727 DnsRequestOptions::default(),
728 None,
729 ))
730 .unwrap()
731 .iter()
732 .map(|r| r.ip_addr().unwrap())
733 .collect::<Vec<IpAddr>>(),
734 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
735 );
736 }
737}