use std::borrow::Borrow;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use futures::stream::Stream;
use futures::sync::oneshot;
use futures::{task, Async, Future, Poll};
use rand;
use rand::distributions::{Distribution, Standard};
use smallvec::SmallVec;
use tokio_timer::Delay;
use error::*;
use op::{Message, MessageFinalizer, OpCode};
use xfer::{
ignore_send, DnsClientStream, DnsRequest, DnsRequestOptions, DnsRequestSender, DnsResponse,
SerialMessage,
};
use DnsStreamHandle;
const QOS_MAX_RECEIVE_MSGS: usize = 100;
struct ActiveRequest {
completion: oneshot::Sender<Result<DnsResponse, ProtoError>>,
request_id: u16,
request_options: DnsRequestOptions,
responses: SmallVec<[Message; 1]>,
timeout: Delay,
}
impl ActiveRequest {
fn new(
completion: oneshot::Sender<Result<DnsResponse, ProtoError>>,
request_id: u16,
request_options: DnsRequestOptions,
timeout: Delay,
) -> Self {
ActiveRequest {
completion,
request_id,
request_options,
responses: SmallVec::new(),
timeout,
}
}
fn poll_timeout(&mut self) -> Poll<(), ProtoError> {
self.timeout.poll().map_err(ProtoError::from)
}
fn is_canceled(&self) -> bool {
self.completion.is_canceled()
}
fn add_response(&mut self, message: Message) {
self.responses.push(message);
}
fn request_id(&self) -> u16 {
self.request_id
}
fn request_options(&self) -> &DnsRequestOptions {
&self.request_options
}
fn complete_with_error(self, error: ProtoError) {
ignore_send(self.completion.send(Err(error)));
}
fn complete(self) {
if self.responses.is_empty() {
self.complete_with_error("no responses received, should have timedout".into());
} else {
ignore_send(self.completion.send(Ok(self.responses.into())));
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexer<S, MF, D = Box<DnsStreamHandle>>
where
D: Send + 'static,
S: DnsClientStream + 'static,
MF: MessageFinalizer,
{
stream: S,
timeout_duration: Duration,
stream_handle: D,
active_requests: HashMap<u16, ActiveRequest>,
signer: Option<Arc<MF>>,
is_shutdown: bool,
}
impl<S, MF> DnsMultiplexer<S, MF, Box<DnsStreamHandle>>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer,
{
pub fn new<F>(
stream: F,
stream_handle: Box<DnsStreamHandle>,
signer: Option<Arc<MF>>,
) -> DnsMultiplexerConnect<F, S, MF>
where
F: Future<Item = S, Error = ProtoError> + Send + 'static,
{
Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
}
pub fn with_timeout<F>(
stream: F,
stream_handle: Box<DnsStreamHandle>,
timeout_duration: Duration,
signer: Option<Arc<MF>>,
) -> DnsMultiplexerConnect<F, S, MF>
where
F: Future<Item = S, Error = ProtoError> + Send + 'static,
{
DnsMultiplexerConnect {
stream,
stream_handle: Some(stream_handle),
timeout_duration,
signer: signer,
}
}
fn drop_cancelled(&mut self) {
let mut canceled = HashMap::<u16, ProtoError>::new();
for (&id, ref mut active_req) in &mut self.active_requests {
if active_req.is_canceled() {
canceled.insert(id, ProtoError::from("requestor canceled"));
}
match active_req.poll_timeout() {
Ok(Async::Ready(_)) => {
debug!("request timed out: {}", id);
canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
}
Ok(Async::NotReady) => (),
Err(e) => {
error!("unexpected error from timeout: {}", e);
canceled.insert(id, ProtoError::from("error registering timeout"));
}
}
}
for (id, error) in canceled {
if let Some(active_request) = self.active_requests.remove(&id) {
if active_request.responses.is_empty() {
active_request.complete_with_error(error);
} else {
active_request.complete();
}
}
}
}
fn next_random_query_id(&self) -> Async<u16> {
let mut rand = rand::thread_rng();
for _ in 0..100 {
let id: u16 = Standard.sample(&mut rand);
if !self.active_requests.contains_key(&id) {
return Async::Ready(id);
}
}
warn!("could not get next random query id, delaying");
task::current().notify();
Async::NotReady
}
fn stream_closed_close_all(&mut self) {
if !self.active_requests.is_empty() {
warn!(
"stream closed before response received: {}",
self.stream.name_server_addr()
);
}
let error = ProtoError::from("stream closed before response received");
for (_, mut active_request) in self.active_requests.drain() {
if active_request.responses.is_empty() {
active_request.complete_with_error(error.clone());
} else {
active_request.complete();
}
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexerConnect<F, S, MF>
where
F: Future<Item = S, Error = ProtoError> + Send + 'static,
S: Stream<Item = SerialMessage, Error = ProtoError>,
MF: MessageFinalizer + Send + Sync + 'static,
{
stream: F,
stream_handle: Option<Box<DnsStreamHandle>>,
timeout_duration: Duration,
signer: Option<Arc<MF>>,
}
impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
where
F: Future<Item = S, Error = ProtoError> + Send + 'static,
S: DnsClientStream + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
type Item = DnsMultiplexer<S, MF, Box<DnsStreamHandle>>;
type Error = ProtoError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let stream: S = try_ready!(self.stream.poll());
Ok(Async::Ready(DnsMultiplexer {
stream,
timeout_duration: self.timeout_duration,
stream_handle: self
.stream_handle
.take()
.expect("must not poll after complete"),
active_requests: HashMap::new(),
signer: self.signer.clone(),
is_shutdown: false,
}))
}
}
impl<S, MF> Display for DnsMultiplexer<S, MF>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(formatter, "{}", self.stream)
}
}
impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
type DnsResponseFuture = DnsMultiplexerSerialResponse;
fn send_message(&mut self, request: DnsRequest) -> Self::DnsResponseFuture {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
let query_id: u16 = match self.next_random_query_id() {
Async::Ready(id) => id,
Async::NotReady => {
return DnsMultiplexerSerialResponseInner::Err(Some(ProtoError::from(
"id space exhausted, consider filing an issue",
)))
.into()
}
};
let (mut request, request_options) = request.unwrap();
request.set_id(query_id);
let now = match SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| ProtoErrorKind::Message("Current time is before the Unix epoch.").into())
{
Ok(now) => now.as_secs(),
Err(err) => return DnsMultiplexerSerialResponseInner::Err(Some(err)).into(),
};
let now = now as u32;
if let OpCode::Update = request.op_code() {
if let Some(ref signer) = self.signer {
if let Err(e) = request.finalize::<MF>(signer.borrow(), now) {
debug!("could not sign message: {}", e);
return DnsMultiplexerSerialResponseInner::Err(Some(e.into())).into();
}
}
}
let timeout = Delay::new(Instant::now() + self.timeout_duration);
let (complete, receiver) = oneshot::channel();
let active_request = ActiveRequest::new(complete, request.id(), request_options, timeout);
match request.to_vec() {
Ok(buffer) => {
debug!("sending message id: {}", active_request.request_id());
let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
match self.stream_handle.send(serial_message) {
Ok(()) => self
.active_requests
.insert(active_request.request_id(), active_request),
Err(err) => {
return DnsMultiplexerSerialResponseInner::Err(Some(err.into())).into()
}
};
}
Err(e) => {
debug!(
"error message id: {} error: {}",
active_request.request_id(),
e
);
return DnsMultiplexerSerialResponseInner::Err(Some(e)).into();
}
}
DnsMultiplexerSerialResponseInner::Completion(receiver).into()
}
fn error_response(error: ProtoError) -> Self::DnsResponseFuture {
DnsMultiplexerSerialResponseInner::Err(Some(error)).into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<S, MF> Stream for DnsMultiplexer<S, MF>
where
S: DnsClientStream + 'static,
MF: MessageFinalizer + Send + Sync + 'static,
{
type Item = ();
type Error = ProtoError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.drop_cancelled();
if self.is_shutdown && self.active_requests.is_empty() {
debug!("stream is done: {}", self);
return Ok(Async::Ready(None));
}
let mut messages_received = 0;
for i in 0..QOS_MAX_RECEIVE_MSGS {
match self.stream.poll()? {
Async::Ready(Some(buffer)) => {
messages_received = i;
match buffer.to_message() {
Ok(message) => match self.active_requests.entry(message.id()) {
Entry::Occupied(mut request_entry) => {
let complete = {
let mut active_request = request_entry.get_mut();
active_request.add_response(message);
!active_request.request_options().expects_multiple_responses
};
if complete {
let mut active_request = request_entry.remove();
active_request.complete();
}
}
Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
},
Err(e) => debug!("error decoding message: {}", e),
}
}
Async::Ready(None) => {
debug!("io_stream closed by other side: {}", self.stream);
self.stream_closed_close_all();
return Ok(Async::Ready(None));
}
Async::NotReady => break,
}
}
if messages_received == QOS_MAX_RECEIVE_MSGS {
task::current().notify();
}
Ok(Async::NotReady)
}
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexerSerialResponse(DnsMultiplexerSerialResponseInner);
impl DnsMultiplexerSerialResponse {
pub fn completion(complete: oneshot::Receiver<ProtoResult<DnsResponse>>) -> Self {
DnsMultiplexerSerialResponseInner::Completion(complete).into()
}
}
impl Future for DnsMultiplexerSerialResponse {
type Item = DnsResponse;
type Error = ProtoError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.0.poll()
}
}
impl From<DnsMultiplexerSerialResponseInner> for DnsMultiplexerSerialResponse {
fn from(inner: DnsMultiplexerSerialResponseInner) -> Self {
DnsMultiplexerSerialResponse(inner)
}
}
enum DnsMultiplexerSerialResponseInner {
Completion(oneshot::Receiver<ProtoResult<DnsResponse>>),
Err(Option<ProtoError>),
}
impl Future for DnsMultiplexerSerialResponseInner {
type Item = DnsResponse;
type Error = ProtoError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
DnsMultiplexerSerialResponseInner::Completion(complete) => match try_ready!(complete
.poll()
.map_err(|_| ProtoError::from("the completion was canceled")))
{
Ok(response) => Ok(Async::Ready(response)),
Err(err) => Err(err),
},
DnsMultiplexerSerialResponseInner::Err(err) => {
Err(err.take().expect("cannot poll after complete"))
}
}
}
}