1use alloc::{boxed::Box, sync::Arc};
11use core::{
12 borrow::Borrow,
13 fmt::{self, Display},
14 marker::Unpin,
15 pin::Pin,
16 task::{Context, Poll},
17 time::Duration,
18};
19use std::{
20 collections::{HashMap, hash_map::Entry},
21 time::{SystemTime, UNIX_EPOCH},
22};
23
24use futures_channel::mpsc;
25use futures_util::{
26 FutureExt,
27 future::Future,
28 ready,
29 stream::{Stream, StreamExt},
30};
31use rand::Rng;
32use tracing::debug;
33
34use crate::{
35 DnsStreamHandle,
36 error::{ProtoError, ProtoErrorKind},
37 op::{MessageFinalizer, MessageVerifier},
38 runtime::Time,
39 xfer::{
40 BufDnsStreamHandle, CHANNEL_BUFFER_SIZE, DnsClientStream, DnsRequest, DnsRequestSender,
41 DnsResponse, DnsResponseStream, SerialMessage, ignore_send,
42 },
43};
44
45const QOS_MAX_RECEIVE_MSGS: usize = 100; struct ActiveRequest {
48 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
50 request_id: u16,
51 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
52 verifier: Option<MessageVerifier>,
53}
54
55impl ActiveRequest {
56 fn new(
57 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
58 request_id: u16,
59 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
60 verifier: Option<MessageVerifier>,
61 ) -> Self {
62 Self {
63 completion,
64 request_id,
65 timeout,
67 verifier,
68 }
69 }
70
71 fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
73 self.timeout.poll_unpin(cx)
74 }
75
76 fn is_canceled(&self) -> bool {
78 self.completion.is_closed()
79 }
80
81 fn request_id(&self) -> u16 {
83 self.request_id
84 }
85
86 fn complete_with_error(mut self, error: ProtoError) {
88 ignore_send(self.completion.try_send(Err(error)));
89 }
90}
91
92#[must_use = "futures do nothing unless polled"]
98pub struct DnsMultiplexer<S>
99where
100 S: DnsClientStream + 'static,
101{
102 stream: S,
103 timeout_duration: Duration,
104 stream_handle: BufDnsStreamHandle,
105 active_requests: HashMap<u16, ActiveRequest>,
106 signer: Option<Arc<dyn MessageFinalizer>>,
107 is_shutdown: bool,
108}
109
110impl<S> DnsMultiplexer<S>
111where
112 S: DnsClientStream + Unpin + 'static,
113{
114 #[allow(clippy::new_ret_no_self)]
123 pub fn new<F>(
124 stream: F,
125 stream_handle: BufDnsStreamHandle,
126 signer: Option<Arc<dyn MessageFinalizer>>,
127 ) -> DnsMultiplexerConnect<F, S>
128 where
129 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
130 {
131 Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
132 }
133
134 pub fn with_timeout<F>(
145 stream: F,
146 stream_handle: BufDnsStreamHandle,
147 timeout_duration: Duration,
148 signer: Option<Arc<dyn MessageFinalizer>>,
149 ) -> DnsMultiplexerConnect<F, S>
150 where
151 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
152 {
153 DnsMultiplexerConnect {
154 stream,
155 stream_handle: Some(stream_handle),
156 timeout_duration,
157 signer,
158 }
159 }
160
161 fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
164 let mut canceled = HashMap::<u16, ProtoError>::new();
165 for (&id, active_req) in &mut self.active_requests {
166 if active_req.is_canceled() {
167 canceled.insert(id, ProtoError::from("requestor canceled"));
168 }
169
170 match active_req.poll_timeout(cx) {
172 Poll::Ready(()) => {
173 debug!("request timed out: {}", id);
174 canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
175 }
176 Poll::Pending => (),
177 }
178 }
179
180 for (id, error) in canceled {
182 if let Some(active_request) = self.active_requests.remove(&id) {
183 active_request.complete_with_error(error);
185 }
186 }
187 }
188
189 fn next_random_query_id(&self) -> Result<u16, ProtoError> {
191 let mut rand = rand::rng();
192
193 for _ in 0..100 {
194 let id: u16 = rand.random(); if !self.active_requests.contains_key(&id) {
197 return Ok(id);
198 }
199 }
200
201 Err(ProtoError::from(
202 "id space exhausted, consider filing an issue",
203 ))
204 }
205
206 fn stream_closed_close_all(&mut self, error: ProtoError) {
208 debug!(error = error.as_dyn(), stream = %self.stream);
209
210 for (_, active_request) in self.active_requests.drain() {
211 active_request.complete_with_error(error.clone());
213 }
214 }
215}
216
217#[must_use = "futures do nothing unless polled"]
219pub struct DnsMultiplexerConnect<F, S>
220where
221 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
222 S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
223{
224 stream: F,
225 stream_handle: Option<BufDnsStreamHandle>,
226 timeout_duration: Duration,
227 signer: Option<Arc<dyn MessageFinalizer>>,
228}
229
230impl<F, S> Future for DnsMultiplexerConnect<F, S>
231where
232 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
233 S: DnsClientStream + Unpin + 'static,
234{
235 type Output = Result<DnsMultiplexer<S>, ProtoError>;
236
237 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 let stream: S = ready!(self.stream.poll_unpin(cx))?;
239
240 Poll::Ready(Ok(DnsMultiplexer {
241 stream,
242 timeout_duration: self.timeout_duration,
243 stream_handle: self
244 .stream_handle
245 .take()
246 .expect("must not poll after complete"),
247 active_requests: HashMap::new(),
248 signer: self.signer.clone(),
249 is_shutdown: false,
250 }))
251 }
252}
253
254impl<S> Display for DnsMultiplexer<S>
255where
256 S: DnsClientStream + 'static,
257{
258 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
259 write!(formatter, "{}", self.stream)
260 }
261}
262
263impl<S> DnsRequestSender for DnsMultiplexer<S>
264where
265 S: DnsClientStream + Unpin + 'static,
266{
267 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
268 if self.is_shutdown {
269 panic!("can not send messages after stream is shutdown")
270 }
271
272 if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
273 return ProtoError::from(ProtoErrorKind::Busy).into();
274 }
275
276 let query_id = match self.next_random_query_id() {
277 Ok(id) => id,
278 Err(e) => return e.into(),
279 };
280
281 let (mut request, _) = request.into_parts();
282 request.set_id(query_id);
283
284 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
285 Ok(now) => now.as_secs(),
286 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
287 };
288
289 let now = now as u32;
291
292 let mut verifier = None;
293 if let Some(signer) = &self.signer {
294 if signer.should_finalize_message(&request) {
295 match request.finalize(signer.borrow(), now) {
296 Ok(answer_verifier) => verifier = answer_verifier,
297 Err(e) => {
298 debug!("could not sign message: {}", e);
299 return e.into();
300 }
301 }
302 }
303 }
304
305 let timeout = S::Time::delay_for(self.timeout_duration);
307
308 let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
309
310 let active_request =
312 ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
313
314 match request.to_vec() {
315 Ok(buffer) => {
316 debug!(id = %active_request.request_id(), "sending message");
317 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
318
319 debug!(
320 "final message: {}",
321 serial_message
322 .to_message()
323 .expect("bizarre we just made this message")
324 );
325
326 match self.stream_handle.send(serial_message) {
329 Ok(()) => self
330 .active_requests
331 .insert(active_request.request_id(), active_request),
332 Err(err) => return err.into(),
333 };
334 }
335 Err(e) => {
336 debug!(
337 id = %active_request.request_id(),
338 error = e.as_dyn(),
339 "error message"
340 );
341 return e.into();
343 }
344 }
345
346 receiver.into()
347 }
348
349 fn shutdown(&mut self) {
350 self.is_shutdown = true;
351 }
352
353 fn is_shutdown(&self) -> bool {
354 self.is_shutdown
355 }
356}
357
358impl<S> Stream for DnsMultiplexer<S>
359where
360 S: DnsClientStream + Unpin + 'static,
361{
362 type Item = Result<(), ProtoError>;
363
364 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
365 self.drop_cancelled(cx);
367
368 if self.is_shutdown && self.active_requests.is_empty() {
369 debug!("stream is done: {}", self);
370 return Poll::Ready(None);
371 }
372
373 let mut messages_received = 0;
377 for i in 0..QOS_MAX_RECEIVE_MSGS {
378 match self.stream.poll_next_unpin(cx) {
379 Poll::Ready(Some(Ok(buffer))) => {
380 messages_received = i;
381
382 match DnsResponse::from_buffer(buffer.into_parts().0) {
384 Ok(response) => match self.active_requests.entry(response.id()) {
385 Entry::Occupied(mut request_entry) => {
386 let active_request = request_entry.get_mut();
388 if let Some(verifier) = &mut active_request.verifier {
389 ignore_send(
390 active_request
391 .completion
392 .try_send(verifier(response.as_buffer())),
393 );
394 } else {
395 ignore_send(active_request.completion.try_send(Ok(response)));
396 }
397 }
398 Entry::Vacant(..) => debug!("unexpected request_id: {}", response.id()),
399 },
400 Err(error) => debug!(error = error.as_dyn(), "error decoding message"),
402 }
403 }
404 Poll::Ready(err) => {
405 let err = match err {
406 Some(Err(e)) => e,
407 None => ProtoError::from("stream closed"),
408 _ => unreachable!(),
409 };
410
411 self.stream_closed_close_all(err);
412 self.is_shutdown = true;
413 return Poll::Ready(None);
414 }
415 Poll::Pending => break,
416 }
417 }
418
419 if messages_received == QOS_MAX_RECEIVE_MSGS {
423 cx.waker().wake_by_ref();
425 }
426
427 Poll::Pending
429 }
430}
431
432#[cfg(test)]
433mod test {
434 use alloc::vec::Vec;
435 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
436
437 use futures_util::future;
438 use futures_util::stream::TryStreamExt;
439 use test_support::subscribe;
440
441 use super::*;
442 use crate::op::op_code::OpCode;
443 use crate::op::{Message, MessageType, Query};
444 use crate::rr::record_type::RecordType;
445 use crate::rr::{DNSClass, Name, RData, Record};
446 use crate::serialize::binary::BinEncodable;
447 use crate::xfer::StreamReceiver;
448 use crate::xfer::{DnsClientStream, DnsRequestOptions};
449
450 struct MockClientStream {
451 messages: Vec<Message>,
452 addr: SocketAddr,
453 id: Option<u16>,
454 receiver: Option<StreamReceiver>,
455 }
456
457 impl MockClientStream {
458 fn new(
459 mut messages: Vec<Message>,
460 addr: SocketAddr,
461 ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
462 messages.reverse(); Box::pin(future::ok(Self {
464 messages,
465 addr,
466 id: None,
467 receiver: None,
468 }))
469 }
470 }
471
472 impl fmt::Display for MockClientStream {
473 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
474 write!(formatter, "TestClientStream")
475 }
476 }
477
478 impl Stream for MockClientStream {
479 type Item = Result<SerialMessage, ProtoError>;
480
481 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
482 let id = if let Some(id) = self.id {
483 id
484 } else {
485 let serial = ready!(
486 self.receiver
487 .as_mut()
488 .expect("should only be polled after receiver has been set")
489 .poll_next_unpin(cx)
490 );
491 let message = serial.unwrap().to_message().unwrap();
492 self.id = Some(message.id());
493 message.id()
494 };
495
496 if let Some(mut message) = self.messages.pop() {
497 message.set_id(id);
498 Poll::Ready(Some(Ok(SerialMessage::new(
499 message.to_bytes().unwrap(),
500 self.addr,
501 ))))
502 } else {
503 Poll::Pending
504 }
505 }
506 }
507
508 impl DnsClientStream for MockClientStream {
509 type Time = crate::runtime::TokioTime;
510
511 fn name_server_addr(&self) -> SocketAddr {
512 self.addr
513 }
514 }
515
516 async fn get_mocked_multiplexer(
517 mock_response: Vec<Message>,
518 ) -> DnsMultiplexer<MockClientStream> {
519 let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
520 let mock_response = MockClientStream::new(mock_response, addr);
521 let (handler, receiver) = BufDnsStreamHandle::new(addr);
522 let mut multiplexer =
523 DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
524 .await
525 .unwrap();
526
527 multiplexer.stream.receiver = Some(receiver); multiplexer
530 }
531
532 fn a_query_answer() -> (DnsRequest, Vec<Message>) {
533 let name = Name::from_ascii("www.example.com.").unwrap();
534
535 let mut msg = Message::new();
536 msg.add_query({
537 let mut query = Query::query(name.clone(), RecordType::A);
538 query.set_query_class(DNSClass::IN);
539 query
540 })
541 .set_message_type(MessageType::Query)
542 .set_op_code(OpCode::Query)
543 .set_recursion_desired(true);
544
545 let query = msg.clone();
546 msg.set_message_type(MessageType::Response).add_answer(
547 Record::from_rdata(
548 name,
549 86400,
550 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
551 )
552 .set_dns_class(DNSClass::IN)
553 .clone(),
554 );
555 (
556 DnsRequest::new(query, DnsRequestOptions::default()),
557 vec![msg],
558 )
559 }
560
561 fn axfr_query() -> Message {
562 let name = Name::from_ascii("example.com.").unwrap();
563
564 let mut msg = Message::new();
565 msg.add_query({
566 let mut query = Query::query(name, RecordType::AXFR);
567 query.set_query_class(DNSClass::IN);
568 query
569 })
570 .set_message_type(MessageType::Query)
571 .set_op_code(OpCode::Query)
572 .set_recursion_desired(true);
573 msg
574 }
575
576 fn axfr_response() -> Vec<Record> {
577 use crate::rr::rdata::*;
578 let origin = Name::from_ascii("example.com.").unwrap();
579 let soa = Record::from_rdata(
580 origin.clone(),
581 3600,
582 RData::SOA(SOA::new(
583 Name::parse("sns.dns.icann.org.", None).unwrap(),
584 Name::parse("noc.dns.icann.org.", None).unwrap(),
585 2015082403,
586 7200,
587 3600,
588 1209600,
589 3600,
590 )),
591 )
592 .set_dns_class(DNSClass::IN)
593 .clone();
594
595 vec![
596 soa.clone(),
597 Record::from_rdata(
598 origin.clone(),
599 86400,
600 RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
601 )
602 .set_dns_class(DNSClass::IN)
603 .clone(),
604 Record::from_rdata(
605 origin.clone(),
606 86400,
607 RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
608 )
609 .set_dns_class(DNSClass::IN)
610 .clone(),
611 Record::from_rdata(
612 origin.clone(),
613 86400,
614 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
615 )
616 .set_dns_class(DNSClass::IN)
617 .clone(),
618 Record::from_rdata(
619 origin,
620 86400,
621 RData::AAAA(
622 Ipv6Addr::new(
623 0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
624 )
625 .into(),
626 ),
627 )
628 .set_dns_class(DNSClass::IN)
629 .clone(),
630 soa,
631 ]
632 }
633
634 fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
635 let mut msg = axfr_query();
636
637 let query = msg.clone();
638 msg.set_message_type(MessageType::Response)
639 .insert_answers(axfr_response());
640 (
641 DnsRequest::new(query, DnsRequestOptions::default()),
642 vec![msg],
643 )
644 }
645
646 fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
647 let base = axfr_query();
648
649 let query = base.clone();
650 let mut rr = axfr_response();
651 let rr2 = rr.split_off(3);
652 let mut msg1 = base.clone();
653 msg1.set_message_type(MessageType::Response)
654 .insert_answers(rr);
655 let mut msg2 = base;
656 msg2.set_message_type(MessageType::Response)
657 .insert_answers(rr2);
658 (
659 DnsRequest::new(query, DnsRequestOptions::default()),
660 vec![msg1, msg2],
661 )
662 }
663
664 #[tokio::test]
665 async fn test_multiplexer_a() {
666 subscribe();
667 let (query, answer) = a_query_answer();
668 let mut multiplexer = get_mocked_multiplexer(answer).await;
669 let response = multiplexer.send_message(query);
670 let response = tokio::select! {
671 _ = multiplexer.next() => {
672 panic!("should never end")
674 },
675 r = response.try_collect::<Vec<_>>() => r.unwrap(),
676 };
677 assert_eq!(response.len(), 1);
678 }
679
680 #[tokio::test]
681 async fn test_multiplexer_axfr() {
682 subscribe();
683 let (query, answer) = axfr_query_answer();
684 let mut multiplexer = get_mocked_multiplexer(answer).await;
685 let response = multiplexer.send_message(query);
686 let response = tokio::select! {
687 _ = multiplexer.next() => {
688 panic!("should never end")
690 },
691 r = response.try_collect::<Vec<_>>() => r.unwrap(),
692 };
693 assert_eq!(response.len(), 1);
694 assert_eq!(response[0].answers().len(), axfr_response().len());
695 }
696
697 #[tokio::test]
698 async fn test_multiplexer_axfr_multi() {
699 subscribe();
700 let (query, answer) = axfr_query_answer_multi();
701 let mut multiplexer = get_mocked_multiplexer(answer).await;
702 let response = multiplexer.send_message(query);
703 let response = tokio::select! {
704 _ = multiplexer.next() => {
705 panic!("should never end")
707 },
708 r = response.try_collect::<Vec<_>>() => r.unwrap(),
709 };
710 assert_eq!(response.len(), 2);
711 assert_eq!(
712 response.iter().map(|m| m.answers().len()).sum::<usize>(),
713 axfr_response().len()
714 );
715 }
716}