use std::{
borrow::Borrow,
collections::{hash_map::Entry, HashMap},
fmt::{self, Display},
marker::Unpin,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use futures_channel::mpsc;
use futures_util::{
future::Future,
ready,
stream::{Stream, StreamExt},
FutureExt,
};
use rand::{
self,
distributions::{Distribution, Standard},
};
use tracing::debug;
use crate::{
error::{ProtoError, ProtoErrorKind},
op::{MessageFinalizer, MessageVerifier},
xfer::{
ignore_send, BufDnsStreamHandle, DnsClientStream, DnsRequest, DnsRequestSender,
DnsResponse, DnsResponseStream, SerialMessage, CHANNEL_BUFFER_SIZE,
},
DnsStreamHandle, Time,
};
const QOS_MAX_RECEIVE_MSGS: usize = 100; struct ActiveRequest {
completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
request_id: u16,
timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
verifier: Option<MessageVerifier>,
}
impl ActiveRequest {
fn new(
completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
request_id: u16,
timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
verifier: Option<MessageVerifier>,
) -> Self {
Self {
completion,
request_id,
timeout,
verifier,
}
}
fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.timeout.poll_unpin(cx)
}
fn is_canceled(&self) -> bool {
self.completion.is_closed()
}
fn request_id(&self) -> u16 {
self.request_id
}
fn complete_with_error(mut self, error: ProtoError) {
ignore_send(self.completion.try_send(Err(error)));
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexer<S, MF>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer,
{
stream: S,
timeout_duration: Duration,
stream_handle: BufDnsStreamHandle,
active_requests: HashMap<u16, ActiveRequest>,
signer: Option<Arc<MF>>,
is_shutdown: bool,
}
impl<S, MF> DnsMultiplexer<S, MF>
where
S: DnsClientStream + Unpin + 'static,
MF: MessageFinalizer,
{
#[allow(clippy::new_ret_no_self)]
pub fn new<F>(
stream: F,
stream_handle: BufDnsStreamHandle,
signer: Option<Arc<MF>>,
) -> DnsMultiplexerConnect<F, S, MF>
where
F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
{
Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
}
pub fn with_timeout<F>(
stream: F,
stream_handle: BufDnsStreamHandle,
timeout_duration: Duration,
signer: Option<Arc<MF>>,
) -> DnsMultiplexerConnect<F, S, MF>
where
F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
{
DnsMultiplexerConnect {
stream,
stream_handle: Some(stream_handle),
timeout_duration,
signer,
}
}
fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
let mut canceled = HashMap::<u16, ProtoError>::new();
for (&id, ref mut active_req) in &mut self.active_requests {
if active_req.is_canceled() {
canceled.insert(id, ProtoError::from("requestor canceled"));
}
match active_req.poll_timeout(cx) {
Poll::Ready(()) => {
debug!("request timed out: {}", id);
canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
}
Poll::Pending => (),
}
}
for (id, error) in canceled {
if let Some(active_request) = self.active_requests.remove(&id) {
active_request.complete_with_error(error);
}
}
}
fn next_random_query_id(&self) -> Result<u16, ProtoError> {
let mut rand = rand::thread_rng();
for _ in 0..100 {
let id: u16 = Standard.sample(&mut rand); if !self.active_requests.contains_key(&id) {
return Ok(id);
}
}
Err(ProtoError::from(
"id space exhausted, consider filing an issue",
))
}
fn stream_closed_close_all(&mut self, error: ProtoError) {
debug!(error = error.as_dyn(), stream = %self.stream);
for (_, active_request) in self.active_requests.drain() {
active_request.complete_with_error(error.clone());
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexerConnect<F, S, MF>
where
F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
MF: MessageFinalizer + Send + Sync + 'static,
{
stream: F,
stream_handle: Option<BufDnsStreamHandle>,
timeout_duration: Duration,
signer: Option<Arc<MF>>,
}
impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
where
F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
S: DnsClientStream + Unpin + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
type Output = Result<DnsMultiplexer<S, MF>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let stream: S = ready!(self.stream.poll_unpin(cx))?;
Poll::Ready(Ok(DnsMultiplexer {
stream,
timeout_duration: self.timeout_duration,
stream_handle: self
.stream_handle
.take()
.expect("must not poll after complete"),
active_requests: HashMap::new(),
signer: self.signer.clone(),
is_shutdown: false,
}))
}
}
impl<S, MF> Display for DnsMultiplexer<S, MF>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(formatter, "{}", self.stream)
}
}
impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
where
S: DnsClientStream + Unpin + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
return ProtoError::from(ProtoErrorKind::Busy).into();
}
let query_id = match self.next_random_query_id() {
Ok(id) => id,
Err(e) => return e.into(),
};
let (mut request, _) = request.into_parts();
request.set_id(query_id);
let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(now) => now.as_secs(),
Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
};
let now = now as u32;
let mut verifier = None;
if let Some(ref signer) = self.signer {
if signer.should_finalize_message(&request) {
match request.finalize::<MF>(signer.borrow(), now) {
Ok(answer_verifier) => verifier = answer_verifier,
Err(e) => {
debug!("could not sign message: {}", e);
return e.into();
}
}
}
}
let timeout = S::Time::delay_for(self.timeout_duration);
let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let active_request =
ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
match request.to_vec() {
Ok(buffer) => {
debug!(id = %active_request.request_id(), "sending message");
let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
debug!(
"final message: {}",
serial_message
.to_message()
.expect("bizarre we just made this message")
);
match self.stream_handle.send(serial_message) {
Ok(()) => self
.active_requests
.insert(active_request.request_id(), active_request),
Err(err) => return err.into(),
};
}
Err(e) => {
debug!(
id = %active_request.request_id(),
error = e.as_dyn(),
"error message"
);
return e.into();
}
}
receiver.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<S, MF> Stream for DnsMultiplexer<S, MF>
where
S: DnsClientStream + Unpin + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
type Item = Result<(), ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.drop_cancelled(cx);
if self.is_shutdown && self.active_requests.is_empty() {
debug!("stream is done: {}", self);
return Poll::Ready(None);
}
let mut messages_received = 0;
for i in 0..QOS_MAX_RECEIVE_MSGS {
match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(buffer))) => {
messages_received = i;
match buffer.to_message() {
Ok(message) => match self.active_requests.entry(message.id()) {
Entry::Occupied(mut request_entry) => {
let active_request = request_entry.get_mut();
if let Some(ref mut verifier) = active_request.verifier {
ignore_send(
active_request
.completion
.try_send(verifier(buffer.bytes())),
);
} else {
ignore_send(active_request.completion.try_send(Ok(
DnsResponse::new(message, buffer.into_parts().0),
)));
}
}
Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
},
Err(error) => debug!(error = error.as_dyn(), "error decoding message"),
}
}
Poll::Ready(err) => {
let err = match err {
Some(Err(e)) => e,
None => ProtoError::from("stream closed"),
_ => unreachable!(),
};
self.stream_closed_close_all(err);
self.is_shutdown = true;
return Poll::Ready(None);
}
Poll::Pending => break,
}
}
if messages_received == QOS_MAX_RECEIVE_MSGS {
cx.waker().wake_by_ref();
}
Poll::Pending
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::op::message::NoopMessageFinalizer;
use crate::op::op_code::OpCode;
use crate::op::{Message, MessageType, Query};
use crate::rr::record_type::RecordType;
use crate::rr::{DNSClass, Name, RData, Record};
use crate::serialize::binary::BinEncodable;
use crate::xfer::StreamReceiver;
use crate::xfer::{DnsClientStream, DnsRequestOptions};
use futures_util::future;
use futures_util::stream::TryStreamExt;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
struct MockClientStream {
messages: Vec<Message>,
addr: SocketAddr,
id: Option<u16>,
receiver: Option<StreamReceiver>,
}
impl MockClientStream {
fn new(
mut messages: Vec<Message>,
addr: SocketAddr,
) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
messages.reverse(); Box::pin(future::ok(Self {
messages,
addr,
id: None,
receiver: None,
}))
}
}
impl fmt::Display for MockClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(formatter, "TestClientStream")
}
}
impl Stream for MockClientStream {
type Item = Result<SerialMessage, ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let id = if let Some(id) = self.id {
id
} else {
let serial = ready!(self
.receiver
.as_mut()
.expect("should only be polled after receiver has been set")
.poll_next_unpin(cx));
let message = serial.unwrap().to_message().unwrap();
self.id = Some(message.id());
message.id()
};
if let Some(mut message) = self.messages.pop() {
message.set_id(id);
Poll::Ready(Some(Ok(SerialMessage::new(
message.to_bytes().unwrap(),
self.addr,
))))
} else {
Poll::Pending
}
}
}
impl DnsClientStream for MockClientStream {
type Time = crate::TokioTime;
fn name_server_addr(&self) -> SocketAddr {
self.addr
}
}
async fn get_mocked_multiplexer(
mock_response: Vec<Message>,
) -> DnsMultiplexer<MockClientStream, NoopMessageFinalizer> {
let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let mock_response = MockClientStream::new(mock_response, addr);
let (handler, receiver) = BufDnsStreamHandle::new(addr);
let mut multiplexer =
DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
.await
.unwrap();
multiplexer.stream.receiver = Some(receiver); multiplexer
}
fn a_query_answer() -> (DnsRequest, Vec<Message>) {
let name = Name::from_ascii("www.example.com").unwrap();
let mut msg = Message::new();
msg.add_query({
let mut query = Query::query(name.clone(), RecordType::A);
query.set_query_class(DNSClass::IN);
query
})
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Query)
.set_recursion_desired(true);
let query = msg.clone();
msg.set_message_type(MessageType::Response).add_answer(
Record::from_rdata(
name,
86400,
RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
)
.set_dns_class(DNSClass::IN)
.clone(),
);
(
DnsRequest::new(query, DnsRequestOptions::default()),
vec![msg],
)
}
fn axfr_query() -> Message {
let name = Name::from_ascii("example.com").unwrap();
let mut msg = Message::new();
msg.add_query({
let mut query = Query::query(name, RecordType::AXFR);
query.set_query_class(DNSClass::IN);
query
})
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Query)
.set_recursion_desired(true);
msg
}
fn axfr_response() -> Vec<Record> {
use crate::rr::rdata::*;
let origin = Name::from_ascii("example.com").unwrap();
let soa = Record::from_rdata(
origin.clone(),
3600,
RData::SOA(SOA::new(
Name::parse("sns.dns.icann.org.", None).unwrap(),
Name::parse("noc.dns.icann.org.", None).unwrap(),
2015082403,
7200,
3600,
1209600,
3600,
)),
)
.set_dns_class(DNSClass::IN)
.clone();
vec![
soa.clone(),
Record::from_rdata(
origin.clone(),
86400,
RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
)
.set_dns_class(DNSClass::IN)
.clone(),
Record::from_rdata(
origin.clone(),
86400,
RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
)
.set_dns_class(DNSClass::IN)
.clone(),
Record::from_rdata(
origin.clone(),
86400,
RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
)
.set_dns_class(DNSClass::IN)
.clone(),
Record::from_rdata(
origin,
86400,
RData::AAAA(
Ipv6Addr::new(
0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
)
.into(),
),
)
.set_dns_class(DNSClass::IN)
.clone(),
soa,
]
}
fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
let mut msg = axfr_query();
let query = msg.clone();
msg.set_message_type(MessageType::Response)
.insert_answers(axfr_response());
(
DnsRequest::new(query, DnsRequestOptions::default()),
vec![msg],
)
}
fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
let base = axfr_query();
let query = base.clone();
let mut rr = axfr_response();
let rr2 = rr.split_off(3);
let mut msg1 = base.clone();
msg1.set_message_type(MessageType::Response)
.insert_answers(rr);
let mut msg2 = base;
msg2.set_message_type(MessageType::Response)
.insert_answers(rr2);
(
DnsRequest::new(query, DnsRequestOptions::default()),
vec![msg1, msg2],
)
}
#[tokio::test]
async fn test_multiplexer_a() {
let (query, answer) = a_query_answer();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 1);
}
#[tokio::test]
async fn test_multiplexer_axfr() {
let (query, answer) = axfr_query_answer();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 1);
assert_eq!(response[0].answers().len(), axfr_response().len());
}
#[tokio::test]
async fn test_multiplexer_axfr_multi() {
let (query, answer) = axfr_query_answer_multi();
let mut multiplexer = get_mocked_multiplexer(answer).await;
let response = multiplexer.send_message(query);
let response = tokio::select! {
_ = multiplexer.next() => {
panic!("should never end")
},
r = response.try_collect::<Vec<_>>() => r.unwrap(),
};
assert_eq!(response.len(), 2);
assert_eq!(
response.iter().map(|m| m.answers().len()).sum::<usize>(),
axfr_response().len()
);
}
}