compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::RefCell,
4    collections::VecDeque,
5    future::{Future, ready},
6    io,
7    marker::PhantomData,
8    panic::AssertUnwindSafe,
9    rc::Rc,
10    sync::Arc,
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18    AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod send_wrapper;
29use send_wrapper::SendWrapper;
30
31#[cfg(feature = "time")]
32use crate::runtime::time::{TimerFuture, TimerRuntime};
33use crate::{BufResult, runtime::op::OpFuture};
34
35scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
36
37/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
38/// `Err` when the spawned future panicked.
39pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
40
41struct RunnableQueue {
42    local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
43    sync_runnables: SegQueue<Runnable>,
44}
45
46impl RunnableQueue {
47    pub fn new() -> Self {
48        Self {
49            local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
50            sync_runnables: SegQueue::new(),
51        }
52    }
53
54    pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
55        if let Some(runnables) = self.local_runnables.get() {
56            runnables.borrow_mut().push_back(runnable);
57        } else {
58            self.sync_runnables.push(runnable);
59            handle.notify().ok();
60        }
61    }
62
63    /// SAFETY: call in the main thread
64    pub unsafe fn run(&self, event_interval: usize) -> bool {
65        let local_runnables = self.local_runnables.get_unchecked();
66        for _i in 0..event_interval {
67            let next_task = local_runnables.borrow_mut().pop_front();
68            let has_local_task = next_task.is_some();
69            if let Some(task) = next_task {
70                task.run();
71            }
72            // Cheaper than pop.
73            let has_sync_task = !self.sync_runnables.is_empty();
74            if has_sync_task {
75                if let Some(task) = self.sync_runnables.pop() {
76                    task.run();
77                }
78            } else if !has_local_task {
79                break;
80            }
81        }
82        !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
83    }
84}
85
86/// The async runtime of compio. It is a thread local runtime, and cannot be
87/// sent to other threads.
88pub struct Runtime {
89    driver: RefCell<Proactor>,
90    runnables: Arc<RunnableQueue>,
91    #[cfg(feature = "time")]
92    timer_runtime: RefCell<TimerRuntime>,
93    event_interval: usize,
94    // Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send,
95    // otherwise it won't be valid if the runtime is sent to other threads.
96    _p: PhantomData<Rc<VecDeque<Runnable>>>,
97}
98
99impl Runtime {
100    /// Create [`Runtime`] with default config.
101    pub fn new() -> io::Result<Self> {
102        Self::builder().build()
103    }
104
105    /// Create a builder for [`Runtime`].
106    pub fn builder() -> RuntimeBuilder {
107        RuntimeBuilder::new()
108    }
109
110    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
111        Ok(Self {
112            driver: RefCell::new(builder.proactor_builder.build()?),
113            runnables: Arc::new(RunnableQueue::new()),
114            #[cfg(feature = "time")]
115            timer_runtime: RefCell::new(TimerRuntime::new()),
116            event_interval: builder.event_interval,
117            _p: PhantomData,
118        })
119    }
120
121    /// Try to perform a function on the current runtime, and if no runtime is
122    /// running, return the function back.
123    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
124        if CURRENT_RUNTIME.is_set() {
125            Ok(CURRENT_RUNTIME.with(f))
126        } else {
127            Err(f)
128        }
129    }
130
131    /// Perform a function on the current runtime.
132    ///
133    /// ## Panics
134    ///
135    /// This method will panic if there are no running [`Runtime`].
136    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
137        #[cold]
138        fn not_in_compio_runtime() -> ! {
139            panic!("not in a compio runtime")
140        }
141
142        if CURRENT_RUNTIME.is_set() {
143            CURRENT_RUNTIME.with(f)
144        } else {
145            not_in_compio_runtime()
146        }
147    }
148
149    /// Set this runtime as current runtime, and perform a function in the
150    /// current scope.
151    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
152        CURRENT_RUNTIME.set(self, f)
153    }
154
155    /// Spawns a new asynchronous task, returning a [`Task`] for it.
156    ///
157    /// # Safety
158    ///
159    /// The caller should ensure the captured lifetime long enough.
160    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
161        let runnables = self.runnables.clone();
162        let handle = self.driver.borrow().handle();
163        let schedule = move |runnable| {
164            runnables.schedule(runnable, &handle);
165        };
166        let (runnable, task) = async_task::spawn_unchecked(future, schedule);
167        runnable.schedule();
168        task
169    }
170
171    /// Low level API to control the runtime.
172    ///
173    /// Run the scheduled tasks.
174    ///
175    /// The return value indicates whether there are still tasks in the queue.
176    pub fn run(&self) -> bool {
177        // SAFETY: self is !Send + !Sync.
178        unsafe { self.runnables.run(self.event_interval) }
179    }
180
181    /// Block on the future till it completes.
182    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
183        CURRENT_RUNTIME.set(self, || {
184            let mut result = None;
185            unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
186            loop {
187                let remaining_tasks = self.run();
188                if let Some(result) = result.take() {
189                    return result;
190                }
191                if remaining_tasks {
192                    self.poll_with(Some(Duration::ZERO));
193                } else {
194                    self.poll();
195                }
196            }
197        })
198    }
199
200    /// Spawns a new asynchronous task, returning a [`Task`] for it.
201    ///
202    /// Spawning a task enables the task to execute concurrently to other tasks.
203    /// There is no guarantee that a spawned task will execute to completion.
204    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
205        unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
206    }
207
208    /// Spawns a blocking task in a new thread, and wait for it.
209    ///
210    /// The task will not be cancelled even if the future is dropped.
211    pub fn spawn_blocking<T: Send + 'static>(
212        &self,
213        f: impl (FnOnce() -> T) + Send + 'static,
214    ) -> JoinHandle<T> {
215        let op = Asyncify::new(move || {
216            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
217            BufResult(Ok(0), res)
218        });
219        // It is safe and sound to use `submit` here because the task is spawned
220        // immediately.
221        #[allow(deprecated)]
222        unsafe {
223            self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
224        }
225    }
226
227    /// Attach a raw file descriptor/handle/socket to the runtime.
228    ///
229    /// You only need this when authoring your own high-level APIs. High-level
230    /// resources in this crate are attached automatically.
231    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
232        self.driver.borrow_mut().attach(fd)
233    }
234
235    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
236        self.driver.borrow_mut().push(op)
237    }
238
239    /// Submit an operation to the runtime.
240    ///
241    /// You only need this when authoring your own [`OpCode`].
242    ///
243    /// It is safe to send the returned future to another runtime and poll it,
244    /// but the exact behavior is not guaranteed, e.g. it may return pending
245    /// forever or else.
246    #[deprecated = "use compio::runtime::submit instead"]
247    pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
248        #[allow(deprecated)]
249        self.submit_with_flags(op).map(|(res, _)| res)
250    }
251
252    /// Submit an operation to the runtime.
253    ///
254    /// The difference between [`Runtime::submit`] is this method will return
255    /// the flags
256    ///
257    /// You only need this when authoring your own [`OpCode`].
258    ///
259    /// It is safe to send the returned future to another runtime and poll it,
260    /// but the exact behavior is not guaranteed, e.g. it may return pending
261    /// forever or else.
262    #[deprecated = "use compio::runtime::submit_with_flags instead"]
263    pub fn submit_with_flags<T: OpCode + 'static>(
264        &self,
265        op: T,
266    ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
267        match self.submit_raw(op) {
268            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
269            PushEntry::Ready(res) => {
270                // submit_flags won't be ready immediately, if ready, it must be error without
271                // flags
272                Either::Right(ready((res, 0)))
273            }
274        }
275    }
276
277    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
278        self.driver.borrow_mut().cancel(op);
279    }
280
281    #[cfg(feature = "time")]
282    pub(crate) fn cancel_timer(&self, key: usize) {
283        self.timer_runtime.borrow_mut().cancel(key);
284    }
285
286    pub(crate) fn poll_task<T: OpCode>(
287        &self,
288        cx: &mut Context,
289        op: Key<T>,
290    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
291        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
292        let mut driver = self.driver.borrow_mut();
293        driver.pop(op).map_pending(|mut k| {
294            driver.update_waker(&mut k, cx.waker().clone());
295            k
296        })
297    }
298
299    #[cfg(feature = "time")]
300    pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
301        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
302        let mut timer_runtime = self.timer_runtime.borrow_mut();
303        if !timer_runtime.is_completed(key) {
304            debug!("pending");
305            timer_runtime.update_waker(key, cx.waker().clone());
306            Poll::Pending
307        } else {
308            debug!("ready");
309            Poll::Ready(())
310        }
311    }
312
313    /// Low level API to control the runtime.
314    ///
315    /// Get the timeout value to be passed to [`Proactor::poll`].
316    pub fn current_timeout(&self) -> Option<Duration> {
317        #[cfg(not(feature = "time"))]
318        let timeout = None;
319        #[cfg(feature = "time")]
320        let timeout = self.timer_runtime.borrow().min_timeout();
321        timeout
322    }
323
324    /// Low level API to control the runtime.
325    ///
326    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
327    /// with [`Runtime::current_timeout`].
328    pub fn poll(&self) {
329        instrument!(compio_log::Level::DEBUG, "poll");
330        let timeout = self.current_timeout();
331        debug!("timeout: {:?}", timeout);
332        self.poll_with(timeout)
333    }
334
335    /// Low level API to control the runtime.
336    ///
337    /// Poll the inner proactor with a custom timeout.
338    pub fn poll_with(&self, timeout: Option<Duration>) {
339        instrument!(compio_log::Level::DEBUG, "poll_with");
340
341        let mut driver = self.driver.borrow_mut();
342        match driver.poll(timeout) {
343            Ok(()) => {}
344            Err(e) => match e.kind() {
345                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
346                    debug!("expected error: {e}");
347                }
348                _ => panic!("{e:?}"),
349            },
350        }
351        #[cfg(feature = "time")]
352        self.timer_runtime.borrow_mut().wake();
353    }
354}
355
356impl Drop for Runtime {
357    fn drop(&mut self) {
358        self.enter(|| {
359            while self.runnables.sync_runnables.pop().is_some() {}
360            let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
361            loop {
362                let runnable = local_runnables.borrow_mut().pop_front();
363                if runnable.is_none() {
364                    break;
365                }
366            }
367        })
368    }
369}
370
371impl AsRawFd for Runtime {
372    fn as_raw_fd(&self) -> RawFd {
373        self.driver.borrow().as_raw_fd()
374    }
375}
376
377#[cfg(feature = "criterion")]
378impl criterion::async_executor::AsyncExecutor for Runtime {
379    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
380        self.block_on(future)
381    }
382}
383
384#[cfg(feature = "criterion")]
385impl criterion::async_executor::AsyncExecutor for &Runtime {
386    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
387        (**self).block_on(future)
388    }
389}
390
391/// Builder for [`Runtime`].
392#[derive(Debug, Clone)]
393pub struct RuntimeBuilder {
394    proactor_builder: ProactorBuilder,
395    event_interval: usize,
396}
397
398impl Default for RuntimeBuilder {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404impl RuntimeBuilder {
405    /// Create the builder with default config.
406    pub fn new() -> Self {
407        Self {
408            proactor_builder: ProactorBuilder::new(),
409            event_interval: 61,
410        }
411    }
412
413    /// Replace proactor builder.
414    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
415        self.proactor_builder = builder;
416        self
417    }
418
419    /// Sets the number of scheduler ticks after which the scheduler will poll
420    /// for external events (timers, I/O, and so on).
421    ///
422    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
423    pub fn event_interval(&mut self, val: usize) -> &mut Self {
424        self.event_interval = val;
425        self
426    }
427
428    /// Build [`Runtime`].
429    pub fn build(&self) -> io::Result<Runtime> {
430        Runtime::with_builder(self)
431    }
432}
433
434/// Spawns a new asynchronous task, returning a [`Task`] for it.
435///
436/// Spawning a task enables the task to execute concurrently to other tasks.
437/// There is no guarantee that a spawned task will execute to completion.
438///
439/// ```
440/// # compio_runtime::Runtime::new().unwrap().block_on(async {
441/// let task = compio_runtime::spawn(async {
442///     println!("Hello from a spawned task!");
443///     42
444/// });
445///
446/// assert_eq!(
447///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
448///     42
449/// );
450/// # })
451/// ```
452///
453/// ## Panics
454///
455/// This method doesn't create runtime. It tries to obtain the current runtime
456/// by [`Runtime::with_current`].
457pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
458    Runtime::with_current(|r| r.spawn(future))
459}
460
461/// Spawns a blocking task in a new thread, and wait for it.
462///
463/// The task will not be cancelled even if the future is dropped.
464///
465/// ## Panics
466///
467/// This method doesn't create runtime. It tries to obtain the current runtime
468/// by [`Runtime::with_current`].
469pub fn spawn_blocking<T: Send + 'static>(
470    f: impl (FnOnce() -> T) + Send + 'static,
471) -> JoinHandle<T> {
472    Runtime::with_current(|r| r.spawn_blocking(f))
473}
474
475/// Submit an operation to the current runtime, and return a future for it.
476///
477/// ## Panics
478///
479/// This method doesn't create runtime. It tries to obtain the current runtime
480/// by [`Runtime::with_current`].
481pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
482    submit_with_flags(op).await.0
483}
484
485/// Submit an operation to the current runtime, and return a future for it with
486/// flags.
487///
488/// ## Panics
489///
490/// This method doesn't create runtime. It tries to obtain the current runtime
491/// by [`Runtime::with_current`].
492pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
493    let state = Runtime::with_current(|r| r.submit_raw(op));
494    match state {
495        PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
496        PushEntry::Ready(res) => {
497            // submit_flags won't be ready immediately, if ready, it must be error without
498            // flags, or the flags are not necessary
499            (res, 0)
500        }
501    }
502}
503
504#[cfg(feature = "time")]
505pub(crate) async fn create_timer(instant: std::time::Instant) {
506    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
507    if let Some(key) = key {
508        TimerFuture::new(key).await
509    }
510}