1use once_cell::sync::{Lazy, OnceCell};
27use rand::Rng;
28use std::sync::Arc;
29use std::thread::JoinHandle;
30use std::time::Duration;
31use thread_local::ThreadLocal;
32use tokio::runtime::{Builder, Handle};
33use tokio::sync::oneshot::{channel, Sender};
34
35pub enum Runtime {
41 Steal(tokio::runtime::Runtime),
42 NoSteal(NoStealRuntime),
43}
44
45impl Runtime {
46 pub fn new_steal(threads: usize, name: &str) -> Self {
48 Self::Steal(
49 Builder::new_multi_thread()
50 .enable_all()
51 .worker_threads(threads)
52 .thread_name(name)
53 .build()
54 .unwrap(),
55 )
56 }
57
58 pub fn new_no_steal(threads: usize, name: &str) -> Self {
60 Self::NoSteal(NoStealRuntime::new(threads, name))
61 }
62
63 pub fn get_handle(&self) -> &Handle {
69 match self {
70 Self::Steal(r) => r.handle(),
71 Self::NoSteal(r) => r.get_runtime(),
72 }
73 }
74
75 pub fn shutdown_timeout(self, timeout: Duration) {
78 match self {
79 Self::Steal(r) => r.shutdown_timeout(timeout),
80 Self::NoSteal(r) => r.shutdown_timeout(timeout),
81 }
82 }
83}
84
85static CURRENT_HANDLE: Lazy<ThreadLocal<Pools>> = Lazy::new(ThreadLocal::new);
87
88pub fn current_handle() -> Handle {
93 if let Some(pools) = CURRENT_HANDLE.get() {
94 let pools = pools.get().unwrap();
96 let mut rng = rand::thread_rng();
97 let index = rng.gen_range(0..pools.len());
98 pools[index].clone()
99 } else {
100 Handle::current()
102 }
103}
104
105type Control = (Sender<Duration>, JoinHandle<()>);
106type Pools = Arc<OnceCell<Box<[Handle]>>>;
107
108pub struct NoStealRuntime {
110 threads: usize,
111 name: String,
112 pools: Arc<OnceCell<Box<[Handle]>>>,
115 controls: OnceCell<Vec<Control>>,
116}
117
118impl NoStealRuntime {
119 pub fn new(threads: usize, name: &str) -> Self {
121 assert!(threads != 0);
122 NoStealRuntime {
123 threads,
124 name: name.to_string(),
125 pools: Arc::new(OnceCell::new()),
126 controls: OnceCell::new(),
127 }
128 }
129
130 fn init_pools(&self) -> (Box<[Handle]>, Vec<Control>) {
131 let mut pools = Vec::with_capacity(self.threads);
132 let mut controls = Vec::with_capacity(self.threads);
133 for _ in 0..self.threads {
134 let rt = Builder::new_current_thread().enable_all().build().unwrap();
135 let handler = rt.handle().clone();
136 let (tx, rx) = channel::<Duration>();
137 let pools_ref = self.pools.clone();
138 let join = std::thread::Builder::new()
139 .name(self.name.clone())
140 .spawn(move || {
141 CURRENT_HANDLE.get_or(|| pools_ref);
142 if let Ok(timeout) = rt.block_on(rx) {
143 rt.shutdown_timeout(timeout);
144 } })
146 .unwrap();
147 pools.push(handler);
148 controls.push((tx, join));
149 }
150
151 (pools.into_boxed_slice(), controls)
152 }
153
154 pub fn get_runtime(&self) -> &Handle {
156 let mut rng = rand::thread_rng();
157
158 let index = rng.gen_range(0..self.threads);
159 self.get_runtime_at(index)
160 }
161
162 pub fn threads(&self) -> usize {
164 self.threads
165 }
166
167 fn get_pools(&self) -> &[Handle] {
168 if let Some(p) = self.pools.get() {
169 p
170 } else {
171 let (pools, controls) = self.init_pools();
173 match self.pools.try_insert(pools) {
175 Ok(p) => {
176 self.controls.set(controls).unwrap();
178 p
179 }
180 Err((p, _my_pools)) => p,
182 }
183 }
184 }
185
186 pub fn get_runtime_at(&self, index: usize) -> &Handle {
188 let pools = self.get_pools();
189 &pools[index]
190 }
191
192 pub fn shutdown_timeout(mut self, timeout: Duration) {
195 if let Some(controls) = self.controls.take() {
196 let (txs, joins): (Vec<Sender<_>>, Vec<JoinHandle<()>>) = controls.into_iter().unzip();
197 for tx in txs {
198 let _ = tx.send(timeout); }
200 for join in joins {
201 let _ = join.join(); }
203 } }
205
206 }
208
209#[test]
210fn test_steal_runtime() {
211 use tokio::time::{sleep, Duration};
212
213 let rt = Runtime::new_steal(2, "test");
214 let handle = rt.get_handle();
215 let ret = handle.block_on(async {
216 sleep(Duration::from_secs(1)).await;
217 let handle = current_handle();
218 let join = handle.spawn(async {
219 sleep(Duration::from_secs(1)).await;
220 });
221 join.await.unwrap();
222 1
223 });
224
225 assert_eq!(ret, 1);
226}
227
228#[test]
229fn test_no_steal_runtime() {
230 use tokio::time::{sleep, Duration};
231
232 let rt = Runtime::new_no_steal(2, "test");
233 let handle = rt.get_handle();
234 let ret = handle.block_on(async {
235 sleep(Duration::from_secs(1)).await;
236 let handle = current_handle();
237 let join = handle.spawn(async {
238 sleep(Duration::from_secs(1)).await;
239 });
240 join.await.unwrap();
241 1
242 });
243
244 assert_eq!(ret, 1);
245}
246
247#[test]
248fn test_no_steal_shutdown() {
249 use tokio::time::{sleep, Duration};
250
251 let rt = Runtime::new_no_steal(2, "test");
252 let handle = rt.get_handle();
253 let ret = handle.block_on(async {
254 sleep(Duration::from_secs(1)).await;
255 let handle = current_handle();
256 let join = handle.spawn(async {
257 sleep(Duration::from_secs(1)).await;
258 });
259 join.await.unwrap();
260 1
261 });
262 assert_eq!(ret, 1);
263
264 rt.shutdown_timeout(Duration::from_secs(1));
265}