use std::{
cmp::min,
pin::Pin,
slice::Iter,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use futures_util::{
future::{self, Future},
stream::Stream,
FutureExt,
};
use crate::{
caching_client::CachingClient,
dns_lru::MAX_TTL,
error::*,
lookup_ip::LookupIpIter,
name_server::{ConnectionProvider, NameServerPool},
proto::{
error::ProtoError,
op::Query,
rr::{
rdata::{self, A, AAAA, NS, PTR},
Name, RData, Record, RecordType,
},
xfer::{DnsRequest, DnsRequestOptions, DnsResponse},
DnsHandle, RetryDnsHandle,
},
};
#[cfg(feature = "dnssec")]
use proto::{rr::dnssec::Proven, DnssecDnsHandle};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Lookup {
query: Query,
records: Arc<[Record]>,
valid_until: Instant,
}
impl Lookup {
pub fn from_rdata(query: Query, rdata: RData) -> Self {
let record = Record::from_rdata(query.name().clone(), MAX_TTL, rdata);
Self::new_with_max_ttl(query, Arc::from([record]))
}
pub fn new_with_max_ttl(query: Query, records: Arc<[Record]>) -> Self {
let valid_until = Instant::now() + Duration::from_secs(u64::from(MAX_TTL));
Self {
query,
records,
valid_until,
}
}
pub fn new_with_deadline(query: Query, records: Arc<[Record]>, valid_until: Instant) -> Self {
Self {
query,
records,
valid_until,
}
}
pub fn query(&self) -> &Query {
&self.query
}
pub fn iter(&self) -> LookupIter<'_> {
LookupIter(self.records.iter())
}
#[cfg(feature = "dnssec")]
pub fn dnssec_iter(&self) -> DnssecIter<'_> {
DnssecIter(self.dnssec_record_iter())
}
pub fn record_iter(&self) -> LookupRecordIter<'_> {
LookupRecordIter(self.records.iter())
}
#[cfg(feature = "dnssec")]
pub fn dnssec_record_iter(&self) -> DnssecLookupRecordIter<'_> {
DnssecLookupRecordIter(self.records.iter())
}
pub fn valid_until(&self) -> Instant {
self.valid_until
}
#[doc(hidden)]
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
pub(crate) fn len(&self) -> usize {
self.records.len()
}
pub fn records(&self) -> &[Record] {
self.records.as_ref()
}
pub(crate) fn append(&self, other: Self) -> Self {
let mut records = Vec::with_capacity(self.len() + other.len());
records.extend_from_slice(&self.records);
records.extend_from_slice(&other.records);
let valid_until = min(self.valid_until(), other.valid_until());
Self::new_with_deadline(self.query.clone(), Arc::from(records), valid_until)
}
}
pub struct LookupIter<'a>(Iter<'a, Record>);
impl<'a> Iterator for LookupIter<'a> {
type Item = &'a RData;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(Record::data)
}
}
#[cfg(feature = "dnssec")]
pub struct DnssecIter<'a>(DnssecLookupRecordIter<'a>);
#[cfg(feature = "dnssec")]
impl<'a> Iterator for DnssecIter<'a> {
type Item = Proven<&'a RData>;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|r| r.map(Record::data))
}
}
pub struct LookupRecordIter<'a>(Iter<'a, Record>);
impl<'a> Iterator for LookupRecordIter<'a> {
type Item = &'a Record;
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
#[cfg(feature = "dnssec")]
pub struct DnssecLookupRecordIter<'a>(Iter<'a, Record>);
#[cfg(feature = "dnssec")]
impl<'a> Iterator for DnssecLookupRecordIter<'a> {
type Item = Proven<&'a Record>;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(Proven::from)
}
}
impl IntoIterator for Lookup {
type Item = RData;
type IntoIter = LookupIntoIter;
fn into_iter(self) -> Self::IntoIter {
LookupIntoIter {
records: Arc::clone(&self.records),
index: 0,
}
}
}
pub struct LookupIntoIter {
records: Arc<[Record]>,
index: usize,
}
impl Iterator for LookupIntoIter {
type Item = RData;
fn next(&mut self) -> Option<Self::Item> {
let rdata = self.records.get(self.index).map(Record::data);
self.index += 1;
rdata.cloned()
}
}
#[derive(Clone)]
#[doc(hidden)]
pub enum LookupEither<P: ConnectionProvider + Send> {
Retry(RetryDnsHandle<NameServerPool<P>>),
#[cfg(feature = "dnssec")]
#[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
Secure(DnssecDnsHandle<RetryDnsHandle<NameServerPool<P>>>),
}
impl<P: ConnectionProvider> DnsHandle for LookupEither<P> {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
fn is_verifying_dnssec(&self) -> bool {
match *self {
Self::Retry(ref c) => c.is_verifying_dnssec(),
#[cfg(feature = "dnssec")]
Self::Secure(ref c) => c.is_verifying_dnssec(),
}
}
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
match *self {
Self::Retry(ref c) => c.send(request),
#[cfg(feature = "dnssec")]
Self::Secure(ref c) => c.send(request),
}
}
}
#[doc(hidden)]
pub struct LookupFuture<C>
where
C: DnsHandle + 'static,
{
client_cache: CachingClient<C>,
names: Vec<Name>,
record_type: RecordType,
options: DnsRequestOptions,
query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
}
impl<C> LookupFuture<C>
where
C: DnsHandle + 'static,
{
#[doc(hidden)]
pub fn lookup(
mut names: Vec<Name>,
record_type: RecordType,
options: DnsRequestOptions,
mut client_cache: CachingClient<C>,
) -> Self {
let name = names.pop().ok_or_else(|| {
ResolveError::from(ResolveErrorKind::Message("can not lookup for no names"))
});
let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name {
Ok(name) => client_cache
.lookup(Query::query(name, record_type), options)
.boxed(),
Err(err) => future::err(err).boxed(),
};
Self {
client_cache,
names,
record_type,
options,
query,
}
}
}
impl<C> Future for LookupFuture<C>
where
C: DnsHandle + 'static,
{
type Output = Result<Lookup, ResolveError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let query = self.query.as_mut().poll_unpin(cx);
let should_retry = match query {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(ref lookup)) => lookup.records.len() == 0,
Poll::Ready(Err(_)) => true,
};
if should_retry {
if let Some(name) = self.names.pop() {
let record_type = self.record_type;
let options = self.options;
self.query = self
.client_cache
.lookup(Query::query(name, record_type), options);
continue;
}
}
return query;
}
}
}
#[derive(Debug, Clone)]
pub struct SrvLookup(Lookup);
impl SrvLookup {
pub fn iter(&self) -> SrvLookupIter<'_> {
SrvLookupIter(self.0.iter())
}
pub fn query(&self) -> &Query {
self.0.query()
}
pub fn ip_iter(&self) -> LookupIpIter<'_> {
LookupIpIter(self.0.iter())
}
pub fn as_lookup(&self) -> &Lookup {
&self.0
}
}
impl From<Lookup> for SrvLookup {
fn from(lookup: Lookup) -> Self {
Self(lookup)
}
}
pub struct SrvLookupIter<'i>(LookupIter<'i>);
impl<'i> Iterator for SrvLookupIter<'i> {
type Item = &'i rdata::SRV;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.find_map(|rdata| match *rdata {
RData::SRV(ref data) => Some(data),
_ => None,
})
}
}
impl IntoIterator for SrvLookup {
type Item = rdata::SRV;
type IntoIter = SrvLookupIntoIter;
fn into_iter(self) -> Self::IntoIter {
SrvLookupIntoIter(self.0.into_iter())
}
}
pub struct SrvLookupIntoIter(LookupIntoIter);
impl Iterator for SrvLookupIntoIter {
type Item = rdata::SRV;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.find_map(|rdata| match rdata {
RData::SRV(data) => Some(data),
_ => None,
})
}
}
macro_rules! lookup_type {
($l:ident, $i:ident, $ii:ident, $r:path, $t:path) => {
#[derive(Debug, Clone)]
pub struct $l(Lookup);
impl $l {
#[doc = stringify!(Returns an iterator over the records that match $r)]
pub fn iter(&self) -> $i<'_> {
$i(self.0.iter())
}
pub fn query(&self) -> &Query {
self.0.query()
}
pub fn valid_until(&self) -> Instant {
self.0.valid_until()
}
pub fn as_lookup(&self) -> &Lookup {
&self.0
}
}
impl From<Lookup> for $l {
fn from(lookup: Lookup) -> Self {
$l(lookup)
}
}
impl From<$l> for Lookup {
fn from(revlookup: $l) -> Self {
revlookup.0
}
}
pub struct $i<'i>(LookupIter<'i>);
impl<'i> Iterator for $i<'i> {
type Item = &'i $t;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.find_map(|rdata| match *rdata {
$r(ref data) => Some(data),
_ => None,
})
}
}
impl IntoIterator for $l {
type Item = $t;
type IntoIter = $ii;
fn into_iter(self) -> Self::IntoIter {
$ii(self.0.into_iter())
}
}
pub struct $ii(LookupIntoIter);
impl Iterator for $ii {
type Item = $t;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.find_map(|rdata| match rdata {
$r(data) => Some(data),
_ => None,
})
}
}
};
}
lookup_type!(
ReverseLookup,
ReverseLookupIter,
ReverseLookupIntoIter,
RData::PTR,
PTR
);
lookup_type!(Ipv4Lookup, Ipv4LookupIter, Ipv4LookupIntoIter, RData::A, A);
lookup_type!(
Ipv6Lookup,
Ipv6LookupIter,
Ipv6LookupIntoIter,
RData::AAAA,
AAAA
);
lookup_type!(
MxLookup,
MxLookupIter,
MxLookupIntoIter,
RData::MX,
rdata::MX
);
lookup_type!(
TlsaLookup,
TlsaLookupIter,
TlsaLookupIntoIter,
RData::TLSA,
rdata::TLSA
);
lookup_type!(
TxtLookup,
TxtLookupIter,
TxtLookupIntoIter,
RData::TXT,
rdata::TXT
);
lookup_type!(
SoaLookup,
SoaLookupIter,
SoaLookupIntoIter,
RData::SOA,
rdata::SOA
);
lookup_type!(NsLookup, NsLookupIter, NsLookupIntoIter, RData::NS, NS);
#[cfg(test)]
pub mod tests {
use std::net::{IpAddr, Ipv4Addr};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use futures_executor::block_on;
use futures_util::future;
use futures_util::stream::once;
use hickory_proto::error::ProtoErrorKind;
use proto::error::ProtoError;
use proto::op::{Message, Query};
use proto::rr::{Name, RData, Record, RecordType};
use proto::xfer::{DnsRequest, DnsRequestOptions};
use super::*;
#[derive(Clone)]
pub struct MockDnsHandle {
messages: Arc<Mutex<Vec<Result<DnsResponse, ProtoError>>>>,
}
impl DnsHandle for MockDnsHandle {
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
fn send<R: Into<DnsRequest>>(&self, _: R) -> Self::Response {
Box::pin(once(
future::ready(self.messages.lock().unwrap().pop().unwrap_or_else(empty)).boxed(),
))
}
}
pub fn v4_message() -> Result<DnsResponse, ProtoError> {
let mut message = Message::new();
message.add_query(Query::query(Name::root(), RecordType::A));
message.insert_answers(vec![Record::from_rdata(
Name::root(),
86400,
RData::A(A::new(127, 0, 0, 1)),
)]);
let resp = DnsResponse::from_message(message).unwrap();
assert!(resp.contains_answer());
Ok(resp)
}
pub fn empty() -> Result<DnsResponse, ProtoError> {
Ok(DnsResponse::from_message(Message::new()).unwrap())
}
pub fn error() -> Result<DnsResponse, ProtoError> {
Err(ProtoError::from(std::io::Error::from(
std::io::ErrorKind::Other,
)))
}
pub fn mock(messages: Vec<Result<DnsResponse, ProtoError>>) -> MockDnsHandle {
MockDnsHandle {
messages: Arc::new(Mutex::new(messages)),
}
}
#[test]
fn test_lookup() {
assert_eq!(
block_on(LookupFuture::lookup(
vec![Name::root()],
RecordType::A,
DnsRequestOptions::default(),
CachingClient::new(0, mock(vec![v4_message()]), false),
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
}
#[test]
fn test_lookup_slice() {
assert_eq!(
Record::data(
&block_on(LookupFuture::lookup(
vec![Name::root()],
RecordType::A,
DnsRequestOptions::default(),
CachingClient::new(0, mock(vec![v4_message()]), false),
))
.unwrap()
.records()[0]
)
.ip_addr()
.unwrap(),
Ipv4Addr::new(127, 0, 0, 1)
);
}
#[test]
fn test_lookup_into_iter() {
assert_eq!(
block_on(LookupFuture::lookup(
vec![Name::root()],
RecordType::A,
DnsRequestOptions::default(),
CachingClient::new(0, mock(vec![v4_message()]), false),
))
.unwrap()
.into_iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
}
#[test]
fn test_error() {
assert!(block_on(LookupFuture::lookup(
vec![Name::root()],
RecordType::A,
DnsRequestOptions::default(),
CachingClient::new(0, mock(vec![error()]), false),
))
.is_err());
}
#[test]
fn test_empty_no_response() {
if let ProtoErrorKind::NoRecordsFound {
query,
negative_ttl,
..
} = block_on(LookupFuture::lookup(
vec![Name::root()],
RecordType::A,
DnsRequestOptions::default(),
CachingClient::new(0, mock(vec![empty()]), false),
))
.expect_err("this should have been a NoRecordsFound")
.proto()
.expect("it should have been a ProtoError")
.kind()
{
assert_eq!(**query, Query::query(Name::root(), RecordType::A));
assert_eq!(*negative_ttl, None);
} else {
panic!("wrong error received");
}
}
#[test]
fn test_lookup_into_iter_arc() {
let mut lookup = LookupIntoIter {
records: Arc::from([
Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
80,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
80,
RData::A(A::new(127, 0, 0, 2)),
),
]),
index: 0,
};
assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 1)));
assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 2)));
assert_eq!(lookup.next(), None);
}
#[test]
#[cfg(feature = "dnssec")]
fn test_dnssec_lookup() {
use hickory_proto::rr::dnssec::Proof;
let mut a1 = Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
80,
RData::A(A::new(127, 0, 0, 1)),
);
a1.set_proof(Proof::Secure);
let mut a2 = Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
80,
RData::A(A::new(127, 0, 0, 2)),
);
a2.set_proof(Proof::Insecure);
let lookup = Lookup {
query: Query::default(),
records: Arc::from([a1.clone(), a2.clone()]),
valid_until: Instant::now(),
};
let mut lookup = lookup.dnssec_iter();
assert_eq!(
*lookup.next().unwrap().require(Proof::Secure).unwrap(),
*a1.data()
);
assert_eq!(
*lookup.next().unwrap().require(Proof::Insecure).unwrap(),
*a2.data()
);
assert_eq!(lookup.next(), None);
}
}