1use std::fmt::{self, Display};
9use std::future::Future;
10use std::io;
11use std::net::SocketAddr;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18use bytes::{Buf, Bytes, BytesMut};
19use futures_util::future::{FutureExt, TryFutureExt};
20use futures_util::ready;
21use futures_util::stream::Stream;
22use h2::client::{Connection, SendRequest};
23use http::header::{self, CONTENT_LENGTH};
24use rustls::ClientConfig;
25use tokio_rustls::{
26 client::TlsStream as TokioTlsClientStream, Connect as TokioTlsConnect, TlsConnector,
27};
28use tracing::{debug, warn};
29
30use crate::error::ProtoError;
31use crate::http::Version;
32use crate::iocompat::AsyncIoStdAsTokio;
33use crate::op::Message;
34use crate::tcp::{Connect, DnsTcpStream};
35use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
36
37const ALPN_H2: &[u8] = b"h2";
38
39#[derive(Clone)]
41#[must_use = "futures do nothing unless polled"]
42pub struct HttpsClientStream {
43 name_server_name: Arc<str>,
45 name_server: SocketAddr,
46 h2: SendRequest<Bytes>,
47 is_shutdown: bool,
48}
49
50impl Display for HttpsClientStream {
51 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
52 write!(
53 formatter,
54 "HTTPS({},{})",
55 self.name_server, self.name_server_name
56 )
57 }
58}
59
60impl HttpsClientStream {
61 async fn inner_send(
62 h2: SendRequest<Bytes>,
63 message: Bytes,
64 name_server_name: Arc<str>,
65 ) -> Result<DnsResponse, ProtoError> {
66 let mut h2 = match h2.ready().await {
67 Ok(h2) => h2,
68 Err(err) => {
69 return Err(ProtoError::from(format!("h2 send_request error: {err}")));
71 }
72 };
73
74 let request =
76 crate::http::request::new(Version::Http2, &name_server_name, message.remaining());
77
78 let request =
79 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
80
81 debug!("request: {:#?}", request);
82
83 let (response_future, mut send_stream) = h2
85 .send_request(request, false)
86 .map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
87
88 send_stream
89 .send_data(message, true)
90 .map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
91
92 let mut response_stream = response_future
93 .await
94 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
95
96 debug!("got response: {:#?}", response_stream);
97
98 let content_length = response_stream
100 .headers()
101 .get(CONTENT_LENGTH)
102 .map(|v| v.to_str())
103 .transpose()
104 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
105 .map(usize::from_str)
106 .transpose()
107 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
108
109 let mut response_bytes =
113 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
114
115 while let Some(partial_bytes) = response_stream.body_mut().data().await {
116 let partial_bytes =
117 partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
118
119 debug!("got bytes: {}", partial_bytes.len());
120 response_bytes.extend(partial_bytes);
121
122 if let Some(content_length) = content_length {
124 if response_bytes.len() >= content_length {
125 break;
126 }
127 }
128 }
129
130 if let Some(content_length) = content_length {
132 if response_bytes.len() != content_length {
133 return Err(ProtoError::from(format!(
135 "expected byte length: {}, got: {}",
136 content_length,
137 response_bytes.len()
138 )));
139 }
140 }
141
142 if !response_stream.status().is_success() {
144 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
145
146 return Err(ProtoError::from(format!(
148 "http unsuccessful code: {}, message: {}",
149 response_stream.status(),
150 error_string
151 )));
152 } else {
153 {
155 let content_type = response_stream
157 .headers()
158 .get(header::CONTENT_TYPE)
159 .map(|h| {
160 h.to_str().map_err(|err| {
161 ProtoError::from(format!("ContentType header not a string: {err}"))
163 })
164 })
165 .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
166
167 if content_type != crate::http::MIME_APPLICATION_DNS {
168 return Err(ProtoError::from(format!(
169 "ContentType unsupported (must be '{}'): '{}'",
170 crate::http::MIME_APPLICATION_DNS,
171 content_type
172 )));
173 }
174 }
175 };
176
177 let message = Message::from_vec(&response_bytes)?;
179 Ok(DnsResponse::new(message, response_bytes.to_vec()))
180 }
181}
182
183impl DnsRequestSender for HttpsClientStream {
184 fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
232 if self.is_shutdown {
233 panic!("can not send messages after stream is shutdown")
234 }
235
236 message.set_id(0);
238
239 let bytes = match message.to_vec() {
240 Ok(bytes) => bytes,
241 Err(err) => return err.into(),
242 };
243
244 Box::pin(Self::inner_send(
245 self.h2.clone(),
246 Bytes::from(bytes),
247 Arc::clone(&self.name_server_name),
248 ))
249 .into()
250 }
251
252 fn shutdown(&mut self) {
253 self.is_shutdown = true;
254 }
255
256 fn is_shutdown(&self) -> bool {
257 self.is_shutdown
258 }
259}
260
261impl Stream for HttpsClientStream {
262 type Item = Result<(), ProtoError>;
263
264 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
265 if self.is_shutdown {
266 return Poll::Ready(None);
267 }
268
269 match self.h2.poll_ready(cx) {
271 Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
272 Poll::Pending => Poll::Pending,
273 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
274 "h2 stream errored: {e}",
275 ))))),
276 }
277 }
278}
279
280#[derive(Clone)]
282pub struct HttpsClientStreamBuilder {
283 client_config: Arc<ClientConfig>,
284 bind_addr: Option<SocketAddr>,
285}
286
287impl HttpsClientStreamBuilder {
288 pub fn with_client_config(client_config: Arc<ClientConfig>) -> Self {
290 Self {
291 client_config,
292 bind_addr: None,
293 }
294 }
295
296 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
298 self.bind_addr = Some(bind_addr);
299 }
300
301 pub fn build<S: Connect>(
308 mut self,
309 name_server: SocketAddr,
310 dns_name: String,
311 ) -> HttpsClientConnect<S> {
312 if self.client_config.alpn_protocols.is_empty() {
314 let mut client_config = (*self.client_config).clone();
315 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
316
317 self.client_config = Arc::new(client_config);
318 }
319
320 let tls = TlsConfig {
321 client_config: self.client_config,
322 dns_name: Arc::from(dns_name),
323 };
324
325 let connect = S::connect_with_bind(name_server, self.bind_addr);
326
327 HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
328 connect,
329 name_server,
330 tls: Some(tls),
331 })
332 }
333
334 pub fn build_with_future<S, F>(
336 future: F,
337 mut client_config: Arc<ClientConfig>,
338 name_server: SocketAddr,
339 dns_name: String,
340 ) -> HttpsClientConnect<S>
341 where
342 S: DnsTcpStream,
343 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
344 {
345 if client_config.alpn_protocols.is_empty() {
347 let mut client_cfg = (*client_config).clone();
348 client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
349
350 client_config = Arc::new(client_cfg);
351 }
352
353 let tls = TlsConfig {
354 client_config,
355 dns_name: Arc::from(dns_name),
356 };
357
358 HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
359 connect: Box::pin(future),
360 name_server,
361 tls: Some(tls),
362 })
363 }
364}
365
366pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
368where
369 S: DnsTcpStream;
370
371impl<S> Future for HttpsClientConnect<S>
372where
373 S: DnsTcpStream,
374{
375 type Output = Result<HttpsClientStream, ProtoError>;
376
377 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
378 self.0.poll_unpin(cx)
379 }
380}
381
382struct TlsConfig {
383 client_config: Arc<ClientConfig>,
384 dns_name: Arc<str>,
385}
386
387#[allow(clippy::large_enum_variant)]
388#[allow(clippy::type_complexity)]
389enum HttpsClientConnectState<S>
390where
391 S: DnsTcpStream,
392{
393 TcpConnecting {
394 connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
395 name_server: SocketAddr,
396 tls: Option<TlsConfig>,
397 },
398 TlsConnecting {
399 tls: TokioTlsConnect<AsyncIoStdAsTokio<S>>,
401 name_server_name: Arc<str>,
402 name_server: SocketAddr,
403 },
404 H2Handshake {
405 handshake: Pin<
406 Box<
407 dyn Future<
408 Output = Result<
409 (
410 SendRequest<Bytes>,
411 Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
412 ),
413 h2::Error,
414 >,
415 > + Send,
416 >,
417 >,
418 name_server_name: Arc<str>,
419 name_server: SocketAddr,
420 },
421 Connected(Option<HttpsClientStream>),
422 Errored(Option<ProtoError>),
423}
424
425impl<S> Future for HttpsClientConnectState<S>
426where
427 S: DnsTcpStream,
428{
429 type Output = Result<HttpsClientStream, ProtoError>;
430
431 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
432 loop {
433 let next = match *self {
434 Self::TcpConnecting {
435 ref mut connect,
436 name_server,
437 ref mut tls,
438 } => {
439 let tcp = ready!(connect.poll_unpin(cx))?;
440
441 debug!("tcp connection established to: {}", name_server);
442 let tls = tls
443 .take()
444 .expect("programming error, tls should not be None here");
445 let name_server_name = Arc::clone(&tls.dns_name);
446
447 match tls.dns_name.as_ref().try_into() {
448 Ok(dns_name) => {
449 let tls = TlsConnector::from(tls.client_config);
450 let tls = tls.connect(dns_name, AsyncIoStdAsTokio(tcp));
451 Self::TlsConnecting {
452 name_server_name,
453 name_server,
454 tls,
455 }
456 }
457 Err(_) => Self::Errored(Some(ProtoError::from(format!(
458 "bad dns_name: {}",
459 &tls.dns_name
460 )))),
461 }
462 }
463 Self::TlsConnecting {
464 ref name_server_name,
465 name_server,
466 ref mut tls,
467 } => {
468 let tls = ready!(tls.poll_unpin(cx))?;
469 debug!("tls connection established to: {}", name_server);
470 let mut handshake = h2::client::Builder::new();
471 handshake.enable_push(false);
472
473 let handshake = handshake.handshake(tls);
474 Self::H2Handshake {
475 name_server_name: Arc::clone(name_server_name),
476 name_server,
477 handshake: Box::pin(handshake),
478 }
479 }
480 Self::H2Handshake {
481 ref name_server_name,
482 name_server,
483 ref mut handshake,
484 } => {
485 let (send_request, connection) = ready!(handshake
486 .poll_unpin(cx)
487 .map_err(|e| ProtoError::from(format!("h2 handshake error: {e}"))))?;
488
489 debug!("h2 connection established to: {}", name_server);
491 tokio::spawn(
492 connection
493 .map_err(|e| warn!("h2 connection failed: {e}"))
494 .map(|_: Result<(), ()>| ()),
495 );
496
497 Self::Connected(Some(HttpsClientStream {
498 name_server_name: Arc::clone(name_server_name),
499 name_server,
500 h2: send_request,
501 is_shutdown: false,
502 }))
503 }
504 Self::Connected(ref mut conn) => {
505 return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")))
506 }
507 Self::Errored(ref mut err) => {
508 return Poll::Ready(Err(err.take().expect("cannot poll after complete")))
509 }
510 };
511
512 *self.as_mut().deref_mut() = next;
513 }
514 }
515}
516
517pub struct HttpsClientResponse(
519 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
520);
521
522impl Future for HttpsClientResponse {
523 type Output = Result<DnsResponse, ProtoError>;
524
525 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
526 self.0.as_mut().poll(cx).map_err(ProtoError::from)
527 }
528}
529
530#[cfg(any(feature = "webpki-roots", feature = "native-certs"))]
531#[cfg(test)]
532mod tests {
533 use std::net::SocketAddr;
534 use std::str::FromStr;
535
536 use rustls::KeyLogFile;
537 use tokio::net::TcpStream as TokioTcpStream;
538 use tokio::runtime::Runtime;
539
540 use crate::iocompat::AsyncIoTokioAsStd;
541 use crate::op::{Edns, Message, Query};
542 use crate::rr::rdata::A;
543 use crate::rr::{Name, RData, RecordType};
544 use crate::xfer::{DnsRequestOptions, FirstAnswer};
545
546 use super::*;
547
548 #[test]
549 fn test_https_google() {
550 let google = SocketAddr::from(([8, 8, 8, 8], 443));
553 let mut request = Message::new();
554 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
555 request.add_query(query);
556 request.set_recursion_desired(true);
557 let mut edns = Edns::new();
558 edns.set_version(0);
559 edns.set_max_payload(1232);
560 *request.extensions_mut() = Some(edns);
561
562 let request = DnsRequest::new(request, DnsRequestOptions::default());
563
564 let mut client_config = client_config_tls12();
565 client_config.key_log = Arc::new(KeyLogFile::new());
566
567 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
568 let connect = https_builder
569 .build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, "dns.google".to_string());
570
571 let runtime = Runtime::new().expect("could not start runtime");
573 let mut https = runtime.block_on(connect).expect("https connect failed");
574
575 let response = runtime
576 .block_on(https.send_message(request).first_answer())
577 .expect("send_message failed");
578
579 assert!(response
580 .answers()
581 .iter()
582 .any(|record| record.data().unwrap().as_a().is_some()));
583
584 let mut request = Message::new();
587 let query = Query::query(
588 Name::from_str("www.example.com.").unwrap(),
589 RecordType::AAAA,
590 );
591 request.add_query(query);
592 request.set_recursion_desired(true);
593 let mut edns = Edns::new();
594 edns.set_version(0);
595 edns.set_max_payload(1232);
596 *request.extensions_mut() = Some(edns);
597
598 let request = DnsRequest::new(request, DnsRequestOptions::default());
599
600 for _ in 0..3 {
601 let response = runtime
602 .block_on(https.send_message(request.clone()).first_answer())
603 .expect("send_message failed");
604
605 assert!(response.answers().iter().any(|record| record
606 .data()
607 .unwrap()
608 .as_aaaa()
609 .is_some()));
610 }
611 }
612
613 #[test]
614 fn test_https_google_with_pure_ip_address_server() {
615 let google = SocketAddr::from(([8, 8, 8, 8], 443));
618 let mut request = Message::new();
619 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
620 request.add_query(query);
621 request.set_recursion_desired(true);
622 let mut edns = Edns::new();
623 edns.set_version(0);
624 edns.set_max_payload(1232);
625 *request.extensions_mut() = Some(edns);
626
627 let request = DnsRequest::new(request, DnsRequestOptions::default());
628
629 let mut client_config = client_config_tls12();
630 client_config.key_log = Arc::new(KeyLogFile::new());
631
632 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
633 let connect = https_builder
634 .build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, google.ip().to_string());
635
636 let runtime = Runtime::new().expect("could not start runtime");
638 let mut https = runtime.block_on(connect).expect("https connect failed");
639
640 let response = runtime
641 .block_on(https.send_message(request).first_answer())
642 .expect("send_message failed");
643
644 assert!(response
645 .answers()
646 .iter()
647 .any(|record| record.data().unwrap().as_a().is_some()));
648
649 let mut request = Message::new();
652 let query = Query::query(
653 Name::from_str("www.example.com.").unwrap(),
654 RecordType::AAAA,
655 );
656 request.add_query(query);
657 request.set_recursion_desired(true);
658 let mut edns = Edns::new();
659 edns.set_version(0);
660 edns.set_max_payload(1232);
661 *request.extensions_mut() = Some(edns);
662
663 let request = DnsRequest::new(request, DnsRequestOptions::default());
664
665 for _ in 0..3 {
666 let response = runtime
667 .block_on(https.send_message(request.clone()).first_answer())
668 .expect("send_message failed");
669
670 assert!(response.answers().iter().any(|record| record
671 .data()
672 .unwrap()
673 .as_aaaa()
674 .is_some()));
675 }
676 }
677
678 #[test]
679 #[ignore] fn test_https_cloudflare() {
681 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
684 let mut request = Message::new();
685 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
686 request.add_query(query);
687
688 let request = DnsRequest::new(request, DnsRequestOptions::default());
689
690 let client_config = client_config_tls12();
691 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
692 let connect = https_builder.build::<AsyncIoTokioAsStd<TokioTcpStream>>(
693 cloudflare,
694 "cloudflare-dns.com".to_string(),
695 );
696
697 let runtime = Runtime::new().expect("could not start runtime");
699 let mut https = runtime.block_on(connect).expect("https connect failed");
700
701 let response = runtime
702 .block_on(https.send_message(request).first_answer())
703 .expect("send_message failed");
704
705 let record = &response.answers()[0];
706 let addr = record
707 .data()
708 .and_then(RData::as_a)
709 .expect("invalid response, expected A record");
710
711 assert_eq!(addr, &A::new(93, 184, 215, 14));
712
713 let mut request = Message::new();
716 let query = Query::query(
717 Name::from_str("www.example.com.").unwrap(),
718 RecordType::AAAA,
719 );
720 request.add_query(query);
721 request.set_recursion_desired(true);
722 let mut edns = Edns::new();
723 edns.set_version(0);
724 edns.set_max_payload(1232);
725 *request.extensions_mut() = Some(edns);
726
727 let request = DnsRequest::new(request, DnsRequestOptions::default());
728
729 let response = runtime
730 .block_on(https.send_message(request).first_answer())
731 .expect("send_message failed");
732
733 assert!(response
734 .answers()
735 .iter()
736 .any(|record| record.data().unwrap().as_aaaa().is_some()));
737 }
738
739 fn client_config_tls12() -> ClientConfig {
740 use rustls::RootCertStore;
741 #[cfg_attr(
742 not(any(feature = "native-certs", feature = "webpki-roots")),
743 allow(unused_mut)
744 )]
745 let mut root_store = RootCertStore::empty();
746 #[cfg(all(feature = "native-certs", not(feature = "webpki-roots")))]
747 {
748 let (added, ignored) = root_store
749 .add_parsable_certificates(&rustls_native_certs::load_native_certs().unwrap());
750
751 if ignored > 0 {
752 warn!(
753 "failed to parse {} certificate(s) from the native root store",
754 ignored
755 );
756 }
757
758 if added == 0 {
759 panic!("no valid certificates found in the native root store");
760 }
761 }
762 #[cfg(feature = "webpki-roots")]
763 root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
764 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
765 ta.subject,
766 ta.spki,
767 ta.name_constraints,
768 )
769 }));
770
771 let mut client_config = ClientConfig::builder()
772 .with_safe_default_cipher_suites()
773 .with_safe_default_kx_groups()
774 .with_safe_default_protocol_versions()
775 .unwrap()
776 .with_root_certificates(root_store)
777 .with_no_client_auth();
778
779 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
780 client_config
781 }
782}