trust_dns_proto/quic/
quic_client_stream.rs

1// Copyright 2015-2022 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::fmt::{Debug, Formatter};
9use std::{
10    fmt::{self, Display},
11    future::Future,
12    net::SocketAddr,
13    pin::Pin,
14    sync::Arc,
15    task::{Context, Poll},
16};
17
18use futures_util::{future::FutureExt, stream::Stream};
19use quinn::{AsyncUdpSocket, ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
20use rustls::{version::TLS13, ClientConfig as TlsClientConfig};
21
22use crate::udp::{DnsUdpSocket, QuicLocalAddr};
23use crate::{
24    error::ProtoError,
25    quic::quic_stream::{DoqErrorCode, QuicStream},
26    udp::UdpSocket,
27    xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
28};
29
30use super::{quic_config, quic_stream};
31
32/// A DNS client connection for DNS-over-QUIC
33#[must_use = "futures do nothing unless polled"]
34pub struct QuicClientStream {
35    quic_connection: Connection,
36    name_server_name: Arc<str>,
37    name_server: SocketAddr,
38    is_shutdown: bool,
39}
40
41impl Display for QuicClientStream {
42    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
43        write!(
44            formatter,
45            "QUIC({},{})",
46            self.name_server, self.name_server_name
47        )
48    }
49}
50
51impl QuicClientStream {
52    /// Builder for QuicClientStream
53    pub fn builder() -> QuicClientStreamBuilder {
54        QuicClientStreamBuilder::default()
55    }
56
57    async fn inner_send(
58        connection: Connection,
59        message: DnsRequest,
60    ) -> Result<DnsResponse, ProtoError> {
61        let (send_stream, recv_stream) = connection.open_bi().await?;
62
63        // RFC: The mapping specified here requires that the client selects a separate
64        //  QUIC stream for each query. The server then uses the same stream to provide all the response messages for that query.
65        let mut stream = QuicStream::new(send_stream, recv_stream);
66
67        stream.send(message.into_parts().0).await?;
68
69        // The client MUST send the DNS query over the selected stream,
70        // and MUST indicate through the STREAM FIN mechanism that no further data will be sent on that stream.
71        stream.finish().await?;
72
73        stream.receive().await
74    }
75}
76
77impl DnsRequestSender for QuicClientStream {
78    /// The send loop for QUIC in DNS stipulates that a new QUIC "stream" should be opened and use for sending data.
79    ///
80    /// It should be closed after receiving the response. TODO: AXFR/IXFR support...
81    ///
82    /// ```text
83    /// 5.2. Stream Mapping and Usage
84    ///
85    /// The mapping of DNS traffic over QUIC streams takes advantage of the QUIC stream features detailed in Section 2 of [RFC9000],
86    /// the QUIC transport specification.
87    ///
88    /// DNS traffic follows a simple pattern in which the client sends a query, and the server provides one or more responses
89    /// (multiple responses can occur in zone transfers).The mapping specified here requires that the client selects a separate
90    /// QUIC stream for each query. The server then uses the same stream to provide all the response messages for that query. In
91    /// order that multiple responses can be parsed, a 2-octet length field is used in exactly the same way as the 2-octet length
92    /// field defined for DNS over TCP [RFC1035]. The practical result of this is that the content of each QUIC stream is exactly
93    /// the same as the content of a TCP connection that would manage exactly one query.All DNS messages (queries and responses)
94    /// sent over DoQ connections MUST be encoded as a 2-octet length field followed by the message content as specified in [RFC1035].
95    /// The client MUST select the next available client-initiated bidirectional stream for each subsequent query on a QUIC connection,
96    /// in conformance with the QUIC transport specification [RFC9000].The client MUST send the DNS query over the selected stream,
97    /// and MUST indicate through the STREAM FIN mechanism that no further data will be sent on that stream.The server MUST send the
98    /// response(s) on the same stream and MUST indicate, after the last response, through the STREAM FIN mechanism that no further
99    /// data will be sent on that stream.Therefore, a single DNS transaction consumes a single bidirectional client-initiated stream.
100    /// This means that the client's first query occurs on QUIC stream 0, the second on 4, and so on (see Section 2.1 of [RFC9000].
101    /// Servers MAY defer processing of a query until the STREAM FIN has been indicated on the stream selected by the client. Servers
102    /// and clients MAY monitor the number of "dangling" streams for which the expected queries or responses have been received but
103    /// not the STREAM FIN. Implementations MAY impose a limit on the number of such dangling streams. If limits are encountered,
104    /// implementations MAY close the connection.
105    ///
106    /// 5.2.1. DNS Message IDs
107    ///
108    /// When sending queries over a QUIC connection, the DNS Message ID MUST be set to zero. The stream mapping for DoQ allows for
109    /// unambiguous correlation of queries and responses and so the Message ID field is not required.
110    ///
111    /// This has implications for proxying DoQ message to and from other transports. For example, proxies may have to manage the
112    /// fact that DoQ can support a larger number of outstanding queries on a single connection than e.g., DNS over TCP because DoQ
113    /// is not limited by the Message ID space. This issue already exists for DoH, where a Message ID of 0 is recommended.When forwarding
114    /// a DNS message from DoQ over another transport, a DNS Message ID MUST be generated according to the rules of the protocol that is
115    /// in use. When forwarding a DNS message from another transport over DoQ, the Message ID MUST be set to zero.
116    /// ```
117    fn send_message(&mut self, message: DnsRequest) -> DnsResponseStream {
118        if self.is_shutdown {
119            panic!("can not send messages after stream is shutdown")
120        }
121
122        Box::pin(Self::inner_send(self.quic_connection.clone(), message)).into()
123    }
124
125    fn shutdown(&mut self) {
126        self.is_shutdown = true;
127        self.quic_connection
128            .close(DoqErrorCode::NoError.into(), b"Shutdown");
129    }
130
131    fn is_shutdown(&self) -> bool {
132        self.is_shutdown
133    }
134}
135
136impl Stream for QuicClientStream {
137    type Item = Result<(), ProtoError>;
138
139    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        if self.is_shutdown {
141            Poll::Ready(None)
142        } else {
143            Poll::Ready(Some(Ok(())))
144        }
145    }
146}
147
148/// A QUIC connection builder for DNS-over-QUIC
149#[derive(Clone)]
150pub struct QuicClientStreamBuilder {
151    crypto_config: TlsClientConfig,
152    transport_config: Arc<TransportConfig>,
153    bind_addr: Option<SocketAddr>,
154}
155
156impl QuicClientStreamBuilder {
157    /// Constructs a new TlsStreamBuilder with the associated ClientConfig
158    pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
159        self.crypto_config = crypto_config;
160        self
161    }
162
163    /// Sets the address to connect from.
164    pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
165        self.bind_addr = Some(bind_addr);
166        self
167    }
168
169    /// Creates a new QuicStream to the specified name_server
170    ///
171    /// # Arguments
172    ///
173    /// * `name_server` - IP and Port for the remote DNS resolver
174    /// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
175    pub fn build(self, name_server: SocketAddr, dns_name: String) -> QuicClientConnect {
176        QuicClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
177    }
178
179    /// Create a QuicStream with existing connection
180    pub fn build_with_future<S, F>(
181        self,
182        future: F,
183        name_server: SocketAddr,
184        dns_name: String,
185    ) -> QuicClientConnect
186    where
187        S: DnsUdpSocket + QuicLocalAddr + 'static,
188        F: Future<Output = std::io::Result<S>> + Send + 'static,
189    {
190        QuicClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
191    }
192
193    async fn connect_with_future<S, F>(
194        self,
195        future: F,
196        name_server: SocketAddr,
197        dns_name: String,
198    ) -> Result<QuicClientStream, ProtoError>
199    where
200        S: DnsUdpSocket + QuicLocalAddr + 'static,
201        F: Future<Output = std::io::Result<S>> + Send,
202    {
203        let socket = future.await?;
204        let endpoint_config = quic_config::endpoint();
205        let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
206        let endpoint = Endpoint::new_with_abstract_socket(
207            endpoint_config,
208            None,
209            wrapper,
210            Arc::new(quinn::TokioRuntime),
211        )?;
212        self.connect_inner(endpoint, name_server, dns_name).await
213    }
214
215    async fn connect(
216        self,
217        name_server: SocketAddr,
218        dns_name: String,
219    ) -> Result<QuicClientStream, ProtoError> {
220        let connect = if let Some(bind_addr) = self.bind_addr {
221            <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
222        } else {
223            <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
224        };
225
226        let socket = connect.await?;
227        let socket = socket.into_std()?;
228        let endpoint_config = quic_config::endpoint();
229        let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
230        self.connect_inner(endpoint, name_server, dns_name).await
231    }
232
233    async fn connect_inner(
234        self,
235        mut endpoint: Endpoint,
236        name_server: SocketAddr,
237        dns_name: String,
238    ) -> Result<QuicClientStream, ProtoError> {
239        // ensure the ALPN protocol is set correctly
240        let mut crypto_config = self.crypto_config;
241        if crypto_config.alpn_protocols.is_empty() {
242            crypto_config.alpn_protocols = vec![quic_stream::DOQ_ALPN.to_vec()];
243        }
244        let early_data_enabled = crypto_config.enable_early_data;
245
246        let mut client_config = ClientConfig::new(Arc::new(crypto_config));
247        client_config.transport_config(self.transport_config.clone());
248
249        endpoint.set_default_client_config(client_config);
250
251        let connecting = endpoint.connect(name_server, &dns_name)?;
252        // TODO: for Client/Dynamic update, don't use RTT, for queries, do use it.
253
254        let quic_connection = if early_data_enabled {
255            match connecting.into_0rtt() {
256                Ok((new_connection, _)) => new_connection,
257                Err(connecting) => connecting.await?,
258            }
259        } else {
260            connecting.await?
261        };
262
263        Ok(QuicClientStream {
264            quic_connection,
265            name_server_name: Arc::from(dns_name),
266            name_server,
267            is_shutdown: false,
268        })
269    }
270}
271
272/// Default crypto options for quic
273pub fn client_config_tls13_webpki_roots() -> TlsClientConfig {
274    use rustls::{OwnedTrustAnchor, RootCertStore};
275    let mut root_store = RootCertStore::empty();
276    root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
277        OwnedTrustAnchor::from_subject_spki_name_constraints(
278            ta.subject,
279            ta.spki,
280            ta.name_constraints,
281        )
282    }));
283
284    TlsClientConfig::builder()
285        .with_safe_default_cipher_suites()
286        .with_safe_default_kx_groups()
287        .with_protocol_versions(&[&TLS13])
288        .expect("TLS 1.3 not supported")
289        .with_root_certificates(root_store)
290        .with_no_client_auth()
291}
292
293impl Default for QuicClientStreamBuilder {
294    fn default() -> Self {
295        let mut transport_config = quic_config::transport();
296        // clients never accept new bidirectional streams
297        transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
298
299        let client_config = client_config_tls13_webpki_roots();
300
301        Self {
302            crypto_config: client_config,
303            transport_config: Arc::new(transport_config),
304            bind_addr: None,
305        }
306    }
307}
308
309/// A future that resolves to an QuicClientStream
310pub struct QuicClientConnect(
311    Pin<Box<dyn Future<Output = Result<QuicClientStream, ProtoError>> + Send>>,
312);
313
314impl Future for QuicClientConnect {
315    type Output = Result<QuicClientStream, ProtoError>;
316
317    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318        self.0.poll_unpin(cx)
319    }
320}
321
322/// A future that resolves to
323pub struct QuicClientResponse(
324    Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
325);
326
327impl Future for QuicClientResponse {
328    type Output = Result<DnsResponse, ProtoError>;
329
330    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
331        self.0.as_mut().poll(cx).map_err(ProtoError::from)
332    }
333}
334
335/// Wrapper used for quinn::Endpoint::new_with_abstract_socket
336struct QuinnAsyncUdpSocketAdapter<S: DnsUdpSocket + QuicLocalAddr> {
337    io: S,
338}
339
340impl<S: DnsUdpSocket + QuicLocalAddr> Debug for QuinnAsyncUdpSocketAdapter<S> {
341    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
342        f.write_str("Wrapper for quinn::AsyncUdpSocket")
343    }
344}
345
346/// TODO: Naive implementation. Look forward to future improvements.
347impl<S: DnsUdpSocket + QuicLocalAddr + 'static> AsyncUdpSocket for QuinnAsyncUdpSocketAdapter<S> {
348    fn poll_send(
349        &self,
350        _state: &quinn::udp::UdpState,
351        cx: &mut Context<'_>,
352        transmits: &[quinn::udp::Transmit],
353    ) -> Poll<std::io::Result<usize>> {
354        // logics from quinn-udp::fallback.rs
355        let io = &self.io;
356        let mut sent = 0;
357        for transmit in transmits {
358            match io.poll_send_to(cx, &transmit.contents, transmit.destination) {
359                Poll::Ready(ready) => match ready {
360                    Ok(_) => {
361                        sent += 1;
362                    }
363                    // We need to report that some packets were sent in this case, so we rely on
364                    // errors being either harmlessly transient (in the case of WouldBlock) or
365                    // recurring on the next call.
366                    Err(_) if sent != 0 => return Poll::Ready(Ok(sent)),
367                    Err(e) => {
368                        if e.kind() == std::io::ErrorKind::WouldBlock {
369                            return Poll::Ready(Err(e));
370                        }
371
372                        // Other errors are ignored, since they will ususally be handled
373                        // by higher level retransmits and timeouts.
374                        // - PermissionDenied errors have been observed due to iptable rules.
375                        //   Those are not fatal errors, since the
376                        //   configuration can be dynamically changed.
377                        // - Destination unreachable errors have been observed for other
378                        // log_sendmsg_error(&mut self.last_send_error, e, transmit);
379                        sent += 1;
380                    }
381                },
382                Poll::Pending => {
383                    return if sent == 0 {
384                        Poll::Pending
385                    } else {
386                        Poll::Ready(Ok(sent))
387                    }
388                }
389            }
390        }
391        Poll::Ready(Ok(sent))
392    }
393
394    fn poll_recv(
395        &self,
396        cx: &mut Context<'_>,
397        bufs: &mut [std::io::IoSliceMut<'_>],
398        meta: &mut [quinn::udp::RecvMeta],
399    ) -> Poll<std::io::Result<usize>> {
400        // logics from quinn-udp::fallback.rs
401
402        let io = &self.io;
403        let Some(buf) = bufs.get_mut(0)else {
404            return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidInput,"no buf")));
405        };
406        match io.poll_recv_from(cx, buf.as_mut()) {
407            Poll::Ready(res) => match res {
408                Ok((len, addr)) => {
409                    meta[0] = quinn::udp::RecvMeta {
410                        len,
411                        stride: len,
412                        addr,
413                        ecn: None,
414                        dst_ip: None,
415                    };
416                    Poll::Ready(Ok(1))
417                }
418                Err(err) => Poll::Ready(Err(err)),
419            },
420            Poll::Pending => Poll::Pending,
421        }
422    }
423
424    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
425        self.io.local_addr()
426    }
427}