1#![warn(missing_docs)]
4
5use std::{
6 future::Future,
7 io,
8 num::NonZeroUsize,
9 panic::resume_unwind,
10 thread::{JoinHandle, available_parallelism},
11};
12
13use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
14use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
15use flume::{Sender, unbounded};
16use futures_channel::oneshot;
17
18type Spawning = Box<dyn Spawnable + Send>;
19
20trait Spawnable {
21 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
22}
23
24struct Concrete<F, R> {
26 callback: oneshot::Sender<R>,
27 func: F,
28}
29
30impl<F, R> Concrete<F, R> {
31 pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
32 let (tx, rx) = oneshot::channel();
33 (Self { callback: tx, func }, rx)
34 }
35}
36
37impl<F, Fut, R> Spawnable for Concrete<F, R>
38where
39 F: FnOnce() -> Fut + Send + 'static,
40 Fut: Future<Output = R>,
41 R: Send + 'static,
42{
43 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
44 let Concrete { callback, func } = *self;
45 handle.spawn(async move {
46 let res = func().await;
47 callback.send(res).ok();
48 })
49 }
50}
51
52impl<F, R> Dispatchable for Concrete<F, R>
53where
54 F: FnOnce() -> R + Send + 'static,
55 R: Send + 'static,
56{
57 fn run(self: Box<Self>) {
58 let Concrete { callback, func } = *self;
59 let res = func();
60 callback.send(res).ok();
61 }
62}
63
64#[derive(Debug)]
66pub struct Dispatcher {
67 sender: Sender<Spawning>,
68 threads: Vec<JoinHandle<()>>,
69 pool: AsyncifyPool,
70}
71
72impl Dispatcher {
73 pub(crate) fn new_impl(mut builder: DispatcherBuilder) -> io::Result<Self> {
75 let mut proactor_builder = builder.proactor_builder;
76 proactor_builder.force_reuse_thread_pool();
77 let pool = proactor_builder.create_or_get_thread_pool();
78 let (sender, receiver) = unbounded::<Spawning>();
79
80 let threads = (0..builder.nthreads)
81 .map({
82 |index| {
83 let proactor_builder = proactor_builder.clone();
84 let receiver = receiver.clone();
85
86 let thread_builder = std::thread::Builder::new();
87 let thread_builder = if let Some(s) = builder.stack_size {
88 thread_builder.stack_size(s)
89 } else {
90 thread_builder
91 };
92 let thread_builder = if let Some(f) = &mut builder.names {
93 thread_builder.name(f(index))
94 } else {
95 thread_builder
96 };
97
98 thread_builder.spawn(move || {
99 Runtime::builder()
100 .with_proactor(proactor_builder)
101 .build()
102 .expect("cannot create compio runtime")
103 .block_on(async move {
104 while let Ok(f) = receiver.recv_async().await {
105 let task = Runtime::with_current(|rt| f.spawn(rt));
106 if builder.concurrent {
107 task.detach()
108 } else {
109 task.await.ok();
110 }
111 }
112 });
113 })
114 }
115 })
116 .collect::<io::Result<Vec<_>>>()?;
117 Ok(Self {
118 sender,
119 threads,
120 pool,
121 })
122 }
123
124 pub fn new() -> io::Result<Self> {
126 Self::builder().build()
127 }
128
129 pub fn builder() -> DispatcherBuilder {
131 DispatcherBuilder::default()
132 }
133
134 pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
145 where
146 Fn: (FnOnce() -> Fut) + Send + 'static,
147 Fut: Future<Output = R> + 'static,
148 R: Send + 'static,
149 {
150 let (concrete, rx) = Concrete::new(f);
151
152 match self.sender.send(Box::new(concrete)) {
153 Ok(_) => Ok(rx),
154 Err(err) => {
155 let recovered =
157 unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
158 Err(DispatchError(recovered.func))
159 }
160 }
161 }
162
163 pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
176 where
177 Fn: FnOnce() -> R + Send + 'static,
178 R: Send + 'static,
179 {
180 let (concrete, rx) = Concrete::new(f);
181
182 self.pool
183 .dispatch(concrete)
184 .map_err(|e| DispatchError(e.0.func))?;
185
186 Ok(rx)
187 }
188
189 pub async fn join(self) -> io::Result<()> {
192 drop(self.sender);
193 let (tx, rx) = oneshot::channel::<Vec<_>>();
194 if let Err(f) = self.pool.dispatch({
195 move || {
196 let results = self
197 .threads
198 .into_iter()
199 .map(|thread| thread.join())
200 .collect();
201 tx.send(results).ok();
202 }
203 }) {
204 std::thread::spawn(f.0);
205 }
206 let results = rx
207 .await
208 .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
209 for res in results {
210 res.unwrap_or_else(|e| resume_unwind(e));
211 }
212 Ok(())
213 }
214}
215
216pub struct DispatcherBuilder {
218 nthreads: usize,
219 concurrent: bool,
220 stack_size: Option<usize>,
221 names: Option<Box<dyn FnMut(usize) -> String>>,
222 proactor_builder: ProactorBuilder,
223}
224
225impl DispatcherBuilder {
226 pub fn new() -> Self {
228 Self {
229 nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
230 concurrent: true,
231 stack_size: None,
232 names: None,
233 proactor_builder: ProactorBuilder::new(),
234 }
235 }
236
237 pub fn concurrent(mut self, concurrent: bool) -> Self {
242 self.concurrent = concurrent;
243 self
244 }
245
246 pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
250 self.nthreads = nthreads.get();
251 self
252 }
253
254 pub fn stack_size(mut self, s: usize) -> Self {
256 self.stack_size = Some(s);
257 self
258 }
259
260 pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
262 self.names = Some(Box::new(f) as _);
263 self
264 }
265
266 pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
268 self.proactor_builder = builder;
269 self
270 }
271
272 pub fn build(self) -> io::Result<Dispatcher> {
274 Dispatcher::new_impl(self)
275 }
276}
277
278impl Default for DispatcherBuilder {
279 fn default() -> Self {
280 Self::new()
281 }
282}