hickory_proto/
runtime.rs

1//! Abstractions to deal with different async runtimes.
2
3use alloc::boxed::Box;
4#[cfg(feature = "__quic")]
5use alloc::sync::Arc;
6use core::future::Future;
7use core::marker::Send;
8use core::pin::Pin;
9use core::time::Duration;
10use std::io;
11use std::net::SocketAddr;
12
13use async_trait::async_trait;
14#[cfg(any(test, feature = "tokio"))]
15use tokio::runtime::Runtime;
16#[cfg(any(test, feature = "tokio"))]
17use tokio::task::JoinHandle;
18
19use crate::error::ProtoError;
20use crate::tcp::DnsTcpStream;
21use crate::udp::DnsUdpSocket;
22
23/// Spawn a background task, if it was present
24#[cfg(any(test, feature = "tokio"))]
25pub fn spawn_bg<F: Future<Output = R> + Send + 'static, R: Send + 'static>(
26    runtime: &Runtime,
27    background: F,
28) -> JoinHandle<R> {
29    runtime.spawn(background)
30}
31
32#[cfg(feature = "tokio")]
33#[doc(hidden)]
34pub mod iocompat {
35    use core::pin::Pin;
36    use core::task::{Context, Poll};
37    use std::io;
38
39    use futures_io::{AsyncRead, AsyncWrite};
40    use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite, ReadBuf};
41
42    /// Conversion from `tokio::io::{AsyncRead, AsyncWrite}` to `std::io::{AsyncRead, AsyncWrite}`
43    pub struct AsyncIoTokioAsStd<T: TokioAsyncRead + TokioAsyncWrite>(pub T);
44
45    impl<T: TokioAsyncRead + TokioAsyncWrite + Unpin> Unpin for AsyncIoTokioAsStd<T> {}
46    impl<R: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncRead for AsyncIoTokioAsStd<R> {
47        fn poll_read(
48            mut self: Pin<&mut Self>,
49            cx: &mut Context<'_>,
50            buf: &mut [u8],
51        ) -> Poll<io::Result<usize>> {
52            let mut buf = ReadBuf::new(buf);
53            let polled = Pin::new(&mut self.0).poll_read(cx, &mut buf);
54
55            polled.map_ok(|_| buf.filled().len())
56        }
57    }
58
59    impl<W: TokioAsyncRead + TokioAsyncWrite + Unpin> AsyncWrite for AsyncIoTokioAsStd<W> {
60        fn poll_write(
61            mut self: Pin<&mut Self>,
62            cx: &mut Context<'_>,
63            buf: &[u8],
64        ) -> Poll<io::Result<usize>> {
65            Pin::new(&mut self.0).poll_write(cx, buf)
66        }
67        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68            Pin::new(&mut self.0).poll_flush(cx)
69        }
70        fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71            Pin::new(&mut self.0).poll_shutdown(cx)
72        }
73    }
74
75    /// Conversion from `std::io::{AsyncRead, AsyncWrite}` to `tokio::io::{AsyncRead, AsyncWrite}`
76    pub struct AsyncIoStdAsTokio<T: AsyncRead + AsyncWrite>(pub T);
77
78    impl<T: AsyncRead + AsyncWrite + Unpin> Unpin for AsyncIoStdAsTokio<T> {}
79    impl<R: AsyncRead + AsyncWrite + Unpin> TokioAsyncRead for AsyncIoStdAsTokio<R> {
80        fn poll_read(
81            self: Pin<&mut Self>,
82            cx: &mut Context<'_>,
83            buf: &mut ReadBuf<'_>,
84        ) -> Poll<io::Result<()>> {
85            Pin::new(&mut self.get_mut().0)
86                .poll_read(cx, buf.initialized_mut())
87                .map_ok(|len| buf.advance(len))
88        }
89    }
90
91    impl<W: AsyncRead + AsyncWrite + Unpin> TokioAsyncWrite for AsyncIoStdAsTokio<W> {
92        fn poll_write(
93            self: Pin<&mut Self>,
94            cx: &mut Context<'_>,
95            buf: &[u8],
96        ) -> Poll<Result<usize, io::Error>> {
97            Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
98        }
99
100        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
101            Pin::new(&mut self.get_mut().0).poll_flush(cx)
102        }
103
104        fn poll_shutdown(
105            self: Pin<&mut Self>,
106            cx: &mut Context<'_>,
107        ) -> Poll<Result<(), io::Error>> {
108            Pin::new(&mut self.get_mut().0).poll_close(cx)
109        }
110    }
111}
112
113#[cfg(feature = "tokio")]
114#[allow(unreachable_pub)]
115mod tokio_runtime {
116    use alloc::sync::Arc;
117    use std::sync::Mutex;
118
119    use futures_util::FutureExt;
120    #[cfg(feature = "__quic")]
121    use quinn::Runtime;
122    use tokio::net::{TcpSocket, TcpStream, UdpSocket as TokioUdpSocket};
123    use tokio::task::JoinSet;
124    use tokio::time::timeout;
125
126    use super::iocompat::AsyncIoTokioAsStd;
127    use super::*;
128    use crate::xfer::CONNECT_TIMEOUT;
129
130    /// A handle to the Tokio runtime
131    #[derive(Clone, Default)]
132    pub struct TokioHandle {
133        join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
134    }
135
136    impl Spawn for TokioHandle {
137        fn spawn_bg<F>(&mut self, future: F)
138        where
139            F: Future<Output = Result<(), ProtoError>> + Send + 'static,
140        {
141            let mut join_set = self.join_set.lock().unwrap();
142            join_set.spawn(future);
143            reap_tasks(&mut join_set);
144        }
145    }
146
147    /// The Tokio Runtime for async execution
148    #[derive(Clone, Default)]
149    pub struct TokioRuntimeProvider(TokioHandle);
150
151    impl TokioRuntimeProvider {
152        /// Create a Tokio runtime
153        pub fn new() -> Self {
154            Self::default()
155        }
156    }
157
158    impl RuntimeProvider for TokioRuntimeProvider {
159        type Handle = TokioHandle;
160        type Timer = TokioTime;
161        type Udp = TokioUdpSocket;
162        type Tcp = AsyncIoTokioAsStd<TcpStream>;
163
164        fn create_handle(&self) -> Self::Handle {
165            self.0.clone()
166        }
167
168        fn connect_tcp(
169            &self,
170            server_addr: SocketAddr,
171            bind_addr: Option<SocketAddr>,
172            wait_for: Option<Duration>,
173        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
174            Box::pin(async move {
175                let socket = match server_addr {
176                    SocketAddr::V4(_) => TcpSocket::new_v4(),
177                    SocketAddr::V6(_) => TcpSocket::new_v6(),
178                }?;
179
180                if let Some(bind_addr) = bind_addr {
181                    socket.bind(bind_addr)?;
182                }
183
184                socket.set_nodelay(true)?;
185                let future = socket.connect(server_addr);
186                let wait_for = wait_for.unwrap_or(CONNECT_TIMEOUT);
187                match timeout(wait_for, future).await {
188                    Ok(Ok(socket)) => Ok(AsyncIoTokioAsStd(socket)),
189                    Ok(Err(e)) => Err(e),
190                    Err(_) => Err(io::Error::new(
191                        io::ErrorKind::TimedOut,
192                        format!("connection to {server_addr:?} timed out after {wait_for:?}"),
193                    )),
194                }
195            })
196        }
197
198        fn bind_udp(
199            &self,
200            local_addr: SocketAddr,
201            _server_addr: SocketAddr,
202        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
203            Box::pin(tokio::net::UdpSocket::bind(local_addr))
204        }
205
206        #[cfg(feature = "__quic")]
207        fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
208            Some(&TokioQuicSocketBinder)
209        }
210    }
211
212    /// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
213    fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
214        while FutureExt::now_or_never(join_set.join_next())
215            .flatten()
216            .is_some()
217        {}
218    }
219
220    #[cfg(feature = "__quic")]
221    struct TokioQuicSocketBinder;
222
223    #[cfg(feature = "__quic")]
224    impl QuicSocketBinder for TokioQuicSocketBinder {
225        fn bind_quic(
226            &self,
227            local_addr: SocketAddr,
228            _server_addr: SocketAddr,
229        ) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error> {
230            let socket = std::net::UdpSocket::bind(local_addr)?;
231            quinn::TokioRuntime.wrap_udp_socket(socket)
232        }
233    }
234}
235
236#[cfg(feature = "tokio")]
237pub use tokio_runtime::{TokioHandle, TokioRuntimeProvider};
238
239/// RuntimeProvider defines which async runtime that handles IO and timers.
240pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
241    /// Handle to the executor;
242    type Handle: Clone + Send + Spawn + Sync + Unpin;
243
244    /// Timer
245    type Timer: Time + Send + Unpin;
246
247    /// UdpSocket
248    type Udp: DnsUdpSocket + Send;
249
250    /// TcpStream
251    type Tcp: DnsTcpStream;
252
253    /// Create a runtime handle
254    fn create_handle(&self) -> Self::Handle;
255
256    /// Create a TCP connection with custom configuration.
257    fn connect_tcp(
258        &self,
259        server_addr: SocketAddr,
260        bind_addr: Option<SocketAddr>,
261        timeout: Option<Duration>,
262    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
263
264    /// Create a UDP socket bound to `local_addr`. The returned value should **not** be connected to `server_addr`.
265    /// *Notice: the future should be ready once returned at best effort. Otherwise UDP DNS may need much more retries.*
266    fn bind_udp(
267        &self,
268        local_addr: SocketAddr,
269        server_addr: SocketAddr,
270    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
271
272    /// Yields an object that knows how to bind a QUIC socket.
273    //
274    // Use some indirection here to avoid exposing the `quinn` crate in the public API
275    // even for runtimes that might not (want to) provide QUIC support.
276    fn quic_binder(&self) -> Option<&dyn QuicSocketBinder> {
277        None
278    }
279}
280
281/// Noop trait for when the `quinn` dependency is not available.
282#[cfg(not(feature = "__quic"))]
283pub trait QuicSocketBinder {}
284
285/// Create a UDP socket for QUIC usage.
286/// This trait is designed for customization.
287#[cfg(feature = "__quic")]
288pub trait QuicSocketBinder {
289    /// Create a UDP socket for QUIC usage.
290    fn bind_quic(
291        &self,
292        _local_addr: SocketAddr,
293        _server_addr: SocketAddr,
294    ) -> Result<Arc<dyn quinn::AsyncUdpSocket>, io::Error>;
295}
296
297/// A type defines the Handle which can spawn future.
298pub trait Spawn {
299    /// Spawn a future in the background
300    fn spawn_bg<F>(&mut self, future: F)
301    where
302        F: Future<Output = Result<(), ProtoError>> + Send + 'static;
303}
304
305/// Generic executor.
306// This trait is created to facilitate running the tests defined in the tests mod using different types of
307// executors. It's used in Fuchsia OS, please be mindful when update it.
308pub trait Executor {
309    /// Create the implementor itself.
310    fn new() -> Self;
311
312    /// Spawns a future object to run synchronously or asynchronously depending on the specific
313    /// executor.
314    fn block_on<F: Future>(&mut self, future: F) -> F::Output;
315}
316
317#[cfg(feature = "tokio")]
318impl Executor for Runtime {
319    fn new() -> Self {
320        Self::new().expect("failed to create tokio runtime")
321    }
322
323    fn block_on<F: Future>(&mut self, future: F) -> F::Output {
324        Self::block_on(self, future)
325    }
326}
327
328/// Generic Time for Delay and Timeout.
329// This trait is created to allow to use different types of time systems. It's used in Fuchsia OS, please be mindful when update it.
330#[async_trait]
331pub trait Time {
332    /// Return a type that implements `Future` that will wait until the specified duration has
333    /// elapsed.
334    async fn delay_for(duration: Duration);
335
336    /// Return a type that implement `Future` to complete before the specified duration has elapsed.
337    async fn timeout<F: 'static + Future + Send>(
338        duration: Duration,
339        future: F,
340    ) -> Result<F::Output, std::io::Error>;
341}
342
343/// New type which is implemented using tokio::time::{Delay, Timeout}
344#[cfg(any(test, feature = "tokio"))]
345#[derive(Clone, Copy, Debug)]
346pub struct TokioTime;
347
348#[cfg(any(test, feature = "tokio"))]
349#[async_trait]
350impl Time for TokioTime {
351    async fn delay_for(duration: Duration) {
352        tokio::time::sleep(duration).await
353    }
354
355    async fn timeout<F: 'static + Future + Send>(
356        duration: Duration,
357        future: F,
358    ) -> Result<F::Output, std::io::Error> {
359        tokio::time::timeout(duration, future)
360            .await
361            .map_err(move |_| std::io::Error::new(std::io::ErrorKind::TimedOut, "future timed out"))
362    }
363}