compio_dispatcher/
lib.rs

1//! Multithreading dispatcher for compio.
2
3#![warn(missing_docs)]
4
5use std::{
6    future::Future,
7    io,
8    num::NonZeroUsize,
9    panic::resume_unwind,
10    sync::{Arc, Mutex},
11    thread::{JoinHandle, available_parallelism},
12};
13
14use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
15use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime, event::Event};
16use flume::{Sender, unbounded};
17use futures_channel::oneshot;
18
19type Spawning = Box<dyn Spawnable + Send>;
20
21trait Spawnable {
22    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
23}
24
25/// Concrete type for the closure we're sending to worker threads
26struct Concrete<F, R> {
27    callback: oneshot::Sender<R>,
28    func: F,
29}
30
31impl<F, R> Concrete<F, R> {
32    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
33        let (tx, rx) = oneshot::channel();
34        (Self { callback: tx, func }, rx)
35    }
36}
37
38impl<F, Fut, R> Spawnable for Concrete<F, R>
39where
40    F: FnOnce() -> Fut + Send + 'static,
41    Fut: Future<Output = R>,
42    R: Send + 'static,
43{
44    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
45        let Concrete { callback, func } = *self;
46        handle.spawn(async move {
47            let res = func().await;
48            callback.send(res).ok();
49        })
50    }
51}
52
53impl<F, R> Dispatchable for Concrete<F, R>
54where
55    F: FnOnce() -> R + Send + 'static,
56    R: Send + 'static,
57{
58    fn run(self: Box<Self>) {
59        let Concrete { callback, func } = *self;
60        let res = func();
61        callback.send(res).ok();
62    }
63}
64
65/// The dispatcher. It manages the threads and dispatches the tasks.
66#[derive(Debug)]
67pub struct Dispatcher {
68    sender: Sender<Spawning>,
69    threads: Vec<JoinHandle<()>>,
70    pool: AsyncifyPool,
71}
72
73impl Dispatcher {
74    /// Create the dispatcher with specified number of threads.
75    pub(crate) fn new_impl(mut builder: DispatcherBuilder) -> io::Result<Self> {
76        let mut proactor_builder = builder.proactor_builder;
77        proactor_builder.force_reuse_thread_pool();
78        let pool = proactor_builder.create_or_get_thread_pool();
79        let (sender, receiver) = unbounded::<Spawning>();
80
81        let threads = (0..builder.nthreads)
82            .map({
83                |index| {
84                    let proactor_builder = proactor_builder.clone();
85                    let receiver = receiver.clone();
86
87                    let thread_builder = std::thread::Builder::new();
88                    let thread_builder = if let Some(s) = builder.stack_size {
89                        thread_builder.stack_size(s)
90                    } else {
91                        thread_builder
92                    };
93                    let thread_builder = if let Some(f) = &mut builder.names {
94                        thread_builder.name(f(index))
95                    } else {
96                        thread_builder
97                    };
98
99                    thread_builder.spawn(move || {
100                        Runtime::builder()
101                            .with_proactor(proactor_builder)
102                            .build()
103                            .expect("cannot create compio runtime")
104                            .block_on(async move {
105                                while let Ok(f) = receiver.recv_async().await {
106                                    let task = Runtime::with_current(|rt| f.spawn(rt));
107                                    if builder.concurrent {
108                                        task.detach()
109                                    } else {
110                                        task.await.ok();
111                                    }
112                                }
113                            });
114                    })
115                }
116            })
117            .collect::<io::Result<Vec<_>>>()?;
118        Ok(Self {
119            sender,
120            threads,
121            pool,
122        })
123    }
124
125    /// Create the dispatcher with default config.
126    pub fn new() -> io::Result<Self> {
127        Self::builder().build()
128    }
129
130    /// Create a builder to build a dispatcher.
131    pub fn builder() -> DispatcherBuilder {
132        DispatcherBuilder::default()
133    }
134
135    /// Dispatch a task to the threads
136    ///
137    /// The provided `f` should be [`Send`] because it will be send to another
138    /// thread before calling. The returned [`Future`] need not to be [`Send`]
139    /// because it will be executed on only one thread.
140    ///
141    /// # Error
142    ///
143    /// If all threads have panicked, this method will return an error with the
144    /// sent closure.
145    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
146    where
147        Fn: (FnOnce() -> Fut) + Send + 'static,
148        Fut: Future<Output = R> + 'static,
149        R: Send + 'static,
150    {
151        let (concrete, rx) = Concrete::new(f);
152
153        match self.sender.send(Box::new(concrete)) {
154            Ok(_) => Ok(rx),
155            Err(err) => {
156                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
157                let recovered =
158                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
159                Err(DispatchError(recovered.func))
160            }
161        }
162    }
163
164    /// Dispatch a blocking task to the threads.
165    ///
166    /// Blocking pool of the dispatcher will be obtained from the proactor
167    /// builder. So any configuration of the proactor's blocking pool will be
168    /// applied to the dispatcher.
169    ///
170    /// # Error
171    ///
172    /// If all threads are busy and the thread pool is full, this method will
173    /// return an error with the original closure. The limit can be configured
174    /// with [`DispatcherBuilder::proactor_builder`] and
175    /// [`ProactorBuilder::thread_pool_limit`].
176    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
177    where
178        Fn: FnOnce() -> R + Send + 'static,
179        R: Send + 'static,
180    {
181        let (concrete, rx) = Concrete::new(f);
182
183        self.pool
184            .dispatch(concrete)
185            .map_err(|e| DispatchError(e.0.func))?;
186
187        Ok(rx)
188    }
189
190    /// Stop the dispatcher and wait for the threads to complete. If there is a
191    /// thread panicked, this method will resume the panic.
192    pub async fn join(self) -> io::Result<()> {
193        drop(self.sender);
194        let results = Arc::new(Mutex::new(vec![]));
195        let event = Event::new();
196        let handle = event.handle();
197        if let Err(f) = self.pool.dispatch({
198            let results = results.clone();
199            move || {
200                *results.lock().unwrap() = self
201                    .threads
202                    .into_iter()
203                    .map(|thread| thread.join())
204                    .collect();
205                handle.notify();
206            }
207        }) {
208            std::thread::spawn(f.0);
209        }
210        event.wait().await;
211        let mut guard = results.lock().unwrap();
212        for res in std::mem::take::<Vec<std::thread::Result<()>>>(guard.as_mut()) {
213            res.unwrap_or_else(|e| resume_unwind(e));
214        }
215        Ok(())
216    }
217}
218
219/// A builder for [`Dispatcher`].
220pub struct DispatcherBuilder {
221    nthreads: usize,
222    concurrent: bool,
223    stack_size: Option<usize>,
224    names: Option<Box<dyn FnMut(usize) -> String>>,
225    proactor_builder: ProactorBuilder,
226}
227
228impl DispatcherBuilder {
229    /// Create a builder with default settings.
230    pub fn new() -> Self {
231        Self {
232            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
233            concurrent: true,
234            stack_size: None,
235            names: None,
236            proactor_builder: ProactorBuilder::new(),
237        }
238    }
239
240    /// If execute tasks concurrently. Default to be `true`.
241    ///
242    /// When set to `false`, tasks are executed sequentially without any
243    /// concurrency within the thread.
244    pub fn concurrent(mut self, concurrent: bool) -> Self {
245        self.concurrent = concurrent;
246        self
247    }
248
249    /// Set the number of worker threads of the dispatcher. The default value is
250    /// the CPU number. If the CPU number could not be retrieved, the
251    /// default value is 1.
252    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
253        self.nthreads = nthreads.get();
254        self
255    }
256
257    /// Set the size of stack of the worker threads.
258    pub fn stack_size(mut self, s: usize) -> Self {
259        self.stack_size = Some(s);
260        self
261    }
262
263    /// Provide a function to assign names to the worker threads.
264    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
265        self.names = Some(Box::new(f) as _);
266        self
267    }
268
269    /// Set the proactor builder for the inner runtimes.
270    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
271        self.proactor_builder = builder;
272        self
273    }
274
275    /// Build the [`Dispatcher`].
276    pub fn build(self) -> io::Result<Dispatcher> {
277        Dispatcher::new_impl(self)
278    }
279}
280
281impl Default for DispatcherBuilder {
282    fn default() -> Self {
283        Self::new()
284    }
285}