1use alloc::boxed::Box;
4#[cfg(feature = "__quic")]
5use alloc::sync::Arc;
6use core::future::Future;
7use core::marker::Send;
8use core::pin::Pin;
9use core::time::Duration;
10use std::io;
11use std::net::SocketAddr;
12
13use async_trait::async_trait;
14#[cfg(any(test, feature = "tokio"))]
15use tokio::runtime::Runtime;
16#[cfg(any(test, feature = "tokio"))]
17use tokio::task::JoinHandle;
18
19use crate::error::ProtoError;
20use crate::tcp::DnsTcpStream;
21use crate::udp::DnsUdpSocket;
22
23#[cfg(any(test, feature = "tokio"))]
25pub fn spawn_bg<F: Future<Output = R> + Send + 'static, R: Send + 'static>(
26 runtime: &Runtime,
27 background: F,
28) -> JoinHandle<R> {
29 runtime.spawn(background)
30}
31
32#[cfg(feature = "tokio")]
33#[doc(hidden)]
34pub mod iocompat {
35 use core::pin::Pin;
36 use core::task::{Context, Poll};
37 use std::io;
38
39 use futures_io::{AsyncRead, AsyncWrite};
40 use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite, ReadBuf};
41
42 pub struct AsyncIoTokioAsStd<T: TokioAsyncRead + TokioAsyncWrite>(pub T);
44
45 impl<T: TokioAsyncRead + TokioAsyncWrite + Unpin> Unpin for AsyncIoTokioAsStd<T> {}
46 impl<R: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncRead for AsyncIoTokioAsStd<R> {
47 fn poll_read(
48 mut self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &mut [u8],
51 ) -> Poll<io::Result<usize>> {
52 let mut buf = ReadBuf::new(buf);
53 let polled = Pin::new(&mut self.0).poll_read(cx, &mut buf);
54
55 polled.map_ok(|_| buf.filled().len())
56 }
57 }
58
59 impl<W: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncWrite for AsyncIoTokioAsStd<W> {
60 fn poll_write(
61 mut self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<io::Result<usize>> {
65 Pin::new(&mut self.0).poll_write(cx, buf)
66 }
67 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68 Pin::new(&mut self.0).poll_flush(cx)
69 }
70 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71 Pin::new(&mut self.0).poll_shutdown(cx)
72 }
73 }
74
75 pub struct AsyncIoStdAsTokio<T: AsyncRead + AsyncWrite>(pub T);
77
78 impl<T: AsyncRead + AsyncWrite + Unpin> Unpin for AsyncIoStdAsTokio<T> {}
79 impl<R: AsyncRead + AsyncWrite + Unpin> TokioAsyncRead for AsyncIoStdAsTokio<R> {
80 fn poll_read(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 buf: &mut ReadBuf<'_>,
84 ) -> Poll<io::Result<()>> {
85 Pin::new(&mut self.get_mut().0)
86 .poll_read(cx, buf.initialized_mut())
87 .map_ok(|len| buf.advance(len))
88 }
89 }
90
91 impl<W: AsyncRead + AsyncWrite + Unpin> TokioAsyncWrite for AsyncIoStdAsTokio<W> {
92 fn poll_write(
93 self: Pin<&mut Self>,
94 cx: &mut Context<'_>,
95 buf: &[u8],
96 ) -> Poll<Result<usize, io::Error>> {
97 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
98 }
99
100 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
101 Pin::new(&mut self.get_mut().0).poll_flush(cx)
102 }
103
104 fn poll_shutdown(
105 self: Pin<&mut Self>,
106 cx: &mut Context<'_>,
107 ) -> Poll<Result<(), io::Error>> {
108 Pin::new(&mut self.get_mut().0).poll_close(cx)
109 }
110 }
111}
112
113#[cfg(feature = "tokio")]
114#[allow(unreachable_pub)]
115mod tokio_runtime {
116 use alloc::sync::Arc;
117 use std::sync::Mutex;
118
119 use futures_util::FutureExt;
120 #[cfg(feature = "__quic")]
121 use quinn::Runtime;
122 use tokio::net::{TcpSocket, TcpStream, UdpSocket as TokioUdpSocket};
123 use tokio::task::JoinSet;
124 use tokio::time::timeout;
125
126 use super::iocompat::AsyncIoTokioAsStd;
127 use super::*;
128 use crate::xfer::CONNECT_TIMEOUT;
129
130 #[derive(Clone, Default)]
132 pub struct TokioHandle {
133 join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
134 }
135
136 impl Spawn for TokioHandle {
137 fn spawn_bg<F>(&mut self, future: F)
138 where
139 F: Future<Output = Result<(), ProtoError>> + Send + 'static,
140 {
141 let mut join_set = self.join_set.lock().unwrap();
142 join_set.spawn(future);
143 reap_tasks(&mut join_set);
144 }
145 }
146
147 #[derive(Clone, Default)]
149 pub struct TokioRuntimeProvider(TokioHandle);
150
151 impl TokioRuntimeProvider {
152 pub fn new() -> Self {
154 Self::default()
155 }
156 }
157
158 impl RuntimeProvider for TokioRuntimeProvider {
159 type Handle = TokioHandle;
160 type Timer = TokioTime;
161 type Udp = TokioUdpSocket;
162 type Tcp = AsyncIoTokioAsStd<TcpStream>;
163
164 fn create_handle(&self) -> Self::Handle {
165 self.0.clone()
166 }
167
168 fn connect_tcp(
169 &self,
170 server_addr: SocketAddr,
171 bind_addr: Option<SocketAddr>,
172 wait_for: Option<Duration>,
173 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
174 Box::pin(async move {
175 let socket = match server_addr {
176 SocketAddr::V4(_) => TcpSocket::new_v4(),
177 SocketAddr::V6(_) => TcpSocket::new_v6(),
178 }?;
179
180 if let Some(bind_addr) = bind_addr {
181 socket.bind(bind_addr)?;
182 }
183
184 socket.set_nodelay(true)?;
185 let future = socket.connect(server_addr);
186 let wait_for = wait_for.unwrap_or(CONNECT_TIMEOUT);
187 match timeout(wait_for, future).await {
188 Ok(Ok(socket)) => Ok(AsyncIoTokioAsStd(socket)),
189 Ok(Err(e)) => Err(e),
190 Err(_) => Err(io::Error::new(
191 io::ErrorKind::TimedOut,
192 format!("connection to {server_addr:?} timed out after {wait_for:?}"),
193 )),
194 }
195 })
196 }
197
198 fn bind_udp(
199 &self,
200 local_addr: SocketAddr,
201 _server_addr: SocketAddr,
202 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
203 Box::pin(tokio::net::UdpSocket::bind(local_addr))
204 }
205
206 #[cfg(feature = "__quic")]
207 fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
208 Some(&TokioQuicSocketBinder)
209 }
210 }
211
212 fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
214 while FutureExt::now_or_never(join_set.join_next())
215 .flatten()
216 .is_some()
217 {}
218 }
219
220 #[cfg(feature = "__quic")]
221 struct TokioQuicSocketBinder;
222
223 #[cfg(feature = "__quic")]
224 impl QuicSocketBinder for TokioQuicSocketBinder {
225 fn bind_quic(
226 &self,
227 local_addr: SocketAddr,
228 _server_addr: SocketAddr,
229 ) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error> {
230 let socket = std::net::UdpSocket::bind(local_addr)?;
231 quinn::TokioRuntime.wrap_udp_socket(socket)
232 }
233 }
234}
235
236#[cfg(feature = "tokio")]
237pub use tokio_runtime::{TokioHandle, TokioRuntimeProvider};
238
239pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
241 type Handle: Clone + Send + Spawn + Sync + Unpin;
243
244 type Timer: Time + Send + Unpin;
246
247 type Udp: DnsUdpSocket + Send;
249
250 type Tcp: DnsTcpStream;
252
253 fn create_handle(&self) -> Self::Handle;
255
256 fn connect_tcp(
258 &self,
259 server_addr: SocketAddr,
260 bind_addr: Option<SocketAddr>,
261 timeout: Option<Duration>,
262 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
263
264 fn bind_udp(
267 &self,
268 local_addr: SocketAddr,
269 server_addr: SocketAddr,
270 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
271
272 fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
277 None
278 }
279}
280
281#[cfg(not(feature = "__quic"))]
283pub trait QuicSocketBinder {}
284
285#[cfg(feature = "__quic")]
288pub trait QuicSocketBinder {
289 fn bind_quic(
291 &self,
292 _local_addr: SocketAddr,
293 _server_addr: SocketAddr,
294 ) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error>;
295}
296
297pub trait Spawn {
299 fn spawn_bg<F>(&mut self, future: F)
301 where
302 F: Future<Output = Result<(), ProtoError>> + Send + 'static;
303}
304
305pub trait Executor {
309 fn new() -> Self;
311
312 fn block_on<F: Future>(&mut self, future: F) -> F::Output;
315}
316
317#[cfg(feature = "tokio")]
318impl Executor for Runtime {
319 fn new() -> Self {
320 Self::new().expect("failed to create tokio runtime")
321 }
322
323 fn block_on<F: Future>(&mut self, future: F) -> F::Output {
324 Self::block_on(self, future)
325 }
326}
327
328#[async_trait]
331pub trait Time {
332 async fn delay_for(duration: Duration);
335
336 async fn timeout<F: 'static + Future + Send>(
338 duration: Duration,
339 future: F,
340 ) -> Result<F::Output, std::io::Error>;
341}
342
343#[cfg(any(test, feature = "tokio"))]
345#[derive(Clone, Copy, Debug)]
346pub struct TokioTime;
347
348#[cfg(any(test, feature = "tokio"))]
349#[async_trait]
350impl Time for TokioTime {
351 async fn delay_for(duration: Duration) {
352 tokio::time::sleep(duration).await
353 }
354
355 async fn timeout<F: 'static + Future + Send>(
356 duration: Duration,
357 future: F,
358 ) -> Result<F::Output, std::io::Error> {
359 tokio::time::timeout(duration, future)
360 .await
361 .map_err(move |_| std::io::Error::new(std::io::ErrorKind::TimedOut, "future timed out"))
362 }
363}