1use std::{
11 cmp::min,
12 error::Error,
13 pin::Pin,
14 slice::Iter,
15 sync::Arc,
16 task::{Context, Poll},
17 time::{Duration, Instant},
18};
19
20use futures_util::{
21 future::{self, Future},
22 stream::Stream,
23 FutureExt,
24};
25
26use crate::{
27 caching_client::CachingClient,
28 dns_lru::MAX_TTL,
29 error::*,
30 lookup_ip::LookupIpIter,
31 name_server::{ConnectionProvider, NameServerPool},
32 proto::{
33 error::ProtoError,
34 op::Query,
35 rr::{
36 rdata::{self, A, AAAA, NS, PTR},
37 Name, RData, Record, RecordType,
38 },
39 xfer::{DnsRequest, DnsRequestOptions, DnsResponse},
40 DnsHandle, RetryDnsHandle,
41 },
42};
43
44#[cfg(feature = "dnssec")]
45use proto::DnssecDnsHandle;
46
47#[derive(Clone, Debug, Eq, PartialEq)]
51pub struct Lookup {
52 query: Query,
53 records: Arc<[Record]>,
54 valid_until: Instant,
55}
56
57impl Lookup {
58 pub fn from_rdata(query: Query, rdata: RData) -> Self {
60 let record = Record::from_rdata(query.name().clone(), MAX_TTL, rdata);
61 Self::new_with_max_ttl(query, Arc::from([record]))
62 }
63
64 pub fn new_with_max_ttl(query: Query, records: Arc<[Record]>) -> Self {
66 let valid_until = Instant::now() + Duration::from_secs(u64::from(MAX_TTL));
67 Self {
68 query,
69 records,
70 valid_until,
71 }
72 }
73
74 pub fn new_with_deadline(query: Query, records: Arc<[Record]>, valid_until: Instant) -> Self {
76 Self {
77 query,
78 records,
79 valid_until,
80 }
81 }
82
83 pub fn query(&self) -> &Query {
85 &self.query
86 }
87
88 pub fn iter(&self) -> LookupIter<'_> {
90 LookupIter(self.records.iter())
91 }
92
93 pub fn record_iter(&self) -> LookupRecordIter<'_> {
95 LookupRecordIter(self.records.iter())
96 }
97
98 pub fn valid_until(&self) -> Instant {
100 self.valid_until
101 }
102
103 #[doc(hidden)]
104 pub fn is_empty(&self) -> bool {
105 self.records.is_empty()
106 }
107
108 pub(crate) fn len(&self) -> usize {
109 self.records.len()
110 }
111
112 pub fn records(&self) -> &[Record] {
114 self.records.as_ref()
115 }
116
117 pub(crate) fn append(&self, other: Self) -> Self {
119 let mut records = Vec::with_capacity(self.len() + other.len());
120 records.extend_from_slice(&self.records);
121 records.extend_from_slice(&other.records);
122
123 let valid_until = min(self.valid_until(), other.valid_until());
125 Self::new_with_deadline(self.query.clone(), Arc::from(records), valid_until)
126 }
127}
128
129pub struct LookupIter<'a>(Iter<'a, Record>);
131
132impl<'a> Iterator for LookupIter<'a> {
133 type Item = &'a RData;
134
135 fn next(&mut self) -> Option<Self::Item> {
136 self.0.next().and_then(Record::data)
137 }
138}
139
140pub struct LookupRecordIter<'a>(Iter<'a, Record>);
142
143impl<'a> Iterator for LookupRecordIter<'a> {
144 type Item = &'a Record;
145
146 fn next(&mut self) -> Option<Self::Item> {
147 self.0.next()
148 }
149}
150
151impl IntoIterator for Lookup {
153 type Item = RData;
154 type IntoIter = LookupIntoIter;
155
156 fn into_iter(self) -> Self::IntoIter {
159 LookupIntoIter {
160 records: Arc::clone(&self.records),
161 index: 0,
162 }
163 }
164}
165
166pub struct LookupIntoIter {
170 records: Arc<[Record]>,
172 index: usize,
173}
174
175impl Iterator for LookupIntoIter {
176 type Item = RData;
177
178 fn next(&mut self) -> Option<Self::Item> {
179 let rdata = self.records.get(self.index).and_then(Record::data);
180 self.index += 1;
181 rdata.cloned()
182 }
183}
184
185#[derive(Clone)]
187#[doc(hidden)]
188pub enum LookupEither<P: ConnectionProvider + Send> {
189 Retry(RetryDnsHandle<NameServerPool<P>>),
190 #[cfg(feature = "dnssec")]
191 #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
192 Secure(DnssecDnsHandle<RetryDnsHandle<NameServerPool<P>>>),
193}
194
195impl<P: ConnectionProvider> DnsHandle for LookupEither<P> {
196 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
197 type Error = ResolveError;
198
199 fn is_verifying_dnssec(&self) -> bool {
200 match *self {
201 Self::Retry(ref c) => c.is_verifying_dnssec(),
202 #[cfg(feature = "dnssec")]
203 Self::Secure(ref c) => c.is_verifying_dnssec(),
204 }
205 }
206
207 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
208 match *self {
209 Self::Retry(ref mut c) => c.send(request),
210 #[cfg(feature = "dnssec")]
211 Self::Secure(ref mut c) => c.send(request),
212 }
213 }
214}
215
216#[doc(hidden)]
218pub struct LookupFuture<C, E>
219where
220 C: DnsHandle<Error = E> + 'static,
221 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
222{
223 client_cache: CachingClient<C, E>,
224 names: Vec<Name>,
225 record_type: RecordType,
226 options: DnsRequestOptions,
227 query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
228}
229
230impl<C, E> LookupFuture<C, E>
231where
232 C: DnsHandle<Error = E> + 'static,
233 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
234{
235 #[doc(hidden)]
243 pub fn lookup(
244 mut names: Vec<Name>,
245 record_type: RecordType,
246 options: DnsRequestOptions,
247 mut client_cache: CachingClient<C, E>,
248 ) -> Self {
249 let name = names.pop().ok_or_else(|| {
250 ResolveError::from(ResolveErrorKind::Message("can not lookup for no names"))
251 });
252
253 let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name {
254 Ok(name) => client_cache
255 .lookup(Query::query(name, record_type), options)
256 .boxed(),
257 Err(err) => future::err(err).boxed(),
258 };
259
260 Self {
261 client_cache,
262 names,
263 record_type,
264 options,
265 query,
266 }
267 }
268}
269
270impl<C, E> Future for LookupFuture<C, E>
271where
272 C: DnsHandle<Error = E> + 'static,
273 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
274{
275 type Output = Result<Lookup, ResolveError>;
276
277 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278 loop {
279 let query = self.query.as_mut().poll_unpin(cx);
281
282 let should_retry = match query {
284 Poll::Pending => return Poll::Pending,
286 Poll::Ready(Ok(ref lookup)) => lookup.records.len() == 0,
290 Poll::Ready(Err(_)) => true,
292 };
293
294 if should_retry {
295 if let Some(name) = self.names.pop() {
296 let record_type = self.record_type;
297 let options = self.options;
298
299 self.query = self
302 .client_cache
303 .lookup(Query::query(name, record_type), options);
304 continue;
307 }
308 }
309 return query;
313 }
318 }
319}
320
321#[derive(Debug, Clone)]
323pub struct SrvLookup(Lookup);
324
325impl SrvLookup {
326 pub fn iter(&self) -> SrvLookupIter<'_> {
328 SrvLookupIter(self.0.iter())
329 }
330
331 pub fn query(&self) -> &Query {
333 self.0.query()
334 }
335
336 pub fn ip_iter(&self) -> LookupIpIter<'_> {
340 LookupIpIter(self.0.iter())
341 }
342
343 pub fn as_lookup(&self) -> &Lookup {
347 &self.0
348 }
349}
350
351impl From<Lookup> for SrvLookup {
352 fn from(lookup: Lookup) -> Self {
353 Self(lookup)
354 }
355}
356
357pub struct SrvLookupIter<'i>(LookupIter<'i>);
359
360impl<'i> Iterator for SrvLookupIter<'i> {
361 type Item = &'i rdata::SRV;
362
363 fn next(&mut self) -> Option<Self::Item> {
364 let iter: &mut _ = &mut self.0;
365 iter.filter_map(|rdata| match *rdata {
366 RData::SRV(ref data) => Some(data),
367 _ => None,
368 })
369 .next()
370 }
371}
372
373impl IntoIterator for SrvLookup {
374 type Item = rdata::SRV;
375 type IntoIter = SrvLookupIntoIter;
376
377 fn into_iter(self) -> Self::IntoIter {
380 SrvLookupIntoIter(self.0.into_iter())
381 }
382}
383
384pub struct SrvLookupIntoIter(LookupIntoIter);
386
387impl Iterator for SrvLookupIntoIter {
388 type Item = rdata::SRV;
389
390 fn next(&mut self) -> Option<Self::Item> {
391 let iter: &mut _ = &mut self.0;
392 iter.filter_map(|rdata| match rdata {
393 RData::SRV(data) => Some(data),
394 _ => None,
395 })
396 .next()
397 }
398}
399
400macro_rules! lookup_type {
402 ($l:ident, $i:ident, $ii:ident, $r:path, $t:path) => {
403 #[derive(Debug, Clone)]
405 pub struct $l(Lookup);
406
407 impl $l {
408 pub fn iter(&self) -> $i<'_> {
410 $i(self.0.iter())
411 }
412
413 pub fn query(&self) -> &Query {
415 self.0.query()
416 }
417
418 pub fn valid_until(&self) -> Instant {
420 self.0.valid_until()
421 }
422
423 pub fn as_lookup(&self) -> &Lookup {
427 &self.0
428 }
429 }
430
431 impl From<Lookup> for $l {
432 fn from(lookup: Lookup) -> Self {
433 $l(lookup)
434 }
435 }
436
437 impl From<$l> for Lookup {
438 fn from(revlookup: $l) -> Self {
439 revlookup.0
440 }
441 }
442
443 pub struct $i<'i>(LookupIter<'i>);
445
446 impl<'i> Iterator for $i<'i> {
447 type Item = &'i $t;
448
449 fn next(&mut self) -> Option<Self::Item> {
450 let iter: &mut _ = &mut self.0;
451 iter.filter_map(|rdata| match *rdata {
452 $r(ref data) => Some(data),
453 _ => None,
454 })
455 .next()
456 }
457 }
458
459 impl IntoIterator for $l {
460 type Item = $t;
461 type IntoIter = $ii;
462
463 fn into_iter(self) -> Self::IntoIter {
466 $ii(self.0.into_iter())
467 }
468 }
469
470 pub struct $ii(LookupIntoIter);
472
473 impl Iterator for $ii {
474 type Item = $t;
475
476 fn next(&mut self) -> Option<Self::Item> {
477 let iter: &mut _ = &mut self.0;
478 iter.filter_map(|rdata| match rdata {
479 $r(data) => Some(data),
480 _ => None,
481 })
482 .next()
483 }
484 }
485 };
486}
487
488lookup_type!(
490 ReverseLookup,
491 ReverseLookupIter,
492 ReverseLookupIntoIter,
493 RData::PTR,
494 PTR
495);
496lookup_type!(Ipv4Lookup, Ipv4LookupIter, Ipv4LookupIntoIter, RData::A, A);
497lookup_type!(
498 Ipv6Lookup,
499 Ipv6LookupIter,
500 Ipv6LookupIntoIter,
501 RData::AAAA,
502 AAAA
503);
504lookup_type!(
505 MxLookup,
506 MxLookupIter,
507 MxLookupIntoIter,
508 RData::MX,
509 rdata::MX
510);
511lookup_type!(
512 TlsaLookup,
513 TlsaLookupIter,
514 TlsaLookupIntoIter,
515 RData::TLSA,
516 rdata::TLSA
517);
518lookup_type!(
519 TxtLookup,
520 TxtLookupIter,
521 TxtLookupIntoIter,
522 RData::TXT,
523 rdata::TXT
524);
525lookup_type!(
526 SoaLookup,
527 SoaLookupIter,
528 SoaLookupIntoIter,
529 RData::SOA,
530 rdata::SOA
531);
532lookup_type!(NsLookup, NsLookupIter, NsLookupIntoIter, RData::NS, NS);
533
534#[cfg(test)]
535pub mod tests {
536 use std::net::{IpAddr, Ipv4Addr};
537 use std::str::FromStr;
538 use std::sync::{Arc, Mutex};
539
540 use futures_executor::block_on;
541 use futures_util::future;
542 use futures_util::stream::once;
543
544 use proto::op::{Message, Query};
545 use proto::rr::{Name, RData, Record, RecordType};
546 use proto::xfer::{DnsRequest, DnsRequestOptions};
547
548 use super::*;
549 use crate::error::ResolveError;
550
551 #[derive(Clone)]
552 pub struct MockDnsHandle {
553 messages: Arc<Mutex<Vec<Result<DnsResponse, ResolveError>>>>,
554 }
555
556 impl DnsHandle for MockDnsHandle {
557 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
558 type Error = ResolveError;
559
560 fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
561 Box::pin(once(
562 future::ready(self.messages.lock().unwrap().pop().unwrap_or_else(empty)).boxed(),
563 ))
564 }
565 }
566
567 pub fn v4_message() -> Result<DnsResponse, ResolveError> {
568 let mut message = Message::new();
569 message.add_query(Query::query(Name::root(), RecordType::A));
570 message.insert_answers(vec![Record::from_rdata(
571 Name::root(),
572 86400,
573 RData::A(A::new(127, 0, 0, 1)),
574 )]);
575
576 let resp = DnsResponse::from_message(message).unwrap();
577 assert!(resp.contains_answer());
578 Ok(resp)
579 }
580
581 pub fn empty() -> Result<DnsResponse, ResolveError> {
582 Ok(DnsResponse::from_message(Message::new()).unwrap())
583 }
584
585 pub fn error() -> Result<DnsResponse, ResolveError> {
586 Err(ResolveError::from(ProtoError::from(std::io::Error::from(
587 std::io::ErrorKind::Other,
588 ))))
589 }
590
591 pub fn mock(messages: Vec<Result<DnsResponse, ResolveError>>) -> MockDnsHandle {
592 MockDnsHandle {
593 messages: Arc::new(Mutex::new(messages)),
594 }
595 }
596
597 #[test]
598 fn test_lookup() {
599 assert_eq!(
600 block_on(LookupFuture::lookup(
601 vec![Name::root()],
602 RecordType::A,
603 DnsRequestOptions::default(),
604 CachingClient::new(0, mock(vec![v4_message()]), false),
605 ))
606 .unwrap()
607 .iter()
608 .map(|r| r.ip_addr().unwrap())
609 .collect::<Vec<IpAddr>>(),
610 vec![Ipv4Addr::new(127, 0, 0, 1)]
611 );
612 }
613
614 #[test]
615 fn test_lookup_slice() {
616 assert_eq!(
617 Record::data(
618 &block_on(LookupFuture::lookup(
619 vec![Name::root()],
620 RecordType::A,
621 DnsRequestOptions::default(),
622 CachingClient::new(0, mock(vec![v4_message()]), false),
623 ))
624 .unwrap()
625 .records()[0]
626 )
627 .unwrap()
628 .ip_addr()
629 .unwrap(),
630 Ipv4Addr::new(127, 0, 0, 1)
631 );
632 }
633
634 #[test]
635 fn test_lookup_into_iter() {
636 assert_eq!(
637 block_on(LookupFuture::lookup(
638 vec![Name::root()],
639 RecordType::A,
640 DnsRequestOptions::default(),
641 CachingClient::new(0, mock(vec![v4_message()]), false),
642 ))
643 .unwrap()
644 .into_iter()
645 .map(|r| r.ip_addr().unwrap())
646 .collect::<Vec<IpAddr>>(),
647 vec![Ipv4Addr::new(127, 0, 0, 1)]
648 );
649 }
650
651 #[test]
652 fn test_error() {
653 assert!(block_on(LookupFuture::lookup(
654 vec![Name::root()],
655 RecordType::A,
656 DnsRequestOptions::default(),
657 CachingClient::new(0, mock(vec![error()]), false),
658 ))
659 .is_err());
660 }
661
662 #[test]
663 fn test_empty_no_response() {
664 if let ResolveErrorKind::NoRecordsFound {
665 query,
666 negative_ttl,
667 ..
668 } = block_on(LookupFuture::lookup(
669 vec![Name::root()],
670 RecordType::A,
671 DnsRequestOptions::default(),
672 CachingClient::new(0, mock(vec![empty()]), false),
673 ))
674 .unwrap_err()
675 .kind()
676 {
677 assert_eq!(**query, Query::query(Name::root(), RecordType::A));
678 assert_eq!(*negative_ttl, None);
679 } else {
680 panic!("wrong error received");
681 }
682 }
683
684 #[test]
685 fn test_lookup_into_iter_arc() {
686 let mut lookup = LookupIntoIter {
687 records: Arc::from([
688 Record::from_rdata(
689 Name::from_str("www.example.com.").unwrap(),
690 80,
691 RData::A(A::new(127, 0, 0, 1)),
692 ),
693 Record::from_rdata(
694 Name::from_str("www.example.com.").unwrap(),
695 80,
696 RData::A(A::new(127, 0, 0, 2)),
697 ),
698 ]),
699 index: 0,
700 };
701
702 assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 1)));
703 assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 2)));
704 assert_eq!(lookup.next(), None);
705 }
706}