pingora_runtime/
lib.rs

1// Copyright 2024 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Pingora tokio runtime.
16//!
17//! Tokio runtime comes in two flavors: a single-threaded runtime
18//! and a multi-threaded one which provides work stealing.
19//! Benchmark shows that, compared to the single-threaded runtime, the multi-threaded one
20//! has some overhead due to its more sophisticated work steal scheduling.
21//!
22//! This crate provides a third flavor: a multi-threaded runtime without work stealing.
23//! This flavor is as efficient as the single-threaded runtime while allows the async
24//! program to use multiple cores.
25
26use once_cell::sync::{Lazy, OnceCell};
27use rand::Rng;
28use std::sync::Arc;
29use std::thread::JoinHandle;
30use std::time::Duration;
31use thread_local::ThreadLocal;
32use tokio::runtime::{Builder, Handle};
33use tokio::sync::oneshot::{channel, Sender};
34
35/// Pingora async multi-threaded runtime
36///
37/// The `Steal` flavor is effectively tokio multi-threaded runtime.
38///
39/// The `NoSteal` flavor is backed by multiple tokio single-threaded runtime.
40pub enum Runtime {
41    Steal(tokio::runtime::Runtime),
42    NoSteal(NoStealRuntime),
43}
44
45impl Runtime {
46    /// Create a `Steal` flavor runtime. This just a regular tokio runtime
47    pub fn new_steal(threads: usize, name: &str) -> Self {
48        Self::Steal(
49            Builder::new_multi_thread()
50                .enable_all()
51                .worker_threads(threads)
52                .thread_name(name)
53                .build()
54                .unwrap(),
55        )
56    }
57
58    /// Create a `NoSteal` flavor runtime. This is backed by multiple tokio current-thread runtime
59    pub fn new_no_steal(threads: usize, name: &str) -> Self {
60        Self::NoSteal(NoStealRuntime::new(threads, name))
61    }
62
63    /// Return the &[Handle] of the [Runtime].
64    /// For `Steal` flavor, it will just return the &[Handle].
65    /// For `NoSteal` flavor, it will return the &[Handle] of a random thread in its pool.
66    /// So if we want tasks to spawn on all the threads, call this function to get a fresh [Handle]
67    /// for each async task.
68    pub fn get_handle(&self) -> &Handle {
69        match self {
70            Self::Steal(r) => r.handle(),
71            Self::NoSteal(r) => r.get_runtime(),
72        }
73    }
74
75    /// Call tokio's `shutdown_timeout` of all the runtimes. This function is blocking until
76    /// all runtimes exit.
77    pub fn shutdown_timeout(self, timeout: Duration) {
78        match self {
79            Self::Steal(r) => r.shutdown_timeout(timeout),
80            Self::NoSteal(r) => r.shutdown_timeout(timeout),
81        }
82    }
83}
84
85// only NoStealRuntime set the pools in thread threads
86static CURRENT_HANDLE: Lazy<ThreadLocal<Pools>> = Lazy::new(ThreadLocal::new);
87
88/// Return the [Handle] of current runtime.
89/// If the current thread is under a `Steal` runtime, the current [Handle] is returned.
90/// If the current thread is under a `NoSteal` runtime, the [Handle] of a random thread
91/// under this runtime is returned. This function will panic if called outside any runtime.
92pub fn current_handle() -> Handle {
93    if let Some(pools) = CURRENT_HANDLE.get() {
94        // safety: the CURRENT_HANDLE is set when the pool is being initialized in init_pools()
95        let pools = pools.get().unwrap();
96        let mut rng = rand::thread_rng();
97        let index = rng.gen_range(0..pools.len());
98        pools[index].clone()
99    } else {
100        // not NoStealRuntime, just check the current tokio runtime
101        Handle::current()
102    }
103}
104
105type Control = (Sender<Duration>, JoinHandle<()>);
106type Pools = Arc<OnceCell<Box<[Handle]>>>;
107
108/// Multi-threaded runtime backed by a pool of single threaded tokio runtime
109pub struct NoStealRuntime {
110    threads: usize,
111    name: String,
112    // Lazily init the runtimes so that they are created after pingora
113    // daemonize itself. Otherwise the runtime threads are lost.
114    pools: Arc<OnceCell<Box<[Handle]>>>,
115    controls: OnceCell<Vec<Control>>,
116}
117
118impl NoStealRuntime {
119    /// Create a new [NoStealRuntime]. Panic if `threads` is 0
120    pub fn new(threads: usize, name: &str) -> Self {
121        assert!(threads != 0);
122        NoStealRuntime {
123            threads,
124            name: name.to_string(),
125            pools: Arc::new(OnceCell::new()),
126            controls: OnceCell::new(),
127        }
128    }
129
130    fn init_pools(&self) -> (Box<[Handle]>, Vec<Control>) {
131        let mut pools = Vec::with_capacity(self.threads);
132        let mut controls = Vec::with_capacity(self.threads);
133        for _ in 0..self.threads {
134            let rt = Builder::new_current_thread().enable_all().build().unwrap();
135            let handler = rt.handle().clone();
136            let (tx, rx) = channel::<Duration>();
137            let pools_ref = self.pools.clone();
138            let join = std::thread::Builder::new()
139                .name(self.name.clone())
140                .spawn(move || {
141                    CURRENT_HANDLE.get_or(|| pools_ref);
142                    if let Ok(timeout) = rt.block_on(rx) {
143                        rt.shutdown_timeout(timeout);
144                    } // else Err(_): tx is dropped, just exit
145                })
146                .unwrap();
147            pools.push(handler);
148            controls.push((tx, join));
149        }
150
151        (pools.into_boxed_slice(), controls)
152    }
153
154    /// Return the &[Handle] of a random thread of this runtime
155    pub fn get_runtime(&self) -> &Handle {
156        let mut rng = rand::thread_rng();
157
158        let index = rng.gen_range(0..self.threads);
159        self.get_runtime_at(index)
160    }
161
162    /// Return the number of threads of this runtime
163    pub fn threads(&self) -> usize {
164        self.threads
165    }
166
167    fn get_pools(&self) -> &[Handle] {
168        if let Some(p) = self.pools.get() {
169            p
170        } else {
171            // TODO: use a mutex to avoid creating a lot threads only to drop them
172            let (pools, controls) = self.init_pools();
173            // there could be another thread racing with this one to init the pools
174            match self.pools.try_insert(pools) {
175                Ok(p) => {
176                    // unwrap to make sure that this is the one that init both pools and controls
177                    self.controls.set(controls).unwrap();
178                    p
179                }
180                // another thread already set it, just return it
181                Err((p, _my_pools)) => p,
182            }
183        }
184    }
185
186    /// Return the &[Handle] of a given thread of this runtime
187    pub fn get_runtime_at(&self, index: usize) -> &Handle {
188        let pools = self.get_pools();
189        &pools[index]
190    }
191
192    /// Call tokio's `shutdown_timeout` of all the runtimes. This function is blocking until
193    /// all runtimes exit.
194    pub fn shutdown_timeout(mut self, timeout: Duration) {
195        if let Some(controls) = self.controls.take() {
196            let (txs, joins): (Vec<Sender<_>>, Vec<JoinHandle<()>>) = controls.into_iter().unzip();
197            for tx in txs {
198                let _ = tx.send(timeout); // Err() when rx is dropped
199            }
200            for join in joins {
201                let _ = join.join(); // ignore thread error
202            }
203        } // else, the controls and the runtimes are not even init yet, just return;
204    }
205
206    // TODO: runtime metrics
207}
208
209#[test]
210fn test_steal_runtime() {
211    use tokio::time::{sleep, Duration};
212
213    let rt = Runtime::new_steal(2, "test");
214    let handle = rt.get_handle();
215    let ret = handle.block_on(async {
216        sleep(Duration::from_secs(1)).await;
217        let handle = current_handle();
218        let join = handle.spawn(async {
219            sleep(Duration::from_secs(1)).await;
220        });
221        join.await.unwrap();
222        1
223    });
224
225    assert_eq!(ret, 1);
226}
227
228#[test]
229fn test_no_steal_runtime() {
230    use tokio::time::{sleep, Duration};
231
232    let rt = Runtime::new_no_steal(2, "test");
233    let handle = rt.get_handle();
234    let ret = handle.block_on(async {
235        sleep(Duration::from_secs(1)).await;
236        let handle = current_handle();
237        let join = handle.spawn(async {
238            sleep(Duration::from_secs(1)).await;
239        });
240        join.await.unwrap();
241        1
242    });
243
244    assert_eq!(ret, 1);
245}
246
247#[test]
248fn test_no_steal_shutdown() {
249    use tokio::time::{sleep, Duration};
250
251    let rt = Runtime::new_no_steal(2, "test");
252    let handle = rt.get_handle();
253    let ret = handle.block_on(async {
254        sleep(Duration::from_secs(1)).await;
255        let handle = current_handle();
256        let join = handle.spawn(async {
257            sleep(Duration::from_secs(1)).await;
258        });
259        join.await.unwrap();
260        1
261    });
262    assert_eq!(ret, 1);
263
264    rt.shutdown_timeout(Duration::from_secs(1));
265}