use std::fmt::{self, Display};
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures::{future, ready, Future, FutureExt, Stream, TryFutureExt};
use h2;
use h2::client::{Connection, SendRequest};
use http::{self, header};
use log::{debug, warn};
use rustls::ClientConfig;
use tokio;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_rustls::{client::TlsStream as TokioTlsClientStream, Connect, TlsConnector};
use typed_headers::{ContentLength, HeaderMapExt};
use webpki::DNSNameRef;
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::xfer::{DnsRequest, DnsRequestSender, DnsResponse, SerialMessage};
use trust_dns_proto::Time;
const ALPN_H2: &[u8] = b"h2";
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct HttpsClientStream {
name_server_name: Arc<String>,
name_server: SocketAddr,
h2: SendRequest<Bytes>,
is_shutdown: bool,
}
impl Display for HttpsClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
formatter,
"HTTPS({},{})",
self.name_server, self.name_server_name
)
}
}
impl HttpsClientStream {
async fn inner_send(
h2: SendRequest<Bytes>,
message: SerialMessage,
name_server_name: Arc<String>,
name_server: SocketAddr,
) -> Result<DnsResponse, ProtoError> {
let mut h2 = match h2.ready().await {
Ok(h2) => h2,
Err(err) => {
return Err(ProtoError::from(format!("h2 send_request error: {}", err)));
}
};
let bytes = BytesMut::from(message.bytes());
let request = crate::request::new(&name_server_name, bytes.len());
let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {}", err)))?;
debug!("request: {:#?}", request);
let (response_future, mut send_stream) = h2
.send_request(request, false)
.map_err(|err| ProtoError::from(format!("h2 send_request error: {}", err)))?;
send_stream
.send_data(bytes.freeze(), true)
.map_err(|e| ProtoError::from(format!("h2 send_data error: {}", e)))?;
let mut response_stream = response_future
.await
.map_err(|err| ProtoError::from(format!("received a stream error: {}", err)))?;
debug!("got response: {:#?}", response_stream);
let content_length: Option<usize> = response_stream
.headers()
.typed_get()
.map_err(|e| ProtoError::from(format!("bad headers received: {}", e)))?
.map(|c: ContentLength| *c as usize);
let mut response_bytes =
BytesMut::with_capacity(content_length.unwrap_or(512).max(512).min(4096));
while let Some(partial_bytes) = response_stream.body_mut().data().await {
let partial_bytes =
partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {}", e)))?;
debug!("got bytes: {}", partial_bytes.len());
response_bytes.extend(partial_bytes);
if let Some(content_length) = content_length {
if response_bytes.len() >= content_length {
break;
}
}
}
if let Some(content_length) = content_length {
if response_bytes.len() != content_length {
return Err(ProtoError::from(format!(
"expected byte length: {}, got: {}",
content_length,
response_bytes.len()
)));
}
}
if !response_stream.status().is_success() {
let error_string = String::from_utf8_lossy(response_bytes.as_ref());
return Err(ProtoError::from(format!(
"http unsuccessful code: {}, message: {}",
response_stream.status(),
error_string
)));
} else {
{
let content_type = response_stream
.headers()
.get(header::CONTENT_TYPE)
.map(|h| {
h.to_str().map_err(|err| {
ProtoError::from(format!("ContentType header not a string: {}", err))
})
})
.unwrap_or(Ok(crate::MIME_APPLICATION_DNS))?;
if content_type != crate::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::MIME_APPLICATION_DNS,
content_type
)));
}
}
};
let message = SerialMessage::new(response_bytes.to_vec(), name_server).to_message()?;
Ok(message.into())
}
}
impl DnsRequestSender for HttpsClientStream {
type DnsResponseFuture = HttpsClientResponse;
fn send_message<TE: Time>(
&mut self,
mut message: DnsRequest,
_cx: &mut Context,
) -> Self::DnsResponseFuture {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
message.set_id(0);
let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => return HttpsClientResponse(Box::pin(future::err(err))),
};
let message = SerialMessage::new(bytes, self.name_server);
HttpsClientResponse(Box::pin(Self::inner_send(
self.h2.clone(),
message,
Arc::clone(&self.name_server_name),
self.name_server,
)))
}
fn error_response<TE: Time>(error: ProtoError) -> Self::DnsResponseFuture {
HttpsClientResponse(Box::pin(future::err(error)))
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl Stream for HttpsClientStream {
type Item = Result<(), ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}
match self.h2.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h2 stream errored: {}",
e
))))),
}
}
}
#[derive(Clone)]
pub struct HttpsClientStreamBuilder {
client_config: Arc<ClientConfig>,
}
impl HttpsClientStreamBuilder {
pub fn new() -> HttpsClientStreamBuilder {
let mut client_config = ClientConfig::new();
client_config.alpn_protocols.push(ALPN_H2.to_vec());
HttpsClientStreamBuilder {
client_config: Arc::new(client_config),
}
}
pub fn with_client_config(client_config: Arc<ClientConfig>) -> Self {
HttpsClientStreamBuilder { client_config }
}
pub fn build(self, name_server: SocketAddr, dns_name: String) -> HttpsClientConnect {
assert!(self
.client_config
.alpn_protocols
.iter()
.any(|protocol| *protocol == ALPN_H2.to_vec()));
let tls = TlsConfig {
client_config: self.client_config,
dns_name: Arc::new(dns_name),
};
HttpsClientConnect(HttpsClientConnectState::ConnectTcp {
name_server,
tls: Some(tls),
})
}
}
impl Default for HttpsClientStreamBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct HttpsClientConnect(HttpsClientConnectState);
impl Future for HttpsClientConnect {
type Output = Result<HttpsClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
struct TlsConfig {
client_config: Arc<ClientConfig>,
dns_name: Arc<String>,
}
#[allow(clippy::large_enum_variant)]
#[allow(clippy::type_complexity)]
enum HttpsClientConnectState {
ConnectTcp {
name_server: SocketAddr,
tls: Option<TlsConfig>,
},
TcpConnecting {
connect: Pin<Box<dyn Future<Output = io::Result<TokioTcpStream>> + Send>>,
name_server: SocketAddr,
tls: Option<TlsConfig>,
},
TlsConnecting {
tls: Connect<TokioTcpStream>,
name_server_name: Arc<String>,
name_server: SocketAddr,
},
H2Handshake {
handshake: Pin<
Box<
dyn Future<
Output = Result<
(
SendRequest<Bytes>,
Connection<TokioTlsClientStream<TokioTcpStream>, Bytes>,
),
h2::Error,
>,
> + Send,
>,
>,
name_server_name: Arc<String>,
name_server: SocketAddr,
},
Connected(Option<HttpsClientStream>),
Errored(Option<ProtoError>),
}
impl Future for HttpsClientConnectState {
type Output = Result<HttpsClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
let next = match *self {
HttpsClientConnectState::ConnectTcp {
name_server,
ref mut tls,
} => {
debug!("tcp connecting to: {}", name_server);
let connect = Box::pin(TokioTcpStream::connect(name_server));
HttpsClientConnectState::TcpConnecting {
connect,
name_server,
tls: tls.take(),
}
}
HttpsClientConnectState::TcpConnecting {
ref mut connect,
name_server,
ref mut tls,
} => {
let tcp = ready!(connect.poll_unpin(cx))?;
debug!("tcp connection established to: {}", name_server);
let tls = tls
.take()
.expect("programming error, tls should not be None here");
let dns_name = tls.dns_name;
let name_server_name = Arc::clone(&dns_name);
match DNSNameRef::try_from_ascii_str(&dns_name) {
Ok(dns_name) => {
let tls = TlsConnector::from(tls.client_config);
let tls = tls.connect(dns_name, tcp);
HttpsClientConnectState::TlsConnecting {
name_server_name,
name_server,
tls,
}
}
Err(_) => HttpsClientConnectState::Errored(Some(ProtoError::from(
format!("bad dns_name: {}", dns_name),
))),
}
}
HttpsClientConnectState::TlsConnecting {
ref name_server_name,
name_server,
ref mut tls,
} => {
let tls = ready!(tls.poll_unpin(cx))?;
debug!("tls connection established to: {}", name_server);
let mut handshake = h2::client::Builder::new();
handshake.enable_push(false);
let handshake = handshake.handshake(tls);
HttpsClientConnectState::H2Handshake {
name_server_name: Arc::clone(&name_server_name),
name_server,
handshake: Box::pin(handshake),
}
}
HttpsClientConnectState::H2Handshake {
ref name_server_name,
name_server,
ref mut handshake,
} => {
let (send_request, connection) = ready!(handshake
.poll_unpin(cx)
.map_err(|e| ProtoError::from(format!("h2 handshake error: {}", e))))?;
debug!("h2 connection established to: {}", name_server);
tokio::spawn(
connection
.map_err(|e| warn!("h2 connection failed: {}", e))
.map(|_: Result<(), ()>| ()),
);
HttpsClientConnectState::Connected(Some(HttpsClientStream {
name_server_name: Arc::clone(&name_server_name),
name_server,
h2: send_request,
is_shutdown: false,
}))
}
HttpsClientConnectState::Connected(ref mut conn) => {
return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")))
}
HttpsClientConnectState::Errored(ref mut err) => {
return Poll::Ready(Err(err.take().expect("cannot poll after complete")))
}
};
mem::replace(self.as_mut().deref_mut(), next);
}
}
}
pub struct HttpsClientResponse(
Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
);
impl Future for HttpsClientResponse {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.as_mut().poll(cx).map_err(ProtoError::from)
}
}
#[cfg(test)]
mod tests {
extern crate env_logger;
extern crate tokio;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use rustls::{ClientConfig, ProtocolVersion, RootCertStore};
use tokio::runtime::Runtime;
use webpki_roots;
use trust_dns_proto::op::{Message, Query};
use trust_dns_proto::rr::{Name, RData, RecordType};
use trust_dns_proto::TokioTime;
use super::*;
#[test]
fn test_https_google() {
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, Default::default());
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let versions = vec![ProtocolVersion::TLSv1_2];
let mut client_config = ClientConfig::new();
client_config.root_store = root_store;
client_config.versions = versions;
client_config.alpn_protocols.push(ALPN_H2.to_vec());
let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
let connect = https_builder.build(google, "dns.google".to_string());
let mut runtime = Runtime::new().expect("could not start runtime");
let mut https = runtime.block_on(connect).expect("https connect failed");
let sending = runtime.block_on(future::lazy(|cx| {
https.send_message::<TokioTime>(request, cx)
}));
let response: DnsResponse = runtime.block_on(sending).expect("send_message failed");
let record = &response.answers()[0];
let addr = if let RData::A(addr) = record.rdata() {
addr
} else {
panic!("invalid response, expected A record");
};
assert_eq!(addr, &Ipv4Addr::new(93, 184, 216, 34));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, Default::default());
let sending = runtime.block_on(future::lazy(|cx| {
https.send_message::<TokioTime>(request, cx)
}));
let response: DnsResponse = runtime.block_on(sending).expect("send_message failed");
let record = &response.answers()[0];
let addr = if let RData::AAAA(addr) = record.rdata() {
addr
} else {
panic!("invalid response, expected A record");
};
assert_eq!(
addr,
&Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
#[test]
#[ignore]
fn test_https_cloudflare() {
let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, Default::default());
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let versions = vec![ProtocolVersion::TLSv1_2];
let mut client_config = ClientConfig::new();
client_config.root_store = root_store;
client_config.versions = versions;
client_config.alpn_protocols.push(ALPN_H2.to_vec());
let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
let connect = https_builder.build(cloudflare, "cloudflare-dns.com".to_string());
let mut runtime = Runtime::new().expect("could not start runtime");
let mut https = runtime.block_on(connect).expect("https connect failed");
let sending = runtime.block_on(future::lazy(|cx| {
https.send_message::<TokioTime>(request, cx)
}));
let response: DnsResponse = runtime.block_on(sending).expect("send_message failed");
let record = &response.answers()[0];
let addr = if let RData::A(addr) = record.rdata() {
addr
} else {
panic!("invalid response, expected A record");
};
assert_eq!(addr, &Ipv4Addr::new(93, 184, 216, 34));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, Default::default());
let sending = runtime.block_on(future::lazy(|cx| {
https.send_message::<TokioTime>(request, cx)
}));
let response: DnsResponse = runtime.block_on(sending).expect("send_message failed");
let record = &response.answers()[0];
let addr = if let RData::AAAA(addr) = record.rdata() {
addr
} else {
panic!("invalid response, expected A record");
};
assert_eq!(
addr,
&Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
}
}