tokio_rayon/async_thread_pool.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
use crate::AsyncRayonHandle;
use rayon::ThreadPool;
use std::panic::{catch_unwind, AssertUnwindSafe};
use tokio::sync::oneshot;
/// Extension trait that integrates Rayon's [`ThreadPool`](rayon::ThreadPool)
/// with Tokio.
///
/// This trait is sealed and cannot be implemented by external crates.
pub trait AsyncThreadPool: private::Sealed {
/// Asynchronous wrapper around Rayon's
/// [`ThreadPool::spawn`](rayon::ThreadPool::spawn).
///
/// Runs a function on the global Rayon thread pool with LIFO priority,
/// produciing a future that resolves with the function's return value.
///
/// # Panics
/// If the task function panics, the panic will be propagated through the
/// returned future. This will NOT trigger the Rayon thread pool's panic
/// handler.
///
/// If the returned handle is dropped, and the return value of `func` panics
/// when dropped, that panic WILL trigger the thread pool's panic
/// handler.
fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static;
/// Asynchronous wrapper around Rayon's
/// [`ThreadPool::spawn_fifo`](rayon::ThreadPool::spawn_fifo).
///
/// Runs a function on the global Rayon thread pool with FIFO priority,
/// produciing a future that resolves with the function's return value.
///
/// # Panics
/// If the task function panics, the panic will be propagated through the
/// returned future. This will NOT trigger the Rayon thread pool's panic
/// handler.
///
/// If the returned handle is dropped, and the return value of `func` panics
/// when dropped, that panic WILL trigger the thread pool's panic
/// handler.
fn spawn_fifo_async<F, R>(&self, f: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static;
}
impl AsyncThreadPool for ThreadPool {
fn spawn_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.spawn(move || {
let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
});
AsyncRayonHandle { rx }
}
fn spawn_fifo_async<F, R>(&self, func: F) -> AsyncRayonHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.spawn_fifo(move || {
let _result = tx.send(catch_unwind(AssertUnwindSafe(func)));
});
AsyncRayonHandle { rx }
}
}
mod private {
use rayon::ThreadPool;
pub trait Sealed {}
impl Sealed for ThreadPool {}
}
#[cfg(test)]
mod tests {
use super::*;
use rayon::{ThreadPool, ThreadPoolBuilder};
fn build_thread_pool() -> ThreadPool {
ThreadPoolBuilder::new().num_threads(1).build().unwrap()
}
#[tokio::test]
async fn test_spawn_async_works() {
let pool = build_thread_pool();
let result = pool
.spawn_async(|| {
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, Some(0));
1337_usize
})
.await;
assert_eq!(result, 1337);
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, None);
}
#[tokio::test]
async fn test_spawn_fifo_async_works() {
let pool = build_thread_pool();
let result = pool
.spawn_fifo_async(|| {
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, Some(0));
1337_usize
})
.await;
assert_eq!(result, 1337);
let thread_index = rayon::current_thread_index();
assert_eq!(thread_index, None);
}
#[tokio::test]
#[should_panic(expected = "Task failed successfully")]
async fn test_spawn_async_propagates_panic() {
let pool = build_thread_pool();
let handle = pool.spawn_async(|| {
panic!("Task failed successfully");
});
handle.await;
}
#[tokio::test]
#[should_panic(expected = "Task failed successfully")]
async fn test_spawn_fifo_async_propagates_panic() {
let pool = build_thread_pool();
let handle = pool.spawn_fifo_async(|| {
panic!("Task failed successfully");
});
handle.await;
}
}