compio_dispatcher/
lib.rs

1//! Multithreading dispatcher for compio.
2
3#![warn(missing_docs)]
4
5use std::{
6    future::Future,
7    io,
8    num::NonZeroUsize,
9    panic::resume_unwind,
10    thread::{JoinHandle, available_parallelism},
11};
12
13use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
14use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
15use flume::{Sender, unbounded};
16use futures_channel::oneshot;
17
18type Spawning = Box<dyn Spawnable + Send>;
19
20trait Spawnable {
21    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
22}
23
24/// Concrete type for the closure we're sending to worker threads
25struct Concrete<F, R> {
26    callback: oneshot::Sender<R>,
27    func: F,
28}
29
30impl<F, R> Concrete<F, R> {
31    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
32        let (tx, rx) = oneshot::channel();
33        (Self { callback: tx, func }, rx)
34    }
35}
36
37impl<F, Fut, R> Spawnable for Concrete<F, R>
38where
39    F: FnOnce() -> Fut + Send + 'static,
40    Fut: Future<Output = R>,
41    R: Send + 'static,
42{
43    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
44        let Concrete { callback, func } = *self;
45        handle.spawn(async move {
46            let res = func().await;
47            callback.send(res).ok();
48        })
49    }
50}
51
52impl<F, R> Dispatchable for Concrete<F, R>
53where
54    F: FnOnce() -> R + Send + 'static,
55    R: Send + 'static,
56{
57    fn run(self: Box<Self>) {
58        let Concrete { callback, func } = *self;
59        let res = func();
60        callback.send(res).ok();
61    }
62}
63
64/// The dispatcher. It manages the threads and dispatches the tasks.
65#[derive(Debug)]
66pub struct Dispatcher {
67    sender: Sender<Spawning>,
68    threads: Vec<JoinHandle<()>>,
69    pool: AsyncifyPool,
70}
71
72impl Dispatcher {
73    /// Create the dispatcher with specified number of threads.
74    pub(crate) fn new_impl(mut builder: DispatcherBuilder) -> io::Result<Self> {
75        let mut proactor_builder = builder.proactor_builder;
76        proactor_builder.force_reuse_thread_pool();
77        let pool = proactor_builder.create_or_get_thread_pool();
78        let (sender, receiver) = unbounded::<Spawning>();
79
80        let threads = (0..builder.nthreads)
81            .map({
82                |index| {
83                    let proactor_builder = proactor_builder.clone();
84                    let receiver = receiver.clone();
85
86                    let thread_builder = std::thread::Builder::new();
87                    let thread_builder = if let Some(s) = builder.stack_size {
88                        thread_builder.stack_size(s)
89                    } else {
90                        thread_builder
91                    };
92                    let thread_builder = if let Some(f) = &mut builder.names {
93                        thread_builder.name(f(index))
94                    } else {
95                        thread_builder
96                    };
97
98                    thread_builder.spawn(move || {
99                        Runtime::builder()
100                            .with_proactor(proactor_builder)
101                            .build()
102                            .expect("cannot create compio runtime")
103                            .block_on(async move {
104                                while let Ok(f) = receiver.recv_async().await {
105                                    let task = Runtime::with_current(|rt| f.spawn(rt));
106                                    if builder.concurrent {
107                                        task.detach()
108                                    } else {
109                                        task.await.ok();
110                                    }
111                                }
112                            });
113                    })
114                }
115            })
116            .collect::<io::Result<Vec<_>>>()?;
117        Ok(Self {
118            sender,
119            threads,
120            pool,
121        })
122    }
123
124    /// Create the dispatcher with default config.
125    pub fn new() -> io::Result<Self> {
126        Self::builder().build()
127    }
128
129    /// Create a builder to build a dispatcher.
130    pub fn builder() -> DispatcherBuilder {
131        DispatcherBuilder::default()
132    }
133
134    /// Dispatch a task to the threads
135    ///
136    /// The provided `f` should be [`Send`] because it will be send to another
137    /// thread before calling. The returned [`Future`] need not to be [`Send`]
138    /// because it will be executed on only one thread.
139    ///
140    /// # Error
141    ///
142    /// If all threads have panicked, this method will return an error with the
143    /// sent closure.
144    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
145    where
146        Fn: (FnOnce() -> Fut) + Send + 'static,
147        Fut: Future<Output = R> + 'static,
148        R: Send + 'static,
149    {
150        let (concrete, rx) = Concrete::new(f);
151
152        match self.sender.send(Box::new(concrete)) {
153            Ok(_) => Ok(rx),
154            Err(err) => {
155                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
156                let recovered =
157                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
158                Err(DispatchError(recovered.func))
159            }
160        }
161    }
162
163    /// Dispatch a blocking task to the threads.
164    ///
165    /// Blocking pool of the dispatcher will be obtained from the proactor
166    /// builder. So any configuration of the proactor's blocking pool will be
167    /// applied to the dispatcher.
168    ///
169    /// # Error
170    ///
171    /// If all threads are busy and the thread pool is full, this method will
172    /// return an error with the original closure. The limit can be configured
173    /// with [`DispatcherBuilder::proactor_builder`] and
174    /// [`ProactorBuilder::thread_pool_limit`].
175    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
176    where
177        Fn: FnOnce() -> R + Send + 'static,
178        R: Send + 'static,
179    {
180        let (concrete, rx) = Concrete::new(f);
181
182        self.pool
183            .dispatch(concrete)
184            .map_err(|e| DispatchError(e.0.func))?;
185
186        Ok(rx)
187    }
188
189    /// Stop the dispatcher and wait for the threads to complete. If there is a
190    /// thread panicked, this method will resume the panic.
191    pub async fn join(self) -> io::Result<()> {
192        drop(self.sender);
193        let (tx, rx) = oneshot::channel::<Vec<_>>();
194        if let Err(f) = self.pool.dispatch({
195            move || {
196                let results = self
197                    .threads
198                    .into_iter()
199                    .map(|thread| thread.join())
200                    .collect();
201                tx.send(results).ok();
202            }
203        }) {
204            std::thread::spawn(f.0);
205        }
206        let results = rx
207            .await
208            .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
209        for res in results {
210            res.unwrap_or_else(|e| resume_unwind(e));
211        }
212        Ok(())
213    }
214}
215
216/// A builder for [`Dispatcher`].
217pub struct DispatcherBuilder {
218    nthreads: usize,
219    concurrent: bool,
220    stack_size: Option<usize>,
221    names: Option<Box<dyn FnMut(usize) -> String>>,
222    proactor_builder: ProactorBuilder,
223}
224
225impl DispatcherBuilder {
226    /// Create a builder with default settings.
227    pub fn new() -> Self {
228        Self {
229            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
230            concurrent: true,
231            stack_size: None,
232            names: None,
233            proactor_builder: ProactorBuilder::new(),
234        }
235    }
236
237    /// If execute tasks concurrently. Default to be `true`.
238    ///
239    /// When set to `false`, tasks are executed sequentially without any
240    /// concurrency within the thread.
241    pub fn concurrent(mut self, concurrent: bool) -> Self {
242        self.concurrent = concurrent;
243        self
244    }
245
246    /// Set the number of worker threads of the dispatcher. The default value is
247    /// the CPU number. If the CPU number could not be retrieved, the
248    /// default value is 1.
249    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
250        self.nthreads = nthreads.get();
251        self
252    }
253
254    /// Set the size of stack of the worker threads.
255    pub fn stack_size(mut self, s: usize) -> Self {
256        self.stack_size = Some(s);
257        self
258    }
259
260    /// Provide a function to assign names to the worker threads.
261    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
262        self.names = Some(Box::new(f) as _);
263        self
264    }
265
266    /// Set the proactor builder for the inner runtimes.
267    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
268        self.proactor_builder = builder;
269        self
270    }
271
272    /// Build the [`Dispatcher`].
273    pub fn build(self) -> io::Result<Dispatcher> {
274        Dispatcher::new_impl(self)
275    }
276}
277
278impl Default for DispatcherBuilder {
279    fn default() -> Self {
280        Self::new()
281    }
282}