compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::RefCell,
4    collections::VecDeque,
5    future::{Future, poll_fn, ready},
6    io,
7    marker::PhantomData,
8    panic::AssertUnwindSafe,
9    rc::Rc,
10    sync::Arc,
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18    AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod send_wrapper;
29use send_wrapper::SendWrapper;
30
31#[cfg(feature = "time")]
32use crate::runtime::time::{TimerFuture, TimerRuntime};
33use crate::{BufResult, runtime::op::OpFuture};
34
35scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
36
37/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
38/// `Err` when the spawned future panicked.
39pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
40
41struct RunnableQueue {
42    local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
43    sync_runnables: SegQueue<Runnable>,
44}
45
46impl RunnableQueue {
47    pub fn new() -> Self {
48        Self {
49            local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
50            sync_runnables: SegQueue::new(),
51        }
52    }
53
54    pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
55        if let Some(runnables) = self.local_runnables.get() {
56            runnables.borrow_mut().push_back(runnable);
57        } else {
58            self.sync_runnables.push(runnable);
59            handle.notify().ok();
60        }
61    }
62
63    /// SAFETY: call in the main thread
64    pub unsafe fn run(&self, event_interval: usize) -> bool {
65        let local_runnables = self.local_runnables.get_unchecked();
66        for _i in 0..event_interval {
67            let next_task = local_runnables.borrow_mut().pop_front();
68            let has_local_task = next_task.is_some();
69            if let Some(task) = next_task {
70                task.run();
71            }
72            // Cheaper than pop.
73            let has_sync_task = !self.sync_runnables.is_empty();
74            if has_sync_task {
75                if let Some(task) = self.sync_runnables.pop() {
76                    task.run();
77                }
78            } else if !has_local_task {
79                break;
80            }
81        }
82        !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
83    }
84}
85
86/// The async runtime of compio. It is a thread local runtime, and cannot be
87/// sent to other threads.
88pub struct Runtime {
89    driver: RefCell<Proactor>,
90    runnables: Arc<RunnableQueue>,
91    #[cfg(feature = "time")]
92    timer_runtime: RefCell<TimerRuntime>,
93    event_interval: usize,
94    // Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send,
95    // otherwise it won't be valid if the runtime is sent to other threads.
96    _p: PhantomData<Rc<VecDeque<Runnable>>>,
97}
98
99impl Runtime {
100    /// Create [`Runtime`] with default config.
101    pub fn new() -> io::Result<Self> {
102        Self::builder().build()
103    }
104
105    /// Create a builder for [`Runtime`].
106    pub fn builder() -> RuntimeBuilder {
107        RuntimeBuilder::new()
108    }
109
110    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
111        Ok(Self {
112            driver: RefCell::new(builder.proactor_builder.build()?),
113            runnables: Arc::new(RunnableQueue::new()),
114            #[cfg(feature = "time")]
115            timer_runtime: RefCell::new(TimerRuntime::new()),
116            event_interval: builder.event_interval,
117            _p: PhantomData,
118        })
119    }
120
121    /// Try to perform a function on the current runtime, and if no runtime is
122    /// running, return the function back.
123    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
124        if CURRENT_RUNTIME.is_set() {
125            Ok(CURRENT_RUNTIME.with(f))
126        } else {
127            Err(f)
128        }
129    }
130
131    /// Perform a function on the current runtime.
132    ///
133    /// ## Panics
134    ///
135    /// This method will panic if there are no running [`Runtime`].
136    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
137        #[cold]
138        fn not_in_compio_runtime() -> ! {
139            panic!("not in a compio runtime")
140        }
141
142        if CURRENT_RUNTIME.is_set() {
143            CURRENT_RUNTIME.with(f)
144        } else {
145            not_in_compio_runtime()
146        }
147    }
148
149    /// Set this runtime as current runtime, and perform a function in the
150    /// current scope.
151    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
152        CURRENT_RUNTIME.set(self, f)
153    }
154
155    /// Spawns a new asynchronous task, returning a [`Task`] for it.
156    ///
157    /// # Safety
158    ///
159    /// The caller should ensure the captured lifetime long enough.
160    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
161        let runnables = self.runnables.clone();
162        let handle = self
163            .driver
164            .borrow()
165            .handle()
166            .expect("cannot create notify handle of the proactor");
167        let schedule = move |runnable| {
168            runnables.schedule(runnable, &handle);
169        };
170        let (runnable, task) = async_task::spawn_unchecked(future, schedule);
171        runnable.schedule();
172        task
173    }
174
175    /// Low level API to control the runtime.
176    ///
177    /// Run the scheduled tasks.
178    ///
179    /// The return value indicates whether there are still tasks in the queue.
180    pub fn run(&self) -> bool {
181        // SAFETY: self is !Send + !Sync.
182        unsafe { self.runnables.run(self.event_interval) }
183    }
184
185    /// Block on the future till it completes.
186    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
187        CURRENT_RUNTIME.set(self, || {
188            let mut result = None;
189            unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
190            loop {
191                let remaining_tasks = self.run();
192                if let Some(result) = result.take() {
193                    return result;
194                }
195                if remaining_tasks {
196                    self.poll_with(Some(Duration::ZERO));
197                } else {
198                    self.poll();
199                }
200            }
201        })
202    }
203
204    /// Spawns a new asynchronous task, returning a [`Task`] for it.
205    ///
206    /// Spawning a task enables the task to execute concurrently to other tasks.
207    /// There is no guarantee that a spawned task will execute to completion.
208    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
209        unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
210    }
211
212    /// Spawns a blocking task in a new thread, and wait for it.
213    ///
214    /// The task will not be cancelled even if the future is dropped.
215    pub fn spawn_blocking<T: Send + 'static>(
216        &self,
217        f: impl (FnOnce() -> T) + Send + Sync + 'static,
218    ) -> JoinHandle<T> {
219        let op = Asyncify::new(move || {
220            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
221            BufResult(Ok(0), res)
222        });
223        let closure = async move {
224            let mut op = op;
225            loop {
226                match self.submit(op).await {
227                    BufResult(Ok(_), rop) => break rop.into_inner(),
228                    BufResult(Err(_), rop) => op = rop,
229                }
230                // Possible error: thread pool is full, or failed to create notify handle.
231                // Push the future to the back of the queue.
232                let mut yielded = false;
233                poll_fn(|cx| {
234                    if yielded {
235                        Poll::Ready(())
236                    } else {
237                        yielded = true;
238                        cx.waker().wake_by_ref();
239                        Poll::Pending
240                    }
241                })
242                .await;
243            }
244        };
245        // SAFETY: the closure catches the shared reference of self, which is in an Rc
246        // so it won't be moved.
247        unsafe { self.spawn_unchecked(closure) }
248    }
249
250    /// Attach a raw file descriptor/handle/socket to the runtime.
251    ///
252    /// You only need this when authoring your own high-level APIs. High-level
253    /// resources in this crate are attached automatically.
254    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
255        self.driver.borrow_mut().attach(fd)
256    }
257
258    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
259        self.driver.borrow_mut().push(op)
260    }
261
262    /// Submit an operation to the runtime.
263    ///
264    /// You only need this when authoring your own [`OpCode`].
265    pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
266        self.submit_with_flags(op).map(|(res, _)| res)
267    }
268
269    /// Submit an operation to the runtime.
270    ///
271    /// The difference between [`Runtime::submit`] is this method will return
272    /// the flags
273    ///
274    /// You only need this when authoring your own [`OpCode`].
275    pub fn submit_with_flags<T: OpCode + 'static>(
276        &self,
277        op: T,
278    ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
279        match self.submit_raw(op) {
280            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
281            PushEntry::Ready(res) => {
282                // submit_flags won't be ready immediately, if ready, it must be error without
283                // flags
284                Either::Right(ready((res, 0)))
285            }
286        }
287    }
288
289    #[cfg(feature = "time")]
290    pub(crate) fn create_timer(&self, delay: std::time::Duration) -> impl Future<Output = ()> {
291        let mut timer_runtime = self.timer_runtime.borrow_mut();
292        if let Some(key) = timer_runtime.insert(delay) {
293            Either::Left(TimerFuture::new(key))
294        } else {
295            Either::Right(std::future::ready(()))
296        }
297    }
298
299    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
300        self.driver.borrow_mut().cancel(op);
301    }
302
303    #[cfg(feature = "time")]
304    pub(crate) fn cancel_timer(&self, key: usize) {
305        self.timer_runtime.borrow_mut().cancel(key);
306    }
307
308    pub(crate) fn poll_task<T: OpCode>(
309        &self,
310        cx: &mut Context,
311        op: Key<T>,
312    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
313        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
314        let mut driver = self.driver.borrow_mut();
315        driver.pop(op).map_pending(|mut k| {
316            driver.update_waker(&mut k, cx.waker().clone());
317            k
318        })
319    }
320
321    #[cfg(feature = "time")]
322    pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
323        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
324        let mut timer_runtime = self.timer_runtime.borrow_mut();
325        if !timer_runtime.is_completed(key) {
326            debug!("pending");
327            timer_runtime.update_waker(key, cx.waker().clone());
328            Poll::Pending
329        } else {
330            debug!("ready");
331            Poll::Ready(())
332        }
333    }
334
335    /// Low level API to control the runtime.
336    ///
337    /// Get the timeout value to be passed to [`Proactor::poll`].
338    pub fn current_timeout(&self) -> Option<Duration> {
339        #[cfg(not(feature = "time"))]
340        let timeout = None;
341        #[cfg(feature = "time")]
342        let timeout = self.timer_runtime.borrow().min_timeout();
343        timeout
344    }
345
346    /// Low level API to control the runtime.
347    ///
348    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
349    /// with [`Runtime::current_timeout`].
350    pub fn poll(&self) {
351        instrument!(compio_log::Level::DEBUG, "poll");
352        let timeout = self.current_timeout();
353        debug!("timeout: {:?}", timeout);
354        self.poll_with(timeout)
355    }
356
357    /// Low level API to control the runtime.
358    ///
359    /// Poll the inner proactor with a custom timeout.
360    pub fn poll_with(&self, timeout: Option<Duration>) {
361        instrument!(compio_log::Level::DEBUG, "poll_with");
362
363        let mut driver = self.driver.borrow_mut();
364        match driver.poll(timeout) {
365            Ok(()) => {}
366            Err(e) => match e.kind() {
367                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
368                    debug!("expected error: {e}");
369                }
370                _ => panic!("{e:?}"),
371            },
372        }
373        #[cfg(feature = "time")]
374        self.timer_runtime.borrow_mut().wake();
375    }
376}
377
378impl Drop for Runtime {
379    fn drop(&mut self) {
380        self.enter(|| {
381            while self.runnables.sync_runnables.pop().is_some() {}
382            let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
383            loop {
384                let runnable = local_runnables.borrow_mut().pop_front();
385                if runnable.is_none() {
386                    break;
387                }
388            }
389        })
390    }
391}
392
393impl AsRawFd for Runtime {
394    fn as_raw_fd(&self) -> RawFd {
395        self.driver.borrow().as_raw_fd()
396    }
397}
398
399#[cfg(feature = "criterion")]
400impl criterion::async_executor::AsyncExecutor for Runtime {
401    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
402        self.block_on(future)
403    }
404}
405
406#[cfg(feature = "criterion")]
407impl criterion::async_executor::AsyncExecutor for &Runtime {
408    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
409        (**self).block_on(future)
410    }
411}
412
413/// Builder for [`Runtime`].
414#[derive(Debug, Clone)]
415pub struct RuntimeBuilder {
416    proactor_builder: ProactorBuilder,
417    event_interval: usize,
418}
419
420impl Default for RuntimeBuilder {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426impl RuntimeBuilder {
427    /// Create the builder with default config.
428    pub fn new() -> Self {
429        Self {
430            proactor_builder: ProactorBuilder::new(),
431            event_interval: 61,
432        }
433    }
434
435    /// Replace proactor builder.
436    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
437        self.proactor_builder = builder;
438        self
439    }
440
441    /// Sets the number of scheduler ticks after which the scheduler will poll
442    /// for external events (timers, I/O, and so on).
443    ///
444    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
445    pub fn event_interval(&mut self, val: usize) -> &mut Self {
446        self.event_interval = val;
447        self
448    }
449
450    /// Build [`Runtime`].
451    pub fn build(&self) -> io::Result<Runtime> {
452        Runtime::with_builder(self)
453    }
454}
455
456/// Spawns a new asynchronous task, returning a [`Task`] for it.
457///
458/// Spawning a task enables the task to execute concurrently to other tasks.
459/// There is no guarantee that a spawned task will execute to completion.
460///
461/// ```
462/// # compio_runtime::Runtime::new().unwrap().block_on(async {
463/// let task = compio_runtime::spawn(async {
464///     println!("Hello from a spawned task!");
465///     42
466/// });
467///
468/// assert_eq!(
469///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
470///     42
471/// );
472/// # })
473/// ```
474///
475/// ## Panics
476///
477/// This method doesn't create runtime. It tries to obtain the current runtime
478/// by [`Runtime::with_current`].
479pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
480    Runtime::with_current(|r| r.spawn(future))
481}
482
483/// Spawns a blocking task in a new thread, and wait for it.
484///
485/// The task will not be cancelled even if the future is dropped.
486///
487/// ## Panics
488///
489/// This method doesn't create runtime. It tries to obtain the current runtime
490/// by [`Runtime::with_current`].
491pub fn spawn_blocking<T: Send + 'static>(
492    f: impl (FnOnce() -> T) + Send + Sync + 'static,
493) -> JoinHandle<T> {
494    Runtime::with_current(|r| r.spawn_blocking(f))
495}
496
497/// Submit an operation to the current runtime, and return a future for it.
498///
499/// ## Panics
500///
501/// This method doesn't create runtime. It tries to obtain the current runtime
502/// by [`Runtime::with_current`].
503pub fn submit<T: OpCode + 'static>(op: T) -> impl Future<Output = BufResult<usize, T>> {
504    Runtime::with_current(|r| r.submit(op))
505}
506
507/// Submit an operation to the current runtime, and return a future for it with
508/// flags.
509///
510/// ## Panics
511///
512/// This method doesn't create runtime. It tries to obtain the current runtime
513/// by [`Runtime::with_current`].
514pub fn submit_with_flags<T: OpCode + 'static>(
515    op: T,
516) -> impl Future<Output = (BufResult<usize, T>, u32)> {
517    Runtime::with_current(|r| r.submit_with_flags(op))
518}