1use std::fmt::{self, Display};
9use std::future::Future;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::str::FromStr;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16use bytes::{Buf, BufMut, Bytes, BytesMut};
17use futures_util::future::FutureExt;
18use futures_util::stream::Stream;
19use h3::client::{Connection, SendRequest};
20use h3_quinn::OpenStreams;
21use http::header::{self, CONTENT_LENGTH};
22use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
23use rustls::ClientConfig as TlsClientConfig;
24use tracing::debug;
25
26use crate::error::ProtoError;
27use crate::http::Version;
28use crate::op::Message;
29use crate::quic::quic_socket::QuinnAsyncUdpSocketAdapter;
30use crate::quic::QuicLocalAddr;
31use crate::udp::{DnsUdpSocket, UdpSocket};
32use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
33
34use super::ALPN_H3;
35
36#[must_use = "futures do nothing unless polled"]
38pub struct H3ClientStream {
39 name_server_name: Arc<str>,
41 name_server: SocketAddr,
42 driver: Connection<h3_quinn::Connection, Bytes>,
43 send_request: SendRequest<OpenStreams, Bytes>,
44 is_shutdown: bool,
45}
46
47impl Display for H3ClientStream {
48 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
49 write!(
50 formatter,
51 "H3({},{})",
52 self.name_server, self.name_server_name
53 )
54 }
55}
56
57impl H3ClientStream {
58 pub fn builder() -> H3ClientStreamBuilder {
60 H3ClientStreamBuilder::default()
61 }
62
63 async fn inner_send(
64 mut h3: SendRequest<OpenStreams, Bytes>,
65 message: Bytes,
66 name_server_name: Arc<str>,
67 ) -> Result<DnsResponse, ProtoError> {
68 let request =
70 crate::http::request::new(Version::Http3, &name_server_name, message.remaining());
71
72 let request =
73 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
74
75 debug!("request: {:#?}", request);
76
77 let mut stream = h3
79 .send_request(request)
80 .await
81 .map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
82
83 stream
84 .send_data(message)
85 .await
86 .map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
87
88 stream
89 .finish()
90 .await
91 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
92
93 let response = stream
94 .recv_response()
95 .await
96 .map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
97
98 debug!("got response: {:#?}", response);
99
100 let content_length = response
102 .headers()
103 .get(CONTENT_LENGTH)
104 .map(|v| v.to_str())
105 .transpose()
106 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
107 .map(usize::from_str)
108 .transpose()
109 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
110
111 let mut response_bytes =
115 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
116
117 while let Some(partial_bytes) = stream
118 .recv_data()
119 .await
120 .map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
121 {
122 debug!("got bytes: {}", partial_bytes.remaining());
123 response_bytes.put(partial_bytes);
124
125 if let Some(content_length) = content_length {
127 if response_bytes.len() >= content_length {
128 break;
129 }
130 }
131 }
132
133 if let Some(content_length) = content_length {
135 if response_bytes.len() != content_length {
136 return Err(ProtoError::from(format!(
138 "expected byte length: {}, got: {}",
139 content_length,
140 response_bytes.len()
141 )));
142 }
143 }
144
145 if !response.status().is_success() {
147 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
148
149 return Err(ProtoError::from(format!(
151 "http unsuccessful code: {}, message: {}",
152 response.status(),
153 error_string
154 )));
155 } else {
156 {
158 let content_type = response
160 .headers()
161 .get(header::CONTENT_TYPE)
162 .map(|h| {
163 h.to_str().map_err(|err| {
164 ProtoError::from(format!("ContentType header not a string: {err}"))
166 })
167 })
168 .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
169
170 if content_type != crate::http::MIME_APPLICATION_DNS {
171 return Err(ProtoError::from(format!(
172 "ContentType unsupported (must be '{}'): '{}'",
173 crate::http::MIME_APPLICATION_DNS,
174 content_type
175 )));
176 }
177 }
178 };
179
180 let message = Message::from_vec(&response_bytes)?;
182 Ok(DnsResponse::new(message, response_bytes.to_vec()))
183 }
184}
185
186impl DnsRequestSender for H3ClientStream {
187 fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
235 if self.is_shutdown {
236 panic!("can not send messages after stream is shutdown")
237 }
238
239 message.set_id(0);
241
242 let bytes = match message.to_vec() {
243 Ok(bytes) => bytes,
244 Err(err) => return err.into(),
245 };
246
247 Box::pin(Self::inner_send(
248 self.send_request.clone(),
249 Bytes::from(bytes),
250 Arc::clone(&self.name_server_name),
251 ))
252 .into()
253 }
254
255 fn shutdown(&mut self) {
256 self.is_shutdown = true;
257 }
258
259 fn is_shutdown(&self) -> bool {
260 self.is_shutdown
261 }
262}
263
264impl Stream for H3ClientStream {
265 type Item = Result<(), ProtoError>;
266
267 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
268 if self.is_shutdown {
269 return Poll::Ready(None);
270 }
271
272 match self.driver.poll_close(cx) {
274 Poll::Ready(Ok(())) => Poll::Ready(None),
275 Poll::Pending => Poll::Pending,
276 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
277 "h3 stream errored: {e}",
278 ))))),
279 }
280 }
281}
282
283#[derive(Clone)]
285pub struct H3ClientStreamBuilder {
286 crypto_config: TlsClientConfig,
287 transport_config: Arc<TransportConfig>,
288 bind_addr: Option<SocketAddr>,
289}
290
291impl H3ClientStreamBuilder {
292 pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
294 self.crypto_config = crypto_config;
295 self
296 }
297
298 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
300 self.bind_addr = Some(bind_addr);
301 }
302
303 pub fn build(self, name_server: SocketAddr, dns_name: String) -> H3ClientConnect {
310 H3ClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
311 }
312
313 pub fn build_with_future<S, F>(
315 self,
316 future: F,
317 name_server: SocketAddr,
318 dns_name: String,
319 ) -> H3ClientConnect
320 where
321 S: DnsUdpSocket + QuicLocalAddr + 'static,
322 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
323 {
324 H3ClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
325 }
326
327 async fn connect_with_future<S, F>(
328 self,
329 future: F,
330 name_server: SocketAddr,
331 dns_name: String,
332 ) -> Result<H3ClientStream, ProtoError>
333 where
334 S: DnsUdpSocket + QuicLocalAddr + 'static,
335 F: Future<Output = std::io::Result<S>> + Send,
336 {
337 let socket = future.await?;
338 let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
339 let endpoint = Endpoint::new_with_abstract_socket(
340 EndpointConfig::default(),
341 None,
342 wrapper,
343 Arc::new(quinn::TokioRuntime),
344 )?;
345 self.connect_inner(endpoint, name_server, dns_name).await
346 }
347
348 async fn connect(
349 self,
350 name_server: SocketAddr,
351 dns_name: String,
352 ) -> Result<H3ClientStream, ProtoError> {
353 let connect = if let Some(bind_addr) = self.bind_addr {
354 <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
355 } else {
356 <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
357 };
358
359 let socket = connect.await?;
360 let socket = socket.into_std()?;
361 let endpoint = Endpoint::new(
362 EndpointConfig::default(),
363 None,
364 socket,
365 Arc::new(quinn::TokioRuntime),
366 )?;
367 self.connect_inner(endpoint, name_server, dns_name).await
368 }
369
370 async fn connect_inner(
371 self,
372 mut endpoint: Endpoint,
373 name_server: SocketAddr,
374 dns_name: String,
375 ) -> Result<H3ClientStream, ProtoError> {
376 let mut crypto_config = self.crypto_config;
377 if crypto_config.alpn_protocols.is_empty() {
379 crypto_config.alpn_protocols = vec![ALPN_H3.to_vec()];
380 }
381 let early_data_enabled = crypto_config.enable_early_data;
382
383 let mut client_config = ClientConfig::new(Arc::new(crypto_config));
384 client_config.transport_config(self.transport_config.clone());
385
386 endpoint.set_default_client_config(client_config);
387
388 let connecting = endpoint.connect(name_server, &dns_name)?;
389 let quic_connection = if early_data_enabled {
392 match connecting.into_0rtt() {
393 Ok((new_connection, _)) => new_connection,
394 Err(connecting) => connecting.await?,
395 }
396 } else {
397 connecting.await?
398 };
399
400 let h3_connection = h3_quinn::Connection::new(quic_connection);
401 let (driver, send_request) = h3::client::new(h3_connection)
402 .await
403 .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
404
405 Ok(H3ClientStream {
406 name_server_name: Arc::from(dns_name),
407 name_server,
408 driver,
409 send_request,
410 is_shutdown: false,
411 })
412 }
413}
414
415impl Default for H3ClientStreamBuilder {
416 fn default() -> Self {
417 Self {
418 crypto_config: super::client_config_tls13().unwrap(),
419 transport_config: Arc::new(super::transport()),
420 bind_addr: None,
421 }
422 }
423}
424
425pub struct H3ClientConnect(
427 Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
428);
429
430impl Future for H3ClientConnect {
431 type Output = Result<H3ClientStream, ProtoError>;
432
433 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
434 self.0.poll_unpin(cx)
435 }
436}
437
438pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
440
441impl Future for H3ClientResponse {
442 type Output = Result<DnsResponse, ProtoError>;
443
444 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
445 self.0.as_mut().poll(cx).map_err(ProtoError::from)
446 }
447}
448
449#[cfg(all(test, any(feature = "native-certs", feature = "webpki-roots")))]
450mod tests {
451 use std::net::SocketAddr;
452 use std::str::FromStr;
453
454 use rustls::KeyLogFile;
455 use tokio::runtime::Runtime;
456
457 use crate::op::{Edns, Message, Query};
458 use crate::rr::rdata::A;
459 use crate::rr::{Name, RData, RecordType};
460 use crate::xfer::{DnsRequestOptions, FirstAnswer};
461
462 use super::*;
463
464 #[test]
465 fn test_h3_google() {
466 let google = SocketAddr::from(([8, 8, 8, 8], 443));
469 let mut request = Message::new();
470 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
471 request.add_query(query);
472 request.set_recursion_desired(true);
473 let mut edns = Edns::new();
474 edns.set_version(0);
475 edns.set_max_payload(1232);
476 *request.extensions_mut() = Some(edns);
477
478 let request = DnsRequest::new(request, DnsRequestOptions::default());
479
480 let mut client_config = super::super::client_config_tls13().unwrap();
481 client_config.key_log = Arc::new(KeyLogFile::new());
482
483 let mut h3_builder = H3ClientStream::builder();
484 h3_builder.crypto_config(client_config);
485 let connect = h3_builder.build(google, "dns.google".to_string());
486
487 let runtime = Runtime::new().expect("could not start runtime");
489 let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
490
491 let response = runtime
492 .block_on(h3.send_message(request).first_answer())
493 .expect("send_message failed");
494
495 assert!(response
496 .answers()
497 .iter()
498 .any(|record| record.data().unwrap().as_a().is_some()));
499
500 let mut request = Message::new();
503 let query = Query::query(
504 Name::from_str("www.example.com.").unwrap(),
505 RecordType::AAAA,
506 );
507 request.add_query(query);
508 request.set_recursion_desired(true);
509 let mut edns = Edns::new();
510 edns.set_version(0);
511 edns.set_max_payload(1232);
512 *request.extensions_mut() = Some(edns);
513
514 let request = DnsRequest::new(request, DnsRequestOptions::default());
515
516 for _ in 0..3 {
517 let response = runtime
518 .block_on(h3.send_message(request.clone()).first_answer())
519 .expect("send_message failed");
520
521 assert!(response.answers().iter().any(|record| record
522 .data()
523 .unwrap()
524 .as_aaaa()
525 .is_some()));
526 }
527 }
528
529 #[test]
530 fn test_h3_google_with_pure_ip_address_server() {
531 let google = SocketAddr::from(([8, 8, 8, 8], 443));
534 let mut request = Message::new();
535 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
536 request.add_query(query);
537 request.set_recursion_desired(true);
538 let mut edns = Edns::new();
539 edns.set_version(0);
540 edns.set_max_payload(1232);
541 *request.extensions_mut() = Some(edns);
542
543 let request = DnsRequest::new(request, DnsRequestOptions::default());
544
545 let mut client_config = super::super::client_config_tls13().unwrap();
546 client_config.key_log = Arc::new(KeyLogFile::new());
547
548 let mut h3_builder = H3ClientStream::builder();
549 h3_builder.crypto_config(client_config);
550 let connect = h3_builder.build(google, google.ip().to_string());
551
552 let runtime = Runtime::new().expect("could not start runtime");
554 let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
555
556 let response = runtime
557 .block_on(h3.send_message(request).first_answer())
558 .expect("send_message failed");
559
560 assert!(response
561 .answers()
562 .iter()
563 .any(|record| record.data().unwrap().as_a().is_some()));
564
565 let mut request = Message::new();
568 let query = Query::query(
569 Name::from_str("www.example.com.").unwrap(),
570 RecordType::AAAA,
571 );
572 request.add_query(query);
573 request.set_recursion_desired(true);
574 let mut edns = Edns::new();
575 edns.set_version(0);
576 edns.set_max_payload(1232);
577 *request.extensions_mut() = Some(edns);
578
579 let request = DnsRequest::new(request, DnsRequestOptions::default());
580
581 for _ in 0..3 {
582 let response = runtime
583 .block_on(h3.send_message(request.clone()).first_answer())
584 .expect("send_message failed");
585
586 assert!(response.answers().iter().any(|record| record
587 .data()
588 .unwrap()
589 .as_aaaa()
590 .is_some()));
591 }
592 }
593
594 #[test]
596 #[ignore] fn test_h3_cloudflare() {
598 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
601 let mut request = Message::new();
602 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
603 request.add_query(query);
604
605 let request = DnsRequest::new(request, DnsRequestOptions::default());
606
607 let mut client_config = super::super::client_config_tls13().unwrap();
608 client_config.key_log = Arc::new(KeyLogFile::new());
609
610 let mut h3_builder = H3ClientStream::builder();
611 h3_builder.crypto_config(client_config);
612 let connect = h3_builder.build(cloudflare, "cloudflare-dns.com".to_string());
613
614 let runtime = Runtime::new().expect("could not start runtime");
616 let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
617
618 let response = runtime
619 .block_on(h3.send_message(request).first_answer())
620 .expect("send_message failed");
621
622 let record = &response.answers()[0];
623 let addr = record
624 .data()
625 .and_then(RData::as_a)
626 .expect("invalid response, expected A record");
627
628 assert_eq!(addr, &A::new(93, 184, 215, 14));
629
630 let mut request = Message::new();
633 let query = Query::query(
634 Name::from_str("www.example.com.").unwrap(),
635 RecordType::AAAA,
636 );
637 request.add_query(query);
638 request.set_recursion_desired(true);
639 let mut edns = Edns::new();
640 edns.set_version(0);
641 edns.set_max_payload(1232);
642 *request.extensions_mut() = Some(edns);
643
644 let request = DnsRequest::new(request, DnsRequestOptions::default());
645
646 let response = runtime
647 .block_on(h3.send_message(request).first_answer())
648 .expect("send_message failed");
649
650 assert!(response
651 .answers()
652 .iter()
653 .any(|record| record.data().unwrap().as_aaaa().is_some()));
654 }
655}