compio_driver/
asyncify.rs1use std::{
2 fmt,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 },
7 time::Duration,
8};
9
10use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
11
12#[derive(Copy, Clone, PartialEq, Eq)]
16pub struct DispatchError<T>(pub T);
17
18impl<T> DispatchError<T> {
19 pub fn into_inner(self) -> T {
21 self.0
22 }
23}
24
25impl<T> fmt::Debug for DispatchError<T> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 "DispatchError(..)".fmt(f)
28 }
29}
30
31impl<T> fmt::Display for DispatchError<T> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 "all threads are busy".fmt(f)
34 }
35}
36
37impl<T> std::error::Error for DispatchError<T> {}
38
39type BoxedDispatchable = Box<dyn Dispatchable + Send>;
40
41pub trait Dispatchable: Send + 'static {
45 fn run(self: Box<Self>);
47}
48
49impl<F> Dispatchable for F
50where
51 F: FnOnce() + Send + 'static,
52{
53 fn run(self: Box<Self>) {
54 (*self)()
55 }
56}
57
58struct CounterGuard(Arc<AtomicUsize>);
59
60impl Drop for CounterGuard {
61 fn drop(&mut self) {
62 self.0.fetch_sub(1, Ordering::AcqRel);
63 }
64}
65
66fn worker(
67 receiver: Receiver<BoxedDispatchable>,
68 counter: Arc<AtomicUsize>,
69 timeout: Duration,
70) -> impl FnOnce() {
71 move || {
72 counter.fetch_add(1, Ordering::AcqRel);
73 let _guard = CounterGuard(counter);
74 while let Ok(f) = receiver.recv_timeout(timeout) {
75 f.run();
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct AsyncifyPool {
83 sender: Sender<BoxedDispatchable>,
84 receiver: Receiver<BoxedDispatchable>,
85 counter: Arc<AtomicUsize>,
86 thread_limit: usize,
87 recv_timeout: Duration,
88}
89
90impl AsyncifyPool {
91 pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
94 let (sender, receiver) = bounded(0);
95 Self {
96 sender,
97 receiver,
98 counter: Arc::new(AtomicUsize::new(0)),
99 thread_limit,
100 recv_timeout,
101 }
102 }
103
104 pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
109 match self.sender.try_send(Box::new(f) as BoxedDispatchable) {
110 Ok(_) => Ok(()),
111 Err(e) => match e {
112 TrySendError::Full(f) => {
113 if self.counter.load(Ordering::Acquire) >= self.thread_limit {
114 Err(DispatchError(*unsafe {
116 Box::from_raw(Box::into_raw(f).cast())
117 }))
118 } else {
119 std::thread::spawn(worker(
120 self.receiver.clone(),
121 self.counter.clone(),
122 self.recv_timeout,
123 ));
124 self.sender.send(f).expect("the channel should not be full");
125 Ok(())
126 }
127 }
128 TrySendError::Disconnected(_) => {
129 unreachable!("receiver should not all disconnected")
130 }
131 },
132 }
133 }
134}