use std::io;
use std::marker::PhantomData;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::stream::{Fuse, Peekable, Stream, StreamExt};
use futures::{ready, Future, FutureExt, TryFutureExt};
use log::debug;
use rand;
use rand::distributions::{uniform::Uniform, Distribution};
use crate::xfer::{BufStreamHandle, SerialMessage};
#[async_trait]
pub trait UdpSocket
where
Self: Sized + Unpin,
{
async fn bind(addr: &SocketAddr) -> io::Result<Self>;
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)>;
async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize>;
}
#[must_use = "futures do nothing unless polled"]
pub struct UdpStream<S: Send> {
socket: S,
outbound_messages: Peekable<Fuse<UnboundedReceiver<SerialMessage>>>,
}
impl<S: UdpSocket + Send + 'static> UdpStream<S> {
#[allow(clippy::type_complexity)]
pub fn new(
name_server: SocketAddr,
) -> (
Box<dyn Future<Output = Result<UdpStream<S>, io::Error>> + Send + Unpin>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let next_socket = NextRandomUdpSocket::new(&name_server);
let stream = Box::new(next_socket.map_ok(move |socket| UdpStream {
socket,
outbound_messages: outbound_messages.fuse().peekable(),
}));
(stream, message_sender)
}
pub fn with_bound(socket: S) -> (Self, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let stream = UdpStream {
socket,
outbound_messages: outbound_messages.fuse().peekable(),
};
(stream, message_sender)
}
#[allow(unused)]
pub(crate) fn from_parts(
socket: S,
outbound_messages: UnboundedReceiver<SerialMessage>,
) -> Self {
UdpStream {
socket,
outbound_messages: outbound_messages.fuse().peekable(),
}
}
}
impl<S: Send> UdpStream<S> {
#[allow(clippy::type_complexity)]
fn pollable_split(
&mut self,
) -> (
&mut S,
&mut Peekable<Fuse<UnboundedReceiver<SerialMessage>>>,
) {
(&mut self.socket, &mut self.outbound_messages)
}
}
impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
type Item = Result<SerialMessage, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let (socket, outbound_messages) = self.pollable_split();
let mut socket = Pin::new(socket);
let mut outbound_messages = Pin::new(outbound_messages);
while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
let addr = &message.addr();
ready!(socket.send_to(message.bytes(), addr).poll_unpin(cx))?;
assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
}
let mut buf = [0u8; 2048];
let (len, src) = ready!(socket.recv_from(&mut buf).poll_unpin(cx))?;
Poll::Ready(Some(Ok(SerialMessage::new(
buf.iter().take(len).cloned().collect(),
src,
))))
}
}
#[must_use = "futures do nothing unless polled"]
pub(crate) struct NextRandomUdpSocket<S> {
bind_address: IpAddr,
marker: PhantomData<S>,
}
impl<S: UdpSocket> NextRandomUdpSocket<S> {
pub(crate) fn new(name_server: &SocketAddr) -> NextRandomUdpSocket<S> {
let zero_addr: IpAddr = match *name_server {
SocketAddr::V4(..) => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
SocketAddr::V6(..) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
};
NextRandomUdpSocket {
bind_address: zero_addr,
marker: PhantomData,
}
}
async fn bind(zero_addr: SocketAddr) -> Result<S, io::Error> {
S::bind(&zero_addr).await
}
}
impl<S: UdpSocket> Future for NextRandomUdpSocket<S> {
type Output = Result<S, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let rand_port_range = Uniform::new_inclusive(1025_u16, u16::max_value());
let mut rand = rand::thread_rng();
for attempt in 0..10 {
let port = rand_port_range.sample(&mut rand);
let zero_addr = SocketAddr::new(self.bind_address, port);
match Box::pin(Self::bind(zero_addr)).as_mut().poll(cx) {
Poll::Ready(Ok(socket)) => {
debug!("created socket successfully");
return Poll::Ready(Ok(socket));
}
Poll::Ready(Err(err)) => {
debug!("unable to bind port, attempt: {}: {}", attempt, err)
}
Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
}
}
debug!("could not get next random port, delaying");
cx.waker().wake_by_ref();
Poll::Pending
}
}
#[cfg(feature = "tokio-runtime")]
#[async_trait]
impl UdpSocket for tokio::net::UdpSocket {
async fn bind(addr: &SocketAddr) -> io::Result<Self> {
tokio::net::UdpSocket::bind(addr).await
}
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from(buf).await
}
async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
self.send_to(buf, target).await
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
#[cfg(not(target_os = "linux"))]
use std::net::Ipv6Addr;
use std::net::{IpAddr, Ipv4Addr};
use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
#[test]
fn test_next_random_socket() {
use crate::tests::next_random_socket_test;
let io_loop = Runtime::new().expect("failed to create tokio runtime");
next_random_socket_test::<TokioUdpSocket, Runtime>(io_loop)
}
#[test]
fn test_udp_stream_ipv4() {
use crate::tests::udp_stream_test;
let io_loop = Runtime::new().expect("failed to create tokio runtime");
udp_stream_test::<TokioUdpSocket, Runtime>(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
io_loop,
);
}
#[test]
#[cfg(not(target_os = "linux"))]
fn test_udp_stream_ipv6() {
use crate::tests::udp_stream_test;
let io_loop = Runtime::new().expect("failed to create tokio runtime");
udp_stream_test::<TokioUdpSocket, Runtime>(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
io_loop,
);
}
}