1use std::{
2 any::Any,
3 cell::RefCell,
4 collections::VecDeque,
5 future::{Future, ready},
6 io,
7 marker::PhantomData,
8 panic::AssertUnwindSafe,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18 AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod send_wrapper;
29use send_wrapper::SendWrapper;
30
31#[cfg(feature = "time")]
32use crate::runtime::time::{TimerFuture, TimerRuntime};
33use crate::{BufResult, runtime::op::OpFuture};
34
35scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
36
37pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
40
41struct RunnableQueue {
42 local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
43 sync_runnables: SegQueue<Runnable>,
44}
45
46impl RunnableQueue {
47 pub fn new() -> Self {
48 Self {
49 local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
50 sync_runnables: SegQueue::new(),
51 }
52 }
53
54 pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
55 if let Some(runnables) = self.local_runnables.get() {
56 runnables.borrow_mut().push_back(runnable);
57 } else {
58 self.sync_runnables.push(runnable);
59 handle.notify().ok();
60 }
61 }
62
63 pub unsafe fn run(&self, event_interval: usize) -> bool {
65 let local_runnables = self.local_runnables.get_unchecked();
66 for _i in 0..event_interval {
67 let next_task = local_runnables.borrow_mut().pop_front();
68 let has_local_task = next_task.is_some();
69 if let Some(task) = next_task {
70 task.run();
71 }
72 let has_sync_task = !self.sync_runnables.is_empty();
74 if has_sync_task {
75 if let Some(task) = self.sync_runnables.pop() {
76 task.run();
77 }
78 } else if !has_local_task {
79 break;
80 }
81 }
82 !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
83 }
84}
85
86pub struct Runtime {
89 driver: RefCell<Proactor>,
90 runnables: Arc<RunnableQueue>,
91 #[cfg(feature = "time")]
92 timer_runtime: RefCell<TimerRuntime>,
93 event_interval: usize,
94 _p: PhantomData<Rc<VecDeque<Runnable>>>,
97}
98
99impl Runtime {
100 pub fn new() -> io::Result<Self> {
102 Self::builder().build()
103 }
104
105 pub fn builder() -> RuntimeBuilder {
107 RuntimeBuilder::new()
108 }
109
110 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
111 Ok(Self {
112 driver: RefCell::new(builder.proactor_builder.build()?),
113 runnables: Arc::new(RunnableQueue::new()),
114 #[cfg(feature = "time")]
115 timer_runtime: RefCell::new(TimerRuntime::new()),
116 event_interval: builder.event_interval,
117 _p: PhantomData,
118 })
119 }
120
121 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
124 if CURRENT_RUNTIME.is_set() {
125 Ok(CURRENT_RUNTIME.with(f))
126 } else {
127 Err(f)
128 }
129 }
130
131 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
137 #[cold]
138 fn not_in_compio_runtime() -> ! {
139 panic!("not in a compio runtime")
140 }
141
142 if CURRENT_RUNTIME.is_set() {
143 CURRENT_RUNTIME.with(f)
144 } else {
145 not_in_compio_runtime()
146 }
147 }
148
149 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
152 CURRENT_RUNTIME.set(self, f)
153 }
154
155 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
161 let runnables = self.runnables.clone();
162 let handle = self.driver.borrow().handle();
163 let schedule = move |runnable| {
164 runnables.schedule(runnable, &handle);
165 };
166 let (runnable, task) = async_task::spawn_unchecked(future, schedule);
167 runnable.schedule();
168 task
169 }
170
171 pub fn run(&self) -> bool {
177 unsafe { self.runnables.run(self.event_interval) }
179 }
180
181 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
183 CURRENT_RUNTIME.set(self, || {
184 let mut result = None;
185 unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
186 loop {
187 let remaining_tasks = self.run();
188 if let Some(result) = result.take() {
189 return result;
190 }
191 if remaining_tasks {
192 self.poll_with(Some(Duration::ZERO));
193 } else {
194 self.poll();
195 }
196 }
197 })
198 }
199
200 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
205 unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
206 }
207
208 pub fn spawn_blocking<T: Send + 'static>(
212 &self,
213 f: impl (FnOnce() -> T) + Send + 'static,
214 ) -> JoinHandle<T> {
215 let op = Asyncify::new(move || {
216 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
217 BufResult(Ok(0), res)
218 });
219 #[allow(deprecated)]
222 unsafe {
223 self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
224 }
225 }
226
227 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
232 self.driver.borrow_mut().attach(fd)
233 }
234
235 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
236 self.driver.borrow_mut().push(op)
237 }
238
239 #[deprecated = "use compio::runtime::submit instead"]
247 pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
248 #[allow(deprecated)]
249 self.submit_with_flags(op).map(|(res, _)| res)
250 }
251
252 #[deprecated = "use compio::runtime::submit_with_flags instead"]
263 pub fn submit_with_flags<T: OpCode + 'static>(
264 &self,
265 op: T,
266 ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
267 match self.submit_raw(op) {
268 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
269 PushEntry::Ready(res) => {
270 Either::Right(ready((res, 0)))
273 }
274 }
275 }
276
277 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
278 self.driver.borrow_mut().cancel(op);
279 }
280
281 #[cfg(feature = "time")]
282 pub(crate) fn cancel_timer(&self, key: usize) {
283 self.timer_runtime.borrow_mut().cancel(key);
284 }
285
286 pub(crate) fn poll_task<T: OpCode>(
287 &self,
288 cx: &mut Context,
289 op: Key<T>,
290 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
291 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
292 let mut driver = self.driver.borrow_mut();
293 driver.pop(op).map_pending(|mut k| {
294 driver.update_waker(&mut k, cx.waker().clone());
295 k
296 })
297 }
298
299 #[cfg(feature = "time")]
300 pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
301 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
302 let mut timer_runtime = self.timer_runtime.borrow_mut();
303 if !timer_runtime.is_completed(key) {
304 debug!("pending");
305 timer_runtime.update_waker(key, cx.waker().clone());
306 Poll::Pending
307 } else {
308 debug!("ready");
309 Poll::Ready(())
310 }
311 }
312
313 pub fn current_timeout(&self) -> Option<Duration> {
317 #[cfg(not(feature = "time"))]
318 let timeout = None;
319 #[cfg(feature = "time")]
320 let timeout = self.timer_runtime.borrow().min_timeout();
321 timeout
322 }
323
324 pub fn poll(&self) {
329 instrument!(compio_log::Level::DEBUG, "poll");
330 let timeout = self.current_timeout();
331 debug!("timeout: {:?}", timeout);
332 self.poll_with(timeout)
333 }
334
335 pub fn poll_with(&self, timeout: Option<Duration>) {
339 instrument!(compio_log::Level::DEBUG, "poll_with");
340
341 let mut driver = self.driver.borrow_mut();
342 match driver.poll(timeout) {
343 Ok(()) => {}
344 Err(e) => match e.kind() {
345 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
346 debug!("expected error: {e}");
347 }
348 _ => panic!("{e:?}"),
349 },
350 }
351 #[cfg(feature = "time")]
352 self.timer_runtime.borrow_mut().wake();
353 }
354}
355
356impl Drop for Runtime {
357 fn drop(&mut self) {
358 self.enter(|| {
359 while self.runnables.sync_runnables.pop().is_some() {}
360 let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
361 loop {
362 let runnable = local_runnables.borrow_mut().pop_front();
363 if runnable.is_none() {
364 break;
365 }
366 }
367 })
368 }
369}
370
371impl AsRawFd for Runtime {
372 fn as_raw_fd(&self) -> RawFd {
373 self.driver.borrow().as_raw_fd()
374 }
375}
376
377#[cfg(feature = "criterion")]
378impl criterion::async_executor::AsyncExecutor for Runtime {
379 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
380 self.block_on(future)
381 }
382}
383
384#[cfg(feature = "criterion")]
385impl criterion::async_executor::AsyncExecutor for &Runtime {
386 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
387 (**self).block_on(future)
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct RuntimeBuilder {
394 proactor_builder: ProactorBuilder,
395 event_interval: usize,
396}
397
398impl Default for RuntimeBuilder {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404impl RuntimeBuilder {
405 pub fn new() -> Self {
407 Self {
408 proactor_builder: ProactorBuilder::new(),
409 event_interval: 61,
410 }
411 }
412
413 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
415 self.proactor_builder = builder;
416 self
417 }
418
419 pub fn event_interval(&mut self, val: usize) -> &mut Self {
424 self.event_interval = val;
425 self
426 }
427
428 pub fn build(&self) -> io::Result<Runtime> {
430 Runtime::with_builder(self)
431 }
432}
433
434pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
458 Runtime::with_current(|r| r.spawn(future))
459}
460
461pub fn spawn_blocking<T: Send + 'static>(
470 f: impl (FnOnce() -> T) + Send + 'static,
471) -> JoinHandle<T> {
472 Runtime::with_current(|r| r.spawn_blocking(f))
473}
474
475pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
482 submit_with_flags(op).await.0
483}
484
485pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
493 let state = Runtime::with_current(|r| r.submit_raw(op));
494 match state {
495 PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
496 PushEntry::Ready(res) => {
497 (res, 0)
500 }
501 }
502}
503
504#[cfg(feature = "time")]
505pub(crate) async fn create_timer(instant: std::time::Instant) {
506 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
507 if let Some(key) = key {
508 TimerFuture::new(key).await
509 }
510}