tokio_rayon/
async_thread_pool.rs1use crate::AsyncRayonHandle;
2use rayon::ThreadPool;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4use tokio::sync::oneshot;
5
6pub trait AsyncThreadPool: private::Sealed {
11 fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
26 where
27 F: FnOnce() -> R + Send + 'static,
28 R: Send + 'static;
29
30 fn spawn_fifo_async<F, R>(&self, f: F) -> AsyncRayonHandle<R>
45 where
46 F: FnOnce() -> R + Send + 'static,
47 R: Send + 'static;
48}
49
50impl AsyncThreadPool for ThreadPool {
51 fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
52 where
53 F: FnOnce() -> R + Send + 'static,
54 R: Send + 'static,
55 {
56 let (tx, rx) = oneshot::channel();
57
58 self.spawn(move || {
59 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
60 });
61
62 AsyncRayonHandle { rx }
63 }
64
65 fn spawn_fifo_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
66 where
67 F: FnOnce() -> R + Send + 'static,
68 R: Send + 'static,
69 {
70 let (tx, rx) = oneshot::channel();
71
72 self.spawn_fifo(move || {
73 let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
74 });
75
76 AsyncRayonHandle { rx }
77 }
78}
79
80mod private {
81 use rayon::ThreadPool;
82
83 pub trait Sealed {}
84
85 impl Sealed for ThreadPool {}
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use rayon::{ThreadPool, ThreadPoolBuilder};
92
93 fn build_thread_pool() -> ThreadPool {
94 ThreadPoolBuilder::new().num_threads(1).build().unwrap()
95 }
96
97 #[tokio::test]
98 async fn test_spawn_async_works() {
99 let pool = build_thread_pool();
100 let result = pool
101 .spawn_async(|| {
102 let thread_index = rayon::current_thread_index();
103 assert_eq!(thread_index, Some(0));
104 1337_usize
105 })
106 .await;
107 assert_eq!(result, 1337);
108 let thread_index = rayon::current_thread_index();
109 assert_eq!(thread_index, None);
110 }
111
112 #[tokio::test]
113 async fn test_spawn_fifo_async_works() {
114 let pool = build_thread_pool();
115 let result = pool
116 .spawn_fifo_async(|| {
117 let thread_index = rayon::current_thread_index();
118 assert_eq!(thread_index, Some(0));
119 1337_usize
120 })
121 .await;
122 assert_eq!(result, 1337);
123 let thread_index = rayon::current_thread_index();
124 assert_eq!(thread_index, None);
125 }
126
127 #[tokio::test]
128 #[should_panic(expected = "Task failed successfully")]
129 async fn test_spawn_async_propagates_panic() {
130 let pool = build_thread_pool();
131 let handle = pool.spawn_async(|| {
132 panic!("Task failed successfully");
133 });
134
135 handle.await;
136 }
137
138 #[tokio::test]
139 #[should_panic(expected = "Task failed successfully")]
140 async fn test_spawn_fifo_async_propagates_panic() {
141 let pool = build_thread_pool();
142 let handle = pool.spawn_fifo_async(|| {
143 panic!("Task failed successfully");
144 });
145
146 handle.await;
147 }
148}