use std::collections::HashSet;
use std::fmt::{self, Display};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures_util::{future::Future, stream::Stream};
use tracing::{debug, trace, warn};
use crate::error::ProtoError;
use crate::op::{Message, MessageFinalizer, MessageVerifier};
use crate::runtime::{RuntimeProvider, Time};
use crate::udp::udp_stream::NextRandomUdpSocket;
use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
pub struct UdpClientStreamBuilder<P> {
name_server: SocketAddr,
timeout: Option<Duration>,
signer: Option<Arc<dyn MessageFinalizer>>,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
}
impl<P> UdpClientStreamBuilder<P> {
pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}
pub fn with_signer(self, signer: Option<Arc<dyn MessageFinalizer>>) -> Self {
Self {
name_server: self.name_server,
timeout: self.timeout,
signer,
bind_addr: self.bind_addr,
avoid_local_ports: self.avoid_local_ports,
os_port_selection: self.os_port_selection,
provider: self.provider,
}
}
pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
self.bind_addr = bind_addr;
self
}
pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
self.avoid_local_ports = avoid_local_ports;
self
}
pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
self.os_port_selection = os_port_selection;
self
}
pub fn build(self) -> UdpClientConnect<P> {
UdpClientConnect {
name_server: self.name_server,
timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
signer: self.signer,
bind_addr: self.bind_addr,
avoid_local_ports: self.avoid_local_ports.clone(),
os_port_selection: self.os_port_selection,
provider: self.provider,
}
}
}
#[must_use = "futures do nothing unless polled"]
pub struct UdpClientStream<P> {
name_server: SocketAddr,
timeout: Duration,
is_shutdown: bool,
signer: Option<Arc<dyn MessageFinalizer>>,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
}
impl<P: RuntimeProvider> UdpClientStream<P> {
pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
UdpClientStreamBuilder {
name_server,
timeout: None,
signer: None,
bind_addr: None,
avoid_local_ports: Arc::default(),
os_port_selection: false,
provider,
}
}
}
impl<P> Display for UdpClientStream<P> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(formatter, "UDP({})", self.name_server)
}
}
fn random_query_id() -> u16 {
rand::random()
}
impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
request.set_id(random_query_id());
let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(now) => now.as_secs(),
Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
};
let now = now as u32;
let mut verifier = None;
if let Some(signer) = &self.signer {
if signer.should_finalize_message(&request) {
match request.finalize(&**signer, now) {
Ok(answer_verifier) => verifier = answer_verifier,
Err(e) => {
debug!("could not sign message: {}", e);
return e.into();
}
}
}
}
let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize);
let bytes = match request.to_vec() {
Ok(bytes) => bytes,
Err(err) => {
return err.into();
}
};
let message_id = request.id();
let message = SerialMessage::new(bytes, self.name_server);
debug!(
"final message: {}",
message
.to_message()
.expect("bizarre we just made this message")
);
let provider = self.provider.clone();
let addr = message.addr();
let bind_addr = self.bind_addr;
let avoid_local_ports = self.avoid_local_ports.clone();
let os_port_selection = self.os_port_selection;
P::Timer::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
self.timeout,
Box::pin(async move {
let socket = NextRandomUdpSocket::new(
addr,
bind_addr,
avoid_local_ports,
os_port_selection,
provider,
)
.await?;
send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
.await
}),
)
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl<P> Stream for UdpClientStream<P> {
type Item = Result<(), ProtoError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(())))
}
}
}
pub struct UdpClientConnect<P> {
name_server: SocketAddr,
timeout: Duration,
signer: Option<Arc<dyn MessageFinalizer>>,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
}
impl<P: RuntimeProvider> Future for UdpClientConnect<P> {
type Output = Result<UdpClientStream<P>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Ok(UdpClientStream {
name_server: self.name_server,
is_shutdown: false,
timeout: self.timeout,
signer: self.signer.take(),
bind_addr: self.bind_addr,
avoid_local_ports: self.avoid_local_ports.clone(),
os_port_selection: self.os_port_selection,
provider: self.provider.clone(),
}))
}
}
async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
msg: SerialMessage,
msg_id: u16,
verifier: Option<MessageVerifier>,
socket: S,
recv_buf_size: usize,
) -> Result<DnsResponse, ProtoError> {
let bytes = msg.bytes();
let addr = msg.addr();
let len_sent: usize = socket.send_to(bytes, addr).await?;
if bytes.len() != len_sent {
return Err(ProtoError::from(format!(
"Not all bytes of message sent, {} of {}",
len_sent,
bytes.len()
)));
}
trace!("creating UDP receive buffer with size {recv_buf_size}");
let mut recv_buf = vec![0; recv_buf_size];
loop {
let (len, src) = socket.recv_from(&mut recv_buf).await?;
let response_bytes = &recv_buf[0..len];
let response_buffer = Vec::from(response_bytes);
let request_target = msg.addr();
if src.ip() != request_target.ip() || src.port() != request_target.port() {
warn!(
"ignoring response from {} because it does not match name_server: {}.",
src, request_target,
);
continue;
}
let response = match DnsResponse::from_buffer(response_buffer) {
Ok(response) => response,
Err(e) => {
warn!("dropped malformed message waiting for id: {msg_id} err: {e}");
continue;
}
};
if msg_id != response.id() {
warn!(
"expected message id: {} got: {}, dropped",
msg_id,
response.id()
);
continue;
}
let request_message = Message::from_vec(msg.bytes())?;
let request_queries = request_message.queries();
let response_queries = response.queries();
if !response_queries
.iter()
.all(|elem| request_queries.contains(elem))
{
warn!("detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}");
continue;
}
debug!("received message id: {}", response.id());
if let Some(mut verifier) = verifier {
return verifier(response_bytes);
} else {
return Ok(response);
}
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use crate::{runtime::TokioRuntimeProvider, tests::udp_client_stream_test};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use test_support::subscribe;
#[tokio::test]
async fn test_udp_client_stream_ipv4() {
subscribe();
let provider = TokioRuntimeProvider::new();
udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
}
#[tokio::test]
async fn test_udp_client_stream_ipv6() {
subscribe();
let provider = TokioRuntimeProvider::new();
udp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
}
}