1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7#[cfg(feature = "_rt-async-std")]
8pub mod rt_async_std;
9
10#[cfg(feature = "_rt-tokio")]
11pub mod rt_tokio;
12
13#[derive(Debug, thiserror::Error)]
14#[error("operation timed out")]
15pub struct TimeoutError(());
16
17pub enum JoinHandle<T> {
18 #[cfg(feature = "_rt-async-std")]
19 AsyncStd(async_std::task::JoinHandle<T>),
20 #[cfg(feature = "_rt-tokio")]
21 Tokio(tokio::task::JoinHandle<T>),
22 _Phantom(PhantomData<fn() -> T>),
24}
25
26pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
27 #[cfg(feature = "_rt-tokio")]
28 if rt_tokio::available() {
29 return tokio::time::timeout(duration, f)
30 .await
31 .map_err(|_| TimeoutError(()));
32 }
33
34 #[cfg(feature = "_rt-async-std")]
35 {
36 async_std::future::timeout(duration, f)
37 .await
38 .map_err(|_| TimeoutError(()))
39 }
40
41 #[cfg(not(feature = "_rt-async-std"))]
42 missing_rt((duration, f))
43}
44
45pub async fn sleep(duration: Duration) {
46 #[cfg(feature = "_rt-tokio")]
47 if rt_tokio::available() {
48 return tokio::time::sleep(duration).await;
49 }
50
51 #[cfg(feature = "_rt-async-std")]
52 {
53 async_std::task::sleep(duration).await
54 }
55
56 #[cfg(not(feature = "_rt-async-std"))]
57 missing_rt(duration)
58}
59
60#[track_caller]
61pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
62where
63 F: Future + Send + 'static,
64 F::Output: Send + 'static,
65{
66 #[cfg(feature = "_rt-tokio")]
67 if let Ok(handle) = tokio::runtime::Handle::try_current() {
68 return JoinHandle::Tokio(handle.spawn(fut));
69 }
70
71 #[cfg(feature = "_rt-async-std")]
72 {
73 JoinHandle::AsyncStd(async_std::task::spawn(fut))
74 }
75
76 #[cfg(not(feature = "_rt-async-std"))]
77 missing_rt(fut)
78}
79
80#[track_caller]
81pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
82where
83 F: FnOnce() -> R + Send + 'static,
84 R: Send + 'static,
85{
86 #[cfg(feature = "_rt-tokio")]
87 if let Ok(handle) = tokio::runtime::Handle::try_current() {
88 return JoinHandle::Tokio(handle.spawn_blocking(f));
89 }
90
91 #[cfg(feature = "_rt-async-std")]
92 {
93 JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
94 }
95
96 #[cfg(not(feature = "_rt-async-std"))]
97 missing_rt(f)
98}
99
100pub async fn yield_now() {
101 #[cfg(feature = "_rt-tokio")]
102 if rt_tokio::available() {
103 return tokio::task::yield_now().await;
104 }
105
106 #[cfg(feature = "_rt-async-std")]
107 {
108 async_std::task::yield_now().await;
109 }
110
111 #[cfg(not(feature = "_rt-async-std"))]
112 missing_rt(())
113}
114
115#[track_caller]
116pub fn test_block_on<F: Future>(f: F) -> F::Output {
117 #[cfg(feature = "_rt-tokio")]
118 {
119 return tokio::runtime::Builder::new_current_thread()
120 .enable_all()
121 .build()
122 .expect("failed to start Tokio runtime")
123 .block_on(f);
124 }
125
126 #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
127 {
128 async_std::task::block_on(f)
129 }
130
131 #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]
132 {
133 missing_rt(f)
134 }
135}
136
137#[track_caller]
138pub fn missing_rt<T>(_unused: T) -> ! {
139 if cfg!(feature = "_rt-tokio") {
140 panic!("this functionality requires a Tokio context")
141 }
142
143 panic!("either the `runtime-async-std` or `runtime-tokio` feature must be enabled")
144}
145
146impl<T: Send + 'static> Future for JoinHandle<T> {
147 type Output = T;
148
149 #[track_caller]
150 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
151 match &mut *self {
152 #[cfg(feature = "_rt-async-std")]
153 Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
154 #[cfg(feature = "_rt-tokio")]
155 Self::Tokio(handle) => Pin::new(handle)
156 .poll(cx)
157 .map(|res| res.expect("spawned task panicked")),
158 Self::_Phantom(_) => {
159 let _ = cx;
160 unreachable!("runtime should have been checked on spawn")
161 }
162 }
163 }
164}