use std::cmp::Ordering;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_util::future::FutureExt;
use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
use smallvec::SmallVec;
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
use crate::proto::runtime::TokioRuntimeProvider;
use crate::proto::runtime::{RuntimeProvider, Time};
use crate::proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
use crate::proto::{ProtoError, ProtoErrorKind};
use tracing::debug;
use rand::thread_rng as rng;
use rand::Rng;
use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
use crate::name_server::name_server::NameServer;
#[derive(Clone)]
pub struct NameServerPool<P: ConnectionProvider + Send + 'static> {
datagram_conns: Arc<[NameServer<P>]>, stream_conns: Arc<[NameServer<P>]>, options: ResolverOpts,
}
pub type GenericNameServerPool<P> = NameServerPool<GenericConnector<P>>;
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
impl GenericNameServerPool<TokioRuntimeProvider> {
pub(crate) fn tokio_from_config(
config: &ResolverConfig,
options: ResolverOpts,
runtime: TokioRuntimeProvider,
) -> Self {
Self::from_config_with_provider(config, options, GenericConnector::new(runtime))
}
}
impl<P> NameServerPool<P>
where
P: ConnectionProvider + 'static,
{
pub(crate) fn from_config_with_provider(
config: &ResolverConfig,
options: ResolverOpts,
conn_provider: P,
) -> Self {
let datagram_conns: Vec<NameServer<P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_datagram())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config.clone_from(config.client_config());
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::new(ns_config, options.clone(), conn_provider.clone())
})
.collect();
let stream_conns: Vec<NameServer<P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_stream())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config.clone_from(config.client_config());
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::new(ns_config, options.clone(), conn_provider.clone())
})
.collect();
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
options,
}
}
pub fn from_config(
name_servers: NameServerConfigGroup,
options: ResolverOpts,
conn_provider: P,
) -> Self {
let map_config_to_ns =
|ns_config| NameServer::new(ns_config, options.clone(), conn_provider.clone());
let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
.into_inner()
.into_iter()
.partition(|ns| ns.protocol.is_datagram());
let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
options,
}
}
#[doc(hidden)]
pub fn from_nameservers(
options: ResolverOpts,
datagram_conns: Vec<NameServer<P>>,
stream_conns: Vec<NameServer<P>>,
) -> Self {
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
options,
}
}
#[cfg(test)]
#[allow(dead_code)]
fn from_nameservers_test(
options: ResolverOpts,
datagram_conns: Arc<[NameServer<P>]>,
stream_conns: Arc<[NameServer<P>]>,
) -> Self {
Self {
datagram_conns,
stream_conns,
options,
}
}
async fn try_send(
opts: ResolverOpts,
conns: Arc<[NameServer<P>]>,
request: DnsRequest,
) -> Result<DnsResponse, ProtoError> {
let mut conns: Vec<NameServer<P>> = conns.to_vec();
match opts.server_ordering_strategy {
ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
ServerOrderingStrategy::UserProvidedOrder => {}
}
let request_loop = request.clone();
parallel_conn_loop(conns, request_loop, opts).await
}
}
impl<P> DnsHandle for NameServerPool<P>
where
P: ConnectionProvider + 'static,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
let opts = self.options.clone();
let request = request.into();
let datagram_conns = Arc::clone(&self.datagram_conns);
let stream_conns = Arc::clone(&self.stream_conns);
let tcp_message = request.clone();
let mdns = Local::NotMdns(request);
if mdns.is_local() {
return mdns.take_stream();
}
let request = mdns.take_request();
Box::pin(once(async move {
debug!("sending request: {:?}", request.queries());
let udp_res: Result<DnsResponse, ProtoError> =
match Self::try_send(opts.clone(), datagram_conns, request).await {
Ok(response) if response.truncated() => {
debug!("truncated response received, retrying over TCP");
Ok(response)
}
Err(e) if (opts.try_tcp_on_error && e.is_io()) || e.is_no_connections() => {
debug!("error from UDP, retrying over TCP: {}", e);
Err(e)
}
result => return result.map_err(ProtoError::from),
};
if stream_conns.is_empty() {
debug!("no TCP connections available");
return udp_res.map_err(ProtoError::from);
}
let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
let tcp_err = match tcp_res {
res @ Ok(..) => return res.map_err(ProtoError::from),
Err(e) => e,
};
let udp_err = match udp_res {
Ok(response) => return Ok(response),
Err(e) => e,
};
match udp_err.cmp_specificity(&tcp_err) {
Ordering::Greater => Err(udp_err),
_ => Err(tcp_err),
}
}))
}
}
async fn parallel_conn_loop<P>(
mut conns: Vec<NameServer<P>>,
request: DnsRequest,
opts: ResolverOpts,
) -> Result<DnsResponse, ProtoError>
where
P: ConnectionProvider + 'static,
{
let mut err = ProtoError::from(ProtoErrorKind::NoConnections);
let mut backoff = Duration::from_millis(20);
let mut busy = SmallVec::<[NameServer<P>; 2]>::new();
loop {
let request_cont = request.clone();
let mut par_conns = SmallVec::<[NameServer<P>; 2]>::new();
let count = conns.len().min(opts.num_concurrent_reqs.max(1));
if opts.shuffle_dns_servers {
for _ in 0..count {
let idx = rng().gen_range(0..conns.len());
par_conns.push(conns.swap_remove(idx));
}
} else {
for conn in conns.drain(..count) {
par_conns.push(conn);
}
}
if par_conns.is_empty() {
if !busy.is_empty() && backoff < Duration::from_millis(300) {
<<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
backoff,
)
.await;
conns.extend(busy.drain(..));
backoff *= 2;
continue;
}
return Err(err);
}
let mut requests = par_conns
.into_iter()
.map(move |conn| {
conn.send(request_cont.clone())
.first_answer()
.map(|result| result.map_err(|e| (conn, e)))
})
.collect::<FuturesUnordered<_>>();
while let Some(result) = requests.next().await {
let (conn, e) = match result {
Ok(sent) => return Ok(sent),
Err((conn, e)) => (conn, e),
};
match e.kind() {
ProtoErrorKind::NoRecordsFound {
trusted, soa, ns, ..
} if *trusted || soa.is_some() || ns.is_some() => {
return Err(e);
}
_ if e.is_busy() => {
busy.push(conn);
}
_ if matches!(err.kind(), ProtoErrorKind::NoConnections) => {
err = e;
}
_ if err.cmp_specificity(&e) == Ordering::Less => {
err = e;
}
_ => {}
}
}
}
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum Local {
#[allow(dead_code)]
ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>),
NotMdns(DnsRequest),
}
impl Local {
fn is_local(&self) -> bool {
matches!(*self, Self::ResolveStream(..))
}
fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>> {
match self {
Self::ResolveStream(future) => future,
_ => panic!("non Local queries have no future, see take_message()"),
}
}
fn take_request(self) -> DnsRequest {
match self {
Self::NotMdns(request) => request,
_ => panic!("Local queries must be polled, see take_future()"),
}
}
}
impl Stream for Local {
type Item = Result<DnsResponse, ProtoError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
Self::ResolveStream(ns) => ns.as_mut().poll_next(cx),
Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use tokio::runtime::Runtime;
use crate::proto::op::Query;
use crate::proto::rr::{Name, RecordType};
use crate::proto::xfer::{DnsHandle, DnsRequestOptions, Protocol};
use super::*;
use crate::config::NameServerConfig;
use crate::name_server::connection_provider::TokioConnectionProvider;
use crate::name_server::GenericNameServer;
#[ignore]
#[test]
#[allow(clippy::uninlined_format_args)]
fn test_failed_then_success_pool() {
let config1 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
protocol: Protocol::Udp,
tls_dns_name: None,
http_endpoint: None,
trust_negative_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let config2 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Udp,
tls_dns_name: None,
http_endpoint: None,
trust_negative_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(config1);
resolver_config.add_name_server(config2);
let io_loop = Runtime::new().unwrap();
let pool = GenericNameServerPool::tokio_from_config(
&resolver_config,
ResolverOpts::default(),
TokioRuntimeProvider::new(),
);
let name = Name::parse("www.example.com.", None).unwrap();
for i in 0..2 {
assert!(
io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
)
.first_answer()
)
.is_err(),
"iter: {}",
i
);
}
for i in 0..10 {
assert!(
io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
)
.first_answer()
)
.is_ok(),
"iter: {}",
i
);
}
}
#[test]
fn test_multi_use_conns() {
let io_loop = Runtime::new().unwrap();
let conn_provider = TokioConnectionProvider::default();
let tcp = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Tcp,
tls_dns_name: None,
http_endpoint: None,
trust_negative_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let opts = ResolverOpts {
try_tcp_on_error: true,
..ResolverOpts::default()
};
let ns_config = { tcp };
let name_server = GenericNameServer::new(ns_config, opts.clone(), conn_provider);
let name_servers: Arc<[_]> = Arc::from([name_server]);
let pool = GenericNameServerPool::from_nameservers_test(
opts,
Arc::from([]),
Arc::clone(&name_servers),
);
let name = Name::from_str("www.example.com.").unwrap();
let response = io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
)
.first_answer(),
)
.expect("lookup failed");
assert_eq!(
*response.answers()[0]
.data()
.as_a()
.expect("no a record available"),
Ipv4Addr::new(93, 184, 215, 14).into()
);
assert!(
name_servers[0].is_connected(),
"if this is failing then the NameServers aren't being properly shared."
);
let response = io_loop
.block_on(
pool.lookup(
Query::query(name, RecordType::AAAA),
DnsRequestOptions::default(),
)
.first_answer(),
)
.expect("lookup failed");
assert_eq!(
*response.answers()[0]
.data()
.as_aaaa()
.expect("no aaaa record available"),
Ipv6Addr::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c).into()
);
assert!(
name_servers[0].is_connected(),
"if this is failing then the NameServers aren't being properly shared."
);
}
}