use std::error::Error;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use futures_util::{future, future::Either, future::Future, FutureExt};
use proto::error::ProtoError;
use proto::op::Query;
use proto::rr::{Name, RData, Record, RecordType};
use proto::xfer::{DnsHandle, DnsRequestOptions};
use tracing::debug;
use crate::caching_client::CachingClient;
use crate::config::LookupIpStrategy;
use crate::dns_lru::MAX_TTL;
use crate::error::*;
use crate::hosts::Hosts;
use crate::lookup::{Lookup, LookupIntoIter, LookupIter};
#[derive(Debug, Clone)]
pub struct LookupIp(Lookup);
impl LookupIp {
pub fn iter(&self) -> LookupIpIter<'_> {
LookupIpIter(self.0.iter())
}
pub fn query(&self) -> &Query {
self.0.query()
}
pub fn valid_until(&self) -> Instant {
self.0.valid_until()
}
pub fn as_lookup(&self) -> &Lookup {
&self.0
}
}
impl From<Lookup> for LookupIp {
fn from(lookup: Lookup) -> Self {
Self(lookup)
}
}
impl From<LookupIp> for Lookup {
fn from(lookup: LookupIp) -> Self {
lookup.0
}
}
pub struct LookupIpIter<'i>(pub(crate) LookupIter<'i>);
impl<'i> Iterator for LookupIpIter<'i> {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.filter_map(|rdata| match *rdata {
RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(ip))),
RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(ip))),
_ => None,
})
.next()
}
}
impl IntoIterator for LookupIp {
type Item = IpAddr;
type IntoIter = LookupIpIntoIter;
fn into_iter(self) -> Self::IntoIter {
LookupIpIntoIter(self.0.into_iter())
}
}
pub struct LookupIpIntoIter(LookupIntoIter);
impl Iterator for LookupIpIntoIter {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
let iter: &mut _ = &mut self.0;
iter.filter_map(|rdata| match rdata {
RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(ip))),
RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(ip))),
_ => None,
})
.next()
}
}
pub struct LookupIpFuture<C, E>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
client_cache: CachingClient<C, E>,
names: Vec<Name>,
strategy: LookupIpStrategy,
options: DnsRequestOptions,
query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
hosts: Option<Arc<Hosts>>,
finally_ip_addr: Option<RData>,
}
impl<C, E> Future for LookupIpFuture<C, E>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
type Output = Result<LookupIp, ResolveError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let query = self.query.as_mut().poll(cx);
let should_retry = match query {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(ref lookup)) => lookup.is_empty(),
Poll::Ready(Err(_)) => true,
};
if should_retry {
if let Some(name) = self.names.pop() {
self.query = strategic_lookup(
name,
self.strategy,
self.client_cache.clone(),
self.options,
self.hosts.clone(),
)
.boxed();
continue;
} else if let Some(ip_addr) = self.finally_ip_addr.take() {
let record = Record::from_rdata(Name::new(), MAX_TTL, ip_addr);
let lookup = Lookup::new_with_max_ttl(Query::new(), Arc::from([record]));
return Poll::Ready(Ok(lookup.into()));
}
};
return query.map(|f| f.map(LookupIp::from));
}
}
}
impl<C, E> LookupIpFuture<C, E>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
pub fn lookup(
names: Vec<Name>,
strategy: LookupIpStrategy,
client_cache: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
finally_ip_addr: Option<RData>,
) -> Self {
let empty =
ResolveError::from(ResolveErrorKind::Message("can not lookup IPs for no names"));
Self {
names,
strategy,
client_cache,
query: future::err(empty).boxed(),
options,
hosts,
finally_ip_addr,
}
}
}
async fn strategic_lookup<C, E>(
name: Name,
strategy: LookupIpStrategy,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
match strategy {
LookupIpStrategy::Ipv4Only => ipv4_only(name, client, options, hosts).await,
LookupIpStrategy::Ipv6Only => ipv6_only(name, client, options, hosts).await,
LookupIpStrategy::Ipv4AndIpv6 => ipv4_and_ipv6(name, client, options, hosts).await,
LookupIpStrategy::Ipv6thenIpv4 => ipv6_then_ipv4(name, client, options, hosts).await,
LookupIpStrategy::Ipv4thenIpv6 => ipv4_then_ipv6(name, client, options, hosts).await,
}
}
async fn hosts_lookup<C, E>(
query: Query,
mut client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
if let Some(hosts) = hosts {
if let Some(lookup) = hosts.lookup_static_host(&query) {
return Ok(lookup);
};
}
client.lookup(query, options).await
}
async fn ipv4_only<C, E>(
name: Name,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
hosts_lookup(Query::query(name, RecordType::A), client, options, hosts).await
}
async fn ipv6_only<C, E>(
name: Name,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
hosts_lookup(Query::query(name, RecordType::AAAA), client, options, hosts).await
}
async fn ipv4_and_ipv6<C, E>(
name: Name,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
let sel_res = future::select(
hosts_lookup(
Query::query(name.clone(), RecordType::A),
client.clone(),
options,
hosts.clone(),
)
.boxed(),
hosts_lookup(Query::query(name, RecordType::AAAA), client, options, hosts).boxed(),
)
.await;
let (ips, remaining_query) = match sel_res {
Either::Left(ips_and_remaining) => ips_and_remaining,
Either::Right(ips_and_remaining) => ips_and_remaining,
};
let next_ips = remaining_query.await;
match (ips, next_ips) {
(Ok(ips), Ok(next_ips)) => {
let ips = ips.append(next_ips);
Ok(ips)
}
(Ok(ips), Err(e)) | (Err(e), Ok(ips)) => {
debug!(
"one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}",
e
);
Ok(ips)
}
(Err(e1), Err(e2)) => {
debug!(
"both of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy e1: {}, e2: {}",
e1, e2
);
Err(e1)
}
}
}
async fn ipv6_then_ipv4<C, E>(
name: Name,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
rt_then_swap(
name,
client,
RecordType::AAAA,
RecordType::A,
options,
hosts,
)
.await
}
async fn ipv4_then_ipv6<C, E>(
name: Name,
client: CachingClient<C, E>,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
rt_then_swap(
name,
client,
RecordType::A,
RecordType::AAAA,
options,
hosts,
)
.await
}
async fn rt_then_swap<C, E>(
name: Name,
client: CachingClient<C, E>,
first_type: RecordType,
second_type: RecordType,
options: DnsRequestOptions,
hosts: Option<Arc<Hosts>>,
) -> Result<Lookup, ResolveError>
where
C: DnsHandle<Error = E> + 'static,
E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
let or_client = client.clone();
let res = hosts_lookup(
Query::query(name.clone(), first_type),
client,
options,
hosts.clone(),
)
.await;
match res {
Ok(ips) => {
if ips.is_empty() {
hosts_lookup(
Query::query(name.clone(), second_type),
or_client,
options,
hosts,
)
.await
} else {
Ok(ips)
}
}
Err(_) => {
hosts_lookup(
Query::query(name.clone(), second_type),
or_client,
options,
hosts,
)
.await
}
}
}
#[cfg(test)]
pub mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::{Arc, Mutex};
use futures_executor::block_on;
use futures_util::future;
use proto::op::Message;
use proto::rr::{Name, RData, Record};
use proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
use futures_util::stream::{once, Stream};
use super::*;
use crate::error::ResolveError;
#[derive(Clone)]
pub struct MockDnsHandle {
messages: Arc<Mutex<Vec<Result<DnsResponse, ResolveError>>>>,
}
impl DnsHandle for MockDnsHandle {
type Response =
Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send + Unpin>>;
type Error = ResolveError;
fn send<R: Into<DnsRequest>>(&self, _: R) -> Self::Response {
Box::pin(once(future::ready(
self.messages.lock().unwrap().pop().unwrap_or_else(empty),
)))
}
}
pub fn v4_message() -> Result<DnsResponse, ResolveError> {
let mut message = Message::new();
message.add_query(Query::query(Name::root(), RecordType::A));
message.insert_answers(vec![Record::from_rdata(
Name::root(),
86400,
RData::A(Ipv4Addr::new(127, 0, 0, 1).into()),
)]);
let resp = DnsResponse::from_message(message).unwrap();
assert!(resp.contains_answer());
Ok(resp)
}
pub fn v6_message() -> Result<DnsResponse, ResolveError> {
let mut message = Message::new();
message.add_query(Query::query(Name::root(), RecordType::AAAA));
message.insert_answers(vec![Record::from_rdata(
Name::root(),
86400,
RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()),
)]);
let resp = DnsResponse::from_message(message).unwrap();
assert!(resp.contains_answer());
Ok(resp)
}
pub fn empty() -> Result<DnsResponse, ResolveError> {
Ok(DnsResponse::from_message(Message::new()).unwrap())
}
pub fn error() -> Result<DnsResponse, ResolveError> {
Err(ResolveError::from("forced test failure"))
}
pub fn mock(messages: Vec<Result<DnsResponse, ResolveError>>) -> MockDnsHandle {
MockDnsHandle {
messages: Arc::new(Mutex::new(messages)),
}
}
#[test]
fn test_ipv4_only_strategy() {
assert_eq!(
block_on(ipv4_only(
Name::root(),
CachingClient::new(0, mock(vec![v4_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
}
#[test]
fn test_ipv6_only_strategy() {
assert_eq!(
block_on(ipv6_only(
Name::root(),
CachingClient::new(0, mock(vec![v6_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
}
#[test]
fn test_ipv4_and_ipv6_strategy() {
assert_eq!(
block_on(ipv4_and_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v6_message(), v4_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
]
);
assert_eq!(
block_on(ipv4_and_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![empty(), v4_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]
);
assert_eq!(
block_on(ipv4_and_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![error(), v4_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]
);
assert_eq!(
block_on(ipv4_and_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v6_message(), empty()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
assert_eq!(
block_on(ipv4_and_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v6_message(), error()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
);
}
#[test]
fn test_ipv6_then_ipv4_strategy() {
assert_eq!(
block_on(ipv6_then_ipv4(
Name::root(),
CachingClient::new(0, mock(vec![v6_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
assert_eq!(
block_on(ipv6_then_ipv4(
Name::root(),
CachingClient::new(0, mock(vec![v4_message(), empty()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
assert_eq!(
block_on(ipv6_then_ipv4(
Name::root(),
CachingClient::new(0, mock(vec![v4_message(), error()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
}
#[test]
fn test_ipv4_then_ipv6_strategy() {
assert_eq!(
block_on(ipv4_then_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v4_message()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv4Addr::new(127, 0, 0, 1)]
);
assert_eq!(
block_on(ipv4_then_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v6_message(), empty()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
assert_eq!(
block_on(ipv4_then_ipv6(
Name::root(),
CachingClient::new(0, mock(vec![v6_message(), error()]), false),
DnsRequestOptions::default(),
None,
))
.unwrap()
.iter()
.map(|r| r.ip_addr().unwrap())
.collect::<Vec<IpAddr>>(),
vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
);
}
}