use std::borrow::Cow;
use std::error::Error;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Instant;
use futures_util::future::Future;
use proto::error::ProtoError;
use proto::op::{Message, Query, ResponseCode};
use proto::rr::domain::usage::{
ResolverUsage, DEFAULT, INVALID, IN_ADDR_ARPA_127, IP6_ARPA_1, LOCAL,
LOCALHOST as LOCALHOST_usage,
};
use proto::rr::rdata::SOA;
use proto::rr::{DNSClass, Name, RData, Record, RecordType};
use proto::xfer::{DnsHandle, DnsRequestOptions, DnsResponse};
use crate::dns_lru;
use crate::dns_lru::DnsLru;
use crate::error::*;
use crate::lookup::Lookup;
const MAX_QUERY_DEPTH: u8 = 8;
lazy_static! {
static ref LOCALHOST: RData = RData::PTR(Name::from_ascii("localhost.").unwrap());
static ref LOCALHOST_V4: RData = RData::A(Ipv4Addr::new(127, 0, 0, 1));
static ref LOCALHOST_V6: RData = RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
}
struct DepthTracker {
query_depth: Arc<AtomicU8>,
}
impl DepthTracker {
fn track(query_depth: Arc<AtomicU8>) -> Self {
query_depth.fetch_add(1, Ordering::Release);
Self { query_depth }
}
}
impl Drop for DepthTracker {
fn drop(&mut self) {
self.query_depth.fetch_sub(1, Ordering::Release);
}
}
#[derive(Clone, Debug)]
#[doc(hidden)]
pub struct CachingClient<C, E>
where
C: DnsHandle<Error = E>,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
lru: DnsLru,
client: C,
query_depth: Arc<AtomicU8>,
preserve_intermediates: bool,
}
impl<C, E> CachingClient<C, E>
where
C: DnsHandle<Error = E> + Send + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
#[doc(hidden)]
pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
Self::with_cache(
DnsLru::new(max_size, Default::default()),
client,
preserve_intermediates,
)
}
pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
let query_depth = Arc::new(AtomicU8::new(0));
CachingClient {
lru,
client,
query_depth,
preserve_intermediates,
}
}
pub fn lookup(
&mut self,
query: Query,
options: DnsRequestOptions,
) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
Box::pin(Self::inner_lookup(query, options, self.clone(), vec![]))
}
async fn inner_lookup(
query: Query,
options: DnsRequestOptions,
mut client: Self,
preserved_records: Vec<(Record, u32)>,
) -> Result<Lookup, ResolveError> {
if query.query_class() == DNSClass::IN {
let usage = match query.name() {
n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
n if INVALID.zone_of(n) => &*INVALID,
n if LOCAL.zone_of(n) => &*LOCAL,
_ => &*DEFAULT,
};
match usage.resolver() {
ResolverUsage::Loopback => match query.query_type() {
RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
_ => {
return Err(ResolveError::nx_error(
query,
None,
None,
ResponseCode::NoError,
false,
))
}
},
#[cfg(feature = "mdns")]
ResolverUsage::LinkLocal => (),
#[cfg(not(feature = "mdns"))]
ResolverUsage::LinkLocal => (),
ResolverUsage::NxDomain => {
return Err(ResolveError::nx_error(
query,
None,
None,
ResponseCode::NXDomain,
false,
))
}
ResolverUsage::Normal => (),
}
}
let _tracker = DepthTracker::track(client.query_depth.clone());
let is_dnssec = client.client.is_verifying_dnssec();
if let Some(cached_lookup) = client.from_cache(&query) {
return cached_lookup;
};
let response_message = client
.client
.lookup(query.clone(), options)
.await
.map_err(E::into);
let response_message = if let Ok(response) = response_message {
ResolveError::from_response(response, false)
} else {
response_message
};
let records: Result<Records, ResolveError> = match response_message {
Err(ResolveError {
kind:
ResolveErrorKind::NoRecordsFound {
query,
soa,
negative_ttl,
response_code,
trusted,
},
..
}) => {
Err(Self::handle_nxdomain(
is_dnssec,
false,
query,
soa,
negative_ttl,
response_code,
trusted,
))
}
Err(e) => return Err(e),
Ok(response_message) => {
let records = Self::handle_noerror(
&mut client,
options,
is_dnssec,
&query,
response_message,
preserved_records,
)?;
Ok(records)
}
};
match records {
Ok(Records::CnameChain {
next: future,
min_ttl: ttl,
}) => client.cname(future.await?, query, ttl),
Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
Err(e) => client.cache(query, Err(e)),
}
}
fn from_cache(&self, query: &Query) -> Option<Result<Lookup, ResolveError>> {
self.lru.get(query, Instant::now())
}
fn handle_nxdomain(
is_dnssec: bool,
valid_nsec: bool,
query: Query,
soa: Option<SOA>,
negative_ttl: Option<u32>,
response_code: ResponseCode,
trusted: bool,
) -> ResolveError {
if valid_nsec || !is_dnssec {
ResolveErrorKind::NoRecordsFound {
query,
soa,
negative_ttl,
response_code,
trusted: true,
}
.into()
} else {
ResolveErrorKind::NoRecordsFound {
query,
soa,
negative_ttl: None,
response_code,
trusted,
}
.into()
}
}
fn handle_noerror(
client: &mut Self,
options: DnsRequestOptions,
is_dnssec: bool,
query: &Query,
mut response: DnsResponse,
mut preserved_records: Vec<(Record, u32)>,
) -> Result<Records, ResolveError> {
const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
let soa = response.soa();
let negative_ttl = response.negative_ttl();
let response_code = response.response_code();
let (search_name, cname_ttl, was_cname, preserved_records) = {
let (search_name, cname_ttl, was_cname) =
if query.query_type().is_any() || query.query_type().is_cname() {
(Cow::Borrowed(query.name()), INITIAL_TTL, false)
} else {
response.messages().flat_map(Message::answers).fold(
(Cow::Borrowed(query.name()), INITIAL_TTL, false),
|(search_name, cname_ttl, was_cname), r| {
match *r.rdata() {
RData::CNAME(ref cname) => {
let ttl = cname_ttl.min(r.ttl());
debug_assert_eq!(r.rr_type(), RecordType::CNAME);
if search_name.as_ref() == r.name() {
return (Cow::Owned(cname.clone()), ttl, true);
}
}
RData::SRV(ref srv) => {
let ttl = cname_ttl.min(r.ttl());
debug_assert_eq!(r.rr_type(), RecordType::SRV);
return (Cow::Owned(srv.target().clone()), ttl, true);
}
_ => (),
}
(search_name, cname_ttl, was_cname)
},
)
};
let answers: Vec<Record> = response
.messages_mut()
.flat_map(Message::take_answers)
.collect();
let additionals: Vec<Record> = response
.messages_mut()
.flat_map(Message::take_additionals)
.collect();
let name_servers: Vec<Record> = response
.messages_mut()
.flat_map(Message::take_name_servers)
.collect();
let mut found_name = false;
let records = answers
.into_iter()
.chain(additionals.into_iter())
.chain(name_servers.into_iter())
.filter_map(|r| {
let ttl = cname_ttl.min(r.ttl());
if query.query_class() == r.dns_class() {
#[allow(clippy::suspicious_operation_groupings)]
if (query.query_type().is_any() || query.query_type() == r.rr_type())
&& (search_name.as_ref() == r.name() || query.name() == r.name())
{
found_name = true;
return Some((r, ttl));
}
if client.preserve_intermediates
&& r.rr_type() == RecordType::CNAME
&& (query.query_type() == RecordType::A
|| query.query_type() == RecordType::AAAA)
{
return Some((r, ttl));
}
if query.query_type().is_srv()
&& r.rr_type().is_ip_addr()
&& search_name.as_ref() == r.name()
{
found_name = true;
Some((r, ttl))
} else {
None
}
} else {
None
}
})
.collect::<Vec<_>>();
preserved_records.extend(records);
if !preserved_records.is_empty() && found_name {
return Ok(Records::Exists(preserved_records));
}
(
search_name.into_owned(),
cname_ttl,
was_cname,
preserved_records,
)
};
if was_cname && client.query_depth.load(Ordering::Acquire) < MAX_QUERY_DEPTH {
let next_query = Query::query(search_name, query.query_type());
Ok(Records::CnameChain {
next: Box::pin(CachingClient::inner_lookup(
next_query,
options,
client.clone(),
preserved_records,
)),
min_ttl: cname_ttl,
})
} else {
Err(Self::handle_nxdomain(
is_dnssec,
true,
query.clone(),
soa,
negative_ttl,
response_code,
false,
))
}
}
#[allow(clippy::unnecessary_wraps)]
fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ResolveError> {
Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
}
fn cache(
&self,
query: Query,
records: Result<Vec<(Record, u32)>, ResolveError>,
) -> Result<Lookup, ResolveError> {
match records {
Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
Err(err) => Err(self.lru.negative(query, err, Instant::now())),
}
}
}
enum Records {
Exists(Vec<(Record, u32)>),
CnameChain {
next: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
min_ttl: u32,
},
}
#[cfg(test)]
mod tests {
use std::net::*;
use std::str::FromStr;
use std::time::*;
use futures_executor::block_on;
use proto::op::{Message, Query};
use proto::rr::rdata::SRV;
use proto::rr::{Name, Record};
use super::*;
use crate::lookup_ip::tests::*;
#[test]
fn test_empty_cache() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let client = mock(vec![empty()]);
let client = CachingClient::with_cache(cache, client, false);
if let ResolveErrorKind::NoRecordsFound {
query,
negative_ttl,
..
} = block_on(CachingClient::inner_lookup(
Query::new(),
Default::default(),
client,
vec![],
))
.unwrap_err()
.kind()
{
assert_eq!(*query, Query::new());
assert_eq!(*negative_ttl, None);
} else {
panic!("wrong error received")
}
}
#[test]
fn test_from_cache() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let query = Query::new();
cache.insert(
query.clone(),
vec![(
Record::from_rdata(
query.name().clone(),
u32::max_value(),
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
),
u32::max_value(),
)],
Instant::now(),
);
let client = mock(vec![empty()]);
let client = CachingClient::with_cache(cache, client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::new(),
Default::default(),
client,
vec![],
))
.unwrap();
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
);
}
#[test]
fn test_no_cache_insert() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let client = mock(vec![v4_message()]);
let client = CachingClient::with_cache(cache.clone(), client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::new(),
Default::default(),
client,
vec![],
))
.unwrap();
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
);
let client = mock(vec![empty()]);
let client = CachingClient::with_cache(cache, client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::new(),
Default::default(),
client,
vec![],
))
.unwrap();
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
);
}
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn cname_message() -> Result<DnsResponse, ResolveError> {
let mut message = Message::new();
message.add_query(Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::A,
));
message.insert_answers(vec![Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
86400,
RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
)]);
Ok(message.into())
}
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn srv_message() -> Result<DnsResponse, ResolveError> {
let mut message = Message::new();
message.add_query(Query::query(
Name::from_str("_443._tcp.www.example.com.").unwrap(),
RecordType::SRV,
));
message.insert_answers(vec![Record::from_rdata(
Name::from_str("_443._tcp.www.example.com.").unwrap(),
86400,
RData::SRV(SRV::new(
1,
2,
443,
Name::from_str("www.example.com.").unwrap(),
)),
)]);
Ok(message.into())
}
fn no_recursion_on_query_test(query_type: RecordType) {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let client = mock(vec![error(), cname_message()]);
let client = CachingClient::with_cache(cache, client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
Default::default(),
client,
vec![],
))
.expect("lookup failed");
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![RData::CNAME(Name::from_str("actual.example.com.").unwrap())]
);
}
#[test]
fn test_no_recursion_on_cname_query() {
no_recursion_on_query_test(RecordType::CNAME);
}
#[test]
fn test_no_recursion_on_all_query() {
no_recursion_on_query_test(RecordType::ANY);
}
#[test]
fn test_non_recursive_srv_query() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let client = mock(vec![error(), srv_message()]);
let client = CachingClient::with_cache(cache, client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::query(
Name::from_str("_443._tcp.www.example.com.").unwrap(),
RecordType::SRV,
),
Default::default(),
client,
vec![],
))
.expect("lookup failed");
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![RData::SRV(SRV::new(
1,
2,
443,
Name::from_str("www.example.com.").unwrap(),
))]
);
}
#[test]
fn test_single_srv_query_response() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let mut message = srv_message().unwrap();
message.add_answer(Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
86400,
RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
));
message.insert_additionals(vec![
Record::from_rdata(
Name::from_str("actual.example.com.").unwrap(),
86400,
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::from_str("actual.example.com.").unwrap(),
86400,
RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
),
]);
let client = mock(vec![error(), Ok(message)]);
let client = CachingClient::with_cache(cache, client, false);
let ips = block_on(CachingClient::inner_lookup(
Query::query(
Name::from_str("_443._tcp.www.example.com.").unwrap(),
RecordType::SRV,
),
Default::default(),
client,
vec![],
))
.expect("lookup failed");
assert_eq!(
ips.iter().cloned().collect::<Vec<_>>(),
vec![
RData::SRV(SRV::new(
1,
2,
443,
Name::from_str("www.example.com.").unwrap(),
)),
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
]
);
}
fn cname_ttl_test(first: u32, second: u32) {
let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
let mut message = Message::new();
message.insert_answers(vec![Record::from_rdata(
Name::from_str("ttl.example.com.").unwrap(),
first,
RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
)]);
message.insert_additionals(vec![Record::from_rdata(
Name::from_str("actual.example.com.").unwrap(),
second,
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
)]);
let records = CachingClient::handle_noerror(
&mut client,
Default::default(),
false,
&Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
message.into(),
vec![],
);
if let Ok(records) = records {
if let Records::Exists(records) = records {
for (record, ttl) in records.iter() {
if record.record_type() == RecordType::CNAME {
continue;
}
assert_eq!(ttl, &1);
}
} else {
panic!("records don't exist");
}
} else {
panic!("error getting records");
}
}
#[test]
fn test_cname_ttl() {
cname_ttl_test(1, 2);
cname_ttl_test(2, 1);
}
#[test]
fn test_early_return_localhost() {
let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
let client = mock(vec![empty()]);
let mut client = CachingClient::with_cache(cache, client, false);
{
let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
let lookup = block_on(client.lookup(query.clone(), Default::default()))
.expect("should have returned localhost");
assert_eq!(lookup.query(), &query);
assert_eq!(
lookup.iter().cloned().collect::<Vec<_>>(),
vec![LOCALHOST_V4.clone()]
);
}
{
let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
let lookup = block_on(client.lookup(query.clone(), Default::default()))
.expect("should have returned localhost");
assert_eq!(lookup.query(), &query);
assert_eq!(
lookup.iter().cloned().collect::<Vec<_>>(),
vec![LOCALHOST_V6.clone()]
);
}
{
let query = Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::PTR);
let lookup = block_on(client.lookup(query.clone(), Default::default()))
.expect("should have returned localhost");
assert_eq!(lookup.query(), &query);
assert_eq!(
lookup.iter().cloned().collect::<Vec<_>>(),
vec![LOCALHOST.clone()]
);
}
{
let query = Query::query(
Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
RecordType::PTR,
);
let lookup = block_on(client.lookup(query.clone(), Default::default()))
.expect("should have returned localhost");
assert_eq!(lookup.query(), &query);
assert_eq!(
lookup.iter().cloned().collect::<Vec<_>>(),
vec![LOCALHOST.clone()]
);
}
assert!(block_on(client.lookup(
Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
Default::default()
))
.is_err());
assert!(block_on(client.lookup(
Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::MX),
Default::default()
))
.is_err());
assert!(block_on(client.lookup(
Query::query(
Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
RecordType::MX
),
Default::default()
))
.is_err());
}
#[test]
fn test_early_return_invalid() {
let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
let client = mock(vec![empty()]);
let mut client = CachingClient::with_cache(cache, client, false);
assert!(block_on(client.lookup(
Query::query(
Name::from_ascii("horrible.invalid.").unwrap(),
RecordType::A,
),
Default::default()
))
.is_err());
}
#[test]
fn test_no_error_on_dot_local_no_mdns() {
let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
let mut message = srv_message().unwrap();
message.add_query(Query::query(
Name::from_ascii("www.example.local.").unwrap(),
RecordType::A,
));
message.add_answer(Record::from_rdata(
Name::from_str("www.example.local.").unwrap(),
86400,
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
));
let client = mock(vec![error(), Ok(message)]);
let mut client = CachingClient::with_cache(cache, client, false);
assert!(block_on(client.lookup(
Query::query(
Name::from_ascii("www.example.local.").unwrap(),
RecordType::A,
),
Default::default()
))
.is_ok());
}
}