hickory_resolver/name_server/
connection_provider.rsuse std::future::Future;
use std::io;
use std::marker::Unpin;
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(feature = "dns-over-tls")]
use crate::proto::runtime::iocompat::AsyncIoStdAsTokio;
use crate::proto::runtime::Spawn;
#[cfg(feature = "tokio-runtime")]
use crate::proto::runtime::TokioRuntimeProvider;
use futures_util::future::FutureExt;
use futures_util::ready;
use futures_util::stream::{Stream, StreamExt};
#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
use tokio_native_tls::TlsStream as TokioTlsStream;
#[cfg(all(
feature = "dns-over-openssl",
not(feature = "dns-over-rustls"),
not(feature = "dns-over-native-tls")
))]
use tokio_openssl::SslStream as TokioTlsStream;
#[cfg(feature = "dns-over-rustls")]
use tokio_rustls::client::TlsStream as TokioTlsStream;
use crate::config::{NameServerConfig, ResolverOpts};
#[cfg(any(feature = "dns-over-h3", feature = "dns-over-https-rustls"))]
use crate::proto;
#[cfg(feature = "dns-over-https-rustls")]
use crate::proto::h2::{HttpsClientConnect, HttpsClientStream};
#[cfg(feature = "dns-over-h3")]
use crate::proto::h3::{H3ClientConnect, H3ClientStream};
#[cfg(feature = "dns-over-quic")]
use crate::proto::quic::{QuicClientConnect, QuicClientStream};
#[cfg(feature = "dns-over-tls")]
use crate::proto::runtime::iocompat::AsyncIoTokioAsStd;
#[cfg(feature = "tokio-runtime")]
#[allow(unused_imports)] use crate::proto::runtime::TokioTime;
use crate::proto::{
runtime::RuntimeProvider,
tcp::TcpClientStream,
udp::{UdpClientConnect, UdpClientStream},
xfer::{
DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
DnsMultiplexerConnect, DnsRequest, DnsResponse, Protocol,
},
ProtoError,
};
pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
type Conn: DnsHandle + Clone + Send + Sync + 'static;
type FutureConn: Future<Output = Result<Self::Conn, ProtoError>> + Send + 'static;
type RuntimeProvider: RuntimeProvider;
fn new_connection(
&self,
config: &NameServerConfig,
options: &ResolverOpts,
) -> Result<Self::FutureConn, io::Error>;
}
#[cfg(feature = "dns-over-tls")]
type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
#[allow(clippy::large_enum_variant, clippy::type_complexity)]
pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
Udp(DnsExchangeConnect<UdpClientConnect<R>, UdpClientStream<R>, R::Timer>),
Tcp(
DnsExchangeConnect<
DnsMultiplexerConnect<
Pin<Box<dyn Future<Output = Result<TcpClientStream<R::Tcp>, ProtoError>> + Send>>,
TcpClientStream<<R as RuntimeProvider>::Tcp>,
>,
DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>>,
R::Timer,
>,
),
#[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
Tls(
DnsExchangeConnect<
DnsMultiplexerConnect<
Pin<
Box<
dyn Future<
Output = Result<
TlsClientStream<<R as RuntimeProvider>::Tcp>,
ProtoError,
>,
> + Send
+ 'static,
>,
>,
TlsClientStream<<R as RuntimeProvider>::Tcp>,
>,
DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>>,
TokioTime,
>,
),
#[cfg(all(feature = "dns-over-https-rustls", feature = "tokio-runtime"))]
Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
#[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
#[cfg(all(feature = "dns-over-h3", feature = "tokio-runtime"))]
H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
}
#[must_use = "futures do nothing unless polled"]
pub struct ConnectionFuture<R: RuntimeProvider> {
pub(crate) connect: ConnectionConnect<R>,
pub(crate) spawner: R::Handle,
}
impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
type Output = Result<GenericConnection, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Ok(match &mut self.connect {
ConnectionConnect::Udp(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
ConnectionConnect::Tcp(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
#[cfg(feature = "dns-over-tls")]
ConnectionConnect::Tls(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
#[cfg(feature = "dns-over-https-rustls")]
ConnectionConnect::Https(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
#[cfg(feature = "dns-over-quic")]
ConnectionConnect::Quic(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
#[cfg(feature = "dns-over-h3")]
ConnectionConnect::H3(conn) => {
let (conn, bg) = ready!(conn.poll_unpin(cx))?;
self.spawner.spawn_bg(bg);
GenericConnection(conn)
}
}))
}
}
#[derive(Clone)]
pub struct GenericConnection(DnsExchange);
impl DnsHandle for GenericConnection {
type Response = ConnectionResponse;
fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
ConnectionResponse(self.0.send(request))
}
}
#[cfg(feature = "tokio-runtime")]
pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
#[derive(Clone)]
pub struct GenericConnector<P: RuntimeProvider> {
runtime_provider: P,
}
impl<P: RuntimeProvider> GenericConnector<P> {
pub fn new(runtime_provider: P) -> Self {
Self { runtime_provider }
}
}
impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
fn default() -> Self {
Self {
runtime_provider: P::default(),
}
}
}
impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
type Conn = GenericConnection;
type FutureConn = ConnectionFuture<P>;
type RuntimeProvider = P;
fn new_connection(
&self,
config: &NameServerConfig,
options: &ResolverOpts,
) -> Result<Self::FutureConn, io::Error> {
let dns_connect = match (config.protocol, self.runtime_provider.quic_binder()) {
(Protocol::Udp, _) => {
let provider_handle = self.runtime_provider.clone();
let stream = UdpClientStream::builder(config.socket_addr, provider_handle)
.with_timeout(Some(options.timeout))
.avoid_local_ports(options.avoid_local_udp_ports.clone())
.build();
let exchange = DnsExchange::connect(stream);
ConnectionConnect::Udp(exchange)
}
(Protocol::Tcp, _) => {
let (future, handle) = TcpClientStream::new(
config.socket_addr,
None,
Some(options.timeout),
self.runtime_provider.clone(),
);
let dns_conn = DnsMultiplexer::with_timeout(future, handle, options.timeout, None);
let exchange = DnsExchange::connect(dns_conn);
ConnectionConnect::Tcp(exchange)
}
#[cfg(feature = "dns-over-tls")]
(Protocol::Tls, _) => {
let socket_addr = config.socket_addr;
let timeout = options.timeout;
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
#[cfg(feature = "dns-over-rustls")]
let (stream, handle) = {
crate::tls::new_tls_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
client_config,
)
};
#[cfg(not(feature = "dns-over-rustls"))]
let (stream, handle) = {
crate::tls::new_tls_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
self.runtime_provider.clone(),
)
};
let dns_conn = DnsMultiplexer::with_timeout(stream, handle, timeout, None);
let exchange = DnsExchange::connect(dns_conn);
ConnectionConnect::Tls(exchange)
}
#[cfg(feature = "dns-over-https-rustls")]
(Protocol::Https, _) => {
let socket_addr = config.socket_addr;
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
let http_endpoint = config
.http_endpoint
.clone()
.unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let tcp_future = self.runtime_provider.connect_tcp(socket_addr, None, None);
let exchange = crate::h2::new_https_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
http_endpoint,
client_config,
);
ConnectionConnect::Https(exchange)
}
#[cfg(feature = "dns-over-quic")]
(Protocol::Quic, Some(binder)) => {
let socket_addr = config.socket_addr;
let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
});
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let socket = binder.bind_quic(bind_addr, socket_addr)?;
let exchange = crate::quic::new_quic_stream_with_future(
socket,
socket_addr,
tls_dns_name,
client_config,
);
ConnectionConnect::Quic(exchange)
}
#[cfg(feature = "dns-over-h3")]
(Protocol::H3, Some(binder)) => {
let socket_addr = config.socket_addr;
let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
});
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
let http_endpoint = config
.http_endpoint
.clone()
.unwrap_or_else(|| proto::http::DEFAULT_DNS_QUERY_PATH.to_owned());
let client_config = config.tls_config.clone();
let socket = binder.bind_quic(bind_addr, socket_addr)?;
let exchange = crate::h3::new_h3_stream_with_future(
socket,
socket_addr,
tls_dns_name,
http_endpoint,
client_config,
);
ConnectionConnect::H3(exchange)
}
(protocol, _) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unsupported protocol: {protocol:?}"),
));
}
};
Ok(ConnectionFuture::<P> {
connect: dns_connect,
spawner: self.runtime_provider.create_handle(),
})
}
}
#[must_use = "streams do nothing unless polled"]
pub struct ConnectionResponse(DnsExchangeSend);
impl Stream for ConnectionResponse {
type Item = Result<DnsResponse, ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(ready!(self.0.poll_next_unpin(cx)))
}
}