use std::fmt::{self, Display};
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::ops::DerefMut;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Buf, Bytes, BytesMut};
use futures_util::future::{FutureExt, TryFutureExt};
use futures_util::ready;
use futures_util::stream::Stream;
use h2::client::{Connection, SendRequest};
use http::header::{self, CONTENT_LENGTH};
use rustls::ClientConfig;
use tokio_rustls::{
client::TlsStream as TokioTlsClientStream, Connect as TokioTlsConnect, TlsConnector,
};
use tracing::{debug, warn};
use crate::error::ProtoError;
use crate::http::Version;
use crate::iocompat::AsyncIoStdAsTokio;
use crate::op::Message;
use crate::tcp::{Connect, DnsTcpStream};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
const ALPN_H2: &[u8] = b"h2";
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct HttpsClientStream {
name_server_name: Arc<str>,
name_server: SocketAddr,
h2: SendRequest<Bytes>,
is_shutdown: bool,
}
impl Display for HttpsClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
formatter,
"HTTPS({},{})",
self.name_server, self.name_server_name
)
}
}
impl HttpsClientStream {
async fn inner_send(
h2: SendRequest<Bytes>,
message: Bytes,
name_server_name: Arc<str>,
) -> Result<DnsResponse, ProtoError> {
let mut h2 = match h2.ready().await {
Ok(h2) => h2,
Err(err) => {
return Err(ProtoError::from(format!("h2 send_request error: {err}")));
}
};
let request =
crate::http::request::new(Version::Http2, &name_server_name, message.remaining());
let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
debug!("request: {:#?}", request);
let (response_future, mut send_stream) = h2
.send_request(request, false)
.map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
send_stream
.send_data(message, true)
.map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
let mut response_stream = response_future
.await
.map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
debug!("got response: {:#?}", response_stream);
let content_length = response_stream
.headers()
.get(CONTENT_LENGTH)
.map(|v| v.to_str())
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
.map(usize::from_str)
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
let mut response_bytes =
BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
while let Some(partial_bytes) = response_stream.body_mut().data().await {
let partial_bytes =
partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
debug!("got bytes: {}", partial_bytes.len());
response_bytes.extend(partial_bytes);
if let Some(content_length) = content_length {
if response_bytes.len() >= content_length {
break;
}
}
}
if let Some(content_length) = content_length {
if response_bytes.len() != content_length {
return Err(ProtoError::from(format!(
"expected byte length: {}, got: {}",
content_length,
response_bytes.len()
)));
}
}
if !response_stream.status().is_success() {
let error_string = String::from_utf8_lossy(response_bytes.as_ref());
return Err(ProtoError::from(format!(
"http unsuccessful code: {}, message: {}",
response_stream.status(),
error_string
)));
} else {
{
let content_type = response_stream
.headers()
.get(header::CONTENT_TYPE)
.map(|h| {
h.to_str().map_err(|err| {
ProtoError::from(format!("ContentType header not a string: {err}"))
})
})
.unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
if content_type != crate::http::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::http::MIME_APPLICATION_DNS,
content_type
)));
}
}
};
let message = Message::from_vec(&response_bytes)?;
Ok(DnsResponse::new(message, response_bytes.to_vec()))
}
}
impl DnsRequestSender for HttpsClientStream {
fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
message.set_id(0);
let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => return err.into(),
};
Box::pin(Self::inner_send(
self.h2.clone(),
Bytes::from(bytes),
Arc::clone(&self.name_server_name),
))
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl Stream for HttpsClientStream {
type Item = Result<(), ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}
match self.h2.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h2 stream errored: {e}",
))))),
}
}
}
#[derive(Clone)]
pub struct HttpsClientStreamBuilder {
client_config: Arc<ClientConfig>,
bind_addr: Option<SocketAddr>,
}
impl HttpsClientStreamBuilder {
pub fn with_client_config(client_config: Arc<ClientConfig>) -> Self {
Self {
client_config,
bind_addr: None,
}
}
pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
self.bind_addr = Some(bind_addr);
}
pub fn build<S: Connect>(
mut self,
name_server: SocketAddr,
dns_name: String,
) -> HttpsClientConnect<S> {
if self.client_config.alpn_protocols.is_empty() {
let mut client_config = (*self.client_config).clone();
client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
self.client_config = Arc::new(client_config);
}
let tls = TlsConfig {
client_config: self.client_config,
dns_name: Arc::from(dns_name),
};
let connect = S::connect_with_bind(name_server, self.bind_addr);
HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
connect,
name_server,
tls: Some(tls),
})
}
pub fn build_with_future<S, F>(
future: F,
mut client_config: Arc<ClientConfig>,
name_server: SocketAddr,
dns_name: String,
) -> HttpsClientConnect<S>
where
S: DnsTcpStream,
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
if client_config.alpn_protocols.is_empty() {
let mut client_cfg = (*client_config).clone();
client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
client_config = Arc::new(client_cfg);
}
let tls = TlsConfig {
client_config,
dns_name: Arc::from(dns_name),
};
HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
connect: Box::pin(future),
name_server,
tls: Some(tls),
})
}
}
pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
where
S: DnsTcpStream;
impl<S> Future for HttpsClientConnect<S>
where
S: DnsTcpStream,
{
type Output = Result<HttpsClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
struct TlsConfig {
client_config: Arc<ClientConfig>,
dns_name: Arc<str>,
}
#[allow(clippy::large_enum_variant)]
#[allow(clippy::type_complexity)]
enum HttpsClientConnectState<S>
where
S: DnsTcpStream,
{
TcpConnecting {
connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
name_server: SocketAddr,
tls: Option<TlsConfig>,
},
TlsConnecting {
tls: TokioTlsConnect<AsyncIoStdAsTokio<S>>,
name_server_name: Arc<str>,
name_server: SocketAddr,
},
H2Handshake {
handshake: Pin<
Box<
dyn Future<
Output = Result<
(
SendRequest<Bytes>,
Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
),
h2::Error,
>,
> + Send,
>,
>,
name_server_name: Arc<str>,
name_server: SocketAddr,
},
Connected(Option<HttpsClientStream>),
Errored(Option<ProtoError>),
}
impl<S> Future for HttpsClientConnectState<S>
where
S: DnsTcpStream,
{
type Output = Result<HttpsClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let next = match *self {
Self::TcpConnecting {
ref mut connect,
name_server,
ref mut tls,
} => {
let tcp = ready!(connect.poll_unpin(cx))?;
debug!("tcp connection established to: {}", name_server);
let tls = tls
.take()
.expect("programming error, tls should not be None here");
let name_server_name = Arc::clone(&tls.dns_name);
match tls.dns_name.as_ref().try_into() {
Ok(dns_name) => {
let tls = TlsConnector::from(tls.client_config);
let tls = tls.connect(dns_name, AsyncIoStdAsTokio(tcp));
Self::TlsConnecting {
name_server_name,
name_server,
tls,
}
}
Err(_) => Self::Errored(Some(ProtoError::from(format!(
"bad dns_name: {}",
&tls.dns_name
)))),
}
}
Self::TlsConnecting {
ref name_server_name,
name_server,
ref mut tls,
} => {
let tls = ready!(tls.poll_unpin(cx))?;
debug!("tls connection established to: {}", name_server);
let mut handshake = h2::client::Builder::new();
handshake.enable_push(false);
let handshake = handshake.handshake(tls);
Self::H2Handshake {
name_server_name: Arc::clone(name_server_name),
name_server,
handshake: Box::pin(handshake),
}
}
Self::H2Handshake {
ref name_server_name,
name_server,
ref mut handshake,
} => {
let (send_request, connection) = ready!(handshake
.poll_unpin(cx)
.map_err(|e| ProtoError::from(format!("h2 handshake error: {e}"))))?;
debug!("h2 connection established to: {}", name_server);
tokio::spawn(
connection
.map_err(|e| warn!("h2 connection failed: {e}"))
.map(|_: Result<(), ()>| ()),
);
Self::Connected(Some(HttpsClientStream {
name_server_name: Arc::clone(name_server_name),
name_server,
h2: send_request,
is_shutdown: false,
}))
}
Self::Connected(ref mut conn) => {
return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")))
}
Self::Errored(ref mut err) => {
return Poll::Ready(Err(err.take().expect("cannot poll after complete")))
}
};
*self.as_mut().deref_mut() = next;
}
}
}
pub struct HttpsClientResponse(
Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
);
impl Future for HttpsClientResponse {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx).map_err(ProtoError::from)
}
}
#[cfg(any(feature = "webpki-roots", feature = "native-certs"))]
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::str::FromStr;
use rustls::KeyLogFile;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;
use crate::iocompat::AsyncIoTokioAsStd;
use crate::op::{Message, Query, ResponseCode};
use crate::rr::rdata::{A, AAAA};
use crate::rr::{Name, RData, RecordType};
use crate::xfer::{DnsRequestOptions, FirstAnswer};
use super::*;
#[test]
fn test_https_google() {
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = client_config_tls12();
client_config.key_log = Arc::new(KeyLogFile::new());
let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
let connect = https_builder
.build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, "dns.google".to_string());
let runtime = Runtime::new().expect("could not start runtime");
let mut https = runtime.block_on(connect).expect("https connect failed");
let response = runtime
.block_on(https.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(https.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}
#[test]
fn test_https_google_with_pure_ip_address_server() {
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = client_config_tls12();
client_config.key_log = Arc::new(KeyLogFile::new());
let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
let connect = https_builder
.build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, google.ip().to_string());
let runtime = Runtime::new().expect("could not start runtime");
let mut https = runtime.block_on(connect).expect("https connect failed");
let response = runtime
.block_on(https.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(https.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}
#[test]
#[ignore] fn test_https_cloudflare() {
let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let client_config = client_config_tls12();
let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
let connect = https_builder.build::<AsyncIoTokioAsStd<TokioTcpStream>>(
cloudflare,
"cloudflare-dns.com".to_string(),
);
let runtime = Runtime::new().expect("could not start runtime");
let mut https = runtime.block_on(connect).expect("https connect failed");
let response = runtime
.block_on(https.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_a)
.expect("invalid response, expected A record");
assert_eq!(addr, &A::new(93, 184, 216, 34));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = runtime
.block_on(https.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.and_then(RData::as_aaaa)
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
fn client_config_tls12() -> ClientConfig {
use rustls::RootCertStore;
#[cfg_attr(
not(any(feature = "native-certs", feature = "webpki-roots")),
allow(unused_mut)
)]
let mut root_store = RootCertStore::empty();
#[cfg(all(feature = "native-certs", not(feature = "webpki-roots")))]
{
let (added, ignored) = root_store
.add_parsable_certificates(&rustls_native_certs::load_native_certs().unwrap());
if ignored > 0 {
warn!(
"failed to parse {} certificate(s) from the native root store",
ignored
);
}
if added == 0 {
panic!("no valid certificates found in the native root store");
}
}
#[cfg(feature = "webpki-roots")]
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let mut client_config = ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(root_store)
.with_no_client_auth();
client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
client_config
}
}