use std::fmt::{self, Display};
use std::future::{poll_fn, Future};
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::future::FutureExt;
use futures_util::stream::Stream;
use h3::client::SendRequest;
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::crypto::rustls::QuicClientConfig;
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
use rustls::ClientConfig as TlsClientConfig;
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::error::ProtoError;
use crate::http::Version;
use crate::op::Message;
use crate::udp::UdpSocket;
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
use super::ALPN_H3;
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
name_server_name: Arc<str>,
name_server: SocketAddr,
query_path: Arc<str>,
send_request: SendRequest<OpenStreams, Bytes>,
shutdown_tx: mpsc::Sender<()>,
is_shutdown: bool,
}
impl Display for H3ClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
formatter,
"H3({},{})",
self.name_server, self.name_server_name
)
}
}
impl H3ClientStream {
pub fn builder() -> H3ClientStreamBuilder {
H3ClientStreamBuilder::default()
}
async fn inner_send(
mut h3: SendRequest<OpenStreams, Bytes>,
message: Bytes,
name_server_name: Arc<str>,
query_path: Arc<str>,
) -> Result<DnsResponse, ProtoError> {
let request = crate::http::request::new(
Version::Http3,
&name_server_name,
&query_path,
message.remaining(),
);
let request =
request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
debug!("request: {:#?}", request);
let mut stream = h3
.send_request(request)
.await
.map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
stream
.send_data(message)
.await
.map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
stream
.finish()
.await
.map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
let response = stream
.recv_response()
.await
.map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
debug!("got response: {:#?}", response);
let content_length = response
.headers()
.get(CONTENT_LENGTH)
.map(|v| v.to_str())
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
.map(usize::from_str)
.transpose()
.map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
let mut response_bytes =
BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
while let Some(partial_bytes) = stream
.recv_data()
.await
.map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
{
debug!("got bytes: {}", partial_bytes.remaining());
response_bytes.put(partial_bytes);
if let Some(content_length) = content_length {
if response_bytes.len() >= content_length {
break;
}
}
}
if let Some(content_length) = content_length {
if response_bytes.len() != content_length {
return Err(ProtoError::from(format!(
"expected byte length: {}, got: {}",
content_length,
response_bytes.len()
)));
}
}
if !response.status().is_success() {
let error_string = String::from_utf8_lossy(response_bytes.as_ref());
return Err(ProtoError::from(format!(
"http unsuccessful code: {}, message: {}",
response.status(),
error_string
)));
} else {
{
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.map(|h| {
h.to_str().map_err(|err| {
ProtoError::from(format!("ContentType header not a string: {err}"))
})
})
.unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
if content_type != crate::http::MIME_APPLICATION_DNS {
return Err(ProtoError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::http::MIME_APPLICATION_DNS,
content_type
)));
}
}
};
let message = Message::from_vec(&response_bytes)?;
Ok(DnsResponse::new(message, response_bytes.to_vec()))
}
}
impl DnsRequestSender for H3ClientStream {
fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
message.set_id(0);
let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => return err.into(),
};
Box::pin(Self::inner_send(
self.send_request.clone(),
Bytes::from(bytes),
Arc::clone(&self.name_server_name),
Arc::clone(&self.query_path),
))
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl Stream for H3ClientStream {
type Item = Result<(), ProtoError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}
if self.shutdown_tx.is_closed() {
return Poll::Ready(Some(Err(ProtoError::from(
"h3 connection is already shutdown",
))));
}
Poll::Ready(Some(Ok(())))
}
}
#[derive(Clone)]
pub struct H3ClientStreamBuilder {
crypto_config: TlsClientConfig,
transport_config: Arc<TransportConfig>,
bind_addr: Option<SocketAddr>,
}
impl H3ClientStreamBuilder {
pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
self.crypto_config = crypto_config;
self
}
pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
self.bind_addr = Some(bind_addr);
}
pub fn build(
self,
name_server: SocketAddr,
dns_name: String,
query_path: String,
) -> H3ClientConnect {
H3ClientConnect(Box::pin(self.connect(name_server, dns_name, query_path)) as _)
}
pub fn build_with_future(
self,
socket: Arc<dyn quinn::AsyncUdpSocket>,
name_server: SocketAddr,
dns_name: String,
query_path: String,
) -> H3ClientConnect {
H3ClientConnect(Box::pin(self.connect_with_future(
socket,
name_server,
dns_name,
query_path,
)) as _)
}
async fn connect_with_future(
self,
socket: Arc<dyn quinn::AsyncUdpSocket>,
name_server: SocketAddr,
server_name: String,
query_path: String,
) -> Result<H3ClientStream, ProtoError> {
let endpoint = Endpoint::new_with_abstract_socket(
EndpointConfig::default(),
None,
socket,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, server_name, query_path)
.await
}
async fn connect(
self,
name_server: SocketAddr,
dns_name: String,
query_path: String,
) -> Result<H3ClientStream, ProtoError> {
let connect = if let Some(bind_addr) = self.bind_addr {
<tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
} else {
<tokio::net::UdpSocket as UdpSocket>::connect(name_server)
};
let socket = connect.await?;
let socket = socket.into_std()?;
let endpoint = Endpoint::new(
EndpointConfig::default(),
None,
socket,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, dns_name, query_path)
.await
}
async fn connect_inner(
self,
mut endpoint: Endpoint,
name_server: SocketAddr,
dns_name: String,
query_path: String,
) -> Result<H3ClientStream, ProtoError> {
let mut crypto_config = self.crypto_config;
if crypto_config.alpn_protocols.is_empty() {
crypto_config.alpn_protocols = vec![ALPN_H3.to_vec()];
}
let early_data_enabled = crypto_config.enable_early_data;
let mut client_config =
ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto_config)?));
client_config.transport_config(self.transport_config.clone());
endpoint.set_default_client_config(client_config);
let connecting = endpoint.connect(name_server, &dns_name)?;
let quic_connection = if early_data_enabled {
match connecting.into_0rtt() {
Ok((new_connection, _)) => new_connection,
Err(connecting) => connecting.await?,
}
} else {
connecting.await?
};
let h3_connection = h3_quinn::Connection::new(quic_connection);
let (mut driver, send_request) = h3::client::new(h3_connection)
.await
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
debug!("h3 connection is ready: {}", name_server);
tokio::spawn(async move {
tokio::select! {
res = poll_fn(|cx| driver.poll_close(cx)) => {
res.map_err(|e| warn!("h3 connection failed: {e}"))
}
_ = shutdown_rx.recv() => {
debug!("h3 connection is shutting down: {}", name_server);
Ok(())
}
}
});
Ok(H3ClientStream {
name_server_name: Arc::from(dns_name),
name_server,
query_path: Arc::from(query_path),
send_request,
shutdown_tx,
is_shutdown: false,
})
}
}
impl Default for H3ClientStreamBuilder {
fn default() -> Self {
Self {
crypto_config: super::client_config_tls13().unwrap(),
transport_config: Arc::new(super::transport()),
bind_addr: None,
}
}
}
pub struct H3ClientConnect(
Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
);
impl Future for H3ClientConnect {
type Output = Result<H3ClientStream, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
impl Future for H3ClientResponse {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx).map_err(ProtoError::from)
}
}
#[cfg(all(test, any(feature = "native-certs", feature = "webpki-roots")))]
mod tests {
use std::net::SocketAddr;
use std::str::FromStr;
use rustls::KeyLogFile;
use test_support::subscribe;
use tokio::runtime::Runtime;
use tokio::task::JoinSet;
use crate::op::{Message, Query, ResponseCode};
use crate::rr::rdata::{A, AAAA};
use crate::rr::{Name, RecordType};
use crate::xfer::{DnsRequestOptions, FirstAnswer};
use super::*;
#[test]
fn test_h3_google() {
subscribe();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record.data().as_a().expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 215, 14));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(h3.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.as_aaaa()
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
);
}
}
#[test]
fn test_h3_google_with_pure_ip_address_server() {
subscribe();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record.data().as_a().expect("Expected A record");
assert_eq!(addr, &A::new(93, 184, 215, 14));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
for _ in 0..3 {
let response = runtime
.block_on(h3.send_message(request.clone()).first_answer())
.expect("send_message failed");
if response.response_code() == ResponseCode::ServFail {
continue;
}
let record = &response.answers()[0];
let addr = record
.data()
.as_aaaa()
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
);
}
}
#[test]
#[ignore] fn test_h3_cloudflare() {
subscribe();
let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
let mut request = Message::new();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(
cloudflare,
"cloudflare-dns.com".to_string(),
"/dns-query".to_string(),
);
let runtime = Runtime::new().expect("could not start runtime");
let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.as_a()
.expect("invalid response, expected A record");
assert_eq!(addr, &A::new(93, 184, 215, 14));
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = runtime
.block_on(h3.send_message(request).first_answer())
.expect("send_message failed");
let record = &response.answers()[0];
let addr = record
.data()
.as_aaaa()
.expect("invalid response, expected A record");
assert_eq!(
addr,
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
);
}
#[test]
#[allow(clippy::print_stdout)]
fn test_h3_client_stream_clonable() {
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
let runtime = Runtime::new().expect("could not start runtime");
let h3 = runtime.block_on(connect).expect("h3 connect failed");
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
runtime.block_on(async move {
let mut join_set = JoinSet::new();
for i in 0..50 {
let mut h3 = h3.clone();
let request = request.clone();
join_set.spawn(async move {
let start = std::time::Instant::now();
h3.send_message(request)
.first_answer()
.await
.expect("send_message failed");
println!("request[{i}] completed: {:?}", start.elapsed());
});
}
let total = join_set.len();
let mut idx = 0usize;
while join_set.join_next().await.is_some() {
println!("join_set completed {idx}/{total}");
idx += 1;
}
});
}
}