1use std::borrow::Borrow;
9use std::fmt::{self, Display};
10use std::marker::PhantomData;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use futures_util::{future::Future, stream::Stream};
18use tracing::{debug, trace, warn};
19
20use crate::error::ProtoError;
21use crate::op::message::NoopMessageFinalizer;
22use crate::op::{Message, MessageFinalizer, MessageVerifier};
23use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket};
24use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
25use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
26use crate::Time;
27
28#[must_use = "futures do nothing unless polled"]
33pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
34where
35 S: Send,
36 MF: MessageFinalizer,
37{
38 name_server: SocketAddr,
39 timeout: Duration,
40 is_shutdown: bool,
41 signer: Option<Arc<MF>>,
42 creator: UdpCreator<S>,
43 marker: PhantomData<S>,
44}
45
46impl<S: UdpSocket + Send + 'static> UdpClientStream<S, NoopMessageFinalizer> {
47 #[allow(clippy::new_ret_no_self)]
56 pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
57 Self::with_timeout(name_server, Duration::from_secs(5))
58 }
59
60 pub fn with_timeout(
67 name_server: SocketAddr,
68 timeout: Duration,
69 ) -> UdpClientConnect<S, NoopMessageFinalizer> {
70 Self::with_bind_addr_and_timeout(name_server, None, timeout)
71 }
72
73 pub fn with_bind_addr_and_timeout(
81 name_server: SocketAddr,
82 bind_addr: Option<SocketAddr>,
83 timeout: Duration,
84 ) -> UdpClientConnect<S, NoopMessageFinalizer> {
85 Self::with_timeout_and_signer_and_bind_addr(name_server, timeout, None, bind_addr)
86 }
87}
88
89impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> UdpClientStream<S, MF> {
90 pub fn with_timeout_and_signer(
97 name_server: SocketAddr,
98 timeout: Duration,
99 signer: Option<Arc<MF>>,
100 ) -> UdpClientConnect<S, MF> {
101 UdpClientConnect {
102 name_server,
103 timeout,
104 signer,
105 creator: Arc::new(|local_addr: _, server_addr: _| {
106 Box::pin(NextRandomUdpSocket::<S>::new(
107 &server_addr,
108 &Some(local_addr),
109 ))
110 }),
111 marker: PhantomData::<S>,
112 }
113 }
114
115 pub fn with_timeout_and_signer_and_bind_addr(
123 name_server: SocketAddr,
124 timeout: Duration,
125 signer: Option<Arc<MF>>,
126 bind_addr: Option<SocketAddr>,
127 ) -> UdpClientConnect<S, MF> {
128 UdpClientConnect {
129 name_server,
130 timeout,
131 signer,
132 creator: Arc::new(move |local_addr: _, server_addr: _| {
133 Box::pin(NextRandomUdpSocket::<S>::new(
134 &server_addr,
135 &Some(bind_addr.unwrap_or(local_addr)),
136 ))
137 }),
138 marker: PhantomData::<S>,
139 }
140 }
141}
142
143impl<S: DnsUdpSocket + Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
144 pub fn with_creator(
153 name_server: SocketAddr,
154 signer: Option<Arc<MF>>,
155 timeout: Duration,
156 creator: UdpCreator<S>,
157 ) -> UdpClientConnect<S, MF> {
158 UdpClientConnect {
159 name_server,
160 timeout,
161 signer,
162 creator,
163 marker: PhantomData::<S>,
164 }
165 }
166}
167
168impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
169 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
170 write!(formatter, "UDP({})", self.name_server)
171 }
172}
173
174fn random_query_id() -> u16 {
176 use rand::distributions::{Distribution, Standard};
177 let mut rand = rand::thread_rng();
178
179 Standard.sample(&mut rand)
180}
181
182impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
183 for UdpClientStream<S, MF>
184{
185 fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
186 if self.is_shutdown {
187 panic!("can not send messages after stream is shutdown")
188 }
189
190 message.set_id(random_query_id());
193
194 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
195 Ok(now) => now.as_secs(),
196 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
197 };
198
199 let now = now as u32;
201
202 let mut verifier = None;
203 if let Some(ref signer) = self.signer {
204 if signer.should_finalize_message(&message) {
205 match message.finalize::<MF>(signer.borrow(), now) {
206 Ok(answer_verifier) => verifier = answer_verifier,
207 Err(e) => {
208 debug!("could not sign message: {}", e);
209 return e.into();
210 }
211 }
212 }
213 }
214
215 let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize);
217
218 let bytes = match message.to_vec() {
219 Ok(bytes) => bytes,
220 Err(err) => {
221 return err.into();
222 }
223 };
224
225 let message_id = message.id();
226 let message = SerialMessage::new(bytes, self.name_server);
227
228 debug!(
229 "final message: {}",
230 message
231 .to_message()
232 .expect("bizarre we just made this message")
233 );
234 let creator = self.creator.clone();
235 let addr = message.addr();
236
237 S::Time::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
238 self.timeout,
239 Box::pin(async move {
240 let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?;
241 send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
242 .await
243 }),
244 )
245 .into()
246 }
247
248 fn shutdown(&mut self) {
249 self.is_shutdown = true;
250 }
251
252 fn is_shutdown(&self) -> bool {
253 self.is_shutdown
254 }
255}
256
257impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
259 type Item = Result<(), ProtoError>;
260
261 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
262 if self.is_shutdown {
264 Poll::Ready(None)
265 } else {
266 Poll::Ready(Some(Ok(())))
267 }
268 }
269}
270
271pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
273where
274 S: Send,
275 MF: MessageFinalizer,
276{
277 name_server: SocketAddr,
278 timeout: Duration,
279 signer: Option<Arc<MF>>,
280 creator: UdpCreator<S>,
281 marker: PhantomData<S>,
282}
283
284impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
285 type Output = Result<UdpClientStream<S, MF>, ProtoError>;
286
287 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
288 Poll::Ready(Ok(UdpClientStream::<S, MF> {
290 name_server: self.name_server,
291 is_shutdown: false,
292 timeout: self.timeout,
293 signer: self.signer.take(),
294 creator: self.creator.clone(),
295 marker: PhantomData,
296 }))
297 }
298}
299
300async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
301 msg: SerialMessage,
302 msg_id: u16,
303 verifier: Option<MessageVerifier>,
304 socket: S,
305 recv_buf_size: usize,
306) -> Result<DnsResponse, ProtoError> {
307 let bytes = msg.bytes();
308 let addr = msg.addr();
309 let len_sent: usize = socket.send_to(bytes, addr).await?;
310
311 if bytes.len() != len_sent {
312 return Err(ProtoError::from(format!(
313 "Not all bytes of message sent, {} of {}",
314 len_sent,
315 bytes.len()
316 )));
317 }
318
319 trace!("creating UDP receive buffer with size {recv_buf_size}");
321 let mut recv_buf = vec![0; recv_buf_size];
322
323 loop {
325 let (len, src) = socket.recv_from(&mut recv_buf).await?;
326
327 let buffer: Vec<_> = Vec::from(&recv_buf[0..len]);
329
330 let request_target = msg.addr();
332
333 if src != request_target {
334 warn!(
335 "ignoring response from {} because it does not match name_server: {}.",
336 src, request_target,
337 );
338
339 continue;
341 }
342
343 match Message::from_vec(&buffer) {
346 Ok(message) => {
347 if msg_id == message.id() {
348 debug!("received message id: {}", message.id());
349 if let Some(mut verifier) = verifier {
350 return verifier(&buffer);
351 } else {
352 return Ok(DnsResponse::new(message, buffer));
353 }
354 } else {
355 warn!(
357 "expected message id: {} got: {}, dropped",
358 msg_id,
359 message.id()
360 );
361
362 continue;
363 }
364 }
365 Err(e) => {
366 warn!(
368 "dropped malformed message waiting for id: {} err: {}",
369 msg_id, e
370 );
371
372 continue;
373 }
374 }
375 }
376}
377
378#[cfg(test)]
379#[cfg(feature = "tokio-runtime")]
380mod tests {
381 #![allow(clippy::dbg_macro, clippy::print_stdout)]
382 use crate::tests::udp_client_stream_test;
383 #[cfg(not(target_os = "linux"))]
384 use std::net::Ipv6Addr;
385 use std::net::{IpAddr, Ipv4Addr};
386 use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
387
388 #[test]
389 fn test_udp_client_stream_ipv4() {
390 let io_loop = Runtime::new().expect("failed to create tokio runtime");
391 udp_client_stream_test::<TokioUdpSocket, Runtime>(
392 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
393 io_loop,
394 )
395 }
396
397 #[test]
398 #[cfg(not(target_os = "linux"))] fn test_udp_client_stream_ipv6() {
400 let io_loop = Runtime::new().expect("failed to create tokio runtime");
401 udp_client_stream_test::<TokioUdpSocket, Runtime>(
402 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
403 io_loop,
404 )
405 }
406}