1use std::fmt::{Debug, Formatter};
9use std::{
10 fmt::{self, Display},
11 future::Future,
12 net::SocketAddr,
13 pin::Pin,
14 sync::Arc,
15 task::{Context, Poll},
16};
17
18use futures_util::{future::FutureExt, stream::Stream};
19use quinn::{AsyncUdpSocket, ClientConfig, Connection, Endpoint, TransportConfig, VarInt};
20use rustls::{version::TLS13, ClientConfig as TlsClientConfig};
21
22use crate::udp::{DnsUdpSocket, QuicLocalAddr};
23use crate::{
24 error::ProtoError,
25 quic::quic_stream::{DoqErrorCode, QuicStream},
26 udp::UdpSocket,
27 xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
28};
29
30use super::{quic_config, quic_stream};
31
32#[must_use = "futures do nothing unless polled"]
34pub struct QuicClientStream {
35 quic_connection: Connection,
36 name_server_name: Arc<str>,
37 name_server: SocketAddr,
38 is_shutdown: bool,
39}
40
41impl Display for QuicClientStream {
42 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
43 write!(
44 formatter,
45 "QUIC({},{})",
46 self.name_server, self.name_server_name
47 )
48 }
49}
50
51impl QuicClientStream {
52 pub fn builder() -> QuicClientStreamBuilder {
54 QuicClientStreamBuilder::default()
55 }
56
57 async fn inner_send(
58 connection: Connection,
59 message: DnsRequest,
60 ) -> Result<DnsResponse, ProtoError> {
61 let (send_stream, recv_stream) = connection.open_bi().await?;
62
63 let mut stream = QuicStream::new(send_stream, recv_stream);
66
67 stream.send(message.into_parts().0).await?;
68
69 stream.finish().await?;
72
73 stream.receive().await
74 }
75}
76
77impl DnsRequestSender for QuicClientStream {
78 fn send_message(&mut self, message: DnsRequest) -> DnsResponseStream {
118 if self.is_shutdown {
119 panic!("can not send messages after stream is shutdown")
120 }
121
122 Box::pin(Self::inner_send(self.quic_connection.clone(), message)).into()
123 }
124
125 fn shutdown(&mut self) {
126 self.is_shutdown = true;
127 self.quic_connection
128 .close(DoqErrorCode::NoError.into(), b"Shutdown");
129 }
130
131 fn is_shutdown(&self) -> bool {
132 self.is_shutdown
133 }
134}
135
136impl Stream for QuicClientStream {
137 type Item = Result<(), ProtoError>;
138
139 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 if self.is_shutdown {
141 Poll::Ready(None)
142 } else {
143 Poll::Ready(Some(Ok(())))
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct QuicClientStreamBuilder {
151 crypto_config: TlsClientConfig,
152 transport_config: Arc<TransportConfig>,
153 bind_addr: Option<SocketAddr>,
154}
155
156impl QuicClientStreamBuilder {
157 pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
159 self.crypto_config = crypto_config;
160 self
161 }
162
163 pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
165 self.bind_addr = Some(bind_addr);
166 self
167 }
168
169 pub fn build(self, name_server: SocketAddr, dns_name: String) -> QuicClientConnect {
176 QuicClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
177 }
178
179 pub fn build_with_future<S, F>(
181 self,
182 future: F,
183 name_server: SocketAddr,
184 dns_name: String,
185 ) -> QuicClientConnect
186 where
187 S: DnsUdpSocket + QuicLocalAddr + 'static,
188 F: Future<Output = std::io::Result<S>> + Send + 'static,
189 {
190 QuicClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
191 }
192
193 async fn connect_with_future<S, F>(
194 self,
195 future: F,
196 name_server: SocketAddr,
197 dns_name: String,
198 ) -> Result<QuicClientStream, ProtoError>
199 where
200 S: DnsUdpSocket + QuicLocalAddr + 'static,
201 F: Future<Output = std::io::Result<S>> + Send,
202 {
203 let socket = future.await?;
204 let endpoint_config = quic_config::endpoint();
205 let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
206 let endpoint = Endpoint::new_with_abstract_socket(
207 endpoint_config,
208 None,
209 wrapper,
210 Arc::new(quinn::TokioRuntime),
211 )?;
212 self.connect_inner(endpoint, name_server, dns_name).await
213 }
214
215 async fn connect(
216 self,
217 name_server: SocketAddr,
218 dns_name: String,
219 ) -> Result<QuicClientStream, ProtoError> {
220 let connect = if let Some(bind_addr) = self.bind_addr {
221 <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
222 } else {
223 <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
224 };
225
226 let socket = connect.await?;
227 let socket = socket.into_std()?;
228 let endpoint_config = quic_config::endpoint();
229 let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
230 self.connect_inner(endpoint, name_server, dns_name).await
231 }
232
233 async fn connect_inner(
234 self,
235 mut endpoint: Endpoint,
236 name_server: SocketAddr,
237 dns_name: String,
238 ) -> Result<QuicClientStream, ProtoError> {
239 let mut crypto_config = self.crypto_config;
241 if crypto_config.alpn_protocols.is_empty() {
242 crypto_config.alpn_protocols = vec![quic_stream::DOQ_ALPN.to_vec()];
243 }
244 let early_data_enabled = crypto_config.enable_early_data;
245
246 let mut client_config = ClientConfig::new(Arc::new(crypto_config));
247 client_config.transport_config(self.transport_config.clone());
248
249 endpoint.set_default_client_config(client_config);
250
251 let connecting = endpoint.connect(name_server, &dns_name)?;
252 let quic_connection = if early_data_enabled {
255 match connecting.into_0rtt() {
256 Ok((new_connection, _)) => new_connection,
257 Err(connecting) => connecting.await?,
258 }
259 } else {
260 connecting.await?
261 };
262
263 Ok(QuicClientStream {
264 quic_connection,
265 name_server_name: Arc::from(dns_name),
266 name_server,
267 is_shutdown: false,
268 })
269 }
270}
271
272pub fn client_config_tls13_webpki_roots() -> TlsClientConfig {
274 use rustls::{OwnedTrustAnchor, RootCertStore};
275 let mut root_store = RootCertStore::empty();
276 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
277 OwnedTrustAnchor::from_subject_spki_name_constraints(
278 ta.subject,
279 ta.spki,
280 ta.name_constraints,
281 )
282 }));
283
284 TlsClientConfig::builder()
285 .with_safe_default_cipher_suites()
286 .with_safe_default_kx_groups()
287 .with_protocol_versions(&[&TLS13])
288 .expect("TLS 1.3 not supported")
289 .with_root_certificates(root_store)
290 .with_no_client_auth()
291}
292
293impl Default for QuicClientStreamBuilder {
294 fn default() -> Self {
295 let mut transport_config = quic_config::transport();
296 transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
298
299 let client_config = client_config_tls13_webpki_roots();
300
301 Self {
302 crypto_config: client_config,
303 transport_config: Arc::new(transport_config),
304 bind_addr: None,
305 }
306 }
307}
308
309pub struct QuicClientConnect(
311 Pin<Box<dyn Future<Output = Result<QuicClientStream, ProtoError>> + Send>>,
312);
313
314impl Future for QuicClientConnect {
315 type Output = Result<QuicClientStream, ProtoError>;
316
317 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318 self.0.poll_unpin(cx)
319 }
320}
321
322pub struct QuicClientResponse(
324 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
325);
326
327impl Future for QuicClientResponse {
328 type Output = Result<DnsResponse, ProtoError>;
329
330 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
331 self.0.as_mut().poll(cx).map_err(ProtoError::from)
332 }
333}
334
335struct QuinnAsyncUdpSocketAdapter<S: DnsUdpSocket + QuicLocalAddr> {
337 io: S,
338}
339
340impl<S: DnsUdpSocket + QuicLocalAddr> Debug for QuinnAsyncUdpSocketAdapter<S> {
341 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
342 f.write_str("Wrapper for quinn::AsyncUdpSocket")
343 }
344}
345
346impl<S: DnsUdpSocket + QuicLocalAddr + 'static> AsyncUdpSocket for QuinnAsyncUdpSocketAdapter<S> {
348 fn poll_send(
349 &self,
350 _state: &quinn::udp::UdpState,
351 cx: &mut Context<'_>,
352 transmits: &[quinn::udp::Transmit],
353 ) -> Poll<std::io::Result<usize>> {
354 let io = &self.io;
356 let mut sent = 0;
357 for transmit in transmits {
358 match io.poll_send_to(cx, &transmit.contents, transmit.destination) {
359 Poll::Ready(ready) => match ready {
360 Ok(_) => {
361 sent += 1;
362 }
363 Err(_) if sent != 0 => return Poll::Ready(Ok(sent)),
367 Err(e) => {
368 if e.kind() == std::io::ErrorKind::WouldBlock {
369 return Poll::Ready(Err(e));
370 }
371
372 sent += 1;
380 }
381 },
382 Poll::Pending => {
383 return if sent == 0 {
384 Poll::Pending
385 } else {
386 Poll::Ready(Ok(sent))
387 }
388 }
389 }
390 }
391 Poll::Ready(Ok(sent))
392 }
393
394 fn poll_recv(
395 &self,
396 cx: &mut Context<'_>,
397 bufs: &mut [std::io::IoSliceMut<'_>],
398 meta: &mut [quinn::udp::RecvMeta],
399 ) -> Poll<std::io::Result<usize>> {
400 let io = &self.io;
403 let Some(buf) = bufs.get_mut(0)else {
404 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidInput,"no buf")));
405 };
406 match io.poll_recv_from(cx, buf.as_mut()) {
407 Poll::Ready(res) => match res {
408 Ok((len, addr)) => {
409 meta[0] = quinn::udp::RecvMeta {
410 len,
411 stride: len,
412 addr,
413 ecn: None,
414 dst_ip: None,
415 };
416 Poll::Ready(Ok(1))
417 }
418 Err(err) => Poll::Ready(Err(err)),
419 },
420 Poll::Pending => Poll::Pending,
421 }
422 }
423
424 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
425 self.io.local_addr()
426 }
427}