trust_dns_resolver/
caching_client.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Caching related functionality for the Resolver.
9
10use std::{
11    borrow::Cow,
12    error::Error,
13    pin::Pin,
14    sync::{
15        atomic::{AtomicU8, Ordering},
16        Arc,
17    },
18    time::Instant,
19};
20
21use futures_util::future::Future;
22use once_cell::sync::Lazy;
23
24use crate::{
25    dns_lru::{self, DnsLru, TtlConfig},
26    error::{ResolveError, ResolveErrorKind},
27    lookup::Lookup,
28    proto::{
29        error::ProtoError,
30        op::{Query, ResponseCode},
31        rr::{
32            domain::usage::{
33                ResolverUsage, DEFAULT, INVALID, IN_ADDR_ARPA_127, IP6_ARPA_1, LOCAL,
34                LOCALHOST as LOCALHOST_usage, ONION,
35            },
36            rdata::{A, AAAA, CNAME, PTR, SOA},
37            resource::RecordRef,
38            DNSClass, Name, RData, Record, RecordType,
39        },
40        xfer::{DnsHandle, DnsRequestOptions, DnsResponse, FirstAnswer},
41    },
42};
43
44const MAX_QUERY_DEPTH: u8 = 8; // arbitrarily chosen number...
45
46static LOCALHOST: Lazy<RData> =
47    Lazy::new(|| RData::PTR(PTR(Name::from_ascii("localhost.").unwrap())));
48static LOCALHOST_V4: Lazy<RData> = Lazy::new(|| RData::A(A::new(127, 0, 0, 1)));
49static LOCALHOST_V6: Lazy<RData> = Lazy::new(|| RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)));
50
51struct DepthTracker {
52    query_depth: Arc<AtomicU8>,
53}
54
55impl DepthTracker {
56    fn track(query_depth: Arc<AtomicU8>) -> Self {
57        query_depth.fetch_add(1, Ordering::Release);
58        Self { query_depth }
59    }
60}
61
62impl Drop for DepthTracker {
63    fn drop(&mut self) {
64        self.query_depth.fetch_sub(1, Ordering::Release);
65    }
66}
67
68// TODO: need to consider this storage type as it compares to Authority in server...
69//       should it just be an variation on Authority?
70#[derive(Clone, Debug)]
71#[doc(hidden)]
72pub struct CachingClient<C, E>
73where
74    C: DnsHandle<Error = E>,
75    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
76{
77    lru: DnsLru,
78    client: C,
79    query_depth: Arc<AtomicU8>,
80    preserve_intermediates: bool,
81}
82
83impl<C, E> CachingClient<C, E>
84where
85    C: DnsHandle<Error = E> + Send + 'static,
86    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
87{
88    #[doc(hidden)]
89    pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
90        Self::with_cache(
91            DnsLru::new(max_size, TtlConfig::default()),
92            client,
93            preserve_intermediates,
94        )
95    }
96
97    pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
98        let query_depth = Arc::new(AtomicU8::new(0));
99        Self {
100            lru,
101            client,
102            query_depth,
103            preserve_intermediates,
104        }
105    }
106
107    /// Perform a lookup against this caching client, looking first in the cache for a result
108    pub fn lookup(
109        &mut self,
110        query: Query,
111        options: DnsRequestOptions,
112    ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
113        Box::pin(Self::inner_lookup(query, options, self.clone(), vec![]))
114    }
115
116    async fn inner_lookup(
117        query: Query,
118        options: DnsRequestOptions,
119        mut client: Self,
120        preserved_records: Vec<(Record, u32)>,
121    ) -> Result<Lookup, ResolveError> {
122        // see https://tools.ietf.org/html/rfc6761
123        //
124        // ```text
125        // Name resolution APIs and libraries SHOULD recognize localhost
126        // names as special and SHOULD always return the IP loopback address
127        // for address queries and negative responses for all other query
128        // types.  Name resolution APIs SHOULD NOT send queries for
129        // localhost names to their configured caching DNS server(s).
130        // ```
131        // special use rules only apply to the IN Class
132        if query.query_class() == DNSClass::IN {
133            let usage = match query.name() {
134                n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
135                n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
136                n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
137                n if INVALID.zone_of(n) => &*INVALID,
138                n if LOCAL.zone_of(n) => &*LOCAL,
139                n if ONION.zone_of(n) => &*ONION,
140                _ => &*DEFAULT,
141            };
142
143            match usage.resolver() {
144                ResolverUsage::Loopback => match query.query_type() {
145                    // TODO: look in hosts for these ips/names first...
146                    RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
147                    RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
148                    RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
149                    _ => {
150                        return Err(ResolveError::nx_error(
151                            query,
152                            None,
153                            None,
154                            ResponseCode::NoError,
155                            false,
156                        ))
157                    } // Are there any other types we can use?
158                },
159                // when mdns is enabled we will follow a standard query path
160                #[cfg(feature = "mdns")]
161                ResolverUsage::LinkLocal => (),
162                // TODO: this requires additional config, as Kubernetes and other systems misuse the .local. zone.
163                // when mdns is not enabled we will return errors on LinkLocal ("*.local.") names
164                #[cfg(not(feature = "mdns"))]
165                ResolverUsage::LinkLocal => (),
166                ResolverUsage::NxDomain => {
167                    return Err(ResolveError::nx_error(
168                        query,
169                        None,
170                        None,
171                        ResponseCode::NXDomain,
172                        false,
173                    ))
174                }
175                ResolverUsage::Normal => (),
176            }
177        }
178
179        let _tracker = DepthTracker::track(client.query_depth.clone());
180        let is_dnssec = client.client.is_verifying_dnssec();
181
182        // first transition any polling that is needed (mutable refs...)
183        if let Some(cached_lookup) = client.lookup_from_cache(&query) {
184            return cached_lookup;
185        };
186
187        let response_message = client
188            .client
189            .lookup(query.clone(), options)
190            .first_answer()
191            .await
192            .map_err(E::into);
193
194        // TODO: technically this might be duplicating work, as name_server already performs this evaluation.
195        //  we may want to create a new type, if evaluated... but this is most generic to support any impl in LookupState...
196        let response_message = if let Ok(response) = response_message {
197            ResolveError::from_response(response, false)
198        } else {
199            response_message
200        };
201
202        // TODO: take all records and cache them?
203        //  if it's DNSSEC they must be signed, otherwise?
204        let records: Result<Records, ResolveError> = match response_message {
205            // this is the only cacheable form
206            Err(ResolveError {
207                kind:
208                    ResolveErrorKind::NoRecordsFound {
209                        query,
210                        soa,
211                        negative_ttl,
212                        response_code,
213                        trusted,
214                    },
215                ..
216            }) => {
217                Err(Self::handle_nxdomain(
218                    is_dnssec,
219                    false, /*tbd*/
220                    *query,
221                    soa.map(|v| *v),
222                    negative_ttl,
223                    response_code,
224                    trusted,
225                ))
226            }
227            Err(e) => return Err(e),
228            Ok(response_message) => {
229                // allow the handle_noerror function to deal with any error codes
230                let records = Self::handle_noerror(
231                    &mut client,
232                    options,
233                    is_dnssec,
234                    &query,
235                    response_message,
236                    preserved_records,
237                )?;
238
239                Ok(records)
240            }
241        };
242
243        // after the request, evaluate if we have additional queries to perform
244        match records {
245            Ok(Records::CnameChain {
246                next: future,
247                min_ttl: ttl,
248            }) => client.cname(future.await?, query, ttl),
249            Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
250            Err(e) => client.cache(query, Err(e)),
251        }
252    }
253
254    /// Check if this query is already cached
255    fn lookup_from_cache(&self, query: &Query) -> Option<Result<Lookup, ResolveError>> {
256        self.lru.get(query, Instant::now())
257    }
258
259    /// See https://tools.ietf.org/html/rfc2308
260    ///
261    /// For now we will regard NXDomain to strictly mean the query failed
262    ///  and a record for the name, regardless of CNAME presence, what have you
263    ///  ultimately does not exist.
264    ///
265    /// This also handles empty responses in the same way. When performing DNSSEC enabled queries, we should
266    ///  never enter here, and should never cache unless verified requests.
267    ///
268    /// TODO: should this should be expanded to do a forward lookup? Today, this will fail even if there are
269    ///   forwarding options.
270    ///
271    /// # Arguments
272    ///
273    /// * `message` - message to extract SOA, etc, from for caching failed requests
274    /// * `valid_nsec` - species that in DNSSEC mode, this request is safe to cache
275    /// * `negative_ttl` - this should be the SOA minimum for negative ttl
276    fn handle_nxdomain(
277        is_dnssec: bool,
278        valid_nsec: bool,
279        query: Query,
280        soa: Option<Record<SOA>>,
281        negative_ttl: Option<u32>,
282        response_code: ResponseCode,
283        trusted: bool,
284    ) -> ResolveError {
285        if valid_nsec || !is_dnssec {
286            // only trust if there were validated NSEC records
287            ResolveErrorKind::NoRecordsFound {
288                query: Box::new(query),
289                soa: soa.map(Box::new),
290                negative_ttl,
291                response_code,
292                trusted: true,
293            }
294            .into()
295        } else {
296            // not cacheable, no ttl...
297            ResolveErrorKind::NoRecordsFound {
298                query: Box::new(query),
299                soa: soa.map(Box::new),
300                negative_ttl: None,
301                response_code,
302                trusted,
303            }
304            .into()
305        }
306    }
307
308    /// Handle the case where there is no error returned
309    fn handle_noerror(
310        client: &mut Self,
311        options: DnsRequestOptions,
312        is_dnssec: bool,
313        query: &Query,
314        response: DnsResponse,
315        mut preserved_records: Vec<(Record, u32)>,
316    ) -> Result<Records, ResolveError> {
317        // initial ttl is what CNAMES for min usage
318        const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
319
320        // need to capture these before the subsequent and destructive record processing
321        let soa = response.soa().as_ref().map(RecordRef::to_owned);
322        let negative_ttl = response.negative_ttl();
323        let response_code = response.response_code();
324
325        // seek out CNAMES, this is only performed if the query is not a CNAME, ANY, or SRV
326        // FIXME: for SRV this evaluation is inadequate. CNAME is a single chain to a single record
327        //   for SRV, there could be many different targets. The search_name needs to be enhanced to
328        //   be a list of names found for SRV records.
329        let (search_name, cname_ttl, was_cname, preserved_records) = {
330            // this will only search for CNAMEs if the request was not meant to be for one of the triggers for recursion
331            let (search_name, cname_ttl, was_cname) =
332                if query.query_type().is_any() || query.query_type().is_cname() {
333                    (Cow::Borrowed(query.name()), INITIAL_TTL, false)
334                } else {
335                    // Folds any cnames from the answers section, into the final cname in the answers section
336                    //   this works by folding the last CNAME found into the final folded result.
337                    //   it assumes that the CNAMEs are in chained order in the DnsResponse Message...
338                    // For SRV, the name added for the search becomes the target name.
339                    //
340                    // TODO: should this include the additionals?
341                    response.answers().iter().fold(
342                        (Cow::Borrowed(query.name()), INITIAL_TTL, false),
343                        |(search_name, cname_ttl, was_cname), r| {
344                            match r.data() {
345                                Some(RData::CNAME(CNAME(ref cname))) => {
346                                    // take the minimum TTL of the cname_ttl and the next record in the chain
347                                    let ttl = cname_ttl.min(r.ttl());
348                                    debug_assert_eq!(r.record_type(), RecordType::CNAME);
349                                    if search_name.as_ref() == r.name() {
350                                        return (Cow::Owned(cname.clone()), ttl, true);
351                                    }
352                                }
353                                Some(RData::SRV(ref srv)) => {
354                                    // take the minimum TTL of the cname_ttl and the next record in the chain
355                                    let ttl = cname_ttl.min(r.ttl());
356                                    debug_assert_eq!(r.record_type(), RecordType::SRV);
357
358                                    // the search name becomes the srv.target
359                                    return (Cow::Owned(srv.target().clone()), ttl, true);
360                                }
361                                _ => (),
362                            }
363
364                            (search_name, cname_ttl, was_cname)
365                        },
366                    )
367                };
368
369            // take all answers. // TODO: following CNAMES?
370            let mut response = response.into_message();
371            let answers = response.take_answers();
372            let additionals = response.take_additionals();
373            let name_servers = response.take_name_servers();
374
375            // set of names that still require resolution
376            // TODO: this needs to be enhanced for SRV
377            let mut found_name = false;
378
379            // After following all the CNAMES to the last one, try and lookup the final name
380            let records = answers
381                .into_iter()
382                // Chained records will generally exist in the additionals section
383                .chain(additionals.into_iter())
384                .chain(name_servers.into_iter())
385                .filter_map(|r| {
386                    // because this resolved potentially recursively, we want the min TTL from the chain
387                    let ttl = cname_ttl.min(r.ttl());
388                    // TODO: disable name validation with ResolverOpts? glibc feature...
389                    // restrict to the RData type requested
390                    if query.query_class() == r.dns_class() {
391                        // standard evaluation, it's an any type or it's the requested type and the search_name matches
392                        #[allow(clippy::suspicious_operation_groupings)]
393                        if (query.query_type().is_any() || query.query_type() == r.record_type())
394                            && (search_name.as_ref() == r.name() || query.name() == r.name())
395                        {
396                            found_name = true;
397                            return Some((r, ttl));
398                        }
399                        // CNAME evaluation, the record is from the CNAME lookup chain.
400                        if client.preserve_intermediates && r.record_type() == RecordType::CNAME {
401                            return Some((r, ttl));
402                        }
403                        // srv evaluation, it's an srv lookup and the srv_search_name/target matches this name
404                        //    and it's an IP
405                        if query.query_type().is_srv()
406                            && r.record_type().is_ip_addr()
407                            && search_name.as_ref() == r.name()
408                        {
409                            found_name = true;
410                            Some((r, ttl))
411                        } else if query.query_type().is_ns() && r.record_type().is_ip_addr() {
412                            Some((r, ttl))
413                        } else {
414                            None
415                        }
416                    } else {
417                        None
418                    }
419                })
420                .collect::<Vec<_>>();
421
422            // adding the newly collected records to the preserved records
423            preserved_records.extend(records);
424            if !preserved_records.is_empty() && found_name {
425                return Ok(Records::Exists(preserved_records));
426            }
427
428            (
429                search_name.into_owned(),
430                cname_ttl,
431                was_cname,
432                preserved_records,
433            )
434        };
435
436        // TODO: for SRV records we *could* do an implicit lookup, but, this requires knowing the type of IP desired
437        //    for now, we'll make the API require the user to perform a follow up to the lookups.
438        // It was a CNAME, but not included in the request...
439        if was_cname && client.query_depth.load(Ordering::Acquire) < MAX_QUERY_DEPTH {
440            let next_query = Query::query(search_name, query.query_type());
441            Ok(Records::CnameChain {
442                next: Box::pin(Self::inner_lookup(
443                    next_query,
444                    options,
445                    client.clone(),
446                    preserved_records,
447                )),
448                min_ttl: cname_ttl,
449            })
450        } else {
451            // TODO: review See https://tools.ietf.org/html/rfc2308 for NoData section
452            // Note on DNSSEC, in secure_client_handle, if verify_nsec fails then the request fails.
453            //   this will mean that no unverified negative caches will make it to this point and be stored
454            Err(Self::handle_nxdomain(
455                is_dnssec,
456                true,
457                query.clone(),
458                soa,
459                negative_ttl,
460                response_code,
461                false,
462            ))
463        }
464    }
465
466    #[allow(clippy::unnecessary_wraps)]
467    fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ResolveError> {
468        // this duplicates the cache entry under the original query
469        Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
470    }
471
472    fn cache(
473        &self,
474        query: Query,
475        records: Result<Vec<(Record, u32)>, ResolveError>,
476    ) -> Result<Lookup, ResolveError> {
477        // this will put this object into an inconsistent state, but no one should call poll again...
478        match records {
479            Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
480            Err(err) => Err(self.lru.negative(query, err, Instant::now())),
481        }
482    }
483
484    /// Flushes/Removes all entries from the cache
485    pub fn clear_cache(&self) {
486        self.lru.clear();
487    }
488}
489
490enum Records {
491    /// The records exists, a vec of rdata with ttl
492    Exists(Vec<(Record, u32)>),
493    /// Future lookup for recursive cname records
494    CnameChain {
495        next: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
496        min_ttl: u32,
497    },
498}
499
500// see also the lookup_tests.rs in integration-tests crate
501#[cfg(test)]
502mod tests {
503    use std::net::*;
504    use std::str::FromStr;
505    use std::time::*;
506
507    use futures_executor::block_on;
508    use proto::op::{Message, Query};
509    use proto::rr::rdata::SRV;
510    use proto::rr::{Name, Record};
511    use trust_dns_proto::rr::rdata::NS;
512
513    use super::*;
514    use crate::lookup_ip::tests::*;
515
516    #[test]
517    fn test_empty_cache() {
518        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
519        let client = mock(vec![empty()]);
520        let client = CachingClient::with_cache(cache, client, false);
521
522        if let ResolveErrorKind::NoRecordsFound {
523            query,
524            negative_ttl,
525            ..
526        } = block_on(CachingClient::inner_lookup(
527            Query::new(),
528            DnsRequestOptions::default(),
529            client,
530            vec![],
531        ))
532        .unwrap_err()
533        .kind()
534        {
535            assert_eq!(**query, Query::new());
536            assert_eq!(*negative_ttl, None);
537        } else {
538            panic!("wrong error received")
539        }
540    }
541
542    #[test]
543    fn test_from_cache() {
544        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
545        let query = Query::new();
546        cache.insert(
547            query.clone(),
548            vec![(
549                Record::from_rdata(
550                    query.name().clone(),
551                    u32::max_value(),
552                    RData::A(A::new(127, 0, 0, 1)),
553                ),
554                u32::max_value(),
555            )],
556            Instant::now(),
557        );
558
559        let client = mock(vec![empty()]);
560        let client = CachingClient::with_cache(cache, client, false);
561
562        let ips = block_on(CachingClient::inner_lookup(
563            Query::new(),
564            DnsRequestOptions::default(),
565            client,
566            vec![],
567        ))
568        .unwrap();
569
570        assert_eq!(
571            ips.iter().cloned().collect::<Vec<_>>(),
572            vec![RData::A(A::new(127, 0, 0, 1))]
573        );
574    }
575
576    #[test]
577    fn test_no_cache_insert() {
578        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
579        // first should come from client...
580        let client = mock(vec![v4_message()]);
581        let client = CachingClient::with_cache(cache.clone(), client, false);
582
583        let ips = block_on(CachingClient::inner_lookup(
584            Query::new(),
585            DnsRequestOptions::default(),
586            client,
587            vec![],
588        ))
589        .unwrap();
590
591        assert_eq!(
592            ips.iter().cloned().collect::<Vec<_>>(),
593            vec![RData::A(A::new(127, 0, 0, 1))]
594        );
595
596        // next should come from cache...
597        let client = mock(vec![empty()]);
598        let client = CachingClient::with_cache(cache, client, false);
599
600        let ips = block_on(CachingClient::inner_lookup(
601            Query::new(),
602            DnsRequestOptions::default(),
603            client,
604            vec![],
605        ))
606        .unwrap();
607
608        assert_eq!(
609            ips.iter().cloned().collect::<Vec<_>>(),
610            vec![RData::A(A::new(127, 0, 0, 1))]
611        );
612    }
613
614    #[allow(clippy::unnecessary_wraps)]
615    pub(crate) fn cname_message() -> Result<DnsResponse, ResolveError> {
616        let mut message = Message::new();
617        message.add_query(Query::query(
618            Name::from_str("www.example.com.").unwrap(),
619            RecordType::A,
620        ));
621        message.insert_answers(vec![Record::from_rdata(
622            Name::from_str("www.example.com.").unwrap(),
623            86400,
624            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
625        )]);
626        Ok(DnsResponse::from_message(message).unwrap())
627    }
628
629    #[allow(clippy::unnecessary_wraps)]
630    pub(crate) fn srv_message() -> Result<DnsResponse, ResolveError> {
631        let mut message = Message::new();
632        message.add_query(Query::query(
633            Name::from_str("_443._tcp.www.example.com.").unwrap(),
634            RecordType::SRV,
635        ));
636        message.insert_answers(vec![Record::from_rdata(
637            Name::from_str("_443._tcp.www.example.com.").unwrap(),
638            86400,
639            RData::SRV(SRV::new(
640                1,
641                2,
642                443,
643                Name::from_str("www.example.com.").unwrap(),
644            )),
645        )]);
646        Ok(DnsResponse::from_message(message).unwrap())
647    }
648
649    #[allow(clippy::unnecessary_wraps)]
650    pub(crate) fn ns_message() -> Result<DnsResponse, ResolveError> {
651        let mut message = Message::new();
652        message.add_query(Query::query(
653            Name::from_str("www.example.com.").unwrap(),
654            RecordType::NS,
655        ));
656        message.insert_answers(vec![Record::from_rdata(
657            Name::from_str("www.example.com.").unwrap(),
658            86400,
659            RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
660        )]);
661        Ok(DnsResponse::from_message(message).unwrap())
662    }
663
664    fn no_recursion_on_query_test(query_type: RecordType) {
665        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
666
667        // the cname should succeed, we shouldn't query again after that, which would cause an error...
668        let client = mock(vec![error(), cname_message()]);
669        let client = CachingClient::with_cache(cache, client, false);
670
671        let ips = block_on(CachingClient::inner_lookup(
672            Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
673            DnsRequestOptions::default(),
674            client,
675            vec![],
676        ))
677        .expect("lookup failed");
678
679        assert_eq!(
680            ips.iter().cloned().collect::<Vec<_>>(),
681            vec![RData::CNAME(CNAME(
682                Name::from_str("actual.example.com.").unwrap()
683            ))]
684        );
685    }
686
687    #[test]
688    fn test_no_recursion_on_cname_query() {
689        no_recursion_on_query_test(RecordType::CNAME);
690    }
691
692    #[test]
693    fn test_no_recursion_on_all_query() {
694        no_recursion_on_query_test(RecordType::ANY);
695    }
696
697    #[test]
698    fn test_non_recursive_srv_query() {
699        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
700
701        // the cname should succeed, we shouldn't query again after that, which would cause an error...
702        let client = mock(vec![error(), srv_message()]);
703        let client = CachingClient::with_cache(cache, client, false);
704
705        let ips = block_on(CachingClient::inner_lookup(
706            Query::query(
707                Name::from_str("_443._tcp.www.example.com.").unwrap(),
708                RecordType::SRV,
709            ),
710            DnsRequestOptions::default(),
711            client,
712            vec![],
713        ))
714        .expect("lookup failed");
715
716        assert_eq!(
717            ips.iter().cloned().collect::<Vec<_>>(),
718            vec![RData::SRV(SRV::new(
719                1,
720                2,
721                443,
722                Name::from_str("www.example.com.").unwrap(),
723            ))]
724        );
725    }
726
727    #[test]
728    fn test_single_srv_query_response() {
729        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
730
731        let mut message = srv_message().unwrap().into_message();
732        message.add_answer(Record::from_rdata(
733            Name::from_str("www.example.com.").unwrap(),
734            86400,
735            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
736        ));
737        message.insert_additionals(vec![
738            Record::from_rdata(
739                Name::from_str("actual.example.com.").unwrap(),
740                86400,
741                RData::A(A::new(127, 0, 0, 1)),
742            ),
743            Record::from_rdata(
744                Name::from_str("actual.example.com.").unwrap(),
745                86400,
746                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
747            ),
748        ]);
749
750        let client = mock(vec![
751            error(),
752            Ok(DnsResponse::from_message(message).unwrap()),
753        ]);
754        let client = CachingClient::with_cache(cache, client, false);
755
756        let ips = block_on(CachingClient::inner_lookup(
757            Query::query(
758                Name::from_str("_443._tcp.www.example.com.").unwrap(),
759                RecordType::SRV,
760            ),
761            DnsRequestOptions::default(),
762            client,
763            vec![],
764        ))
765        .expect("lookup failed");
766
767        assert_eq!(
768            ips.iter().cloned().collect::<Vec<_>>(),
769            vec![
770                RData::SRV(SRV::new(
771                    1,
772                    2,
773                    443,
774                    Name::from_str("www.example.com.").unwrap(),
775                )),
776                RData::A(A::new(127, 0, 0, 1)),
777                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
778            ]
779        );
780    }
781
782    // TODO: if we ever enable recursive lookups for SRV, here are the tests...
783    // #[test]
784    // fn test_recursive_srv_query() {
785    //     let cache = Arc::new(Mutex::new(DnsLru::new(1)));
786
787    //     let mut message = Message::new();
788    //     message.add_answer(Record::from_rdata(
789    //         Name::from_str("www.example.com.").unwrap(),
790    //         86400,
791    //         RecordType::CNAME,
792    //         RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
793    //     ));
794    //     message.insert_additionals(vec![
795    //         Record::from_rdata(
796    //             Name::from_str("actual.example.com.").unwrap(),
797    //             86400,
798    //             RecordType::A,
799    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
800    //         ),
801    //     ]);
802
803    //     let mut client = mock(vec![error(), Ok(DnsResponse::from_message(message).unwrap()), srv_message()]);
804
805    //     let ips = QueryState::lookup(
806    //         Query::query(
807    //             Name::from_str("_443._tcp.www.example.com.").unwrap(),
808    //             RecordType::SRV,
809    //         ),
810    //         Default::default(),
811    //         &mut client,
812    //         cache.clone(),
813    //     ).wait()
814    //         .expect("lookup failed");
815
816    //     assert_eq!(
817    //         ips.iter().cloned().collect::<Vec<_>>(),
818    //         vec![
819    //             RData::SRV(SRV::new(
820    //                 1,
821    //                 2,
822    //                 443,
823    //                 Name::from_str("www.example.com.").unwrap(),
824    //             )),
825    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
826    //             //RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
827    //         ]
828    //     );
829    // }
830
831    #[test]
832    fn test_single_ns_query_response() {
833        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
834
835        let mut message = ns_message().unwrap().into_message();
836        message.add_answer(Record::from_rdata(
837            Name::from_str("www.example.com.").unwrap(),
838            86400,
839            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
840        ));
841        message.insert_additionals(vec![
842            Record::from_rdata(
843                Name::from_str("actual.example.com.").unwrap(),
844                86400,
845                RData::A(A::new(127, 0, 0, 1)),
846            ),
847            Record::from_rdata(
848                Name::from_str("actual.example.com.").unwrap(),
849                86400,
850                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
851            ),
852        ]);
853
854        let client = mock(vec![
855            error(),
856            Ok(DnsResponse::from_message(message).unwrap()),
857        ]);
858        let client = CachingClient::with_cache(cache, client, false);
859
860        let ips = block_on(CachingClient::inner_lookup(
861            Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::NS),
862            DnsRequestOptions::default(),
863            client,
864            vec![],
865        ))
866        .expect("lookup failed");
867
868        assert_eq!(
869            ips.iter().cloned().collect::<Vec<_>>(),
870            vec![
871                RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
872                RData::A(A::new(127, 0, 0, 1)),
873                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
874            ]
875        );
876    }
877
878    fn cname_ttl_test(first: u32, second: u32) {
879        let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
880        // expecting no queries to be performed
881        let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
882
883        let mut message = Message::new();
884        message.insert_answers(vec![Record::from_rdata(
885            Name::from_str("ttl.example.com.").unwrap(),
886            first,
887            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
888        )]);
889        message.insert_additionals(vec![Record::from_rdata(
890            Name::from_str("actual.example.com.").unwrap(),
891            second,
892            RData::A(A::new(127, 0, 0, 1)),
893        )]);
894
895        let records = CachingClient::handle_noerror(
896            &mut client,
897            DnsRequestOptions::default(),
898            false,
899            &Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
900            DnsResponse::from_message(message).unwrap(),
901            vec![],
902        );
903
904        if let Ok(records) = records {
905            if let Records::Exists(records) = records {
906                for (record, ttl) in records.iter() {
907                    if record.record_type() == RecordType::CNAME {
908                        continue;
909                    }
910                    assert_eq!(ttl, &1);
911                }
912            } else {
913                panic!("records don't exist");
914            }
915        } else {
916            panic!("error getting records");
917        }
918    }
919
920    #[test]
921    fn test_cname_ttl() {
922        cname_ttl_test(1, 2);
923        cname_ttl_test(2, 1);
924    }
925
926    #[test]
927    fn test_early_return_localhost() {
928        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
929        let client = mock(vec![empty()]);
930        let mut client = CachingClient::with_cache(cache, client, false);
931
932        {
933            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
934            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
935                .expect("should have returned localhost");
936            assert_eq!(lookup.query(), &query);
937            assert_eq!(
938                lookup.iter().cloned().collect::<Vec<_>>(),
939                vec![LOCALHOST_V4.clone()]
940            );
941        }
942
943        {
944            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
945            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
946                .expect("should have returned localhost");
947            assert_eq!(lookup.query(), &query);
948            assert_eq!(
949                lookup.iter().cloned().collect::<Vec<_>>(),
950                vec![LOCALHOST_V6.clone()]
951            );
952        }
953
954        {
955            let query = Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::PTR);
956            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
957                .expect("should have returned localhost");
958            assert_eq!(lookup.query(), &query);
959            assert_eq!(
960                lookup.iter().cloned().collect::<Vec<_>>(),
961                vec![LOCALHOST.clone()]
962            );
963        }
964
965        {
966            let query = Query::query(
967                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
968                RecordType::PTR,
969            );
970            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
971                .expect("should have returned localhost");
972            assert_eq!(lookup.query(), &query);
973            assert_eq!(
974                lookup.iter().cloned().collect::<Vec<_>>(),
975                vec![LOCALHOST.clone()]
976            );
977        }
978
979        assert!(block_on(client.lookup(
980            Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
981            DnsRequestOptions::default()
982        ))
983        .is_err());
984
985        assert!(block_on(client.lookup(
986            Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::MX),
987            DnsRequestOptions::default()
988        ))
989        .is_err());
990
991        assert!(block_on(client.lookup(
992            Query::query(
993                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
994                RecordType::MX
995            ),
996            DnsRequestOptions::default()
997        ))
998        .is_err());
999    }
1000
1001    #[test]
1002    fn test_early_return_invalid() {
1003        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
1004        let client = mock(vec![empty()]);
1005        let mut client = CachingClient::with_cache(cache, client, false);
1006
1007        assert!(block_on(client.lookup(
1008            Query::query(
1009                Name::from_ascii("horrible.invalid.").unwrap(),
1010                RecordType::A,
1011            ),
1012            DnsRequestOptions::default()
1013        ))
1014        .is_err());
1015    }
1016
1017    #[test]
1018    fn test_no_error_on_dot_local_no_mdns() {
1019        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
1020
1021        let mut message = srv_message().unwrap().into_message();
1022        message.add_query(Query::query(
1023            Name::from_ascii("www.example.local.").unwrap(),
1024            RecordType::A,
1025        ));
1026        message.add_answer(Record::from_rdata(
1027            Name::from_str("www.example.local.").unwrap(),
1028            86400,
1029            RData::A(A::new(127, 0, 0, 1)),
1030        ));
1031
1032        let client = mock(vec![
1033            error(),
1034            Ok(DnsResponse::from_message(message).unwrap()),
1035        ]);
1036        let mut client = CachingClient::with_cache(cache, client, false);
1037
1038        assert!(block_on(client.lookup(
1039            Query::query(
1040                Name::from_ascii("www.example.local.").unwrap(),
1041                RecordType::A,
1042            ),
1043            DnsRequestOptions::default()
1044        ))
1045        .is_ok());
1046    }
1047}