1use alloc::boxed::Box;
9use alloc::sync::Arc;
10use alloc::vec::Vec;
11use core::fmt::{self, Display};
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::collections::HashSet;
16use std::net::SocketAddr;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19use futures_util::{future::Future, stream::Stream};
20use tracing::{debug, trace, warn};
21
22use crate::error::{ProtoError, ProtoErrorKind};
23use crate::op::{Message, MessageFinalizer, MessageVerifier, Query};
24use crate::runtime::{RuntimeProvider, Time};
25use crate::udp::udp_stream::NextRandomUdpSocket;
26use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
27use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
28
29pub struct UdpClientStreamBuilder<P> {
33 name_server: SocketAddr,
34 timeout: Option<Duration>,
35 signer: Option<Arc<dyn MessageFinalizer>>,
36 bind_addr: Option<SocketAddr>,
37 avoid_local_ports: Arc<HashSet<u16>>,
38 os_port_selection: bool,
39 provider: P,
40}
41
42impl<P> UdpClientStreamBuilder<P> {
43 pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
45 self.timeout = timeout;
46 self
47 }
48
49 pub fn with_signer(self, signer: Option<Arc<dyn MessageFinalizer>>) -> Self {
51 Self {
52 name_server: self.name_server,
53 timeout: self.timeout,
54 signer,
55 bind_addr: self.bind_addr,
56 avoid_local_ports: self.avoid_local_ports,
57 os_port_selection: self.os_port_selection,
58 provider: self.provider,
59 }
60 }
61
62 pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
67 self.bind_addr = bind_addr;
68 self
69 }
70
71 pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
74 self.avoid_local_ports = avoid_local_ports;
75 self
76 }
77
78 pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
80 self.os_port_selection = os_port_selection;
81 self
82 }
83
84 pub fn build(self) -> UdpClientConnect<P> {
88 UdpClientConnect {
89 name_server: self.name_server,
90 timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
91 signer: self.signer,
92 bind_addr: self.bind_addr,
93 avoid_local_ports: self.avoid_local_ports.clone(),
94 os_port_selection: self.os_port_selection,
95 provider: self.provider,
96 }
97 }
98}
99
100#[must_use = "futures do nothing unless polled"]
106pub struct UdpClientStream<P> {
107 name_server: SocketAddr,
108 timeout: Duration,
109 is_shutdown: bool,
110 signer: Option<Arc<dyn MessageFinalizer>>,
111 bind_addr: Option<SocketAddr>,
112 avoid_local_ports: Arc<HashSet<u16>>,
113 os_port_selection: bool,
114 provider: P,
115}
116
117impl<P: RuntimeProvider> UdpClientStream<P> {
118 pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
120 UdpClientStreamBuilder {
121 name_server,
122 timeout: None,
123 signer: None,
124 bind_addr: None,
125 avoid_local_ports: Arc::default(),
126 os_port_selection: false,
127 provider,
128 }
129 }
130}
131
132impl<P> Display for UdpClientStream<P> {
133 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
134 write!(formatter, "UDP({})", self.name_server)
135 }
136}
137
138fn random_query_id() -> u16 {
140 rand::random()
141}
142
143impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
144 fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
145 if self.is_shutdown {
146 panic!("can not send messages after stream is shutdown")
147 }
148
149 let case_randomization = request.options().case_randomization;
150
151 request.set_id(random_query_id());
154
155 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
156 Ok(now) => now.as_secs(),
157 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
158 };
159
160 let now = now as u32;
162
163 let mut verifier = None;
164 if let Some(signer) = &self.signer {
165 if signer.should_finalize_message(&request) {
166 match request.finalize(&**signer, now) {
167 Ok(answer_verifier) => verifier = answer_verifier,
168 Err(e) => {
169 debug!("could not sign message: {}", e);
170 return e.into();
171 }
172 }
173 }
174 }
175
176 let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize);
178
179 let bytes = match request.to_vec() {
180 Ok(bytes) => bytes,
181 Err(err) => {
182 return err.into();
183 }
184 };
185
186 let message_id = request.id();
187 let message = SerialMessage::new(bytes, self.name_server);
188
189 debug!(
190 "final message: {}",
191 message
192 .to_message()
193 .expect("bizarre we just made this message")
194 );
195 let provider = self.provider.clone();
196 let addr = message.addr();
197 let bind_addr = self.bind_addr;
198 let avoid_local_ports = self.avoid_local_ports.clone();
199 let os_port_selection = self.os_port_selection;
200
201 P::Timer::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
202 self.timeout,
203 Box::pin(async move {
204 let socket = NextRandomUdpSocket::new(
205 addr,
206 bind_addr,
207 avoid_local_ports,
208 os_port_selection,
209 provider,
210 )
211 .await?;
212 send_serial_message_inner(
213 message,
214 message_id,
215 verifier,
216 socket,
217 recv_buf_size,
218 case_randomization,
219 request.original_query(),
220 )
221 .await
222 }),
223 )
224 .into()
225 }
226
227 fn shutdown(&mut self) {
228 self.is_shutdown = true;
229 }
230
231 fn is_shutdown(&self) -> bool {
232 self.is_shutdown
233 }
234}
235
236impl<P> Stream for UdpClientStream<P> {
238 type Item = Result<(), ProtoError>;
239
240 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241 if self.is_shutdown {
243 Poll::Ready(None)
244 } else {
245 Poll::Ready(Some(Ok(())))
246 }
247 }
248}
249
250pub struct UdpClientConnect<P> {
252 name_server: SocketAddr,
253 timeout: Duration,
254 signer: Option<Arc<dyn MessageFinalizer>>,
255 bind_addr: Option<SocketAddr>,
256 avoid_local_ports: Arc<HashSet<u16>>,
257 os_port_selection: bool,
258 provider: P,
259}
260
261impl<P: RuntimeProvider> Future for UdpClientConnect<P> {
262 type Output = Result<UdpClientStream<P>, ProtoError>;
263
264 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
265 Poll::Ready(Ok(UdpClientStream {
267 name_server: self.name_server,
268 is_shutdown: false,
269 timeout: self.timeout,
270 signer: self.signer.take(),
271 bind_addr: self.bind_addr,
272 avoid_local_ports: self.avoid_local_ports.clone(),
273 os_port_selection: self.os_port_selection,
274 provider: self.provider.clone(),
275 }))
276 }
277}
278
279async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
280 msg: SerialMessage,
281 msg_id: u16,
282 verifier: Option<MessageVerifier>,
283 socket: S,
284 recv_buf_size: usize,
285 case_randomization: bool,
286 original_query: Option<&Query>,
287) -> Result<DnsResponse, ProtoError> {
288 let bytes = msg.bytes();
289 let addr = msg.addr();
290 let len_sent: usize = socket.send_to(bytes, addr).await?;
291
292 if bytes.len() != len_sent {
293 return Err(ProtoError::from(format!(
294 "Not all bytes of message sent, {} of {}",
295 len_sent,
296 bytes.len()
297 )));
298 }
299
300 trace!("creating UDP receive buffer with size {recv_buf_size}");
302 let mut recv_buf = vec![0; recv_buf_size];
303
304 loop {
306 let (len, src) = socket.recv_from(&mut recv_buf).await?;
307
308 let response_bytes = &recv_buf[0..len];
310 let response_buffer = Vec::from(response_bytes);
311
312 let request_target = msg.addr();
314
315 if src.ip() != request_target.ip() || src.port() != request_target.port() {
317 warn!(
318 "ignoring response from {} because it does not match name_server: {}.",
319 src, request_target,
320 );
321
322 continue;
324 }
325
326 let mut response = match DnsResponse::from_buffer(response_buffer) {
327 Ok(response) => response,
328 Err(e) => {
329 warn!("dropped malformed message waiting for id: {msg_id} err: {e}");
331 continue;
332 }
333 };
334
335 if msg_id != response.id() {
337 warn!(
339 "expected message id: {} got: {}, dropped",
340 msg_id,
341 response.id()
342 );
343
344 continue;
345 }
346
347 let request_message = Message::from_vec(msg.bytes())?;
373 let request_queries = request_message.queries();
374 let response_queries = response.queries_mut();
375
376 let question_matches = response_queries
377 .iter()
378 .all(|elem| request_queries.contains(elem));
379 if case_randomization
380 && question_matches
381 && !response_queries.iter().all(|elem| {
382 request_queries
383 .iter()
384 .any(|req_q| req_q == elem && req_q.name().eq_case(elem.name()))
385 })
386 {
387 warn!(
388 "case of question section did not match: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
389 );
390 return Err(ProtoErrorKind::QueryCaseMismatch.into());
391 }
392 if !question_matches {
393 warn!(
394 "detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
395 );
396 continue;
397 }
398
399 if case_randomization {
401 if let Some(original_query) = original_query {
402 for response_query in response_queries.iter_mut() {
403 if response_query == original_query {
404 *response_query = original_query.clone();
405 }
406 }
407 }
408 }
409
410 debug!("received message id: {}", response.id());
411 if let Some(mut verifier) = verifier {
412 return verifier(response_bytes);
413 } else {
414 return Ok(response);
415 }
416 }
417}
418
419#[cfg(test)]
420#[cfg(feature = "tokio")]
421mod tests {
422 #![allow(clippy::dbg_macro, clippy::print_stdout)]
423 use crate::{runtime::TokioRuntimeProvider, tests::udp_client_stream_test};
424 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
425 use test_support::subscribe;
426
427 #[tokio::test]
428 async fn test_udp_client_stream_ipv4() {
429 subscribe();
430 let provider = TokioRuntimeProvider::new();
431 udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
432 }
433
434 #[tokio::test]
435 async fn test_udp_client_stream_ipv6() {
436 subscribe();
437 let provider = TokioRuntimeProvider::new();
438 udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
439 }
440}