1use alloc::{boxed::Box, string::String, sync::Arc};
9use core::{
10 fmt::{self, Display},
11 future::Future,
12 pin::Pin,
13 task::{Context, Poll},
14};
15use std::{io, net::SocketAddr};
16
17use futures_util::{future::FutureExt, stream::Stream};
18use quinn::{
19 ClientConfig, Connection, Endpoint, TransportConfig, VarInt, crypto::rustls::QuicClientConfig,
20};
21use tokio::time::timeout;
22
23use crate::{
24 error::ProtoError,
25 quic::quic_stream::{DoqErrorCode, QuicStream},
26 rustls::client_config,
27 udp::UdpSocket,
28 xfer::{CONNECT_TIMEOUT, DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream},
29};
30
31use super::{quic_config, quic_stream};
32
33#[must_use = "futures do nothing unless polled"]
35#[derive(Clone)]
36pub struct QuicClientStream {
37 quic_connection: Connection,
38 name_server_name: Arc<str>,
39 name_server: SocketAddr,
40 is_shutdown: bool,
41}
42
43impl Display for QuicClientStream {
44 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
45 write!(
46 formatter,
47 "QUIC({},{})",
48 self.name_server, self.name_server_name
49 )
50 }
51}
52
53impl QuicClientStream {
54 pub fn builder() -> QuicClientStreamBuilder {
56 QuicClientStreamBuilder::default()
57 }
58
59 async fn inner_send(
60 connection: Connection,
61 message: DnsRequest,
62 ) -> Result<DnsResponse, ProtoError> {
63 let (send_stream, recv_stream) = connection.open_bi().await?;
64
65 let mut stream = QuicStream::new(send_stream, recv_stream);
68
69 stream.send(message.into_parts().0).await?;
70
71 stream.finish().await?;
74
75 stream.receive().await
76 }
77}
78
79impl DnsRequestSender for QuicClientStream {
80 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
120 if self.is_shutdown {
121 panic!("can not send messages after stream is shutdown")
122 }
123
124 Box::pin(Self::inner_send(self.quic_connection.clone(), request)).into()
125 }
126
127 fn shutdown(&mut self) {
128 self.is_shutdown = true;
129 self.quic_connection
130 .close(DoqErrorCode::NoError.into(), b"Shutdown");
131 }
132
133 fn is_shutdown(&self) -> bool {
134 self.is_shutdown
135 }
136}
137
138impl Stream for QuicClientStream {
139 type Item = Result<(), ProtoError>;
140
141 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142 if self.is_shutdown {
143 Poll::Ready(None)
144 } else {
145 Poll::Ready(Some(Ok(())))
146 }
147 }
148}
149
150#[derive(Clone)]
152pub struct QuicClientStreamBuilder {
153 crypto_config: Option<rustls::ClientConfig>,
154 transport_config: Arc<TransportConfig>,
155 bind_addr: Option<SocketAddr>,
156}
157
158impl QuicClientStreamBuilder {
159 pub fn crypto_config(&mut self, crypto_config: rustls::ClientConfig) -> &mut Self {
161 self.crypto_config = Some(crypto_config);
162 self
163 }
164
165 pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
167 self.bind_addr = Some(bind_addr);
168 self
169 }
170
171 pub fn build(self, name_server: SocketAddr, dns_name: String) -> QuicClientConnect {
178 QuicClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
179 }
180
181 pub fn build_with_future(
183 self,
184 socket: Arc<dyn quinn::AsyncUdpSocket>,
185 name_server: SocketAddr,
186 dns_name: String,
187 ) -> QuicClientConnect {
188 QuicClientConnect(Box::pin(self.connect_with_future(socket, name_server, dns_name)) as _)
189 }
190
191 async fn connect_with_future(
192 self,
193 socket: Arc<dyn quinn::AsyncUdpSocket>,
194 name_server: SocketAddr,
195 dns_name: String,
196 ) -> Result<QuicClientStream, ProtoError> {
197 let endpoint_config = quic_config::endpoint();
198 let endpoint = Endpoint::new_with_abstract_socket(
199 endpoint_config,
200 None,
201 socket,
202 Arc::new(quinn::TokioRuntime),
203 )?;
204 self.connect_inner(endpoint, name_server, dns_name).await
205 }
206
207 async fn connect(
208 self,
209 name_server: SocketAddr,
210 dns_name: String,
211 ) -> Result<QuicClientStream, ProtoError> {
212 let connect = if let Some(bind_addr) = self.bind_addr {
213 <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
214 } else {
215 <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
216 };
217
218 let socket = connect.await?;
219 let socket = socket.into_std()?;
220 let endpoint_config = quic_config::endpoint();
221 let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
222 self.connect_inner(endpoint, name_server, dns_name).await
223 }
224
225 async fn connect_inner(
226 self,
227 endpoint: Endpoint,
228 name_server: SocketAddr,
229 dns_name: String,
230 ) -> Result<QuicClientStream, ProtoError> {
231 let crypto_config = if let Some(crypto_config) = self.crypto_config {
233 crypto_config
234 } else {
235 client_config()
236 };
237
238 let quic_connection = connect_quic(
239 name_server,
240 &dns_name,
241 quic_stream::DOQ_ALPN,
242 crypto_config,
243 self.transport_config,
244 endpoint,
245 )
246 .await?;
247
248 Ok(QuicClientStream {
249 quic_connection,
250 name_server_name: Arc::from(dns_name),
251 name_server,
252 is_shutdown: false,
253 })
254 }
255}
256
257pub(crate) async fn connect_quic(
258 addr: SocketAddr,
259 server_name: &str,
260 protocol: &[u8],
261 mut crypto_config: rustls::ClientConfig,
262 transport_config: Arc<TransportConfig>,
263 mut endpoint: Endpoint,
264) -> Result<Connection, ProtoError> {
265 if crypto_config.alpn_protocols.is_empty() {
266 crypto_config.alpn_protocols = vec![protocol.to_vec()];
267 }
268 let early_data_enabled = crypto_config.enable_early_data;
269
270 let mut client_config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto_config)?));
271 client_config.transport_config(transport_config.clone());
272
273 endpoint.set_default_client_config(client_config);
274
275 let connecting = endpoint.connect(addr, server_name)?;
276 Ok(if early_data_enabled {
279 match connecting.into_0rtt() {
280 Ok((new_connection, _)) => new_connection,
281 Err(connecting) => connect_with_timeout(connecting).await?,
282 }
283 } else {
284 connect_with_timeout(connecting).await?
285 })
286}
287
288async fn connect_with_timeout(connecting: quinn::Connecting) -> Result<Connection, io::Error> {
289 match timeout(CONNECT_TIMEOUT, connecting).await {
290 Ok(Ok(connection)) => Ok(connection),
291 Ok(Err(e)) => Err(e.into()),
292 Err(_) => Err(io::Error::new(
293 io::ErrorKind::TimedOut,
294 format!("QUIC handshake timed out after {CONNECT_TIMEOUT:?}",),
295 )),
296 }
297}
298
299impl Default for QuicClientStreamBuilder {
300 fn default() -> Self {
301 let mut transport_config = quic_config::transport();
302 transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
304
305 Self {
306 crypto_config: None,
307 transport_config: Arc::new(transport_config),
308 bind_addr: None,
309 }
310 }
311}
312
313pub struct QuicClientConnect(
315 Pin<Box<dyn Future<Output = Result<QuicClientStream, ProtoError>> + Send>>,
316);
317
318impl Future for QuicClientConnect {
319 type Output = Result<QuicClientStream, ProtoError>;
320
321 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
322 self.0.poll_unpin(cx)
323 }
324}
325
326pub struct QuicClientResponse(
328 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
329);
330
331impl Future for QuicClientResponse {
332 type Output = Result<DnsResponse, ProtoError>;
333
334 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335 self.0.as_mut().poll(cx).map_err(ProtoError::from)
336 }
337}