sqlx_core/rt/
mod.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7#[cfg(feature = "_rt-async-std")]
8pub mod rt_async_std;
9
10#[cfg(feature = "_rt-tokio")]
11pub mod rt_tokio;
12
13#[derive(Debug, thiserror::Error)]
14#[error("operation timed out")]
15pub struct TimeoutError(());
16
17pub enum JoinHandle<T> {
18    #[cfg(feature = "_rt-async-std")]
19    AsyncStd(async_std::task::JoinHandle<T>),
20    #[cfg(feature = "_rt-tokio")]
21    Tokio(tokio::task::JoinHandle<T>),
22    // `PhantomData<T>` requires `T: Unpin`
23    _Phantom(PhantomData<fn() -> T>),
24}
25
26pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
27    #[cfg(feature = "_rt-tokio")]
28    if rt_tokio::available() {
29        return tokio::time::timeout(duration, f)
30            .await
31            .map_err(|_| TimeoutError(()));
32    }
33
34    #[cfg(feature = "_rt-async-std")]
35    {
36        async_std::future::timeout(duration, f)
37            .await
38            .map_err(|_| TimeoutError(()))
39    }
40
41    #[cfg(not(feature = "_rt-async-std"))]
42    missing_rt((duration, f))
43}
44
45pub async fn sleep(duration: Duration) {
46    #[cfg(feature = "_rt-tokio")]
47    if rt_tokio::available() {
48        return tokio::time::sleep(duration).await;
49    }
50
51    #[cfg(feature = "_rt-async-std")]
52    {
53        async_std::task::sleep(duration).await
54    }
55
56    #[cfg(not(feature = "_rt-async-std"))]
57    missing_rt(duration)
58}
59
60#[track_caller]
61pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
62where
63    F: Future + Send + 'static,
64    F::Output: Send + 'static,
65{
66    #[cfg(feature = "_rt-tokio")]
67    if let Ok(handle) = tokio::runtime::Handle::try_current() {
68        return JoinHandle::Tokio(handle.spawn(fut));
69    }
70
71    #[cfg(feature = "_rt-async-std")]
72    {
73        JoinHandle::AsyncStd(async_std::task::spawn(fut))
74    }
75
76    #[cfg(not(feature = "_rt-async-std"))]
77    missing_rt(fut)
78}
79
80#[track_caller]
81pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
82where
83    F: FnOnce() -> R + Send + 'static,
84    R: Send + 'static,
85{
86    #[cfg(feature = "_rt-tokio")]
87    if let Ok(handle) = tokio::runtime::Handle::try_current() {
88        return JoinHandle::Tokio(handle.spawn_blocking(f));
89    }
90
91    #[cfg(feature = "_rt-async-std")]
92    {
93        JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
94    }
95
96    #[cfg(not(feature = "_rt-async-std"))]
97    missing_rt(f)
98}
99
100pub async fn yield_now() {
101    #[cfg(feature = "_rt-tokio")]
102    if rt_tokio::available() {
103        return tokio::task::yield_now().await;
104    }
105
106    #[cfg(feature = "_rt-async-std")]
107    {
108        async_std::task::yield_now().await;
109    }
110
111    #[cfg(not(feature = "_rt-async-std"))]
112    missing_rt(())
113}
114
115#[track_caller]
116pub fn test_block_on<F: Future>(f: F) -> F::Output {
117    #[cfg(feature = "_rt-tokio")]
118    {
119        return tokio::runtime::Builder::new_current_thread()
120            .enable_all()
121            .build()
122            .expect("failed to start Tokio runtime")
123            .block_on(f);
124    }
125
126    #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
127    {
128        async_std::task::block_on(f)
129    }
130
131    #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]
132    {
133        missing_rt(f)
134    }
135}
136
137#[track_caller]
138pub fn missing_rt<T>(_unused: T) -> ! {
139    if cfg!(feature = "_rt-tokio") {
140        panic!("this functionality requires a Tokio context")
141    }
142
143    panic!("either the `runtime-async-std` or `runtime-tokio` feature must be enabled")
144}
145
146impl<T: Send + 'static> Future for JoinHandle<T> {
147    type Output = T;
148
149    #[track_caller]
150    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
151        match &mut *self {
152            #[cfg(feature = "_rt-async-std")]
153            Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
154            #[cfg(feature = "_rt-tokio")]
155            Self::Tokio(handle) => Pin::new(handle)
156                .poll(cx)
157                .map(|res| res.expect("spawned task panicked")),
158            Self::_Phantom(_) => {
159                let _ = cx;
160                unreachable!("runtime should have been checked on spawn")
161            }
162        }
163    }
164}