1#![warn(missing_docs)]
4
5use std::{
6 future::Future,
7 io,
8 num::NonZeroUsize,
9 panic::resume_unwind,
10 sync::{Arc, Mutex},
11 thread::{JoinHandle, available_parallelism},
12};
13
14use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
15use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime, event::Event};
16use flume::{Sender, unbounded};
17use futures_channel::oneshot;
18
19type Spawning = Box<dyn Spawnable + Send>;
20
21trait Spawnable {
22 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
23}
24
25struct Concrete<F, R> {
27 callback: oneshot::Sender<R>,
28 func: F,
29}
30
31impl<F, R> Concrete<F, R> {
32 pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
33 let (tx, rx) = oneshot::channel();
34 (Self { callback: tx, func }, rx)
35 }
36}
37
38impl<F, Fut, R> Spawnable for Concrete<F, R>
39where
40 F: FnOnce() -> Fut + Send + 'static,
41 Fut: Future<Output = R>,
42 R: Send + 'static,
43{
44 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
45 let Concrete { callback, func } = *self;
46 handle.spawn(async move {
47 let res = func().await;
48 callback.send(res).ok();
49 })
50 }
51}
52
53impl<F, R> Dispatchable for Concrete<F, R>
54where
55 F: FnOnce() -> R + Send + 'static,
56 R: Send + 'static,
57{
58 fn run(self: Box<Self>) {
59 let Concrete { callback, func } = *self;
60 let res = func();
61 callback.send(res).ok();
62 }
63}
64
65#[derive(Debug)]
67pub struct Dispatcher {
68 sender: Sender<Spawning>,
69 threads: Vec<JoinHandle<()>>,
70 pool: AsyncifyPool,
71}
72
73impl Dispatcher {
74 pub(crate) fn new_impl(mut builder: DispatcherBuilder) -> io::Result<Self> {
76 let mut proactor_builder = builder.proactor_builder;
77 proactor_builder.force_reuse_thread_pool();
78 let pool = proactor_builder.create_or_get_thread_pool();
79 let (sender, receiver) = unbounded::<Spawning>();
80
81 let threads = (0..builder.nthreads)
82 .map({
83 |index| {
84 let proactor_builder = proactor_builder.clone();
85 let receiver = receiver.clone();
86
87 let thread_builder = std::thread::Builder::new();
88 let thread_builder = if let Some(s) = builder.stack_size {
89 thread_builder.stack_size(s)
90 } else {
91 thread_builder
92 };
93 let thread_builder = if let Some(f) = &mut builder.names {
94 thread_builder.name(f(index))
95 } else {
96 thread_builder
97 };
98
99 thread_builder.spawn(move || {
100 Runtime::builder()
101 .with_proactor(proactor_builder)
102 .build()
103 .expect("cannot create compio runtime")
104 .block_on(async move {
105 while let Ok(f) = receiver.recv_async().await {
106 let task = Runtime::with_current(|rt| f.spawn(rt));
107 if builder.concurrent {
108 task.detach()
109 } else {
110 task.await.ok();
111 }
112 }
113 });
114 })
115 }
116 })
117 .collect::<io::Result<Vec<_>>>()?;
118 Ok(Self {
119 sender,
120 threads,
121 pool,
122 })
123 }
124
125 pub fn new() -> io::Result<Self> {
127 Self::builder().build()
128 }
129
130 pub fn builder() -> DispatcherBuilder {
132 DispatcherBuilder::default()
133 }
134
135 pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
146 where
147 Fn: (FnOnce() -> Fut) + Send + 'static,
148 Fut: Future<Output = R> + 'static,
149 R: Send + 'static,
150 {
151 let (concrete, rx) = Concrete::new(f);
152
153 match self.sender.send(Box::new(concrete)) {
154 Ok(_) => Ok(rx),
155 Err(err) => {
156 let recovered =
158 unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
159 Err(DispatchError(recovered.func))
160 }
161 }
162 }
163
164 pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
177 where
178 Fn: FnOnce() -> R + Send + 'static,
179 R: Send + 'static,
180 {
181 let (concrete, rx) = Concrete::new(f);
182
183 self.pool
184 .dispatch(concrete)
185 .map_err(|e| DispatchError(e.0.func))?;
186
187 Ok(rx)
188 }
189
190 pub async fn join(self) -> io::Result<()> {
193 drop(self.sender);
194 let results = Arc::new(Mutex::new(vec![]));
195 let event = Event::new();
196 let handle = event.handle();
197 if let Err(f) = self.pool.dispatch({
198 let results = results.clone();
199 move || {
200 *results.lock().unwrap() = self
201 .threads
202 .into_iter()
203 .map(|thread| thread.join())
204 .collect();
205 handle.notify();
206 }
207 }) {
208 std::thread::spawn(f.0);
209 }
210 event.wait().await;
211 let mut guard = results.lock().unwrap();
212 for res in std::mem::take::<Vec<std::thread::Result<()>>>(guard.as_mut()) {
213 res.unwrap_or_else(|e| resume_unwind(e));
214 }
215 Ok(())
216 }
217}
218
219pub struct DispatcherBuilder {
221 nthreads: usize,
222 concurrent: bool,
223 stack_size: Option<usize>,
224 names: Option<Box<dyn FnMut(usize) -> String>>,
225 proactor_builder: ProactorBuilder,
226}
227
228impl DispatcherBuilder {
229 pub fn new() -> Self {
231 Self {
232 nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
233 concurrent: true,
234 stack_size: None,
235 names: None,
236 proactor_builder: ProactorBuilder::new(),
237 }
238 }
239
240 pub fn concurrent(mut self, concurrent: bool) -> Self {
245 self.concurrent = concurrent;
246 self
247 }
248
249 pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
253 self.nthreads = nthreads.get();
254 self
255 }
256
257 pub fn stack_size(mut self, s: usize) -> Self {
259 self.stack_size = Some(s);
260 self
261 }
262
263 pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
265 self.names = Some(Box::new(f) as _);
266 self
267 }
268
269 pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
271 self.proactor_builder = builder;
272 self
273 }
274
275 pub fn build(self) -> io::Result<Dispatcher> {
277 Dispatcher::new_impl(self)
278 }
279}
280
281impl Default for DispatcherBuilder {
282 fn default() -> Self {
283 Self::new()
284 }
285}