hickory_proto/quic/
quic_client_stream.rs

1// Copyright 2015-2022 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::{boxed::Box, string::String, sync::Arc};
9use core::{
10    fmt::{self, Display},
11    future::Future,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use std::{io, net::SocketAddr};
16
17use futures_util::{future::FutureExt, stream::Stream};
18use quinn::{
19    ClientConfig, Connection, Endpoint, TransportConfig, VarInt, crypto::rustls::QuicClientConfig,
20};
21use tokio::time::timeout;
22
23use crate::{
24    error::ProtoError,
25    quic::quic_stream::{DoqErrorCode, QuicStream},
26    rustls::client_config,
27    udp::UdpSocket,
28    xfer::{CONNECT_TIMEOUT, DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
29};
30
31use super::{quic_config, quic_stream};
32
33/// A DNS client connection for DNS-over-QUIC
34#[must_use = "futures do nothing unless polled"]
35#[derive(Clone)]
36pub struct QuicClientStream {
37    quic_connection: Connection,
38    name_server_name: Arc<str>,
39    name_server: SocketAddr,
40    is_shutdown: bool,
41}
42
43impl Display for QuicClientStream {
44    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
45        write!(
46            formatter,
47            "QUIC({},{})",
48            self.name_server, self.name_server_name
49        )
50    }
51}
52
53impl QuicClientStream {
54    /// Builder for QuicClientStream
55    pub fn builder() -> QuicClientStreamBuilder {
56        QuicClientStreamBuilder::default()
57    }
58
59    async fn inner_send(
60        connection: Connection,
61        message: DnsRequest,
62    ) -> Result<DnsResponse, ProtoError> {
63        let (send_stream, recv_stream) = connection.open_bi().await?;
64
65        // RFC: The mapping specified here requires that the client selects a separate
66        //  QUIC stream for each query. The server then uses the same stream to provide all the response messages for that query.
67        let mut stream = QuicStream::new(send_stream, recv_stream);
68
69        stream.send(message.into_parts().0).await?;
70
71        // The client MUST send the DNS query over the selected stream,
72        // and MUST indicate through the STREAM FIN mechanism that no further data will be sent on that stream.
73        stream.finish().await?;
74
75        stream.receive().await
76    }
77}
78
79impl DnsRequestSender for QuicClientStream {
80    /// The send loop for QUIC in DNS stipulates that a new QUIC "stream" should be opened and use for sending data.
81    ///
82    /// It should be closed after receiving the response. TODO: AXFR/IXFR support...
83    ///
84    /// ```text
85    /// 5.2. Stream Mapping and Usage
86    ///
87    /// The mapping of DNS traffic over QUIC streams takes advantage of the QUIC stream features detailed in Section 2 of [RFC9000],
88    /// the QUIC transport specification.
89    ///
90    /// DNS traffic follows a simple pattern in which the client sends a query, and the server provides one or more responses
91    /// (multiple responses can occur in zone transfers).The mapping specified here requires that the client selects a separate
92    /// QUIC stream for each query. The server then uses the same stream to provide all the response messages for that query. In
93    /// order that multiple responses can be parsed, a 2-octet length field is used in exactly the same way as the 2-octet length
94    /// field defined for DNS over TCP [RFC1035]. The practical result of this is that the content of each QUIC stream is exactly
95    /// the same as the content of a TCP connection that would manage exactly one query.All DNS messages (queries and responses)
96    /// sent over DoQ connections MUST be encoded as a 2-octet length field followed by the message content as specified in [RFC1035].
97    /// The client MUST select the next available client-initiated bidirectional stream for each subsequent query on a QUIC connection,
98    /// in conformance with the QUIC transport specification [RFC9000].The client MUST send the DNS query over the selected stream,
99    /// and MUST indicate through the STREAM FIN mechanism that no further data will be sent on that stream.The server MUST send the
100    /// response(s) on the same stream and MUST indicate, after the last response, through the STREAM FIN mechanism that no further
101    /// data will be sent on that stream.Therefore, a single DNS transaction consumes a single bidirectional client-initiated stream.
102    /// This means that the client's first query occurs on QUIC stream 0, the second on 4, and so on (see Section 2.1 of [RFC9000].
103    /// Servers MAY defer processing of a query until the STREAM FIN has been indicated on the stream selected by the client. Servers
104    /// and clients MAY monitor the number of "dangling" streams for which the expected queries or responses have been received but
105    /// not the STREAM FIN. Implementations MAY impose a limit on the number of such dangling streams. If limits are encountered,
106    /// implementations MAY close the connection.
107    ///
108    /// 5.2.1. DNS Message IDs
109    ///
110    /// When sending queries over a QUIC connection, the DNS Message ID MUST be set to zero. The stream mapping for DoQ allows for
111    /// unambiguous correlation of queries and responses and so the Message ID field is not required.
112    ///
113    /// This has implications for proxying DoQ message to and from other transports. For example, proxies may have to manage the
114    /// fact that DoQ can support a larger number of outstanding queries on a single connection than e.g., DNS over TCP because DoQ
115    /// is not limited by the Message ID space. This issue already exists for DoH, where a Message ID of 0 is recommended.When forwarding
116    /// a DNS message from DoQ over another transport, a DNS Message ID MUST be generated according to the rules of the protocol that is
117    /// in use. When forwarding a DNS message from another transport over DoQ, the Message ID MUST be set to zero.
118    /// ```
119    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
120        if self.is_shutdown {
121            panic!("can not send messages after stream is shutdown")
122        }
123
124        Box::pin(Self::inner_send(self.quic_connection.clone(), request)).into()
125    }
126
127    fn shutdown(&mut self) {
128        self.is_shutdown = true;
129        self.quic_connection
130            .close(DoqErrorCode::NoError.into(), b"Shutdown");
131    }
132
133    fn is_shutdown(&self) -> bool {
134        self.is_shutdown
135    }
136}
137
138impl Stream for QuicClientStream {
139    type Item = Result<(), ProtoError>;
140
141    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142        if self.is_shutdown {
143            Poll::Ready(None)
144        } else {
145            Poll::Ready(Some(Ok(())))
146        }
147    }
148}
149
150/// A QUIC connection builder for DNS-over-QUIC
151#[derive(Clone)]
152pub struct QuicClientStreamBuilder {
153    crypto_config: Option<rustls::ClientConfig>,
154    transport_config: Arc<TransportConfig>,
155    bind_addr: Option<SocketAddr>,
156}
157
158impl QuicClientStreamBuilder {
159    /// Constructs a new TlsStreamBuilder with the associated ClientConfig
160    pub fn crypto_config(&mut self, crypto_config: rustls::ClientConfig) -> &mut Self {
161        self.crypto_config = Some(crypto_config);
162        self
163    }
164
165    /// Sets the address to connect from.
166    pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
167        self.bind_addr = Some(bind_addr);
168        self
169    }
170
171    /// Creates a new QuicStream to the specified name_server
172    ///
173    /// # Arguments
174    ///
175    /// * `name_server` - IP and Port for the remote DNS resolver
176    /// * `dns_name` - The DNS name associated with a certificate
177    pub fn build(self, name_server: SocketAddr, dns_name: String) -> QuicClientConnect {
178        QuicClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
179    }
180
181    /// Create a QuicStream with existing connection
182    pub fn build_with_future(
183        self,
184        socket: Arc<dyn quinn::AsyncUdpSocket>,
185        name_server: SocketAddr,
186        dns_name: String,
187    ) -> QuicClientConnect {
188        QuicClientConnect(Box::pin(self.connect_with_future(socket, name_server, dns_name)) as _)
189    }
190
191    async fn connect_with_future(
192        self,
193        socket: Arc<dyn quinn::AsyncUdpSocket>,
194        name_server: SocketAddr,
195        dns_name: String,
196    ) -> Result<QuicClientStream, ProtoError> {
197        let endpoint_config = quic_config::endpoint();
198        let endpoint = Endpoint::new_with_abstract_socket(
199            endpoint_config,
200            None,
201            socket,
202            Arc::new(quinn::TokioRuntime),
203        )?;
204        self.connect_inner(endpoint, name_server, dns_name).await
205    }
206
207    async fn connect(
208        self,
209        name_server: SocketAddr,
210        dns_name: String,
211    ) -> Result<QuicClientStream, ProtoError> {
212        let connect = if let Some(bind_addr) = self.bind_addr {
213            <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
214        } else {
215            <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
216        };
217
218        let socket = connect.await?;
219        let socket = socket.into_std()?;
220        let endpoint_config = quic_config::endpoint();
221        let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
222        self.connect_inner(endpoint, name_server, dns_name).await
223    }
224
225    async fn connect_inner(
226        self,
227        endpoint: Endpoint,
228        name_server: SocketAddr,
229        dns_name: String,
230    ) -> Result<QuicClientStream, ProtoError> {
231        // ensure the ALPN protocol is set correctly
232        let crypto_config = if let Some(crypto_config) = self.crypto_config {
233            crypto_config
234        } else {
235            client_config()
236        };
237
238        let quic_connection = connect_quic(
239            name_server,
240            &dns_name,
241            quic_stream::DOQ_ALPN,
242            crypto_config,
243            self.transport_config,
244            endpoint,
245        )
246        .await?;
247
248        Ok(QuicClientStream {
249            quic_connection,
250            name_server_name: Arc::from(dns_name),
251            name_server,
252            is_shutdown: false,
253        })
254    }
255}
256
257pub(crate) async fn connect_quic(
258    addr: SocketAddr,
259    server_name: &str,
260    protocol: &[u8],
261    mut crypto_config: rustls::ClientConfig,
262    transport_config: Arc<TransportConfig>,
263    mut endpoint: Endpoint,
264) -> Result<Connection, ProtoError> {
265    if crypto_config.alpn_protocols.is_empty() {
266        crypto_config.alpn_protocols = vec![protocol.to_vec()];
267    }
268    let early_data_enabled = crypto_config.enable_early_data;
269
270    let mut client_config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto_config)?));
271    client_config.transport_config(transport_config.clone());
272
273    endpoint.set_default_client_config(client_config);
274
275    let connecting = endpoint.connect(addr, server_name)?;
276    // TODO: for Client/Dynamic update, don't use RTT, for queries, do use it.
277
278    Ok(if early_data_enabled {
279        match connecting.into_0rtt() {
280            Ok((new_connection, _)) => new_connection,
281            Err(connecting) => connect_with_timeout(connecting).await?,
282        }
283    } else {
284        connect_with_timeout(connecting).await?
285    })
286}
287
288async fn connect_with_timeout(connecting: quinn::Connecting) -> Result<Connection, io::Error> {
289    match timeout(CONNECT_TIMEOUT, connecting).await {
290        Ok(Ok(connection)) => Ok(connection),
291        Ok(Err(e)) => Err(e.into()),
292        Err(_) => Err(io::Error::new(
293            io::ErrorKind::TimedOut,
294            format!("QUIC handshake timed out after {CONNECT_TIMEOUT:?}",),
295        )),
296    }
297}
298
299impl Default for QuicClientStreamBuilder {
300    fn default() -> Self {
301        let mut transport_config = quic_config::transport();
302        // clients never accept new bidirectional streams
303        transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
304
305        Self {
306            crypto_config: None,
307            transport_config: Arc::new(transport_config),
308            bind_addr: None,
309        }
310    }
311}
312
313/// A future that resolves to an QuicClientStream
314pub struct QuicClientConnect(
315    Pin<Box<dyn Future<Output = Result<QuicClientStream, ProtoError>> + Send>>,
316);
317
318impl Future for QuicClientConnect {
319    type Output = Result<QuicClientStream, ProtoError>;
320
321    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
322        self.0.poll_unpin(cx)
323    }
324}
325
326/// A future that resolves to
327pub struct QuicClientResponse(
328    Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
329);
330
331impl Future for QuicClientResponse {
332    type Output = Result<DnsResponse, ProtoError>;
333
334    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335        self.0.as_mut().poll(cx).map_err(ProtoError::from)
336    }
337}