1use std::{
9 fmt::{self, Display},
10 future::Future,
11 net::SocketAddr,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16
17use futures_util::{future::FutureExt, stream::Stream};
18use quinn::{ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
19use rustls::{version::TLS13, ClientConfig as TlsClientConfig};
20
21use crate::udp::{DnsUdpSocket, QuicLocalAddr};
22use crate::{
23 error::ProtoError,
24 quic::quic_socket::QuinnAsyncUdpSocketAdapter,
25 quic::quic_stream::{DoqErrorCode, QuicStream},
26 udp::UdpSocket,
27 xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
28};
29
30use super::{quic_config, quic_stream};
31
32#[must_use = "futures do nothing unless polled"]
34pub struct QuicClientStream {
35 quic_connection: Connection,
36 name_server_name: Arc<str>,
37 name_server: SocketAddr,
38 is_shutdown: bool,
39}
40
41impl Display for QuicClientStream {
42 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
43 write!(
44 formatter,
45 "QUIC({},{})",
46 self.name_server, self.name_server_name
47 )
48 }
49}
50
51impl QuicClientStream {
52 pub fn builder() -> QuicClientStreamBuilder {
54 QuicClientStreamBuilder::default()
55 }
56
57 async fn inner_send(
58 connection: Connection,
59 message: DnsRequest,
60 ) -> Result<DnsResponse, ProtoError> {
61 let (send_stream, recv_stream) = connection.open_bi().await?;
62
63 let mut stream = QuicStream::new(send_stream, recv_stream);
66
67 stream.send(message.into_parts().0).await?;
68
69 stream.finish().await?;
72
73 stream.receive().await
74 }
75}
76
77impl DnsRequestSender for QuicClientStream {
78 fn send_message(&mut self, message: DnsRequest) -> DnsResponseStream {
118 if self.is_shutdown {
119 panic!("can not send messages after stream is shutdown")
120 }
121
122 Box::pin(Self::inner_send(self.quic_connection.clone(), message)).into()
123 }
124
125 fn shutdown(&mut self) {
126 self.is_shutdown = true;
127 self.quic_connection
128 .close(DoqErrorCode::NoError.into(), b"Shutdown");
129 }
130
131 fn is_shutdown(&self) -> bool {
132 self.is_shutdown
133 }
134}
135
136impl Stream for QuicClientStream {
137 type Item = Result<(), ProtoError>;
138
139 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 if self.is_shutdown {
141 Poll::Ready(None)
142 } else {
143 Poll::Ready(Some(Ok(())))
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct QuicClientStreamBuilder {
151 crypto_config: Option<TlsClientConfig>,
152 transport_config: Arc<TransportConfig>,
153 bind_addr: Option<SocketAddr>,
154}
155
156impl QuicClientStreamBuilder {
157 pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
159 self.crypto_config = Some(crypto_config);
160 self
161 }
162
163 pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
165 self.bind_addr = Some(bind_addr);
166 self
167 }
168
169 pub fn build(self, name_server: SocketAddr, dns_name: String) -> QuicClientConnect {
176 QuicClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
177 }
178
179 pub fn build_with_future<S, F>(
181 self,
182 future: F,
183 name_server: SocketAddr,
184 dns_name: String,
185 ) -> QuicClientConnect
186 where
187 S: DnsUdpSocket + QuicLocalAddr + 'static,
188 F: Future<Output = std::io::Result<S>> + Send + 'static,
189 {
190 QuicClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
191 }
192
193 async fn connect_with_future<S, F>(
194 self,
195 future: F,
196 name_server: SocketAddr,
197 dns_name: String,
198 ) -> Result<QuicClientStream, ProtoError>
199 where
200 S: DnsUdpSocket + QuicLocalAddr + 'static,
201 F: Future<Output = std::io::Result<S>> + Send,
202 {
203 let socket = future.await?;
204 let endpoint_config = quic_config::endpoint();
205 let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
206 let endpoint = Endpoint::new_with_abstract_socket(
207 endpoint_config,
208 None,
209 wrapper,
210 Arc::new(quinn::TokioRuntime),
211 )?;
212 self.connect_inner(endpoint, name_server, dns_name).await
213 }
214
215 async fn connect(
216 self,
217 name_server: SocketAddr,
218 dns_name: String,
219 ) -> Result<QuicClientStream, ProtoError> {
220 let connect = if let Some(bind_addr) = self.bind_addr {
221 <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
222 } else {
223 <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
224 };
225
226 let socket = connect.await?;
227 let socket = socket.into_std()?;
228 let endpoint_config = quic_config::endpoint();
229 let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
230 self.connect_inner(endpoint, name_server, dns_name).await
231 }
232
233 async fn connect_inner(
234 self,
235 mut endpoint: Endpoint,
236 name_server: SocketAddr,
237 dns_name: String,
238 ) -> Result<QuicClientStream, ProtoError> {
239 let mut crypto_config = if let Some(crypto_config) = self.crypto_config {
241 crypto_config
242 } else {
243 client_config_tls13()?
244 };
245 if crypto_config.alpn_protocols.is_empty() {
246 crypto_config.alpn_protocols = vec![quic_stream::DOQ_ALPN.to_vec()];
247 }
248 let early_data_enabled = crypto_config.enable_early_data;
249
250 let mut client_config = ClientConfig::new(Arc::new(crypto_config));
251 client_config.transport_config(self.transport_config.clone());
252
253 endpoint.set_default_client_config(client_config);
254
255 let connecting = endpoint.connect(name_server, &dns_name)?;
256 let quic_connection = if early_data_enabled {
259 match connecting.into_0rtt() {
260 Ok((new_connection, _)) => new_connection,
261 Err(connecting) => connecting.await?,
262 }
263 } else {
264 connecting.await?
265 };
266
267 Ok(QuicClientStream {
268 quic_connection,
269 name_server_name: Arc::from(dns_name),
270 name_server,
271 is_shutdown: false,
272 })
273 }
274}
275
276pub fn client_config_tls13() -> Result<TlsClientConfig, ProtoError> {
278 use rustls::RootCertStore;
279 #[cfg_attr(
280 not(any(feature = "native-certs", feature = "webpki-roots")),
281 allow(unused_mut)
282 )]
283 let mut root_store = RootCertStore::empty();
284 #[cfg(all(feature = "native-certs", not(feature = "webpki-roots")))]
285 {
286 use crate::error::ProtoErrorKind;
287
288 let (added, ignored) =
289 root_store.add_parsable_certificates(&rustls_native_certs::load_native_certs()?);
290
291 if ignored > 0 {
292 tracing::warn!(
293 "failed to parse {} certificate(s) from the native root store",
294 ignored,
295 );
296 }
297
298 if added == 0 {
299 return Err(ProtoErrorKind::NativeCerts.into());
300 }
301 }
302 #[cfg(feature = "webpki-roots")]
303 root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
304 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
305 ta.subject,
306 ta.spki,
307 ta.name_constraints,
308 )
309 }));
310
311 Ok(TlsClientConfig::builder()
312 .with_safe_default_cipher_suites()
313 .with_safe_default_kx_groups()
314 .with_protocol_versions(&[&TLS13])
315 .expect("TLS 1.3 not supported")
316 .with_root_certificates(root_store)
317 .with_no_client_auth())
318}
319
320impl Default for QuicClientStreamBuilder {
321 fn default() -> Self {
322 let mut transport_config = quic_config::transport();
323 transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
325
326 Self {
327 crypto_config: None,
328 transport_config: Arc::new(transport_config),
329 bind_addr: None,
330 }
331 }
332}
333
334pub struct QuicClientConnect(
336 Pin<Box<dyn Future<Output = Result<QuicClientStream, ProtoError>> + Send>>,
337);
338
339impl Future for QuicClientConnect {
340 type Output = Result<QuicClientStream, ProtoError>;
341
342 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
343 self.0.poll_unpin(cx)
344 }
345}
346
347pub struct QuicClientResponse(
349 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
350);
351
352impl Future for QuicClientResponse {
353 type Output = Result<DnsResponse, ProtoError>;
354
355 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
356 self.0.as_mut().poll(cx).map_err(ProtoError::from)
357 }
358}