use std::future::Future;
use std::io;
use std::marker::Send;
use std::net::SocketAddr;
use std::pin::Pin;
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
#[cfg(any(test, feature = "tokio-runtime"))]
use tokio::runtime::Runtime;
#[cfg(any(test, feature = "tokio-runtime"))]
use tokio::task::JoinHandle;
use crate::error::ProtoError;
use crate::tcp::DnsTcpStream;
use crate::udp::DnsUdpSocket;
#[cfg(any(test, feature = "tokio-runtime"))]
pub fn spawn_bg<F: Future<Output = R> + Send + 'static, R: Send + 'static>(
runtime: &Runtime,
background: F,
) -> JoinHandle<R> {
runtime.spawn(background)
}
#[cfg(feature = "tokio-runtime")]
#[doc(hidden)]
pub mod iocompat {
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite, ReadBuf};
pub struct AsyncIoTokioAsStd<T: TokioAsyncRead + TokioAsyncWrite>(pub T);
impl<T: TokioAsyncRead + TokioAsyncWrite + Unpin> Unpin for AsyncIoTokioAsStd<T> {}
impl<R: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncRead for AsyncIoTokioAsStd<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
let polled = Pin::new(&mut self.0).poll_read(cx, &mut buf);
polled.map_ok(|_| buf.filled().len())
}
}
impl<W: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncWrite for AsyncIoTokioAsStd<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
pub struct AsyncIoStdAsTokio<T: AsyncRead + AsyncWrite>(pub T);
impl<T: AsyncRead + AsyncWrite + Unpin> Unpin for AsyncIoStdAsTokio<T> {}
impl<R: AsyncRead + AsyncWrite + Unpin> TokioAsyncRead for AsyncIoStdAsTokio<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().0)
.poll_read(cx, buf.initialized_mut())
.map_ok(|len| buf.advance(len))
}
}
impl<W: AsyncRead + AsyncWrite + Unpin> TokioAsyncWrite for AsyncIoStdAsTokio<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.get_mut().0).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.get_mut().0).poll_close(cx)
}
}
}
#[cfg(feature = "tokio-runtime")]
#[allow(unreachable_pub)]
mod tokio_runtime {
use super::iocompat::AsyncIoTokioAsStd;
use super::*;
use futures_util::FutureExt;
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
use quinn::Runtime;
use std::sync::{Arc, Mutex};
use tokio::net::{TcpSocket, TcpStream, UdpSocket as TokioUdpSocket};
use tokio::task::JoinSet;
use tokio::time::timeout;
#[derive(Clone, Default)]
pub struct TokioHandle {
join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
}
impl Spawn for TokioHandle {
fn spawn_bg<F>(&mut self, future: F)
where
F: Future<Output = Result<(), ProtoError>> + Send + 'static,
{
let mut join_set = self.join_set.lock().unwrap();
join_set.spawn(future);
reap_tasks(&mut join_set);
}
}
#[derive(Clone, Default)]
pub struct TokioRuntimeProvider(TokioHandle);
impl TokioRuntimeProvider {
pub fn new() -> Self {
Self::default()
}
}
impl RuntimeProvider for TokioRuntimeProvider {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = TokioUdpSocket;
type Tcp = AsyncIoTokioAsStd<TcpStream>;
fn create_handle(&self) -> Self::Handle {
self.0.clone()
}
fn connect_tcp(
&self,
server_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
wait_for: Option<Duration>,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
Box::pin(async move {
let socket = match server_addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}?;
if let Some(bind_addr) = bind_addr {
socket.bind(bind_addr)?;
}
socket.set_nodelay(true)?;
let future = socket.connect(server_addr);
let wait_for = wait_for.unwrap_or_else(|| Duration::from_secs(5));
match timeout(wait_for, future).await {
Ok(Ok(socket)) => Ok(AsyncIoTokioAsStd(socket)),
Ok(Err(e)) => Err(e),
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("connection to {server_addr:?} timed out after {wait_for:?}"),
)),
}
})
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
Box::pin(tokio::net::UdpSocket::bind(local_addr))
}
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
Some(&TokioQuicSocketBinder)
}
}
fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
while FutureExt::now_or_never(join_set.join_next())
.flatten()
.is_some()
{}
}
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
struct TokioQuicSocketBinder;
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
impl QuicSocketBinder for TokioQuicSocketBinder {
fn bind_quic(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error> {
let socket = std::net::UdpSocket::bind(local_addr)?;
quinn::TokioRuntime.wrap_udp_socket(socket)
}
}
}
#[cfg(feature = "tokio-runtime")]
pub use tokio_runtime::{TokioHandle, TokioRuntimeProvider};
pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
type Handle: Clone + Send + Spawn + Sync + Unpin;
type Timer: Time + Send + Unpin;
type Udp: DnsUdpSocket + Send;
type Tcp: DnsTcpStream;
fn create_handle(&self) -> Self::Handle;
fn connect_tcp(
&self,
server_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
timeout: Option<Duration>,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
fn bind_udp(
&self,
local_addr: SocketAddr,
server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
None
}
}
#[cfg(not(any(feature = "dns-over-quic", feature = "dns-over-h3")))]
pub trait QuicSocketBinder {}
#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
pub trait QuicSocketBinder {
fn bind_quic(
&self,
_local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error>;
}
pub trait Spawn {
fn spawn_bg<F>(&mut self, future: F)
where
F: Future<Output = Result<(), ProtoError>> + Send + 'static;
}
pub trait Executor {
fn new() -> Self;
fn block_on<F: Future>(&mut self, future: F) -> F::Output;
}
#[cfg(feature = "tokio-runtime")]
impl Executor for Runtime {
fn new() -> Self {
Self::new().expect("failed to create tokio runtime")
}
fn block_on<F: Future>(&mut self, future: F) -> F::Output {
Self::block_on(self, future)
}
}
#[async_trait]
pub trait Time {
async fn delay_for(duration: Duration);
async fn timeout<F: 'static + Future + Send>(
duration: Duration,
future: F,
) -> Result<F::Output, std::io::Error>;
}
#[cfg(any(test, feature = "tokio-runtime"))]
#[derive(Clone, Copy, Debug)]
pub struct TokioTime;
#[cfg(any(test, feature = "tokio-runtime"))]
#[async_trait]
impl Time for TokioTime {
async fn delay_for(duration: Duration) {
tokio::time::sleep(duration).await
}
async fn timeout<F: 'static + Future + Send>(
duration: Duration,
future: F,
) -> Result<F::Output, std::io::Error> {
tokio::time::timeout(duration, future)
.await
.map_err(move |_| std::io::Error::new(std::io::ErrorKind::TimedOut, "future timed out"))
}
}