compio_driver/
asyncify.rs

1use std::{
2    fmt,
3    sync::{
4        Arc,
5        atomic::{AtomicUsize, Ordering},
6    },
7    time::Duration,
8};
9
10use crossbeam_channel::{Receiver, Sender, TrySendError, bounded};
11
12/// An error that may be emitted when all worker threads are busy. It simply
13/// returns the dispatchable value with a convenient [`fmt::Debug`] and
14/// [`fmt::Display`] implementation.
15#[derive(Copy, Clone, PartialEq, Eq)]
16pub struct DispatchError<T>(pub T);
17
18impl<T> DispatchError<T> {
19    /// Consume the error, yielding the dispatchable that failed to be sent.
20    pub fn into_inner(self) -> T {
21        self.0
22    }
23}
24
25impl<T> fmt::Debug for DispatchError<T> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        "DispatchError(..)".fmt(f)
28    }
29}
30
31impl<T> fmt::Display for DispatchError<T> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        "all threads are busy".fmt(f)
34    }
35}
36
37impl<T> std::error::Error for DispatchError<T> {}
38
39type BoxedDispatchable = Box<dyn Dispatchable + Send>;
40
41/// A trait for dispatching a closure. It's implemented for all `FnOnce() + Send
42/// + 'static` but may also be implemented for any other types that are `Send`
43///   and `'static`.
44pub trait Dispatchable: Send + 'static {
45    /// Run the dispatchable
46    fn run(self: Box<Self>);
47}
48
49impl<F> Dispatchable for F
50where
51    F: FnOnce() + Send + 'static,
52{
53    fn run(self: Box<Self>) {
54        (*self)()
55    }
56}
57
58struct CounterGuard(Arc<AtomicUsize>);
59
60impl Drop for CounterGuard {
61    fn drop(&mut self) {
62        self.0.fetch_sub(1, Ordering::AcqRel);
63    }
64}
65
66fn worker(
67    receiver: Receiver<BoxedDispatchable>,
68    counter: Arc<AtomicUsize>,
69    timeout: Duration,
70) -> impl FnOnce() {
71    move || {
72        counter.fetch_add(1, Ordering::AcqRel);
73        let _guard = CounterGuard(counter);
74        while let Ok(f) = receiver.recv_timeout(timeout) {
75            f.run();
76        }
77    }
78}
79
80/// A thread pool to perform blocking operations in other threads.
81#[derive(Debug, Clone)]
82pub struct AsyncifyPool {
83    sender: Sender<BoxedDispatchable>,
84    receiver: Receiver<BoxedDispatchable>,
85    counter: Arc<AtomicUsize>,
86    thread_limit: usize,
87    recv_timeout: Duration,
88}
89
90impl AsyncifyPool {
91    /// Create [`AsyncifyPool`] with thread number limit and channel receive
92    /// timeout.
93    pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
94        let (sender, receiver) = bounded(0);
95        Self {
96            sender,
97            receiver,
98            counter: Arc::new(AtomicUsize::new(0)),
99            thread_limit,
100            recv_timeout,
101        }
102    }
103
104    /// Send a dispatchable, usually a closure, to another thread. Usually the
105    /// user should not use it. When all threads are busy and thread number
106    /// limit has been reached, it will return an error with the original
107    /// dispatchable.
108    pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
109        match self.sender.try_send(Box::new(f) as BoxedDispatchable) {
110            Ok(_) => Ok(()),
111            Err(e) => match e {
112                TrySendError::Full(f) => {
113                    if self.counter.load(Ordering::Acquire) >= self.thread_limit {
114                        // Safety: we can ensure the type
115                        Err(DispatchError(*unsafe {
116                            Box::from_raw(Box::into_raw(f).cast())
117                        }))
118                    } else {
119                        std::thread::spawn(worker(
120                            self.receiver.clone(),
121                            self.counter.clone(),
122                            self.recv_timeout,
123                        ));
124                        self.sender.send(f).expect("the channel should not be full");
125                        Ok(())
126                    }
127                }
128                TrySendError::Disconnected(_) => {
129                    unreachable!("receiver should not all disconnected")
130                }
131            },
132        }
133    }
134}