1use alloc::boxed::Box;
9use alloc::string::String;
10use alloc::sync::Arc;
11use core::fmt::{self, Display};
12use core::future::{Future, poll_fn};
13use core::pin::Pin;
14use core::str::FromStr;
15use core::task::{Context, Poll};
16use std::net::SocketAddr;
17
18use bytes::{Buf, BufMut, Bytes, BytesMut};
19use futures_util::future::FutureExt;
20use futures_util::stream::Stream;
21use h3::client::SendRequest;
22use h3_quinn::OpenStreams;
23use http::header::{self, CONTENT_LENGTH};
24use quinn::{Endpoint, EndpointConfig, TransportConfig};
25use tokio::sync::mpsc;
26use tracing::{debug, warn};
27
28use crate::error::ProtoError;
29use crate::http::Version;
30use crate::quic::connect_quic;
31use crate::rustls::client_config;
32use crate::udp::UdpSocket;
33use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
34
35use super::ALPN_H3;
36
37#[derive(Clone)]
39#[must_use = "futures do nothing unless polled"]
40pub struct H3ClientStream {
41 name_server_name: Arc<str>,
43 name_server: SocketAddr,
44 query_path: Arc<str>,
45 send_request: SendRequest<OpenStreams, Bytes>,
46 shutdown_tx: mpsc::Sender<()>,
47 is_shutdown: bool,
48}
49
50impl Display for H3ClientStream {
51 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
52 write!(
53 formatter,
54 "H3({},{})",
55 self.name_server, self.name_server_name
56 )
57 }
58}
59
60impl H3ClientStream {
61 pub fn builder() -> H3ClientStreamBuilder {
63 H3ClientStreamBuilder::default()
64 }
65
66 async fn inner_send(
67 mut h3: SendRequest<OpenStreams, Bytes>,
68 message: Bytes,
69 name_server_name: Arc<str>,
70 query_path: Arc<str>,
71 ) -> Result<DnsResponse, ProtoError> {
72 let request = crate::http::request::new(
74 Version::Http3,
75 &name_server_name,
76 &query_path,
77 message.remaining(),
78 );
79
80 let request =
81 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
82
83 debug!("request: {:#?}", request);
84
85 let mut stream = h3
87 .send_request(request)
88 .await
89 .map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
90
91 stream
92 .send_data(message)
93 .await
94 .map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
95
96 stream
97 .finish()
98 .await
99 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
100
101 let response = stream
102 .recv_response()
103 .await
104 .map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
105
106 debug!("got response: {:#?}", response);
107
108 let content_length = response
110 .headers()
111 .get(CONTENT_LENGTH)
112 .map(|v| v.to_str())
113 .transpose()
114 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
115 .map(usize::from_str)
116 .transpose()
117 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
118
119 let mut response_bytes =
123 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
124
125 while let Some(partial_bytes) = stream
126 .recv_data()
127 .await
128 .map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
129 {
130 debug!("got bytes: {}", partial_bytes.remaining());
131 response_bytes.put(partial_bytes);
132
133 if let Some(content_length) = content_length {
135 if response_bytes.len() >= content_length {
136 break;
137 }
138 }
139 }
140
141 if let Some(content_length) = content_length {
143 if response_bytes.len() != content_length {
144 return Err(ProtoError::from(format!(
146 "expected byte length: {}, got: {}",
147 content_length,
148 response_bytes.len()
149 )));
150 }
151 }
152
153 if !response.status().is_success() {
155 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
156
157 return Err(ProtoError::from(format!(
159 "http unsuccessful code: {}, message: {}",
160 response.status(),
161 error_string
162 )));
163 } else {
164 {
166 let content_type = response
168 .headers()
169 .get(header::CONTENT_TYPE)
170 .map(|h| {
171 h.to_str().map_err(|err| {
172 ProtoError::from(format!("ContentType header not a string: {err}"))
174 })
175 })
176 .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
177
178 if content_type != crate::http::MIME_APPLICATION_DNS {
179 return Err(ProtoError::from(format!(
180 "ContentType unsupported (must be '{}'): '{}'",
181 crate::http::MIME_APPLICATION_DNS,
182 content_type
183 )));
184 }
185 }
186 };
187
188 DnsResponse::from_buffer(response_bytes.to_vec())
190 }
191}
192
193impl DnsRequestSender for H3ClientStream {
194 fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
242 if self.is_shutdown {
243 panic!("can not send messages after stream is shutdown")
244 }
245
246 request.set_id(0);
248
249 let bytes = match request.to_vec() {
250 Ok(bytes) => bytes,
251 Err(err) => return err.into(),
252 };
253
254 Box::pin(Self::inner_send(
255 self.send_request.clone(),
256 Bytes::from(bytes),
257 Arc::clone(&self.name_server_name),
258 Arc::clone(&self.query_path),
259 ))
260 .into()
261 }
262
263 fn shutdown(&mut self) {
264 self.is_shutdown = true;
265 }
266
267 fn is_shutdown(&self) -> bool {
268 self.is_shutdown
269 }
270}
271
272impl Stream for H3ClientStream {
273 type Item = Result<(), ProtoError>;
274
275 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276 if self.is_shutdown {
277 return Poll::Ready(None);
278 }
279
280 if self.shutdown_tx.is_closed() {
282 return Poll::Ready(Some(Err(ProtoError::from(
283 "h3 connection is already shutdown",
284 ))));
285 }
286
287 Poll::Ready(Some(Ok(())))
288 }
289}
290
291#[derive(Clone)]
293pub struct H3ClientStreamBuilder {
294 crypto_config: rustls::ClientConfig,
295 transport_config: Arc<TransportConfig>,
296 bind_addr: Option<SocketAddr>,
297}
298
299impl H3ClientStreamBuilder {
300 pub fn crypto_config(&mut self, crypto_config: rustls::ClientConfig) -> &mut Self {
302 self.crypto_config = crypto_config;
303 self
304 }
305
306 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
308 self.bind_addr = Some(bind_addr);
309 }
310
311 pub fn build(
318 self,
319 name_server: SocketAddr,
320 dns_name: String,
321 query_path: String,
322 ) -> H3ClientConnect {
323 H3ClientConnect(Box::pin(self.connect(name_server, dns_name, query_path)) as _)
324 }
325
326 pub fn build_with_future(
328 self,
329 socket: Arc<dyn quinn::AsyncUdpSocket>,
330 name_server: SocketAddr,
331 dns_name: String,
332 query_path: String,
333 ) -> H3ClientConnect {
334 H3ClientConnect(Box::pin(self.connect_with_future(
335 socket,
336 name_server,
337 dns_name,
338 query_path,
339 )) as _)
340 }
341
342 async fn connect_with_future(
343 self,
344 socket: Arc<dyn quinn::AsyncUdpSocket>,
345 name_server: SocketAddr,
346 server_name: String,
347 query_path: String,
348 ) -> Result<H3ClientStream, ProtoError> {
349 let endpoint = Endpoint::new_with_abstract_socket(
350 EndpointConfig::default(),
351 None,
352 socket,
353 Arc::new(quinn::TokioRuntime),
354 )?;
355 self.connect_inner(endpoint, name_server, server_name, query_path)
356 .await
357 }
358
359 async fn connect(
360 self,
361 name_server: SocketAddr,
362 dns_name: String,
363 query_path: String,
364 ) -> Result<H3ClientStream, ProtoError> {
365 let connect = if let Some(bind_addr) = self.bind_addr {
366 <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
367 } else {
368 <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
369 };
370
371 let socket = connect.await?;
372 let socket = socket.into_std()?;
373 let endpoint = Endpoint::new(
374 EndpointConfig::default(),
375 None,
376 socket,
377 Arc::new(quinn::TokioRuntime),
378 )?;
379 self.connect_inner(endpoint, name_server, dns_name, query_path)
380 .await
381 }
382
383 async fn connect_inner(
384 self,
385 endpoint: Endpoint,
386 name_server: SocketAddr,
387 dns_name: String,
388 query_path: String,
389 ) -> Result<H3ClientStream, ProtoError> {
390 let quic_connection = connect_quic(
391 name_server,
392 &dns_name,
393 ALPN_H3,
394 self.crypto_config,
395 self.transport_config,
396 endpoint,
397 )
398 .await?;
399
400 let h3_connection = h3_quinn::Connection::new(quic_connection);
401 let (mut driver, send_request) = h3::client::new(h3_connection)
402 .await
403 .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
404
405 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
406
407 debug!("h3 connection is ready: {}", name_server);
409 tokio::spawn(async move {
410 tokio::select! {
411 res = poll_fn(|cx| driver.poll_close(cx)) => {
412 res.map_err(|e| warn!("h3 connection failed: {e}"))
413 }
414 _ = shutdown_rx.recv() => {
415 debug!("h3 connection is shutting down: {}", name_server);
416 Ok(())
417 }
418 }
419 });
420
421 Ok(H3ClientStream {
422 name_server_name: Arc::from(dns_name),
423 name_server,
424 query_path: Arc::from(query_path),
425 send_request,
426 shutdown_tx,
427 is_shutdown: false,
428 })
429 }
430}
431
432impl Default for H3ClientStreamBuilder {
433 fn default() -> Self {
434 Self {
435 crypto_config: client_config(),
436 transport_config: Arc::new(super::transport()),
437 bind_addr: None,
438 }
439 }
440}
441
442pub struct H3ClientConnect(
444 Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
445);
446
447impl Future for H3ClientConnect {
448 type Output = Result<H3ClientStream, ProtoError>;
449
450 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451 self.0.poll_unpin(cx)
452 }
453}
454
455pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
457
458impl Future for H3ClientResponse {
459 type Output = Result<DnsResponse, ProtoError>;
460
461 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
462 self.0.as_mut().poll(cx).map_err(ProtoError::from)
463 }
464}
465
466#[cfg(all(
467 test,
468 any(feature = "rustls-platform-verifier", feature = "webpki-roots")
469))]
470mod tests {
471 use alloc::string::ToString;
472 use core::str::FromStr;
473 use std::net::SocketAddr;
474 use std::println;
475
476 use rustls::KeyLogFile;
477 use test_support::subscribe;
478 use tokio::runtime::Runtime;
479 use tokio::task::JoinSet;
480
481 use crate::op::{Edns, Message, Query};
482 use crate::rr::{Name, RecordType};
483 use crate::xfer::{DnsRequestOptions, FirstAnswer};
484
485 use super::*;
486
487 #[tokio::test]
488 async fn test_h3_google() {
489 subscribe();
490
491 let google = SocketAddr::from(([8, 8, 8, 8], 443));
492 let mut request = Message::new();
493 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
494 request.add_query(query);
495 request.set_recursion_desired(true);
496 let mut edns = Edns::new();
497 edns.set_version(0);
498 edns.set_max_payload(1232);
499 *request.extensions_mut() = Some(edns);
500
501 let request = DnsRequest::new(request, DnsRequestOptions::default());
502
503 let mut client_config = client_config();
504 client_config.key_log = Arc::new(KeyLogFile::new());
505
506 let mut h3_builder = H3ClientStream::builder();
507 h3_builder.crypto_config(client_config);
508 let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
509
510 let mut h3 = connect.await.expect("h3 connect failed");
511
512 let response = h3
513 .send_message(request)
514 .first_answer()
515 .await
516 .expect("send_message failed");
517
518 assert!(
519 response
520 .answers()
521 .iter()
522 .any(|record| record.data().as_a().is_some())
523 );
524
525 let mut request = Message::new();
528 let query = Query::query(
529 Name::from_str("www.example.com.").unwrap(),
530 RecordType::AAAA,
531 );
532 request.add_query(query);
533 request.set_recursion_desired(true);
534 let mut edns = Edns::new();
535 edns.set_version(0);
536 edns.set_max_payload(1232);
537 *request.extensions_mut() = Some(edns);
538
539 let request = DnsRequest::new(request, DnsRequestOptions::default());
540
541 let response = h3
542 .send_message(request.clone())
543 .first_answer()
544 .await
545 .expect("send_message failed");
546
547 assert!(
548 response
549 .answers()
550 .iter()
551 .any(|record| record.data().as_aaaa().is_some())
552 );
553 }
554
555 #[tokio::test]
556 async fn test_h3_google_with_pure_ip_address_server() {
557 subscribe();
558
559 let google = SocketAddr::from(([8, 8, 8, 8], 443));
560 let mut request = Message::new();
561 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
562 request.add_query(query);
563 request.set_recursion_desired(true);
564 let mut edns = Edns::new();
565 edns.set_version(0);
566 edns.set_max_payload(1232);
567 *request.extensions_mut() = Some(edns);
568
569 let request = DnsRequest::new(request, DnsRequestOptions::default());
570
571 let mut client_config = client_config();
572 client_config.key_log = Arc::new(KeyLogFile::new());
573
574 let mut h3_builder = H3ClientStream::builder();
575 h3_builder.crypto_config(client_config);
576 let connect = h3_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
577
578 let mut h3 = connect.await.expect("h3 connect failed");
579
580 let response = h3
581 .send_message(request)
582 .first_answer()
583 .await
584 .expect("send_message failed");
585
586 assert!(
587 response
588 .answers()
589 .iter()
590 .any(|record| record.data().as_a().is_some())
591 );
592
593 let mut request = Message::new();
596 let query = Query::query(
597 Name::from_str("www.example.com.").unwrap(),
598 RecordType::AAAA,
599 );
600 request.add_query(query);
601 request.set_recursion_desired(true);
602 let mut edns = Edns::new();
603 edns.set_version(0);
604 edns.set_max_payload(1232);
605 *request.extensions_mut() = Some(edns);
606
607 let request = DnsRequest::new(request, DnsRequestOptions::default());
608
609 let response = h3
610 .send_message(request.clone())
611 .first_answer()
612 .await
613 .expect("send_message failed");
614
615 assert!(
616 response
617 .answers()
618 .iter()
619 .any(|record| record.data().as_aaaa().is_some())
620 );
621 }
622
623 #[test]
625 #[ignore = "cloudflare has been unreliable as a public test service"]
626 fn test_h3_cloudflare() {
627 subscribe();
628
629 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
630 let mut request = Message::new();
631 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
632 request.add_query(query);
633
634 let request = DnsRequest::new(request, DnsRequestOptions::default());
635
636 let mut client_config = client_config();
637 client_config.key_log = Arc::new(KeyLogFile::new());
638
639 let mut h3_builder = H3ClientStream::builder();
640 h3_builder.crypto_config(client_config);
641 let connect = h3_builder.build(
642 cloudflare,
643 "cloudflare-dns.com".to_string(),
644 "/dns-query".to_string(),
645 );
646
647 let runtime = Runtime::new().expect("could not start runtime");
649 let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
650
651 let response = runtime
652 .block_on(h3.send_message(request).first_answer())
653 .expect("send_message failed");
654
655 assert!(
656 response
657 .answers()
658 .iter()
659 .any(|record| record.data().as_a().is_some())
660 );
661
662 let mut request = Message::new();
665 let query = Query::query(
666 Name::from_str("www.example.com.").unwrap(),
667 RecordType::AAAA,
668 );
669 request.add_query(query);
670 let request = DnsRequest::new(request, DnsRequestOptions::default());
671
672 let response = runtime
673 .block_on(h3.send_message(request).first_answer())
674 .expect("send_message failed");
675
676 assert!(
677 response
678 .answers()
679 .iter()
680 .any(|record| record.data().as_aaaa().is_some())
681 );
682 }
683
684 #[tokio::test]
685 #[allow(clippy::print_stdout)]
686 async fn test_h3_client_stream_clonable() {
687 subscribe();
688
689 let google = SocketAddr::from(([8, 8, 8, 8], 443));
691
692 let mut client_config = client_config();
693 client_config.key_log = Arc::new(KeyLogFile::new());
694
695 let mut h3_builder = H3ClientStream::builder();
696 h3_builder.crypto_config(client_config);
697 let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
698
699 let h3 = connect.await.expect("h3 connect failed");
700
701 let mut request = Message::new();
703 let query = Query::query(
704 Name::from_str("www.example.com.").unwrap(),
705 RecordType::AAAA,
706 );
707 request.add_query(query);
708 let request = DnsRequest::new(request, DnsRequestOptions::default());
709
710 let mut join_set = JoinSet::new();
711
712 for i in 0..50 {
713 let mut h3 = h3.clone();
714 let request = request.clone();
715
716 join_set.spawn(async move {
717 let start = std::time::Instant::now();
718 h3.send_message(request)
719 .first_answer()
720 .await
721 .expect("send_message failed");
722 println!("request[{i}] completed: {:?}", start.elapsed());
723 });
724 }
725
726 let total = join_set.len();
727 let mut idx = 0usize;
728 while join_set.join_next().await.is_some() {
729 println!("join_set completed {idx}/{total}");
730 idx += 1;
731 }
732 }
733}