hickory_proto/h3/
h3_client_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::boxed::Box;
9use alloc::string::String;
10use alloc::sync::Arc;
11use core::fmt::{self, Display};
12use core::future::{Future, poll_fn};
13use core::pin::Pin;
14use core::str::FromStr;
15use core::task::{Context, Poll};
16use std::net::SocketAddr;
17
18use bytes::{Buf, BufMut, Bytes, BytesMut};
19use futures_util::future::FutureExt;
20use futures_util::stream::Stream;
21use h3::client::SendRequest;
22use h3_quinn::OpenStreams;
23use http::header::{self, CONTENT_LENGTH};
24use quinn::{Endpoint, EndpointConfig, TransportConfig};
25use tokio::sync::mpsc;
26use tracing::{debug, warn};
27
28use crate::error::ProtoError;
29use crate::http::Version;
30use crate::quic::connect_quic;
31use crate::rustls::client_config;
32use crate::udp::UdpSocket;
33use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
34
35use super::ALPN_H3;
36
37/// A DNS client connection for DNS-over-HTTP/3
38#[derive(Clone)]
39#[must_use = "futures do nothing unless polled"]
40pub struct H3ClientStream {
41    // Corresponds to the dns-name of the HTTP/3 server
42    name_server_name: Arc<str>,
43    name_server: SocketAddr,
44    query_path: Arc<str>,
45    send_request: SendRequest<OpenStreams, Bytes>,
46    shutdown_tx: mpsc::Sender<()>,
47    is_shutdown: bool,
48}
49
50impl Display for H3ClientStream {
51    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
52        write!(
53            formatter,
54            "H3({},{})",
55            self.name_server, self.name_server_name
56        )
57    }
58}
59
60impl H3ClientStream {
61    /// Builder for H3ClientStream
62    pub fn builder() -> H3ClientStreamBuilder {
63        H3ClientStreamBuilder::default()
64    }
65
66    async fn inner_send(
67        mut h3: SendRequest<OpenStreams, Bytes>,
68        message: Bytes,
69        name_server_name: Arc<str>,
70        query_path: Arc<str>,
71    ) -> Result<DnsResponse, ProtoError> {
72        // build up the http request
73        let request = crate::http::request::new(
74            Version::Http3,
75            &name_server_name,
76            &query_path,
77            message.remaining(),
78        );
79
80        let request =
81            request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
82
83        debug!("request: {:#?}", request);
84
85        // Send the request
86        let mut stream = h3
87            .send_request(request)
88            .await
89            .map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
90
91        stream
92            .send_data(message)
93            .await
94            .map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
95
96        stream
97            .finish()
98            .await
99            .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
100
101        let response = stream
102            .recv_response()
103            .await
104            .map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
105
106        debug!("got response: {:#?}", response);
107
108        // get the length of packet
109        let content_length = response
110            .headers()
111            .get(CONTENT_LENGTH)
112            .map(|v| v.to_str())
113            .transpose()
114            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
115            .map(usize::from_str)
116            .transpose()
117            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
118
119        // TODO: what is a good max here?
120        // clamp(512, 4096) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k
121        // just a little protection from malicious actors.
122        let mut response_bytes =
123            BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
124
125        while let Some(partial_bytes) = stream
126            .recv_data()
127            .await
128            .map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
129        {
130            debug!("got bytes: {}", partial_bytes.remaining());
131            response_bytes.put(partial_bytes);
132
133            // assert the length
134            if let Some(content_length) = content_length {
135                if response_bytes.len() >= content_length {
136                    break;
137                }
138            }
139        }
140
141        // assert the length
142        if let Some(content_length) = content_length {
143            if response_bytes.len() != content_length {
144                // TODO: make explicit error type
145                return Err(ProtoError::from(format!(
146                    "expected byte length: {}, got: {}",
147                    content_length,
148                    response_bytes.len()
149                )));
150            }
151        }
152
153        // Was it a successful request?
154        if !response.status().is_success() {
155            let error_string = String::from_utf8_lossy(response_bytes.as_ref());
156
157            // TODO: make explicit error type
158            return Err(ProtoError::from(format!(
159                "http unsuccessful code: {}, message: {}",
160                response.status(),
161                error_string
162            )));
163        } else {
164            // verify content type
165            {
166                // in the case that the ContentType is not specified, we assume it's the standard DNS format
167                let content_type = response
168                    .headers()
169                    .get(header::CONTENT_TYPE)
170                    .map(|h| {
171                        h.to_str().map_err(|err| {
172                            // TODO: make explicit error type
173                            ProtoError::from(format!("ContentType header not a string: {err}"))
174                        })
175                    })
176                    .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
177
178                if content_type != crate::http::MIME_APPLICATION_DNS {
179                    return Err(ProtoError::from(format!(
180                        "ContentType unsupported (must be '{}'): '{}'",
181                        crate::http::MIME_APPLICATION_DNS,
182                        content_type
183                    )));
184                }
185            }
186        };
187
188        // and finally convert the bytes into a DNS message
189        DnsResponse::from_buffer(response_bytes.to_vec())
190    }
191}
192
193impl DnsRequestSender for H3ClientStream {
194    /// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream
195    ///
196    /// If the request fails, this will return the error, and it should be assumed that the Stream portion of
197    ///   this will have no date.
198    ///
199    /// ```text
200    /// 5.2.  The HTTP Response
201    ///
202    ///    An HTTP response with a 2xx status code ([RFC7231] Section 6.3)
203    ///    indicates a valid DNS response to the query made in the HTTP request.
204    ///    A valid DNS response includes both success and failure responses.
205    ///    For example, a DNS failure response such as SERVFAIL or NXDOMAIN will
206    ///    be the message in a successful 2xx HTTP response even though there
207    ///    was a failure at the DNS layer.  Responses with non-successful HTTP
208    ///    status codes do not contain DNS answers to the question in the
209    ///    corresponding request.  Some of these non-successful HTTP responses
210    ///    (e.g., redirects or authentication failures) could mean that clients
211    ///    need to make new requests to satisfy the original question.
212    ///
213    ///    Different response media types will provide more or less information
214    ///    from a DNS response.  For example, one response type might include
215    ///    the information from the DNS header bytes while another might omit
216    ///    it.  The amount and type of information that a media type gives is
217    ///    solely up to the format, and not defined in this protocol.
218    ///
219    ///    The only response type defined in this document is "application/dns-
220    ///    message", but it is possible that other response formats will be
221    ///    defined in the future.
222    ///
223    ///    The DNS response for "application/dns-message" in Section 7 MAY have
224    ///    one or more EDNS options [RFC6891], depending on the extension
225    ///    definition of the extensions given in the DNS request.
226    ///
227    ///    Each DNS request-response pair is matched to one HTTP exchange.  The
228    ///    responses may be processed and transported in any order using HTTP's
229    ///    multi-streaming functionality ([RFC7540] Section 5).
230    ///
231    ///    Section 6.1 discusses the relationship between DNS and HTTP response
232    ///    caching.
233    ///
234    ///    A DNS API server MUST be able to process application/dns-message
235    ///    request messages.
236    ///
237    ///    A DNS API server SHOULD respond with HTTP status code 415
238    ///    (Unsupported Media Type) upon receiving a media type it is unable to
239    ///    process.
240    /// ```
241    fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
242        if self.is_shutdown {
243            panic!("can not send messages after stream is shutdown")
244        }
245
246        // per the RFC, a zero id allows for the HTTP packet to be cached better
247        request.set_id(0);
248
249        let bytes = match request.to_vec() {
250            Ok(bytes) => bytes,
251            Err(err) => return err.into(),
252        };
253
254        Box::pin(Self::inner_send(
255            self.send_request.clone(),
256            Bytes::from(bytes),
257            Arc::clone(&self.name_server_name),
258            Arc::clone(&self.query_path),
259        ))
260        .into()
261    }
262
263    fn shutdown(&mut self) {
264        self.is_shutdown = true;
265    }
266
267    fn is_shutdown(&self) -> bool {
268        self.is_shutdown
269    }
270}
271
272impl Stream for H3ClientStream {
273    type Item = Result<(), ProtoError>;
274
275    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276        if self.is_shutdown {
277            return Poll::Ready(None);
278        }
279
280        // just checking if the connection is ok
281        if self.shutdown_tx.is_closed() {
282            return Poll::Ready(Some(Err(ProtoError::from(
283                "h3 connection is already shutdown",
284            ))));
285        }
286
287        Poll::Ready(Some(Ok(())))
288    }
289}
290
291/// A H3 connection builder for DNS-over-HTTP/3
292#[derive(Clone)]
293pub struct H3ClientStreamBuilder {
294    crypto_config: rustls::ClientConfig,
295    transport_config: Arc<TransportConfig>,
296    bind_addr: Option<SocketAddr>,
297}
298
299impl H3ClientStreamBuilder {
300    /// Constructs a new H3ClientStreamBuilder with the associated ClientConfig
301    pub fn crypto_config(&mut self, crypto_config: rustls::ClientConfig) -> &mut Self {
302        self.crypto_config = crypto_config;
303        self
304    }
305
306    /// Sets the address to connect from.
307    pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
308        self.bind_addr = Some(bind_addr);
309    }
310
311    /// Creates a new H3Stream to the specified name_server
312    ///
313    /// # Arguments
314    ///
315    /// * `name_server` - IP and Port for the remote DNS resolver
316    /// * `dns_name` - The DNS name associated with a certificate
317    pub fn build(
318        self,
319        name_server: SocketAddr,
320        dns_name: String,
321        query_path: String,
322    ) -> H3ClientConnect {
323        H3ClientConnect(Box::pin(self.connect(name_server, dns_name, query_path)) as _)
324    }
325
326    /// Creates a new H3Stream with existing connection
327    pub fn build_with_future(
328        self,
329        socket: Arc<dyn quinn::AsyncUdpSocket>,
330        name_server: SocketAddr,
331        dns_name: String,
332        query_path: String,
333    ) -> H3ClientConnect {
334        H3ClientConnect(Box::pin(self.connect_with_future(
335            socket,
336            name_server,
337            dns_name,
338            query_path,
339        )) as _)
340    }
341
342    async fn connect_with_future(
343        self,
344        socket: Arc<dyn quinn::AsyncUdpSocket>,
345        name_server: SocketAddr,
346        server_name: String,
347        query_path: String,
348    ) -> Result<H3ClientStream, ProtoError> {
349        let endpoint = Endpoint::new_with_abstract_socket(
350            EndpointConfig::default(),
351            None,
352            socket,
353            Arc::new(quinn::TokioRuntime),
354        )?;
355        self.connect_inner(endpoint, name_server, server_name, query_path)
356            .await
357    }
358
359    async fn connect(
360        self,
361        name_server: SocketAddr,
362        dns_name: String,
363        query_path: String,
364    ) -> Result<H3ClientStream, ProtoError> {
365        let connect = if let Some(bind_addr) = self.bind_addr {
366            <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
367        } else {
368            <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
369        };
370
371        let socket = connect.await?;
372        let socket = socket.into_std()?;
373        let endpoint = Endpoint::new(
374            EndpointConfig::default(),
375            None,
376            socket,
377            Arc::new(quinn::TokioRuntime),
378        )?;
379        self.connect_inner(endpoint, name_server, dns_name, query_path)
380            .await
381    }
382
383    async fn connect_inner(
384        self,
385        endpoint: Endpoint,
386        name_server: SocketAddr,
387        dns_name: String,
388        query_path: String,
389    ) -> Result<H3ClientStream, ProtoError> {
390        let quic_connection = connect_quic(
391            name_server,
392            &dns_name,
393            ALPN_H3,
394            self.crypto_config,
395            self.transport_config,
396            endpoint,
397        )
398        .await?;
399
400        let h3_connection = h3_quinn::Connection::new(quic_connection);
401        let (mut driver, send_request) = h3::client::new(h3_connection)
402            .await
403            .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
404
405        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
406
407        // TODO: hand this back for others to run rather than spawning here?
408        debug!("h3 connection is ready: {}", name_server);
409        tokio::spawn(async move {
410            tokio::select! {
411                res = poll_fn(|cx| driver.poll_close(cx)) => {
412                    res.map_err(|e| warn!("h3 connection failed: {e}"))
413                }
414                _ = shutdown_rx.recv() => {
415                    debug!("h3 connection is shutting down: {}", name_server);
416                    Ok(())
417                }
418            }
419        });
420
421        Ok(H3ClientStream {
422            name_server_name: Arc::from(dns_name),
423            name_server,
424            query_path: Arc::from(query_path),
425            send_request,
426            shutdown_tx,
427            is_shutdown: false,
428        })
429    }
430}
431
432impl Default for H3ClientStreamBuilder {
433    fn default() -> Self {
434        Self {
435            crypto_config: client_config(),
436            transport_config: Arc::new(super::transport()),
437            bind_addr: None,
438        }
439    }
440}
441
442/// A future that resolves to an H3ClientStream
443pub struct H3ClientConnect(
444    Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
445);
446
447impl Future for H3ClientConnect {
448    type Output = Result<H3ClientStream, ProtoError>;
449
450    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451        self.0.poll_unpin(cx)
452    }
453}
454
455/// A future that resolves to
456pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
457
458impl Future for H3ClientResponse {
459    type Output = Result<DnsResponse, ProtoError>;
460
461    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
462        self.0.as_mut().poll(cx).map_err(ProtoError::from)
463    }
464}
465
466#[cfg(all(
467    test,
468    any(feature = "rustls-platform-verifier", feature = "webpki-roots")
469))]
470mod tests {
471    use alloc::string::ToString;
472    use core::str::FromStr;
473    use std::net::SocketAddr;
474    use std::println;
475
476    use rustls::KeyLogFile;
477    use test_support::subscribe;
478    use tokio::runtime::Runtime;
479    use tokio::task::JoinSet;
480
481    use crate::op::{Edns, Message, Query};
482    use crate::rr::{Name, RecordType};
483    use crate::xfer::{DnsRequestOptions, FirstAnswer};
484
485    use super::*;
486
487    #[tokio::test]
488    async fn test_h3_google() {
489        subscribe();
490
491        let google = SocketAddr::from(([8, 8, 8, 8], 443));
492        let mut request = Message::new();
493        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
494        request.add_query(query);
495        request.set_recursion_desired(true);
496        let mut edns = Edns::new();
497        edns.set_version(0);
498        edns.set_max_payload(1232);
499        *request.extensions_mut() = Some(edns);
500
501        let request = DnsRequest::new(request, DnsRequestOptions::default());
502
503        let mut client_config = client_config();
504        client_config.key_log = Arc::new(KeyLogFile::new());
505
506        let mut h3_builder = H3ClientStream::builder();
507        h3_builder.crypto_config(client_config);
508        let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
509
510        let mut h3 = connect.await.expect("h3 connect failed");
511
512        let response = h3
513            .send_message(request)
514            .first_answer()
515            .await
516            .expect("send_message failed");
517
518        assert!(
519            response
520                .answers()
521                .iter()
522                .any(|record| record.data().as_a().is_some())
523        );
524
525        //
526        // assert that the connection works for a second query
527        let mut request = Message::new();
528        let query = Query::query(
529            Name::from_str("www.example.com.").unwrap(),
530            RecordType::AAAA,
531        );
532        request.add_query(query);
533        request.set_recursion_desired(true);
534        let mut edns = Edns::new();
535        edns.set_version(0);
536        edns.set_max_payload(1232);
537        *request.extensions_mut() = Some(edns);
538
539        let request = DnsRequest::new(request, DnsRequestOptions::default());
540
541        let response = h3
542            .send_message(request.clone())
543            .first_answer()
544            .await
545            .expect("send_message failed");
546
547        assert!(
548            response
549                .answers()
550                .iter()
551                .any(|record| record.data().as_aaaa().is_some())
552        );
553    }
554
555    #[tokio::test]
556    async fn test_h3_google_with_pure_ip_address_server() {
557        subscribe();
558
559        let google = SocketAddr::from(([8, 8, 8, 8], 443));
560        let mut request = Message::new();
561        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
562        request.add_query(query);
563        request.set_recursion_desired(true);
564        let mut edns = Edns::new();
565        edns.set_version(0);
566        edns.set_max_payload(1232);
567        *request.extensions_mut() = Some(edns);
568
569        let request = DnsRequest::new(request, DnsRequestOptions::default());
570
571        let mut client_config = client_config();
572        client_config.key_log = Arc::new(KeyLogFile::new());
573
574        let mut h3_builder = H3ClientStream::builder();
575        h3_builder.crypto_config(client_config);
576        let connect = h3_builder.build(google, google.ip().to_string(), "/dns-query".to_string());
577
578        let mut h3 = connect.await.expect("h3 connect failed");
579
580        let response = h3
581            .send_message(request)
582            .first_answer()
583            .await
584            .expect("send_message failed");
585
586        assert!(
587            response
588                .answers()
589                .iter()
590                .any(|record| record.data().as_a().is_some())
591        );
592
593        //
594        // assert that the connection works for a second query
595        let mut request = Message::new();
596        let query = Query::query(
597            Name::from_str("www.example.com.").unwrap(),
598            RecordType::AAAA,
599        );
600        request.add_query(query);
601        request.set_recursion_desired(true);
602        let mut edns = Edns::new();
603        edns.set_version(0);
604        edns.set_max_payload(1232);
605        *request.extensions_mut() = Some(edns);
606
607        let request = DnsRequest::new(request, DnsRequestOptions::default());
608
609        let response = h3
610            .send_message(request.clone())
611            .first_answer()
612            .await
613            .expect("send_message failed");
614
615        assert!(
616            response
617                .answers()
618                .iter()
619                .any(|record| record.data().as_aaaa().is_some())
620        );
621    }
622
623    /// Currently fails, see <https://github.com/hyperium/h3/issues/206>.
624    #[test]
625    #[ignore = "cloudflare has been unreliable as a public test service"]
626    fn test_h3_cloudflare() {
627        subscribe();
628
629        let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
630        let mut request = Message::new();
631        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
632        request.add_query(query);
633
634        let request = DnsRequest::new(request, DnsRequestOptions::default());
635
636        let mut client_config = client_config();
637        client_config.key_log = Arc::new(KeyLogFile::new());
638
639        let mut h3_builder = H3ClientStream::builder();
640        h3_builder.crypto_config(client_config);
641        let connect = h3_builder.build(
642            cloudflare,
643            "cloudflare-dns.com".to_string(),
644            "/dns-query".to_string(),
645        );
646
647        // tokio runtime stuff...
648        let runtime = Runtime::new().expect("could not start runtime");
649        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
650
651        let response = runtime
652            .block_on(h3.send_message(request).first_answer())
653            .expect("send_message failed");
654
655        assert!(
656            response
657                .answers()
658                .iter()
659                .any(|record| record.data().as_a().is_some())
660        );
661
662        //
663        // assert that the connection works for a second query
664        let mut request = Message::new();
665        let query = Query::query(
666            Name::from_str("www.example.com.").unwrap(),
667            RecordType::AAAA,
668        );
669        request.add_query(query);
670        let request = DnsRequest::new(request, DnsRequestOptions::default());
671
672        let response = runtime
673            .block_on(h3.send_message(request).first_answer())
674            .expect("send_message failed");
675
676        assert!(
677            response
678                .answers()
679                .iter()
680                .any(|record| record.data().as_aaaa().is_some())
681        );
682    }
683
684    #[tokio::test]
685    #[allow(clippy::print_stdout)]
686    async fn test_h3_client_stream_clonable() {
687        subscribe();
688
689        // use google
690        let google = SocketAddr::from(([8, 8, 8, 8], 443));
691
692        let mut client_config = client_config();
693        client_config.key_log = Arc::new(KeyLogFile::new());
694
695        let mut h3_builder = H3ClientStream::builder();
696        h3_builder.crypto_config(client_config);
697        let connect = h3_builder.build(google, "dns.google".to_string(), "/dns-query".to_string());
698
699        let h3 = connect.await.expect("h3 connect failed");
700
701        // prepare request
702        let mut request = Message::new();
703        let query = Query::query(
704            Name::from_str("www.example.com.").unwrap(),
705            RecordType::AAAA,
706        );
707        request.add_query(query);
708        let request = DnsRequest::new(request, DnsRequestOptions::default());
709
710        let mut join_set = JoinSet::new();
711
712        for i in 0..50 {
713            let mut h3 = h3.clone();
714            let request = request.clone();
715
716            join_set.spawn(async move {
717                let start = std::time::Instant::now();
718                h3.send_message(request)
719                    .first_answer()
720                    .await
721                    .expect("send_message failed");
722                println!("request[{i}] completed: {:?}", start.elapsed());
723            });
724        }
725
726        let total = join_set.len();
727        let mut idx = 0usize;
728        while join_set.join_next().await.is_some() {
729            println!("join_set completed {idx}/{total}");
730            idx += 1;
731        }
732    }
733}