1use std::{
2 any::Any,
3 cell::RefCell,
4 collections::VecDeque,
5 future::{Future, poll_fn, ready},
6 io,
7 marker::PhantomData,
8 panic::AssertUnwindSafe,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18 AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod send_wrapper;
29use send_wrapper::SendWrapper;
30
31#[cfg(feature = "time")]
32use crate::runtime::time::{TimerFuture, TimerRuntime};
33use crate::{BufResult, runtime::op::OpFuture};
34
35scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
36
37pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
40
41struct RunnableQueue {
42 local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
43 sync_runnables: SegQueue<Runnable>,
44}
45
46impl RunnableQueue {
47 pub fn new() -> Self {
48 Self {
49 local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
50 sync_runnables: SegQueue::new(),
51 }
52 }
53
54 pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
55 if let Some(runnables) = self.local_runnables.get() {
56 runnables.borrow_mut().push_back(runnable);
57 } else {
58 self.sync_runnables.push(runnable);
59 handle.notify().ok();
60 }
61 }
62
63 pub unsafe fn run(&self, event_interval: usize) -> bool {
65 let local_runnables = self.local_runnables.get_unchecked();
66 for _i in 0..event_interval {
67 let next_task = local_runnables.borrow_mut().pop_front();
68 let has_local_task = next_task.is_some();
69 if let Some(task) = next_task {
70 task.run();
71 }
72 let has_sync_task = !self.sync_runnables.is_empty();
74 if has_sync_task {
75 if let Some(task) = self.sync_runnables.pop() {
76 task.run();
77 }
78 } else if !has_local_task {
79 break;
80 }
81 }
82 !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
83 }
84}
85
86pub struct Runtime {
89 driver: RefCell<Proactor>,
90 runnables: Arc<RunnableQueue>,
91 #[cfg(feature = "time")]
92 timer_runtime: RefCell<TimerRuntime>,
93 event_interval: usize,
94 _p: PhantomData<Rc<VecDeque<Runnable>>>,
97}
98
99impl Runtime {
100 pub fn new() -> io::Result<Self> {
102 Self::builder().build()
103 }
104
105 pub fn builder() -> RuntimeBuilder {
107 RuntimeBuilder::new()
108 }
109
110 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
111 Ok(Self {
112 driver: RefCell::new(builder.proactor_builder.build()?),
113 runnables: Arc::new(RunnableQueue::new()),
114 #[cfg(feature = "time")]
115 timer_runtime: RefCell::new(TimerRuntime::new()),
116 event_interval: builder.event_interval,
117 _p: PhantomData,
118 })
119 }
120
121 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
124 if CURRENT_RUNTIME.is_set() {
125 Ok(CURRENT_RUNTIME.with(f))
126 } else {
127 Err(f)
128 }
129 }
130
131 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
137 #[cold]
138 fn not_in_compio_runtime() -> ! {
139 panic!("not in a compio runtime")
140 }
141
142 if CURRENT_RUNTIME.is_set() {
143 CURRENT_RUNTIME.with(f)
144 } else {
145 not_in_compio_runtime()
146 }
147 }
148
149 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
152 CURRENT_RUNTIME.set(self, f)
153 }
154
155 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
161 let runnables = self.runnables.clone();
162 let handle = self
163 .driver
164 .borrow()
165 .handle()
166 .expect("cannot create notify handle of the proactor");
167 let schedule = move |runnable| {
168 runnables.schedule(runnable, &handle);
169 };
170 let (runnable, task) = async_task::spawn_unchecked(future, schedule);
171 runnable.schedule();
172 task
173 }
174
175 pub fn run(&self) -> bool {
181 unsafe { self.runnables.run(self.event_interval) }
183 }
184
185 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
187 CURRENT_RUNTIME.set(self, || {
188 let mut result = None;
189 unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
190 loop {
191 let remaining_tasks = self.run();
192 if let Some(result) = result.take() {
193 return result;
194 }
195 if remaining_tasks {
196 self.poll_with(Some(Duration::ZERO));
197 } else {
198 self.poll();
199 }
200 }
201 })
202 }
203
204 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
209 unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
210 }
211
212 pub fn spawn_blocking<T: Send + 'static>(
216 &self,
217 f: impl (FnOnce() -> T) + Send + Sync + 'static,
218 ) -> JoinHandle<T> {
219 let op = Asyncify::new(move || {
220 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
221 BufResult(Ok(0), res)
222 });
223 let closure = async move {
224 let mut op = op;
225 loop {
226 match self.submit(op).await {
227 BufResult(Ok(_), rop) => break rop.into_inner(),
228 BufResult(Err(_), rop) => op = rop,
229 }
230 let mut yielded = false;
233 poll_fn(|cx| {
234 if yielded {
235 Poll::Ready(())
236 } else {
237 yielded = true;
238 cx.waker().wake_by_ref();
239 Poll::Pending
240 }
241 })
242 .await;
243 }
244 };
245 unsafe { self.spawn_unchecked(closure) }
248 }
249
250 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
255 self.driver.borrow_mut().attach(fd)
256 }
257
258 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
259 self.driver.borrow_mut().push(op)
260 }
261
262 pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
266 self.submit_with_flags(op).map(|(res, _)| res)
267 }
268
269 pub fn submit_with_flags<T: OpCode + 'static>(
276 &self,
277 op: T,
278 ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
279 match self.submit_raw(op) {
280 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
281 PushEntry::Ready(res) => {
282 Either::Right(ready((res, 0)))
285 }
286 }
287 }
288
289 #[cfg(feature = "time")]
290 pub(crate) fn create_timer(&self, delay: std::time::Duration) -> impl Future<Output = ()> {
291 let mut timer_runtime = self.timer_runtime.borrow_mut();
292 if let Some(key) = timer_runtime.insert(delay) {
293 Either::Left(TimerFuture::new(key))
294 } else {
295 Either::Right(std::future::ready(()))
296 }
297 }
298
299 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
300 self.driver.borrow_mut().cancel(op);
301 }
302
303 #[cfg(feature = "time")]
304 pub(crate) fn cancel_timer(&self, key: usize) {
305 self.timer_runtime.borrow_mut().cancel(key);
306 }
307
308 pub(crate) fn poll_task<T: OpCode>(
309 &self,
310 cx: &mut Context,
311 op: Key<T>,
312 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
313 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
314 let mut driver = self.driver.borrow_mut();
315 driver.pop(op).map_pending(|mut k| {
316 driver.update_waker(&mut k, cx.waker().clone());
317 k
318 })
319 }
320
321 #[cfg(feature = "time")]
322 pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
323 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
324 let mut timer_runtime = self.timer_runtime.borrow_mut();
325 if !timer_runtime.is_completed(key) {
326 debug!("pending");
327 timer_runtime.update_waker(key, cx.waker().clone());
328 Poll::Pending
329 } else {
330 debug!("ready");
331 Poll::Ready(())
332 }
333 }
334
335 pub fn current_timeout(&self) -> Option<Duration> {
339 #[cfg(not(feature = "time"))]
340 let timeout = None;
341 #[cfg(feature = "time")]
342 let timeout = self.timer_runtime.borrow().min_timeout();
343 timeout
344 }
345
346 pub fn poll(&self) {
351 instrument!(compio_log::Level::DEBUG, "poll");
352 let timeout = self.current_timeout();
353 debug!("timeout: {:?}", timeout);
354 self.poll_with(timeout)
355 }
356
357 pub fn poll_with(&self, timeout: Option<Duration>) {
361 instrument!(compio_log::Level::DEBUG, "poll_with");
362
363 let mut driver = self.driver.borrow_mut();
364 match driver.poll(timeout) {
365 Ok(()) => {}
366 Err(e) => match e.kind() {
367 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
368 debug!("expected error: {e}");
369 }
370 _ => panic!("{e:?}"),
371 },
372 }
373 #[cfg(feature = "time")]
374 self.timer_runtime.borrow_mut().wake();
375 }
376}
377
378impl Drop for Runtime {
379 fn drop(&mut self) {
380 self.enter(|| {
381 while self.runnables.sync_runnables.pop().is_some() {}
382 let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
383 loop {
384 let runnable = local_runnables.borrow_mut().pop_front();
385 if runnable.is_none() {
386 break;
387 }
388 }
389 })
390 }
391}
392
393impl AsRawFd for Runtime {
394 fn as_raw_fd(&self) -> RawFd {
395 self.driver.borrow().as_raw_fd()
396 }
397}
398
399#[cfg(feature = "criterion")]
400impl criterion::async_executor::AsyncExecutor for Runtime {
401 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
402 self.block_on(future)
403 }
404}
405
406#[cfg(feature = "criterion")]
407impl criterion::async_executor::AsyncExecutor for &Runtime {
408 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
409 (**self).block_on(future)
410 }
411}
412
413#[derive(Debug, Clone)]
415pub struct RuntimeBuilder {
416 proactor_builder: ProactorBuilder,
417 event_interval: usize,
418}
419
420impl Default for RuntimeBuilder {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426impl RuntimeBuilder {
427 pub fn new() -> Self {
429 Self {
430 proactor_builder: ProactorBuilder::new(),
431 event_interval: 61,
432 }
433 }
434
435 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
437 self.proactor_builder = builder;
438 self
439 }
440
441 pub fn event_interval(&mut self, val: usize) -> &mut Self {
446 self.event_interval = val;
447 self
448 }
449
450 pub fn build(&self) -> io::Result<Runtime> {
452 Runtime::with_builder(self)
453 }
454}
455
456pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
480 Runtime::with_current(|r| r.spawn(future))
481}
482
483pub fn spawn_blocking<T: Send + 'static>(
492 f: impl (FnOnce() -> T) + Send + Sync + 'static,
493) -> JoinHandle<T> {
494 Runtime::with_current(|r| r.spawn_blocking(f))
495}
496
497pub fn submit<T: OpCode + 'static>(op: T) -> impl Future<Output = BufResult<usize, T>> {
504 Runtime::with_current(|r| r.submit(op))
505}
506
507pub fn submit_with_flags<T: OpCode + 'static>(
515 op: T,
516) -> impl Future<Output = (BufResult<usize, T>, u32)> {
517 Runtime::with_current(|r| r.submit_with_flags(op))
518}