tokio_util/task/
spawn_pinned.rs

1use futures_util::future::{AbortHandle, Abortable};
2use std::fmt;
3use std::fmt::{Debug, Formatter};
4use std::future::Future;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::runtime::Builder;
8use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9use tokio::sync::oneshot;
10use tokio::task::{spawn_local, JoinHandle, LocalSet};
11
12/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
13///
14/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
15/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
16/// execute on the same thread) inside the Future you supply to the various spawn methods
17/// of `LocalPoolHandle`.
18///
19/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
20/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
21///
22/// # Examples
23///
24/// ```
25/// use std::rc::Rc;
26/// use tokio::task;
27/// use tokio_util::task::LocalPoolHandle;
28///
29/// #[tokio::main(flavor = "current_thread")]
30/// async fn main() {
31///     let pool = LocalPoolHandle::new(5);
32///
33///     let output = pool.spawn_pinned(|| {
34///         // `data` is !Send + !Sync
35///         let data = Rc::new("local data");
36///         let data_clone = data.clone();
37///
38///         async move {
39///             task::spawn_local(async move {
40///                 println!("{}", data_clone);
41///             });
42///     
43///             data.to_string()
44///         }   
45///     }).await.unwrap();
46///     println!("output: {}", output);
47/// }
48/// ```
49///
50#[derive(Clone)]
51pub struct LocalPoolHandle {
52    pool: Arc<LocalPool>,
53}
54
55impl LocalPoolHandle {
56    /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
57    /// pool via [`LocalPoolHandle::spawn_pinned`].
58    ///
59    /// # Panics
60    ///
61    /// Panics if the pool size is less than one.
62    #[track_caller]
63    pub fn new(pool_size: usize) -> LocalPoolHandle {
64        assert!(pool_size > 0);
65
66        let workers = (0..pool_size)
67            .map(|_| LocalWorkerHandle::new_worker())
68            .collect();
69
70        let pool = Arc::new(LocalPool { workers });
71
72        LocalPoolHandle { pool }
73    }
74
75    /// Returns the number of threads of the Pool.
76    #[inline]
77    pub fn num_threads(&self) -> usize {
78        self.pool.workers.len()
79    }
80
81    /// Returns the number of tasks scheduled on each worker. The indices of the
82    /// worker threads correspond to the indices of the returned `Vec`.
83    pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
84        self.pool
85            .workers
86            .iter()
87            .map(|worker| worker.task_count.load(Ordering::SeqCst))
88            .collect::<Vec<_>>()
89    }
90
91    /// Spawn a task onto a worker thread and pin it there so it can't be moved
92    /// off of the thread. Note that the future is not [`Send`], but the
93    /// [`FnOnce`] which creates it is.
94    ///
95    /// # Examples
96    /// ```
97    /// use std::rc::Rc;
98    /// use tokio_util::task::LocalPoolHandle;
99    ///
100    /// #[tokio::main]
101    /// async fn main() {
102    ///     // Create the local pool
103    ///     let pool = LocalPoolHandle::new(1);
104    ///
105    ///     // Spawn a !Send future onto the pool and await it
106    ///     let output = pool
107    ///         .spawn_pinned(|| {
108    ///             // Rc is !Send + !Sync
109    ///             let local_data = Rc::new("test");
110    ///
111    ///             // This future holds an Rc, so it is !Send
112    ///             async move { local_data.to_string() }
113    ///         })
114    ///         .await
115    ///         .unwrap();
116    ///
117    ///     assert_eq!(output, "test");
118    /// }
119    /// ```
120    pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
121    where
122        F: FnOnce() -> Fut,
123        F: Send + 'static,
124        Fut: Future + 'static,
125        Fut::Output: Send + 'static,
126    {
127        self.pool
128            .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
129    }
130
131    /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
132    /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
133    /// number of tasks scheduled.
134    ///
135    /// A worker thread is chosen by index. Indices are 0 based and the largest index
136    /// is given by `num_threads() - 1`
137    ///
138    /// # Panics
139    ///
140    /// This method panics if the index is out of bounds.
141    ///
142    /// # Examples
143    ///
144    /// This method can be used to spawn a task on all worker threads of the pool:
145    ///
146    /// ```
147    /// use tokio_util::task::LocalPoolHandle;
148    ///
149    /// #[tokio::main]
150    /// async fn main() {
151    ///     const NUM_WORKERS: usize = 3;
152    ///     let pool = LocalPoolHandle::new(NUM_WORKERS);
153    ///     let handles = (0..pool.num_threads())
154    ///         .map(|worker_idx| {
155    ///             pool.spawn_pinned_by_idx(
156    ///                 || {
157    ///                     async {
158    ///                         "test"
159    ///                     }
160    ///                 },
161    ///                 worker_idx,
162    ///             )
163    ///         })
164    ///         .collect::<Vec<_>>();
165    ///
166    ///     for handle in handles {
167    ///         handle.await.unwrap();
168    ///     }
169    /// }
170    /// ```
171    ///
172    #[track_caller]
173    pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
174    where
175        F: FnOnce() -> Fut,
176        F: Send + 'static,
177        Fut: Future + 'static,
178        Fut::Output: Send + 'static,
179    {
180        self.pool
181            .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
182    }
183}
184
185impl Debug for LocalPoolHandle {
186    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
187        f.write_str("LocalPoolHandle")
188    }
189}
190
191enum WorkerChoice {
192    LeastBurdened,
193    ByIdx(usize),
194}
195
196struct LocalPool {
197    workers: Box<[LocalWorkerHandle]>,
198}
199
200impl LocalPool {
201    /// Spawn a `?Send` future onto a worker
202    #[track_caller]
203    fn spawn_pinned<F, Fut>(
204        &self,
205        create_task: F,
206        worker_choice: WorkerChoice,
207    ) -> JoinHandle<Fut::Output>
208    where
209        F: FnOnce() -> Fut,
210        F: Send + 'static,
211        Fut: Future + 'static,
212        Fut::Output: Send + 'static,
213    {
214        let (sender, receiver) = oneshot::channel();
215        let (worker, job_guard) = match worker_choice {
216            WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
217            WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
218        };
219        let worker_spawner = worker.spawner.clone();
220
221        // Spawn a future onto the worker's runtime so we can immediately return
222        // a join handle.
223        worker.runtime_handle.spawn(async move {
224            // Move the job guard into the task
225            let _job_guard = job_guard;
226
227            // Propagate aborts via Abortable/AbortHandle
228            let (abort_handle, abort_registration) = AbortHandle::new_pair();
229            let _abort_guard = AbortGuard(abort_handle);
230
231            // Inside the future we can't run spawn_local yet because we're not
232            // in the context of a LocalSet. We need to send create_task to the
233            // LocalSet task for spawning.
234            let spawn_task = Box::new(move || {
235                // Once we're in the LocalSet context we can call spawn_local
236                let join_handle =
237                    spawn_local(
238                        async move { Abortable::new(create_task(), abort_registration).await },
239                    );
240
241                // Send the join handle back to the spawner. If sending fails,
242                // we assume the parent task was canceled, so cancel this task
243                // as well.
244                if let Err(join_handle) = sender.send(join_handle) {
245                    join_handle.abort()
246                }
247            });
248
249            // Send the callback to the LocalSet task
250            if let Err(e) = worker_spawner.send(spawn_task) {
251                // Propagate the error as a panic in the join handle.
252                panic!("Failed to send job to worker: {e}");
253            }
254
255            // Wait for the task's join handle
256            let join_handle = match receiver.await {
257                Ok(handle) => handle,
258                Err(e) => {
259                    // We sent the task successfully, but failed to get its
260                    // join handle... We assume something happened to the worker
261                    // and the task was not spawned. Propagate the error as a
262                    // panic in the join handle.
263                    panic!("Worker failed to send join handle: {e}");
264                }
265            };
266
267            // Wait for the task to complete
268            let join_result = join_handle.await;
269
270            match join_result {
271                Ok(Ok(output)) => output,
272                Ok(Err(_)) => {
273                    // Pinned task was aborted. But that only happens if this
274                    // task is aborted. So this is an impossible branch.
275                    unreachable!(
276                        "Reaching this branch means this task was previously \
277                         aborted but it continued running anyways"
278                    )
279                }
280                Err(e) => {
281                    if e.is_panic() {
282                        std::panic::resume_unwind(e.into_panic());
283                    } else if e.is_cancelled() {
284                        // No one else should have the join handle, so this is
285                        // unexpected. Forward this error as a panic in the join
286                        // handle.
287                        panic!("spawn_pinned task was canceled: {e}");
288                    } else {
289                        // Something unknown happened (not a panic or
290                        // cancellation). Forward this error as a panic in the
291                        // join handle.
292                        panic!("spawn_pinned task failed: {e}");
293                    }
294                }
295            }
296        })
297    }
298
299    /// Find the worker with the least number of tasks, increment its task
300    /// count, and return its handle. Make sure to actually spawn a task on
301    /// the worker so the task count is kept consistent with load.
302    ///
303    /// A job count guard is also returned to ensure the task count gets
304    /// decremented when the job is done.
305    fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
306        loop {
307            let (worker, task_count) = self
308                .workers
309                .iter()
310                .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
311                .min_by_key(|&(_, count)| count)
312                .expect("There must be more than one worker");
313
314            // Make sure the task count hasn't changed since when we choose this
315            // worker. Otherwise, restart the search.
316            if worker
317                .task_count
318                .compare_exchange(
319                    task_count,
320                    task_count + 1,
321                    Ordering::SeqCst,
322                    Ordering::Relaxed,
323                )
324                .is_ok()
325            {
326                return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
327            }
328        }
329    }
330
331    #[track_caller]
332    fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
333        let worker = &self.workers[idx];
334        worker.task_count.fetch_add(1, Ordering::SeqCst);
335
336        (worker, JobCountGuard(Arc::clone(&worker.task_count)))
337    }
338}
339
340/// Automatically decrements a worker's job count when a job finishes (when
341/// this gets dropped).
342struct JobCountGuard(Arc<AtomicUsize>);
343
344impl Drop for JobCountGuard {
345    fn drop(&mut self) {
346        // Decrement the job count
347        let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
348        debug_assert!(previous_value >= 1);
349    }
350}
351
352/// Calls abort on the handle when dropped.
353struct AbortGuard(AbortHandle);
354
355impl Drop for AbortGuard {
356    fn drop(&mut self) {
357        self.0.abort();
358    }
359}
360
361type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
362
363struct LocalWorkerHandle {
364    runtime_handle: tokio::runtime::Handle,
365    spawner: UnboundedSender<PinnedFutureSpawner>,
366    task_count: Arc<AtomicUsize>,
367}
368
369impl LocalWorkerHandle {
370    /// Create a new worker for executing pinned tasks
371    fn new_worker() -> LocalWorkerHandle {
372        let (sender, receiver) = unbounded_channel();
373        let runtime = Builder::new_current_thread()
374            .enable_all()
375            .build()
376            .expect("Failed to start a pinned worker thread runtime");
377        let runtime_handle = runtime.handle().clone();
378        let task_count = Arc::new(AtomicUsize::new(0));
379        let task_count_clone = Arc::clone(&task_count);
380
381        std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
382
383        LocalWorkerHandle {
384            runtime_handle,
385            spawner: sender,
386            task_count,
387        }
388    }
389
390    fn run(
391        runtime: tokio::runtime::Runtime,
392        mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
393        task_count: Arc<AtomicUsize>,
394    ) {
395        let local_set = LocalSet::new();
396        local_set.block_on(&runtime, async {
397            while let Some(spawn_task) = task_receiver.recv().await {
398                // Calls spawn_local(future)
399                (spawn_task)();
400            }
401        });
402
403        // If there are any tasks on the runtime associated with a LocalSet task
404        // that has already completed, but whose output has not yet been
405        // reported, let that task complete.
406        //
407        // Since the task_count is decremented when the runtime task exits,
408        // reading that counter lets us know if any such tasks completed during
409        // the call to `block_on`.
410        //
411        // Tasks on the LocalSet can't complete during this loop since they're
412        // stored on the LocalSet and we aren't accessing it.
413        let mut previous_task_count = task_count.load(Ordering::SeqCst);
414        loop {
415            // This call will also run tasks spawned on the runtime.
416            runtime.block_on(tokio::task::yield_now());
417            let new_task_count = task_count.load(Ordering::SeqCst);
418            if new_task_count == previous_task_count {
419                break;
420            } else {
421                previous_task_count = new_task_count;
422            }
423        }
424
425        // It's now no longer possible for a task on the runtime to be
426        // associated with a LocalSet task that has completed. Drop both the
427        // LocalSet and runtime to let tasks on the runtime be cancelled if and
428        // only if they are still on the LocalSet.
429        //
430        // Drop the LocalSet task first so that anyone awaiting the runtime
431        // JoinHandle will see the cancelled error after the LocalSet task
432        // destructor has completed.
433        drop(local_set);
434        drop(runtime);
435    }
436}