tokio_rayon/
async_handle.rs

1use std::future::Future;
2use std::panic::resume_unwind;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::thread;
6use tokio::sync::oneshot::Receiver;
7
8/// Async handle for a blocking task running in a Rayon thread pool.
9///
10/// If the spawned task panics, `poll()` will propagate the panic.
11#[must_use]
12#[derive(Debug)]
13pub struct AsyncRayonHandle<T> {
14    pub(crate) rx: Receiver<thread::Result<T>>,
15}
16
17impl<T> Future for AsyncRayonHandle<T> {
18    type Output = T;
19
20    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
21        let rx = Pin::new(&mut self.rx);
22        rx.poll(cx).map(|result| {
23            result
24                .expect("Unreachable error: Tokio channel closed")
25                .unwrap_or_else(|err| resume_unwind(err))
26        })
27    }
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use crate::test::init;
34    use std::panic::catch_unwind;
35    use std::thread;
36    use tokio::sync::oneshot::channel;
37
38    #[tokio::test]
39    #[should_panic(expected = "Task failed successfully")]
40    async fn test_poll_propagates_panic() {
41        init();
42        let panic_err = catch_unwind(|| {
43            panic!("Task failed successfully");
44        })
45        .unwrap_err();
46
47        let (tx, rx) = channel::<thread::Result<()>>();
48        let handle = AsyncRayonHandle { rx };
49        tx.send(Err(panic_err)).unwrap();
50        handle.await;
51    }
52
53    #[tokio::test]
54    #[should_panic(expected = "Unreachable error: Tokio channel closed")]
55    async fn test_unreachable_channel_closed() {
56        init();
57        let (_, rx) = channel::<thread::Result<()>>();
58        let handle = AsyncRayonHandle { rx };
59        handle.await;
60    }
61}