tokio_rayon/
async_thread_pool.rs

1use crate::AsyncRayonHandle;
2use rayon::ThreadPool;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4use tokio::sync::oneshot;
5
6/// Extension trait that integrates Rayon's [`ThreadPool`](rayon::ThreadPool)
7/// with Tokio.
8///
9/// This trait is sealed and cannot be implemented by external crates.
10pub trait AsyncThreadPool: private::Sealed {
11    /// Asynchronous wrapper around Rayon's
12    /// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
13    ///
14    /// Runs a function on the global Rayon thread pool with LIFO priority,
15    /// produciing a future that resolves with the function's return value.
16    ///
17    /// # Panics
18    /// If the task function panics, the panic will be propagated through the
19    /// returned future. This will NOT trigger the Rayon thread pool's panic
20    /// handler.
21    ///
22    /// If the returned handle is dropped, and the return value of `func` panics
23    /// when dropped, that panic WILL trigger the thread pool's panic
24    /// handler.
25    fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
26    where
27        F: FnOnce() -> R + Send + 'static,
28        R: Send + 'static;
29
30    /// Asynchronous wrapper around Rayon's
31    /// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
32    ///
33    /// Runs a function on the global Rayon thread pool with FIFO priority,
34    /// produciing a future that resolves with the function's return value.
35    ///
36    /// # Panics
37    /// If the task function panics, the panic will be propagated through the
38    /// returned future. This will NOT trigger the Rayon thread pool's panic
39    /// handler.
40    ///
41    /// If the returned handle is dropped, and the return value of `func` panics
42    /// when dropped, that panic WILL trigger the thread pool's panic
43    /// handler.
44    fn spawn_fifo_async<F, R>(&self, f: F) -> AsyncRayonHandle<R>
45    where
46        F: FnOnce() -> R + Send + 'static,
47        R: Send + 'static;
48}
49
50impl AsyncThreadPool for ThreadPool {
51    fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
52    where
53        F: FnOnce() -> R + Send + 'static,
54        R: Send + 'static,
55    {
56        let (tx, rx) = oneshot::channel();
57
58        self.spawn(move || {
59            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
60        });
61
62        AsyncRayonHandle { rx }
63    }
64
65    fn spawn_fifo_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
66    where
67        F: FnOnce() -> R + Send + 'static,
68        R: Send + 'static,
69    {
70        let (tx, rx) = oneshot::channel();
71
72        self.spawn_fifo(move || {
73            let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
74        });
75
76        AsyncRayonHandle { rx }
77    }
78}
79
80mod private {
81    use rayon::ThreadPool;
82
83    pub trait Sealed {}
84
85    impl Sealed for ThreadPool {}
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use rayon::{ThreadPool, ThreadPoolBuilder};
92
93    fn build_thread_pool() -> ThreadPool {
94        ThreadPoolBuilder::new().num_threads(1).build().unwrap()
95    }
96
97    #[tokio::test]
98    async fn test_spawn_async_works() {
99        let pool = build_thread_pool();
100        let result = pool
101            .spawn_async(|| {
102                let thread_index = rayon::current_thread_index();
103                assert_eq!(thread_index, Some(0));
104                1337_usize
105            })
106            .await;
107        assert_eq!(result, 1337);
108        let thread_index = rayon::current_thread_index();
109        assert_eq!(thread_index, None);
110    }
111
112    #[tokio::test]
113    async fn test_spawn_fifo_async_works() {
114        let pool = build_thread_pool();
115        let result = pool
116            .spawn_fifo_async(|| {
117                let thread_index = rayon::current_thread_index();
118                assert_eq!(thread_index, Some(0));
119                1337_usize
120            })
121            .await;
122        assert_eq!(result, 1337);
123        let thread_index = rayon::current_thread_index();
124        assert_eq!(thread_index, None);
125    }
126
127    #[tokio::test]
128    #[should_panic(expected = "Task failed successfully")]
129    async fn test_spawn_async_propagates_panic() {
130        let pool = build_thread_pool();
131        let handle = pool.spawn_async(|| {
132            panic!("Task failed successfully");
133        });
134
135        handle.await;
136    }
137
138    #[tokio::test]
139    #[should_panic(expected = "Task failed successfully")]
140    async fn test_spawn_fifo_async_propagates_panic() {
141        let pool = build_thread_pool();
142        let handle = pool.spawn_fifo_async(|| {
143            panic!("Task failed successfully");
144        });
145
146        handle.await;
147    }
148}