hickory_proto/h3/
h3_client_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::fmt::{self, Display};
9use std::future::Future;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::str::FromStr;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16use bytes::{Buf, BufMut, Bytes, BytesMut};
17use futures_util::future::FutureExt;
18use futures_util::stream::Stream;
19use h3::client::{Connection, SendRequest};
20use h3_quinn::OpenStreams;
21use http::header::{self, CONTENT_LENGTH};
22use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
23use rustls::ClientConfig as TlsClientConfig;
24use tracing::debug;
25
26use crate::error::ProtoError;
27use crate::http::Version;
28use crate::op::Message;
29use crate::quic::quic_socket::QuinnAsyncUdpSocketAdapter;
30use crate::quic::QuicLocalAddr;
31use crate::udp::{DnsUdpSocket, UdpSocket};
32use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
33
34use super::ALPN_H3;
35
36/// A DNS client connection for DNS-over-HTTP/3
37#[must_use = "futures do nothing unless polled"]
38pub struct H3ClientStream {
39    // Corresponds to the dns-name of the HTTP/3 server
40    name_server_name: Arc<str>,
41    name_server: SocketAddr,
42    driver: Connection<h3_quinn::Connection, Bytes>,
43    send_request: SendRequest<OpenStreams, Bytes>,
44    is_shutdown: bool,
45}
46
47impl Display for H3ClientStream {
48    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
49        write!(
50            formatter,
51            "H3({},{})",
52            self.name_server, self.name_server_name
53        )
54    }
55}
56
57impl H3ClientStream {
58    /// Builder for H3ClientStream
59    pub fn builder() -> H3ClientStreamBuilder {
60        H3ClientStreamBuilder::default()
61    }
62
63    async fn inner_send(
64        mut h3: SendRequest<OpenStreams, Bytes>,
65        message: Bytes,
66        name_server_name: Arc<str>,
67    ) -> Result<DnsResponse, ProtoError> {
68        // build up the http request
69        let request =
70            crate::http::request::new(Version::Http3, &name_server_name, message.remaining());
71
72        let request =
73            request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
74
75        debug!("request: {:#?}", request);
76
77        // Send the request
78        let mut stream = h3
79            .send_request(request)
80            .await
81            .map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
82
83        stream
84            .send_data(message)
85            .await
86            .map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
87
88        stream
89            .finish()
90            .await
91            .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
92
93        let response = stream
94            .recv_response()
95            .await
96            .map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
97
98        debug!("got response: {:#?}", response);
99
100        // get the length of packet
101        let content_length = response
102            .headers()
103            .get(CONTENT_LENGTH)
104            .map(|v| v.to_str())
105            .transpose()
106            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
107            .map(usize::from_str)
108            .transpose()
109            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
110
111        // TODO: what is a good max here?
112        // clamp(512, 4096) says make sure it is at least 512 bytes, and min 4096 says it is at most 4k
113        // just a little protection from malicious actors.
114        let mut response_bytes =
115            BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
116
117        while let Some(partial_bytes) = stream
118            .recv_data()
119            .await
120            .map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
121        {
122            debug!("got bytes: {}", partial_bytes.remaining());
123            response_bytes.put(partial_bytes);
124
125            // assert the length
126            if let Some(content_length) = content_length {
127                if response_bytes.len() >= content_length {
128                    break;
129                }
130            }
131        }
132
133        // assert the length
134        if let Some(content_length) = content_length {
135            if response_bytes.len() != content_length {
136                // TODO: make explicit error type
137                return Err(ProtoError::from(format!(
138                    "expected byte length: {}, got: {}",
139                    content_length,
140                    response_bytes.len()
141                )));
142            }
143        }
144
145        // Was it a successful request?
146        if !response.status().is_success() {
147            let error_string = String::from_utf8_lossy(response_bytes.as_ref());
148
149            // TODO: make explicit error type
150            return Err(ProtoError::from(format!(
151                "http unsuccessful code: {}, message: {}",
152                response.status(),
153                error_string
154            )));
155        } else {
156            // verify content type
157            {
158                // in the case that the ContentType is not specified, we assume it's the standard DNS format
159                let content_type = response
160                    .headers()
161                    .get(header::CONTENT_TYPE)
162                    .map(|h| {
163                        h.to_str().map_err(|err| {
164                            // TODO: make explicit error type
165                            ProtoError::from(format!("ContentType header not a string: {err}"))
166                        })
167                    })
168                    .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
169
170                if content_type != crate::http::MIME_APPLICATION_DNS {
171                    return Err(ProtoError::from(format!(
172                        "ContentType unsupported (must be '{}'): '{}'",
173                        crate::http::MIME_APPLICATION_DNS,
174                        content_type
175                    )));
176                }
177            }
178        };
179
180        // and finally convert the bytes into a DNS message
181        let message = Message::from_vec(&response_bytes)?;
182        Ok(DnsResponse::new(message, response_bytes.to_vec()))
183    }
184}
185
186impl DnsRequestSender for H3ClientStream {
187    /// This indicates that the HTTP message was successfully sent, and we now have the response.RecvStream
188    ///
189    /// If the request fails, this will return the error, and it should be assumed that the Stream portion of
190    ///   this will have no date.
191    ///
192    /// ```text
193    /// 5.2.  The HTTP Response
194    ///
195    ///    An HTTP response with a 2xx status code ([RFC7231] Section 6.3)
196    ///    indicates a valid DNS response to the query made in the HTTP request.
197    ///    A valid DNS response includes both success and failure responses.
198    ///    For example, a DNS failure response such as SERVFAIL or NXDOMAIN will
199    ///    be the message in a successful 2xx HTTP response even though there
200    ///    was a failure at the DNS layer.  Responses with non-successful HTTP
201    ///    status codes do not contain DNS answers to the question in the
202    ///    corresponding request.  Some of these non-successful HTTP responses
203    ///    (e.g., redirects or authentication failures) could mean that clients
204    ///    need to make new requests to satisfy the original question.
205    ///
206    ///    Different response media types will provide more or less information
207    ///    from a DNS response.  For example, one response type might include
208    ///    the information from the DNS header bytes while another might omit
209    ///    it.  The amount and type of information that a media type gives is
210    ///    solely up to the format, and not defined in this protocol.
211    ///
212    ///    The only response type defined in this document is "application/dns-
213    ///    message", but it is possible that other response formats will be
214    ///    defined in the future.
215    ///
216    ///    The DNS response for "application/dns-message" in Section 7 MAY have
217    ///    one or more EDNS options [RFC6891], depending on the extension
218    ///    definition of the extensions given in the DNS request.
219    ///
220    ///    Each DNS request-response pair is matched to one HTTP exchange.  The
221    ///    responses may be processed and transported in any order using HTTP's
222    ///    multi-streaming functionality ([RFC7540] Section 5).
223    ///
224    ///    Section 6.1 discusses the relationship between DNS and HTTP response
225    ///    caching.
226    ///
227    ///    A DNS API server MUST be able to process application/dns-message
228    ///    request messages.
229    ///
230    ///    A DNS API server SHOULD respond with HTTP status code 415
231    ///    (Unsupported Media Type) upon receiving a media type it is unable to
232    ///    process.
233    /// ```
234    fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
235        if self.is_shutdown {
236            panic!("can not send messages after stream is shutdown")
237        }
238
239        // per the RFC, a zero id allows for the HTTP packet to be cached better
240        message.set_id(0);
241
242        let bytes = match message.to_vec() {
243            Ok(bytes) => bytes,
244            Err(err) => return err.into(),
245        };
246
247        Box::pin(Self::inner_send(
248            self.send_request.clone(),
249            Bytes::from(bytes),
250            Arc::clone(&self.name_server_name),
251        ))
252        .into()
253    }
254
255    fn shutdown(&mut self) {
256        self.is_shutdown = true;
257    }
258
259    fn is_shutdown(&self) -> bool {
260        self.is_shutdown
261    }
262}
263
264impl Stream for H3ClientStream {
265    type Item = Result<(), ProtoError>;
266
267    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
268        if self.is_shutdown {
269            return Poll::Ready(None);
270        }
271
272        // just checking if the connection is ok
273        match self.driver.poll_close(cx) {
274            Poll::Ready(Ok(())) => Poll::Ready(None),
275            Poll::Pending => Poll::Pending,
276            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
277                "h3 stream errored: {e}",
278            ))))),
279        }
280    }
281}
282
283/// A H3 connection builder for DNS-over-HTTP/3
284#[derive(Clone)]
285pub struct H3ClientStreamBuilder {
286    crypto_config: TlsClientConfig,
287    transport_config: Arc<TransportConfig>,
288    bind_addr: Option<SocketAddr>,
289}
290
291impl H3ClientStreamBuilder {
292    /// Constructs a new H3ClientStreamBuilder with the associated ClientConfig
293    pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
294        self.crypto_config = crypto_config;
295        self
296    }
297
298    /// Sets the address to connect from.
299    pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
300        self.bind_addr = Some(bind_addr);
301    }
302
303    /// Creates a new H3Stream to the specified name_server
304    ///
305    /// # Arguments
306    ///
307    /// * `name_server` - IP and Port for the remote DNS resolver
308    /// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
309    pub fn build(self, name_server: SocketAddr, dns_name: String) -> H3ClientConnect {
310        H3ClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
311    }
312
313    /// Creates a new H3Stream with existing connection
314    pub fn build_with_future<S, F>(
315        self,
316        future: F,
317        name_server: SocketAddr,
318        dns_name: String,
319    ) -> H3ClientConnect
320    where
321        S: DnsUdpSocket + QuicLocalAddr + 'static,
322        F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
323    {
324        H3ClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
325    }
326
327    async fn connect_with_future<S, F>(
328        self,
329        future: F,
330        name_server: SocketAddr,
331        dns_name: String,
332    ) -> Result<H3ClientStream, ProtoError>
333    where
334        S: DnsUdpSocket + QuicLocalAddr + 'static,
335        F: Future<Output = std::io::Result<S>> + Send,
336    {
337        let socket = future.await?;
338        let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
339        let endpoint = Endpoint::new_with_abstract_socket(
340            EndpointConfig::default(),
341            None,
342            wrapper,
343            Arc::new(quinn::TokioRuntime),
344        )?;
345        self.connect_inner(endpoint, name_server, dns_name).await
346    }
347
348    async fn connect(
349        self,
350        name_server: SocketAddr,
351        dns_name: String,
352    ) -> Result<H3ClientStream, ProtoError> {
353        let connect = if let Some(bind_addr) = self.bind_addr {
354            <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
355        } else {
356            <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
357        };
358
359        let socket = connect.await?;
360        let socket = socket.into_std()?;
361        let endpoint = Endpoint::new(
362            EndpointConfig::default(),
363            None,
364            socket,
365            Arc::new(quinn::TokioRuntime),
366        )?;
367        self.connect_inner(endpoint, name_server, dns_name).await
368    }
369
370    async fn connect_inner(
371        self,
372        mut endpoint: Endpoint,
373        name_server: SocketAddr,
374        dns_name: String,
375    ) -> Result<H3ClientStream, ProtoError> {
376        let mut crypto_config = self.crypto_config;
377        // ensure the ALPN protocol is set correctly
378        if crypto_config.alpn_protocols.is_empty() {
379            crypto_config.alpn_protocols = vec![ALPN_H3.to_vec()];
380        }
381        let early_data_enabled = crypto_config.enable_early_data;
382
383        let mut client_config = ClientConfig::new(Arc::new(crypto_config));
384        client_config.transport_config(self.transport_config.clone());
385
386        endpoint.set_default_client_config(client_config);
387
388        let connecting = endpoint.connect(name_server, &dns_name)?;
389        // TODO: for Client/Dynamic update, don't use RTT, for queries, do use it.
390
391        let quic_connection = if early_data_enabled {
392            match connecting.into_0rtt() {
393                Ok((new_connection, _)) => new_connection,
394                Err(connecting) => connecting.await?,
395            }
396        } else {
397            connecting.await?
398        };
399
400        let h3_connection = h3_quinn::Connection::new(quic_connection);
401        let (driver, send_request) = h3::client::new(h3_connection)
402            .await
403            .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
404
405        Ok(H3ClientStream {
406            name_server_name: Arc::from(dns_name),
407            name_server,
408            driver,
409            send_request,
410            is_shutdown: false,
411        })
412    }
413}
414
415impl Default for H3ClientStreamBuilder {
416    fn default() -> Self {
417        Self {
418            crypto_config: super::client_config_tls13().unwrap(),
419            transport_config: Arc::new(super::transport()),
420            bind_addr: None,
421        }
422    }
423}
424
425/// A future that resolves to an H3ClientStream
426pub struct H3ClientConnect(
427    Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
428);
429
430impl Future for H3ClientConnect {
431    type Output = Result<H3ClientStream, ProtoError>;
432
433    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
434        self.0.poll_unpin(cx)
435    }
436}
437
438/// A future that resolves to
439pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
440
441impl Future for H3ClientResponse {
442    type Output = Result<DnsResponse, ProtoError>;
443
444    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
445        self.0.as_mut().poll(cx).map_err(ProtoError::from)
446    }
447}
448
449#[cfg(all(test, any(feature = "native-certs", feature = "webpki-roots")))]
450mod tests {
451    use std::net::SocketAddr;
452    use std::str::FromStr;
453
454    use rustls::KeyLogFile;
455    use tokio::runtime::Runtime;
456
457    use crate::op::{Edns, Message, Query};
458    use crate::rr::rdata::A;
459    use crate::rr::{Name, RData, RecordType};
460    use crate::xfer::{DnsRequestOptions, FirstAnswer};
461
462    use super::*;
463
464    #[test]
465    fn test_h3_google() {
466        //env_logger::try_init().ok();
467
468        let google = SocketAddr::from(([8, 8, 8, 8], 443));
469        let mut request = Message::new();
470        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
471        request.add_query(query);
472        request.set_recursion_desired(true);
473        let mut edns = Edns::new();
474        edns.set_version(0);
475        edns.set_max_payload(1232);
476        *request.extensions_mut() = Some(edns);
477
478        let request = DnsRequest::new(request, DnsRequestOptions::default());
479
480        let mut client_config = super::super::client_config_tls13().unwrap();
481        client_config.key_log = Arc::new(KeyLogFile::new());
482
483        let mut h3_builder = H3ClientStream::builder();
484        h3_builder.crypto_config(client_config);
485        let connect = h3_builder.build(google, "dns.google".to_string());
486
487        // tokio runtime stuff...
488        let runtime = Runtime::new().expect("could not start runtime");
489        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
490
491        let response = runtime
492            .block_on(h3.send_message(request).first_answer())
493            .expect("send_message failed");
494
495        assert!(response
496            .answers()
497            .iter()
498            .any(|record| record.data().unwrap().as_a().is_some()));
499
500        //
501        // assert that the connection works for a second query
502        let mut request = Message::new();
503        let query = Query::query(
504            Name::from_str("www.example.com.").unwrap(),
505            RecordType::AAAA,
506        );
507        request.add_query(query);
508        request.set_recursion_desired(true);
509        let mut edns = Edns::new();
510        edns.set_version(0);
511        edns.set_max_payload(1232);
512        *request.extensions_mut() = Some(edns);
513
514        let request = DnsRequest::new(request, DnsRequestOptions::default());
515
516        for _ in 0..3 {
517            let response = runtime
518                .block_on(h3.send_message(request.clone()).first_answer())
519                .expect("send_message failed");
520
521            assert!(response.answers().iter().any(|record| record
522                .data()
523                .unwrap()
524                .as_aaaa()
525                .is_some()));
526        }
527    }
528
529    #[test]
530    fn test_h3_google_with_pure_ip_address_server() {
531        //env_logger::try_init().ok();
532
533        let google = SocketAddr::from(([8, 8, 8, 8], 443));
534        let mut request = Message::new();
535        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
536        request.add_query(query);
537        request.set_recursion_desired(true);
538        let mut edns = Edns::new();
539        edns.set_version(0);
540        edns.set_max_payload(1232);
541        *request.extensions_mut() = Some(edns);
542
543        let request = DnsRequest::new(request, DnsRequestOptions::default());
544
545        let mut client_config = super::super::client_config_tls13().unwrap();
546        client_config.key_log = Arc::new(KeyLogFile::new());
547
548        let mut h3_builder = H3ClientStream::builder();
549        h3_builder.crypto_config(client_config);
550        let connect = h3_builder.build(google, google.ip().to_string());
551
552        // tokio runtime stuff...
553        let runtime = Runtime::new().expect("could not start runtime");
554        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
555
556        let response = runtime
557            .block_on(h3.send_message(request).first_answer())
558            .expect("send_message failed");
559
560        assert!(response
561            .answers()
562            .iter()
563            .any(|record| record.data().unwrap().as_a().is_some()));
564
565        //
566        // assert that the connection works for a second query
567        let mut request = Message::new();
568        let query = Query::query(
569            Name::from_str("www.example.com.").unwrap(),
570            RecordType::AAAA,
571        );
572        request.add_query(query);
573        request.set_recursion_desired(true);
574        let mut edns = Edns::new();
575        edns.set_version(0);
576        edns.set_max_payload(1232);
577        *request.extensions_mut() = Some(edns);
578
579        let request = DnsRequest::new(request, DnsRequestOptions::default());
580
581        for _ in 0..3 {
582            let response = runtime
583                .block_on(h3.send_message(request.clone()).first_answer())
584                .expect("send_message failed");
585
586            assert!(response.answers().iter().any(|record| record
587                .data()
588                .unwrap()
589                .as_aaaa()
590                .is_some()));
591        }
592    }
593
594    /// Currently fails, see <https://github.com/hyperium/h3/issues/206>.
595    #[test]
596    #[ignore] // cloudflare has been unreliable as a public test service.
597    fn test_h3_cloudflare() {
598        // self::env_logger::try_init().ok();
599
600        let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
601        let mut request = Message::new();
602        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
603        request.add_query(query);
604
605        let request = DnsRequest::new(request, DnsRequestOptions::default());
606
607        let mut client_config = super::super::client_config_tls13().unwrap();
608        client_config.key_log = Arc::new(KeyLogFile::new());
609
610        let mut h3_builder = H3ClientStream::builder();
611        h3_builder.crypto_config(client_config);
612        let connect = h3_builder.build(cloudflare, "cloudflare-dns.com".to_string());
613
614        // tokio runtime stuff...
615        let runtime = Runtime::new().expect("could not start runtime");
616        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
617
618        let response = runtime
619            .block_on(h3.send_message(request).first_answer())
620            .expect("send_message failed");
621
622        let record = &response.answers()[0];
623        let addr = record
624            .data()
625            .and_then(RData::as_a)
626            .expect("invalid response, expected A record");
627
628        assert_eq!(addr, &A::new(93, 184, 215, 14));
629
630        //
631        // assert that the connection works for a second query
632        let mut request = Message::new();
633        let query = Query::query(
634            Name::from_str("www.example.com.").unwrap(),
635            RecordType::AAAA,
636        );
637        request.add_query(query);
638        request.set_recursion_desired(true);
639        let mut edns = Edns::new();
640        edns.set_version(0);
641        edns.set_max_payload(1232);
642        *request.extensions_mut() = Some(edns);
643
644        let request = DnsRequest::new(request, DnsRequestOptions::default());
645
646        let response = runtime
647            .block_on(h3.send_message(request).first_answer())
648            .expect("send_message failed");
649
650        assert!(response
651            .answers()
652            .iter()
653            .any(|record| record.data().unwrap().as_aaaa().is_some()));
654    }
655}