1use alloc::boxed::Box;
9use alloc::string::String;
10use alloc::sync::Arc;
11use core::fmt::{self, Display};
12use core::future::Future;
13use core::ops::DerefMut;
14use core::pin::Pin;
15use core::str::FromStr;
16use core::task::{Context, Poll};
17use std::io;
18use std::net::SocketAddr;
19
20use bytes::{Buf, Bytes, BytesMut};
21use futures_util::future::{FutureExt, TryFutureExt};
22use futures_util::ready;
23use futures_util::stream::Stream;
24use h2::client::{Connection, SendRequest};
25use http::header::{self, CONTENT_LENGTH};
26use rustls::ClientConfig;
27use rustls::pki_types::ServerName;
28use tokio::time::{error, timeout};
29use tokio_rustls::{TlsConnector, client::TlsStream as TokioTlsClientStream};
30use tracing::{debug, warn};
31
32use crate::error::ProtoError;
33use crate::http::Version;
34use crate::runtime::RuntimeProvider;
35use crate::runtime::iocompat::AsyncIoStdAsTokio;
36use crate::tcp::DnsTcpStream;
37use crate::xfer::{CONNECT_TIMEOUT, DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
38
39const ALPN_H2: &[u8] = b"h2";
40
41#[derive(Clone)]
43#[must_use = "futures do nothing unless polled"]
44pub struct HttpsClientStream {
45 name_server_name: Arc<str>,
47 query_path: Arc<str>,
48 name_server: SocketAddr,
49 h2: SendRequest<Bytes>,
50 is_shutdown: bool,
51}
52
53impl Display for HttpsClientStream {
54 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
55 write!(
56 formatter,
57 "HTTPS({},{})",
58 self.name_server, self.name_server_name
59 )
60 }
61}
62
63impl HttpsClientStream {
64 async fn inner_send(
65 h2: SendRequest<Bytes>,
66 message: Bytes,
67 name_server_name: Arc<str>,
68 query_path: Arc<str>,
69 ) -> Result<DnsResponse, ProtoError> {
70 let mut h2 = match h2.ready().await {
71 Ok(h2) => h2,
72 Err(err) => {
73 return Err(ProtoError::from(format!("h2 send_request error: {err}")));
75 }
76 };
77
78 let request = crate::http::request::new(
80 Version::Http2,
81 &name_server_name,
82 &query_path,
83 message.remaining(),
84 );
85
86 let request =
87 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
88
89 debug!("request: {:#?}", request);
90
91 let (response_future, mut send_stream) = h2
93 .send_request(request, false)
94 .map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
95
96 send_stream
97 .send_data(message, true)
98 .map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
99
100 let mut response_stream = response_future
101 .await
102 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
103
104 debug!("got response: {:#?}", response_stream);
105
106 let content_length = response_stream
108 .headers()
109 .get(CONTENT_LENGTH)
110 .map(|v| v.to_str())
111 .transpose()
112 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
113 .map(usize::from_str)
114 .transpose()
115 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
116
117 let mut response_bytes =
121 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
122
123 while let Some(partial_bytes) = response_stream.body_mut().data().await {
124 let partial_bytes =
125 partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
126
127 debug!("got bytes: {}", partial_bytes.len());
128 response_bytes.extend(partial_bytes);
129
130 if let Some(content_length) = content_length {
132 if response_bytes.len() >= content_length {
133 break;
134 }
135 }
136 }
137
138 if let Some(content_length) = content_length {
140 if response_bytes.len() != content_length {
141 return Err(ProtoError::from(format!(
143 "expected byte length: {}, got: {}",
144 content_length,
145 response_bytes.len()
146 )));
147 }
148 }
149
150 if !response_stream.status().is_success() {
152 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
153
154 return Err(ProtoError::from(format!(
156 "http unsuccessful code: {}, message: {}",
157 response_stream.status(),
158 error_string
159 )));
160 } else {
161 {
163 let content_type = response_stream
165 .headers()
166 .get(header::CONTENT_TYPE)
167 .map(|h| {
168 h.to_str().map_err(|err| {
169 ProtoError::from(format!("ContentType header not a string: {err}"))
171 })
172 })
173 .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
174
175 if content_type != crate::http::MIME_APPLICATION_DNS {
176 return Err(ProtoError::from(format!(
177 "ContentType unsupported (must be '{}'): '{}'",
178 crate::http::MIME_APPLICATION_DNS,
179 content_type
180 )));
181 }
182 }
183 };
184
185 DnsResponse::from_buffer(response_bytes.to_vec())
187 }
188}
189
190impl DnsRequestSender for HttpsClientStream {
191 fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
239 if self.is_shutdown {
240 panic!("can not send messages after stream is shutdown")
241 }
242
243 request.set_id(0);
245
246 let bytes = match request.to_vec() {
247 Ok(bytes) => bytes,
248 Err(err) => return err.into(),
249 };
250
251 Box::pin(Self::inner_send(
252 self.h2.clone(),
253 Bytes::from(bytes),
254 Arc::clone(&self.name_server_name),
255 Arc::clone(&self.query_path),
256 ))
257 .into()
258 }
259
260 fn shutdown(&mut self) {
261 self.is_shutdown = true;
262 }
263
264 fn is_shutdown(&self) -> bool {
265 self.is_shutdown
266 }
267}
268
269impl Stream for HttpsClientStream {
270 type Item = Result<(), ProtoError>;
271
272 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
273 if self.is_shutdown {
274 return Poll::Ready(None);
275 }
276
277 match self.h2.poll_ready(cx) {
279 Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
280 Poll::Pending => Poll::Pending,
281 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
282 "h2 stream errored: {e}",
283 ))))),
284 }
285 }
286}
287
288#[derive(Clone)]
290pub struct HttpsClientStreamBuilder<P> {
291 provider: P,
292 client_config: Arc<ClientConfig>,
293 bind_addr: Option<SocketAddr>,
294}
295
296impl<P: RuntimeProvider> HttpsClientStreamBuilder<P> {
297 pub fn with_client_config(client_config: Arc<ClientConfig>, provider: P) -> Self {
299 Self {
300 provider,
301 client_config,
302 bind_addr: None,
303 }
304 }
305
306 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
308 self.bind_addr = Some(bind_addr);
309 }
310
311 pub fn build(
319 mut self,
320 name_server: SocketAddr,
321 dns_name: String,
322 http_endpoint: String,
323 ) -> HttpsClientConnect<P::Tcp> {
324 if self.client_config.alpn_protocols.is_empty() {
326 let mut client_config = (*self.client_config).clone();
327 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
328
329 self.client_config = Arc::new(client_config);
330 }
331
332 let tls = TlsConfig {
333 client_config: self.client_config,
334 dns_name: Arc::from(dns_name),
335 http_endpoint: Arc::from(http_endpoint),
336 };
337
338 let connect = self.provider.connect_tcp(name_server, self.bind_addr, None);
339 HttpsClientConnect(HttpsClientConnectState::TcpConnecting {
340 connect,
341 name_server,
342 tls: Some(tls),
343 })
344 }
345}
346
347pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
349where
350 S: DnsTcpStream;
351
352impl<S: DnsTcpStream> HttpsClientConnect<S> {
353 pub fn new<F>(
355 future: F,
356 mut client_config: Arc<ClientConfig>,
357 name_server: SocketAddr,
358 dns_name: String,
359 http_endpoint: String,
360 ) -> Self
361 where
362 S: DnsTcpStream,
363 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
364 {
365 if client_config.alpn_protocols.is_empty() {
367 let mut client_cfg = (*client_config).clone();
368 client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
369
370 client_config = Arc::new(client_cfg);
371 }
372
373 let tls = TlsConfig {
374 client_config,
375 dns_name: Arc::from(dns_name),
376 http_endpoint: Arc::from(http_endpoint),
377 };
378
379 Self(HttpsClientConnectState::TcpConnecting {
380 connect: Box::pin(future),
381 name_server,
382 tls: Some(tls),
383 })
384 }
385}
386
387impl<S> Future for HttpsClientConnect<S>
388where
389 S: DnsTcpStream,
390{
391 type Output = Result<HttpsClientStream, ProtoError>;
392
393 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
394 self.0.poll_unpin(cx)
395 }
396}
397
398struct TlsConfig {
399 client_config: Arc<ClientConfig>,
400 dns_name: Arc<str>,
401 http_endpoint: Arc<str>,
402}
403
404#[allow(clippy::large_enum_variant)]
405#[allow(clippy::type_complexity)]
406enum HttpsClientConnectState<S>
407where
408 S: DnsTcpStream,
409{
410 TcpConnecting {
411 connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
412 name_server: SocketAddr,
413 tls: Option<TlsConfig>,
414 },
415 TlsConnecting {
416 tls: Pin<
418 Box<
419 dyn Future<
420 Output = Result<
421 Result<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, io::Error>,
422 error::Elapsed,
423 >,
424 > + Send,
425 >,
426 >,
427 name_server_name: Arc<str>,
428 name_server: SocketAddr,
429 query_path: Arc<str>,
430 },
431 H2Handshake {
432 handshake: Pin<
433 Box<
434 dyn Future<
435 Output = Result<
436 (
437 SendRequest<Bytes>,
438 Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
439 ),
440 h2::Error,
441 >,
442 > + Send,
443 >,
444 >,
445 name_server_name: Arc<str>,
446 name_server: SocketAddr,
447 query_path: Arc<str>,
448 },
449 Connected(Option<HttpsClientStream>),
450 Errored(Option<ProtoError>),
451}
452
453impl<S> Future for HttpsClientConnectState<S>
454where
455 S: DnsTcpStream,
456{
457 type Output = Result<HttpsClientStream, ProtoError>;
458
459 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
460 loop {
461 let next = match &mut *self.as_mut() {
462 Self::TcpConnecting {
463 connect,
464 name_server,
465 tls,
466 } => {
467 let tcp = ready!(connect.poll_unpin(cx))?;
468
469 debug!("tcp connection established to: {}", name_server);
470 let tls = tls
471 .take()
472 .expect("programming error, tls should not be None here");
473 let name_server_name = Arc::clone(&tls.dns_name);
474 let query_path = Arc::clone(&tls.http_endpoint);
475
476 match ServerName::try_from(&*tls.dns_name) {
477 Ok(dns_name) => Self::TlsConnecting {
478 name_server_name,
479 name_server: *name_server,
480 tls: Box::pin(timeout(
481 CONNECT_TIMEOUT,
482 TlsConnector::from(tls.client_config)
483 .connect(dns_name.to_owned(), AsyncIoStdAsTokio(tcp)),
484 )),
485 query_path,
486 },
487 Err(_) => Self::Errored(Some(ProtoError::from(format!(
488 "bad dns_name: {}",
489 &tls.dns_name
490 )))),
491 }
492 }
493 Self::TlsConnecting {
494 name_server_name,
495 name_server,
496 query_path,
497 tls,
498 } => {
499 let Ok(res) = ready!(tls.poll_unpin(cx)) else {
500 return Poll::Ready(Err(format!(
501 "TLS handshake timed out after {CONNECT_TIMEOUT:?}"
502 )
503 .into()));
504 };
505 let tls = res?;
506 debug!("tls connection established to: {}", name_server);
507 let mut handshake = h2::client::Builder::new();
508 handshake.enable_push(false);
509
510 let handshake = handshake.handshake(tls);
511 Self::H2Handshake {
512 name_server_name: Arc::clone(name_server_name),
513 name_server: *name_server,
514 query_path: Arc::clone(query_path),
515 handshake: Box::pin(handshake),
516 }
517 }
518 Self::H2Handshake {
519 name_server_name,
520 name_server,
521 query_path,
522 handshake,
523 } => {
524 let (send_request, connection) = ready!(
525 handshake
526 .poll_unpin(cx)
527 .map_err(|e| ProtoError::from(format!("h2 handshake error: {e}")))
528 )?;
529
530 debug!("h2 connection established to: {}", name_server);
532 tokio::spawn(
533 connection
534 .map_err(|e| warn!("h2 connection failed: {e}"))
535 .map(|_: Result<(), ()>| ()),
536 );
537
538 Self::Connected(Some(HttpsClientStream {
539 name_server_name: Arc::clone(name_server_name),
540 name_server: *name_server,
541 query_path: Arc::clone(query_path),
542 h2: send_request,
543 is_shutdown: false,
544 }))
545 }
546 Self::Connected(conn) => {
547 return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")));
548 }
549 Self::Errored(err) => {
550 return Poll::Ready(Err(err.take().expect("cannot poll after complete")));
551 }
552 };
553
554 *self.as_mut().deref_mut() = next;
555 }
556 }
557}
558
559pub struct HttpsClientResponse(
561 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
562);
563
564impl Future for HttpsClientResponse {
565 type Output = Result<DnsResponse, ProtoError>;
566
567 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
568 self.0.as_mut().poll(cx).map_err(ProtoError::from)
569 }
570}
571
572#[cfg(any(feature = "webpki-roots", feature = "rustls-platform-verifier"))]
573#[cfg(test)]
574mod tests {
575 use alloc::string::ToString;
576 use std::net::SocketAddr;
577
578 use rustls::KeyLogFile;
579 use test_support::subscribe;
580
581 use crate::op::{Edns, Message, Query};
582 use crate::rr::{Name, RecordType};
583 use crate::runtime::TokioRuntimeProvider;
584 use crate::rustls::client_config;
585 use crate::xfer::{DnsRequestOptions, FirstAnswer};
586
587 use super::*;
588
589 #[tokio::test]
590 async fn test_https_google() {
591 subscribe();
592
593 let google = SocketAddr::from(([8, 8, 8, 8], 443));
594 let mut request = Message::new();
595 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
596 request.add_query(query);
597 request.set_recursion_desired(true);
598 let mut edns = Edns::new();
599 edns.set_version(0);
600 edns.set_max_payload(1232);
601 *request.extensions_mut() = Some(edns);
602
603 let request = DnsRequest::new(request, DnsRequestOptions::default());
604
605 let mut client_config = client_config_h2();
606 client_config.key_log = Arc::new(KeyLogFile::new());
607
608 let provider = TokioRuntimeProvider::new();
609 let https_builder =
610 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
611 let connect =
612 https_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
613
614 let mut https = connect.await.expect("https connect failed");
615
616 let response = https
617 .send_message(request)
618 .first_answer()
619 .await
620 .expect("send_message failed");
621
622 assert!(
623 response
624 .answers()
625 .iter()
626 .any(|record| record.data().as_a().is_some())
627 );
628
629 let mut request = Message::new();
632 let query = Query::query(
633 Name::from_str("www.example.com.").unwrap(),
634 RecordType::AAAA,
635 );
636 request.add_query(query);
637 request.set_recursion_desired(true);
638 let mut edns = Edns::new();
639 edns.set_version(0);
640 edns.set_max_payload(1232);
641 *request.extensions_mut() = Some(edns);
642
643 let request = DnsRequest::new(request, DnsRequestOptions::default());
644
645 let response = https
646 .send_message(request.clone())
647 .first_answer()
648 .await
649 .expect("send_message failed");
650
651 assert!(
652 response
653 .answers()
654 .iter()
655 .any(|record| record.data().as_aaaa().is_some())
656 );
657 }
658
659 #[tokio::test]
660 async fn test_https_google_with_pure_ip_address_server() {
661 subscribe();
662
663 let google = SocketAddr::from(([8, 8, 8, 8], 443));
664 let mut request = Message::new();
665 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
666 request.add_query(query);
667 request.set_recursion_desired(true);
668 let mut edns = Edns::new();
669 edns.set_version(0);
670 edns.set_max_payload(1232);
671 *request.extensions_mut() = Some(edns);
672
673 let request = DnsRequest::new(request, DnsRequestOptions::default());
674
675 let mut client_config = client_config_h2();
676 client_config.key_log = Arc::new(KeyLogFile::new());
677
678 let provider = TokioRuntimeProvider::new();
679 let https_builder =
680 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
681 let connect =
682 https_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
683
684 let mut https = connect.await.expect("https connect failed");
685
686 let response = https
687 .send_message(request)
688 .first_answer()
689 .await
690 .expect("send_message failed");
691
692 assert!(
693 response
694 .answers()
695 .iter()
696 .any(|record| record.data().as_a().is_some())
697 );
698
699 let mut request = Message::new();
702 let query = Query::query(
703 Name::from_str("www.example.com.").unwrap(),
704 RecordType::AAAA,
705 );
706 request.add_query(query);
707 request.set_recursion_desired(true);
708 let mut edns = Edns::new();
709 edns.set_version(0);
710 edns.set_max_payload(1232);
711 *request.extensions_mut() = Some(edns);
712
713 let request = DnsRequest::new(request, DnsRequestOptions::default());
714
715 let response = https
716 .send_message(request.clone())
717 .first_answer()
718 .await
719 .expect("send_message failed");
720
721 assert!(
722 response
723 .answers()
724 .iter()
725 .any(|record| record.data().as_aaaa().is_some())
726 );
727 }
728
729 #[tokio::test]
730 #[ignore = "cloudflare has been unreliable as a public test service"]
731 async fn test_https_cloudflare() {
732 subscribe();
733
734 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
735 let mut request = Message::new();
736 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
737 request.add_query(query);
738 request.set_recursion_desired(true);
739 let mut edns = Edns::new();
740 edns.set_version(0);
741 edns.set_max_payload(1232);
742 *request.extensions_mut() = Some(edns);
743
744 let request = DnsRequest::new(request, DnsRequestOptions::default());
745
746 let client_config = client_config_h2();
747 let provider = TokioRuntimeProvider::new();
748 let https_builder =
749 HttpsClientStreamBuilder::with_client_config(Arc::new(client_config), provider);
750 let connect = https_builder.build(
751 cloudflare,
752 "cloudflare-dns.com".to_string(),
753 "/dns-query".to_string(),
754 );
755
756 let mut https = connect.await.expect("https connect failed");
757
758 let response = https
759 .send_message(request)
760 .first_answer()
761 .await
762 .expect("send_message failed");
763
764 assert!(
765 response
766 .answers()
767 .iter()
768 .any(|record| record.data().as_a().is_some())
769 );
770
771 let mut request = Message::new();
774 let query = Query::query(
775 Name::from_str("www.example.com.").unwrap(),
776 RecordType::AAAA,
777 );
778 request.add_query(query);
779 request.set_recursion_desired(true);
780 let mut edns = Edns::new();
781 edns.set_version(0);
782 edns.set_max_payload(1232);
783 *request.extensions_mut() = Some(edns);
784
785 let request = DnsRequest::new(request, DnsRequestOptions::default());
786
787 let response = https
788 .send_message(request)
789 .first_answer()
790 .await
791 .expect("send_message failed");
792
793 assert!(
794 response
795 .answers()
796 .iter()
797 .any(|record| record.data().as_aaaa().is_some())
798 );
799 }
800
801 fn client_config_h2() -> ClientConfig {
802 let mut config = client_config();
803 config.alpn_protocols = vec![ALPN_H2.to_vec()];
804 config
805 }
806}