use std::cmp::Ordering;
use std::fmt::{self, Debug, Formatter};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use futures_util::lock::Mutex;
use futures_util::stream::{once, Stream};
use proto::{
error::ProtoError,
xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer},
};
use tracing::debug;
use crate::config::{NameServerConfig, ResolverOpts};
use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
use crate::name_server::{NameServerState, NameServerStats};
#[derive(Clone)]
pub struct NameServer<P: ConnectionProvider> {
config: NameServerConfig,
options: ResolverOpts,
client: Arc<Mutex<Option<P::Conn>>>,
state: Arc<NameServerState>,
stats: Arc<NameServerStats>,
connection_provider: P,
}
pub type GenericNameServer<R> = NameServer<GenericConnector<R>>;
impl<P> Debug for NameServer<P>
where
P: ConnectionProvider + Send,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "config: {:?}, options: {:?}", self.config, self.options)
}
}
impl<P> NameServer<P>
where
P: ConnectionProvider + Send,
{
pub fn new(config: NameServerConfig, options: ResolverOpts, connection_provider: P) -> Self {
Self {
config,
options,
client: Arc::new(Mutex::new(None)),
state: Arc::new(NameServerState::init(None)),
stats: Arc::new(NameServerStats::default()),
connection_provider,
}
}
#[doc(hidden)]
pub fn from_conn(
config: NameServerConfig,
options: ResolverOpts,
client: P::Conn,
connection_provider: P,
) -> Self {
Self {
config,
options,
client: Arc::new(Mutex::new(Some(client))),
state: Arc::new(NameServerState::init(None)),
stats: Arc::new(NameServerStats::default()),
connection_provider,
}
}
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn is_connected(&self) -> bool {
!self.state.is_failed()
&& if let Some(client) = self.client.try_lock() {
client.is_some()
} else {
true
}
}
async fn connected_mut_client(&mut self) -> Result<P::Conn, ProtoError> {
let mut client = self.client.lock().await;
if self.state.is_failed() || client.is_none() {
debug!("reconnecting: {:?}", self.config);
self.state.reinit(None);
let new_client = Box::pin(
self.connection_provider
.new_connection(&self.config, &self.options)?,
)
.await?;
*client = Some(new_client);
} else {
debug!("existing connection: {:?}", self.config);
}
Ok((*client)
.clone()
.expect("bad state, client should be connected"))
}
async fn inner_send<R: Into<DnsRequest> + Unpin + Send + 'static>(
mut self,
request: R,
) -> Result<DnsResponse, ProtoError> {
let client = self.connected_mut_client().await?;
let now = Instant::now();
let response = client.send(request).first_answer().await;
let rtt = now.elapsed();
match response {
Ok(response) => {
self.stats.record_rtt(rtt);
let response =
ProtoError::from_response(response, self.config.trust_negative_responses)?;
let remote_edns = response.extensions().clone();
self.state.establish(remote_edns);
Ok(response)
}
Err(error) => {
debug!("name_server connection failure: {}", error);
self.state.fail(Instant::now());
self.stats.record_connection_failure();
Err(error)
}
}
}
pub fn trust_nx_responses(&self) -> bool {
self.config.trust_negative_responses
}
}
impl<P> DnsHandle for NameServer<P>
where
P: ConnectionProvider + Clone,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
fn is_verifying_dnssec(&self) -> bool {
self.options.validate
}
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
let this = self.clone();
Box::pin(once(this.inner_send(request)))
}
}
impl<P> Ord for NameServer<P>
where
P: ConnectionProvider + Send,
{
fn cmp(&self, other: &Self) -> Ordering {
if self == other {
return Ordering::Equal;
}
self.stats.cmp(&other.stats)
}
}
impl<P> PartialOrd for NameServer<P>
where
P: ConnectionProvider + Send,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<P> PartialEq for NameServer<P>
where
P: ConnectionProvider + Send,
{
fn eq(&self, other: &Self) -> bool {
self.config == other.config
}
}
impl<P> Eq for NameServer<P> where P: ConnectionProvider + Send {}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use futures_util::{future, FutureExt};
use tokio::runtime::Runtime;
use proto::op::{Query, ResponseCode};
use proto::rr::{Name, RecordType};
use proto::xfer::{DnsHandle, DnsRequestOptions, FirstAnswer};
use super::*;
use crate::config::Protocol;
use crate::name_server::TokioConnectionProvider;
#[test]
fn test_name_server() {
let config = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Udp,
tls_dns_name: None,
trust_negative_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let io_loop = Runtime::new().unwrap();
let name_server = future::lazy(|_| {
GenericNameServer::new(
config,
ResolverOpts::default(),
TokioConnectionProvider::default(),
)
});
let name = Name::parse("www.example.com.", None).unwrap();
let response = io_loop
.block_on(name_server.then(|name_server| {
name_server
.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
)
.first_answer()
}))
.expect("query failed");
assert_eq!(response.response_code(), ResponseCode::NoError);
}
#[test]
fn test_failed_name_server() {
let options = ResolverOpts {
timeout: Duration::from_millis(1), ..ResolverOpts::default()
};
let config = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 252),
protocol: Protocol::Udp,
tls_dns_name: None,
trust_negative_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let io_loop = Runtime::new().unwrap();
let name_server = future::lazy(|_| {
GenericNameServer::new(config, options, TokioConnectionProvider::default())
});
let name = Name::parse("www.example.com.", None).unwrap();
assert!(io_loop
.block_on(name_server.then(|name_server| {
name_server
.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
)
.first_answer()
}))
.is_err());
}
}