1use std::{
11 borrow::Borrow,
12 collections::{hash_map::Entry, HashMap},
13 fmt::{self, Display},
14 marker::Unpin,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18 time::{Duration, SystemTime, UNIX_EPOCH},
19};
20
21use futures_channel::mpsc;
22use futures_util::{
23 future::Future,
24 ready,
25 stream::{Stream, StreamExt},
26 FutureExt,
27};
28use rand::{
29 self,
30 distributions::{Distribution, Standard},
31};
32use tracing::debug;
33
34use crate::{
35 error::{ProtoError, ProtoErrorKind},
36 op::{MessageFinalizer, MessageVerifier},
37 xfer::{
38 ignore_send, BufDnsStreamHandle, DnsClientStream, DnsRequest, DnsRequestSender,
39 DnsResponse, DnsResponseStream, SerialMessage, CHANNEL_BUFFER_SIZE,
40 },
41 DnsStreamHandle, Time,
42};
43
44const QOS_MAX_RECEIVE_MSGS: usize = 100; struct ActiveRequest {
47 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
49 request_id: u16,
50 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
51 verifier: Option<MessageVerifier>,
52}
53
54impl ActiveRequest {
55 fn new(
56 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
57 request_id: u16,
58 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
59 verifier: Option<MessageVerifier>,
60 ) -> Self {
61 Self {
62 completion,
63 request_id,
64 timeout,
66 verifier,
67 }
68 }
69
70 fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
72 self.timeout.poll_unpin(cx)
73 }
74
75 fn is_canceled(&self) -> bool {
77 self.completion.is_closed()
78 }
79
80 fn request_id(&self) -> u16 {
82 self.request_id
83 }
84
85 fn complete_with_error(mut self, error: ProtoError) {
87 ignore_send(self.completion.try_send(Err(error)));
88 }
89}
90
91#[must_use = "futures do nothing unless polled"]
97pub struct DnsMultiplexer<S, MF>
98where
99 S: DnsClientStream + 'static,
100 MF: MessageFinalizer,
101{
102 stream: S,
103 timeout_duration: Duration,
104 stream_handle: BufDnsStreamHandle,
105 active_requests: HashMap<u16, ActiveRequest>,
106 signer: Option<Arc<MF>>,
107 is_shutdown: bool,
108}
109
110impl<S, MF> DnsMultiplexer<S, MF>
111where
112 S: DnsClientStream + Unpin + 'static,
113 MF: MessageFinalizer,
114{
115 #[allow(clippy::new_ret_no_self)]
124 pub fn new<F>(
125 stream: F,
126 stream_handle: BufDnsStreamHandle,
127 signer: Option<Arc<MF>>,
128 ) -> DnsMultiplexerConnect<F, S, MF>
129 where
130 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
131 {
132 Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
133 }
134
135 pub fn with_timeout<F>(
146 stream: F,
147 stream_handle: BufDnsStreamHandle,
148 timeout_duration: Duration,
149 signer: Option<Arc<MF>>,
150 ) -> DnsMultiplexerConnect<F, S, MF>
151 where
152 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
153 {
154 DnsMultiplexerConnect {
155 stream,
156 stream_handle: Some(stream_handle),
157 timeout_duration,
158 signer,
159 }
160 }
161
162 fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
165 let mut canceled = HashMap::<u16, ProtoError>::new();
166 for (&id, ref mut active_req) in &mut self.active_requests {
167 if active_req.is_canceled() {
168 canceled.insert(id, ProtoError::from("requestor canceled"));
169 }
170
171 match active_req.poll_timeout(cx) {
173 Poll::Ready(()) => {
174 debug!("request timed out: {}", id);
175 canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
176 }
177 Poll::Pending => (),
178 }
179 }
180
181 for (id, error) in canceled {
183 if let Some(active_request) = self.active_requests.remove(&id) {
184 active_request.complete_with_error(error);
186 }
187 }
188 }
189
190 fn next_random_query_id(&self) -> Result<u16, ProtoError> {
192 let mut rand = rand::thread_rng();
193
194 for _ in 0..100 {
195 let id: u16 = Standard.sample(&mut rand); if !self.active_requests.contains_key(&id) {
198 return Ok(id);
199 }
200 }
201
202 Err(ProtoError::from(
203 "id space exhausted, consider filing an issue",
204 ))
205 }
206
207 fn stream_closed_close_all(&mut self, error: ProtoError) {
209 debug!(error = error.as_dyn(), stream = %self.stream);
210
211 for (_, active_request) in self.active_requests.drain() {
212 active_request.complete_with_error(error.clone());
214 }
215 }
216}
217
218#[must_use = "futures do nothing unless polled"]
220pub struct DnsMultiplexerConnect<F, S, MF>
221where
222 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
223 S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
224 MF: MessageFinalizer + Send + Sync + 'static,
225{
226 stream: F,
227 stream_handle: Option<BufDnsStreamHandle>,
228 timeout_duration: Duration,
229 signer: Option<Arc<MF>>,
230}
231
232impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
233where
234 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
235 S: DnsClientStream + Unpin + 'static,
236 MF: MessageFinalizer + Send + Sync + 'static,
237{
238 type Output = Result<DnsMultiplexer<S, MF>, ProtoError>;
239
240 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241 let stream: S = ready!(self.stream.poll_unpin(cx))?;
242
243 Poll::Ready(Ok(DnsMultiplexer {
244 stream,
245 timeout_duration: self.timeout_duration,
246 stream_handle: self
247 .stream_handle
248 .take()
249 .expect("must not poll after complete"),
250 active_requests: HashMap::new(),
251 signer: self.signer.clone(),
252 is_shutdown: false,
253 }))
254 }
255}
256
257impl<S, MF> Display for DnsMultiplexer<S, MF>
258where
259 S: DnsClientStream + 'static,
260 MF: MessageFinalizer + Send + Sync + 'static,
261{
262 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
263 write!(formatter, "{}", self.stream)
264 }
265}
266
267impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
268where
269 S: DnsClientStream + Unpin + 'static,
270 MF: MessageFinalizer + Send + Sync + 'static,
271{
272 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
273 if self.is_shutdown {
274 panic!("can not send messages after stream is shutdown")
275 }
276
277 if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
278 return ProtoError::from(ProtoErrorKind::Busy).into();
279 }
280
281 let query_id = match self.next_random_query_id() {
282 Ok(id) => id,
283 Err(e) => return e.into(),
284 };
285
286 let (mut request, _) = request.into_parts();
287 request.set_id(query_id);
288
289 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
290 Ok(now) => now.as_secs(),
291 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
292 };
293
294 let now = now as u32;
296
297 let mut verifier = None;
298 if let Some(ref signer) = self.signer {
299 if signer.should_finalize_message(&request) {
300 match request.finalize::<MF>(signer.borrow(), now) {
301 Ok(answer_verifier) => verifier = answer_verifier,
302 Err(e) => {
303 debug!("could not sign message: {}", e);
304 return e.into();
305 }
306 }
307 }
308 }
309
310 let timeout = S::Time::delay_for(self.timeout_duration);
312
313 let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
314
315 let active_request =
317 ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
318
319 match request.to_vec() {
320 Ok(buffer) => {
321 debug!(id = %active_request.request_id(), "sending message");
322 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
323
324 debug!(
325 "final message: {}",
326 serial_message
327 .to_message()
328 .expect("bizarre we just made this message")
329 );
330
331 match self.stream_handle.send(serial_message) {
334 Ok(()) => self
335 .active_requests
336 .insert(active_request.request_id(), active_request),
337 Err(err) => return err.into(),
338 };
339 }
340 Err(e) => {
341 debug!(
342 id = %active_request.request_id(),
343 error = e.as_dyn(),
344 "error message"
345 );
346 return e.into();
348 }
349 }
350
351 receiver.into()
352 }
353
354 fn shutdown(&mut self) {
355 self.is_shutdown = true;
356 }
357
358 fn is_shutdown(&self) -> bool {
359 self.is_shutdown
360 }
361}
362
363impl<S, MF> Stream for DnsMultiplexer<S, MF>
364where
365 S: DnsClientStream + Unpin + 'static,
366 MF: MessageFinalizer + Send + Sync + 'static,
367{
368 type Item = Result<(), ProtoError>;
369
370 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
371 self.drop_cancelled(cx);
373
374 if self.is_shutdown && self.active_requests.is_empty() {
375 debug!("stream is done: {}", self);
376 return Poll::Ready(None);
377 }
378
379 let mut messages_received = 0;
383 for i in 0..QOS_MAX_RECEIVE_MSGS {
384 match self.stream.poll_next_unpin(cx) {
385 Poll::Ready(Some(Ok(buffer))) => {
386 messages_received = i;
387
388 match buffer.to_message() {
390 Ok(message) => match self.active_requests.entry(message.id()) {
391 Entry::Occupied(mut request_entry) => {
392 let active_request = request_entry.get_mut();
394 if let Some(ref mut verifier) = active_request.verifier {
395 ignore_send(
396 active_request
397 .completion
398 .try_send(verifier(buffer.bytes())),
399 );
400 } else {
401 ignore_send(active_request.completion.try_send(Ok(
402 DnsResponse::new(message, buffer.into_parts().0),
403 )));
404 }
405 }
406 Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
407 },
408 Err(error) => debug!(error = error.as_dyn(), "error decoding message"),
410 }
411 }
412 Poll::Ready(err) => {
413 let err = match err {
414 Some(Err(e)) => e,
415 None => ProtoError::from("stream closed"),
416 _ => unreachable!(),
417 };
418
419 self.stream_closed_close_all(err);
420 self.is_shutdown = true;
421 return Poll::Ready(None);
422 }
423 Poll::Pending => break,
424 }
425 }
426
427 if messages_received == QOS_MAX_RECEIVE_MSGS {
431 cx.waker().wake_by_ref();
433 }
434
435 Poll::Pending
437 }
438}
439
440#[cfg(test)]
441mod test {
442 use super::*;
443 use crate::op::message::NoopMessageFinalizer;
444 use crate::op::op_code::OpCode;
445 use crate::op::{Message, MessageType, Query};
446 use crate::rr::record_type::RecordType;
447 use crate::rr::{DNSClass, Name, RData, Record};
448 use crate::serialize::binary::BinEncodable;
449 use crate::xfer::StreamReceiver;
450 use crate::xfer::{DnsClientStream, DnsRequestOptions};
451 use futures_util::future;
452 use futures_util::stream::TryStreamExt;
453 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
454
455 struct MockClientStream {
456 messages: Vec<Message>,
457 addr: SocketAddr,
458 id: Option<u16>,
459 receiver: Option<StreamReceiver>,
460 }
461
462 impl MockClientStream {
463 fn new(
464 mut messages: Vec<Message>,
465 addr: SocketAddr,
466 ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
467 messages.reverse(); Box::pin(future::ok(Self {
469 messages,
470 addr,
471 id: None,
472 receiver: None,
473 }))
474 }
475 }
476
477 impl fmt::Display for MockClientStream {
478 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
479 write!(formatter, "TestClientStream")
480 }
481 }
482
483 impl Stream for MockClientStream {
484 type Item = Result<SerialMessage, ProtoError>;
485
486 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
487 let id = if let Some(id) = self.id {
488 id
489 } else {
490 let serial = ready!(self
491 .receiver
492 .as_mut()
493 .expect("should only be polled after receiver has been set")
494 .poll_next_unpin(cx));
495 let message = serial.unwrap().to_message().unwrap();
496 self.id = Some(message.id());
497 message.id()
498 };
499
500 if let Some(mut message) = self.messages.pop() {
501 message.set_id(id);
502 Poll::Ready(Some(Ok(SerialMessage::new(
503 message.to_bytes().unwrap(),
504 self.addr,
505 ))))
506 } else {
507 Poll::Pending
508 }
509 }
510 }
511
512 impl DnsClientStream for MockClientStream {
513 type Time = crate::TokioTime;
514
515 fn name_server_addr(&self) -> SocketAddr {
516 self.addr
517 }
518 }
519
520 async fn get_mocked_multiplexer(
521 mock_response: Vec<Message>,
522 ) -> DnsMultiplexer<MockClientStream, NoopMessageFinalizer> {
523 let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
524 let mock_response = MockClientStream::new(mock_response, addr);
525 let (handler, receiver) = BufDnsStreamHandle::new(addr);
526 let mut multiplexer =
527 DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
528 .await
529 .unwrap();
530
531 multiplexer.stream.receiver = Some(receiver); multiplexer
534 }
535
536 fn a_query_answer() -> (DnsRequest, Vec<Message>) {
537 let name = Name::from_ascii("www.example.com").unwrap();
538
539 let mut msg = Message::new();
540 msg.add_query({
541 let mut query = Query::query(name.clone(), RecordType::A);
542 query.set_query_class(DNSClass::IN);
543 query
544 })
545 .set_message_type(MessageType::Query)
546 .set_op_code(OpCode::Query)
547 .set_recursion_desired(true);
548
549 let query = msg.clone();
550 msg.set_message_type(MessageType::Response).add_answer(
551 Record::new()
552 .set_name(name)
553 .set_ttl(86400)
554 .set_rr_type(RecordType::A)
555 .set_dns_class(DNSClass::IN)
556 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
557 .clone(),
558 );
559 (
560 DnsRequest::new(query, DnsRequestOptions::default()),
561 vec![msg],
562 )
563 }
564
565 fn axfr_query() -> Message {
566 let name = Name::from_ascii("example.com").unwrap();
567
568 let mut msg = Message::new();
569 msg.add_query({
570 let mut query = Query::query(name, RecordType::AXFR);
571 query.set_query_class(DNSClass::IN);
572 query
573 })
574 .set_message_type(MessageType::Query)
575 .set_op_code(OpCode::Query)
576 .set_recursion_desired(true);
577 msg
578 }
579
580 fn axfr_response() -> Vec<Record> {
581 use crate::rr::rdata::*;
582 let origin = Name::from_ascii("example.com").unwrap();
583 let soa = Record::new()
584 .set_name(origin.clone())
585 .set_ttl(3600)
586 .set_rr_type(RecordType::SOA)
587 .set_dns_class(DNSClass::IN)
588 .set_data(Some(RData::SOA(SOA::new(
589 Name::parse("sns.dns.icann.org.", None).unwrap(),
590 Name::parse("noc.dns.icann.org.", None).unwrap(),
591 2015082403,
592 7200,
593 3600,
594 1209600,
595 3600,
596 ))))
597 .clone();
598
599 vec![
600 soa.clone(),
601 Record::new()
602 .set_name(origin.clone())
603 .set_ttl(86400)
604 .set_rr_type(RecordType::NS)
605 .set_dns_class(DNSClass::IN)
606 .set_data(Some(RData::NS(NS(Name::parse(
607 "a.iana-servers.net.",
608 None,
609 )
610 .unwrap()))))
611 .clone(),
612 Record::new()
613 .set_name(origin.clone())
614 .set_ttl(86400)
615 .set_rr_type(RecordType::NS)
616 .set_dns_class(DNSClass::IN)
617 .set_data(Some(RData::NS(NS(Name::parse(
618 "b.iana-servers.net.",
619 None,
620 )
621 .unwrap()))))
622 .clone(),
623 Record::new()
624 .set_name(origin.clone())
625 .set_ttl(86400)
626 .set_rr_type(RecordType::A)
627 .set_dns_class(DNSClass::IN)
628 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
629 .clone(),
630 Record::new()
631 .set_name(origin)
632 .set_ttl(86400)
633 .set_rr_type(RecordType::AAAA)
634 .set_dns_class(DNSClass::IN)
635 .set_data(Some(RData::AAAA(
636 Ipv6Addr::new(
637 0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
638 )
639 .into(),
640 )))
641 .clone(),
642 soa,
643 ]
644 }
645
646 fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
647 let mut msg = axfr_query();
648
649 let query = msg.clone();
650 msg.set_message_type(MessageType::Response)
651 .insert_answers(axfr_response());
652 (
653 DnsRequest::new(query, DnsRequestOptions::default()),
654 vec![msg],
655 )
656 }
657
658 fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
659 let base = axfr_query();
660
661 let query = base.clone();
662 let mut rr = axfr_response();
663 let rr2 = rr.split_off(3);
664 let mut msg1 = base.clone();
665 msg1.set_message_type(MessageType::Response)
666 .insert_answers(rr);
667 let mut msg2 = base;
668 msg2.set_message_type(MessageType::Response)
669 .insert_answers(rr2);
670 (
671 DnsRequest::new(query, DnsRequestOptions::default()),
672 vec![msg1, msg2],
673 )
674 }
675
676 #[tokio::test]
677 async fn test_multiplexer_a() {
678 let (query, answer) = a_query_answer();
679 let mut multiplexer = get_mocked_multiplexer(answer).await;
680 let response = multiplexer.send_message(query);
681 let response = tokio::select! {
682 _ = multiplexer.next() => {
683 panic!("should never end")
685 },
686 r = response.try_collect::<Vec<_>>() => r.unwrap(),
687 };
688 assert_eq!(response.len(), 1);
689 }
690
691 #[tokio::test]
692 async fn test_multiplexer_axfr() {
693 let (query, answer) = axfr_query_answer();
694 let mut multiplexer = get_mocked_multiplexer(answer).await;
695 let response = multiplexer.send_message(query);
696 let response = tokio::select! {
697 _ = multiplexer.next() => {
698 panic!("should never end")
700 },
701 r = response.try_collect::<Vec<_>>() => r.unwrap(),
702 };
703 assert_eq!(response.len(), 1);
704 assert_eq!(response[0].answers().len(), axfr_response().len());
705 }
706
707 #[tokio::test]
708 async fn test_multiplexer_axfr_multi() {
709 let (query, answer) = axfr_query_answer_multi();
710 let mut multiplexer = get_mocked_multiplexer(answer).await;
711 let response = multiplexer.send_message(query);
712 let response = tokio::select! {
713 _ = multiplexer.next() => {
714 panic!("should never end")
716 },
717 r = response.try_collect::<Vec<_>>() => r.unwrap(),
718 };
719 assert_eq!(response.len(), 2);
720 assert_eq!(
721 response.iter().map(|m| m.answers().len()).sum::<usize>(),
722 axfr_response().len()
723 );
724 }
725}