await_tree/
registry.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Display};
17use std::hash::Hash;
18use std::sync::{Arc, Weak};
19
20use derive_builder::Builder;
21use parking_lot::RwLock;
22use weak_table::WeakValueHashMap;
23
24use crate::context::{ContextId, Tree, TreeContext};
25use crate::obj_utils::{DynEq, DynHash};
26use crate::{Span, TreeRoot};
27
28/// Configuration for an await-tree registry, which affects the behavior of all await-trees in the
29/// registry.
30#[derive(Debug, Clone, Builder)]
31#[builder(default)]
32pub struct Config {
33    /// Whether to include the **verbose** span in the await-tree.
34    verbose: bool,
35}
36
37#[allow(clippy::derivable_impls)]
38impl Default for Config {
39    fn default() -> Self {
40        Self { verbose: false }
41    }
42}
43
44/// A key that can be used to identify a task and its await-tree in the [`Registry`].
45///
46/// All thread-safe types that can be used as a key of a hash map are automatically implemented with
47/// this trait.
48pub trait Key: Hash + Eq + Debug + Send + Sync + 'static {}
49impl<T> Key for T where T: Hash + Eq + Debug + Send + Sync + 'static {}
50
51/// The object-safe version of [`Key`], automatically implemented.
52trait ObjKey: DynHash + DynEq + Debug + Send + Sync + 'static {}
53impl<T> ObjKey for T where T: DynHash + DynEq + Debug + Send + Sync + 'static {}
54
55/// Key type for anonymous await-trees.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57struct AnonymousKey(ContextId);
58
59impl Display for AnonymousKey {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "Anonymous #{}", self.0 .0)
62    }
63}
64
65/// Type-erased key for the [`Registry`].
66#[derive(Clone)]
67pub struct AnyKey(Arc<dyn ObjKey>);
68
69impl PartialEq for AnyKey {
70    fn eq(&self, other: &Self) -> bool {
71        self.0.dyn_eq(other.0.as_dyn_eq())
72    }
73}
74
75impl Eq for AnyKey {}
76
77impl Hash for AnyKey {
78    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
79        self.0.dyn_hash(state);
80    }
81}
82
83impl Debug for AnyKey {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        self.0.fmt(f)
86    }
87}
88
89impl Display for AnyKey {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        // TODO: for all `impl Display`?
92        macro_rules! delegate_to_display {
93            ($($t:ty),* $(,)?) => {
94                $(
95                    if let Some(k) = self.as_any().downcast_ref::<$t>() {
96                        return write!(f, "{}", k);
97                    }
98                )*
99            };
100        }
101        delegate_to_display!(String, &str, AnonymousKey);
102
103        write!(f, "{:?}", self)
104    }
105}
106
107impl AnyKey {
108    fn new(key: impl ObjKey) -> Self {
109        Self(Arc::new(key))
110    }
111
112    /// Cast the key to `dyn Any`.
113    pub fn as_any(&self) -> &dyn Any {
114        self.0.as_ref().as_any()
115    }
116
117    /// Returns whether the key is of type `K`.
118    ///
119    /// Equivalent to `self.as_any().is::<K>()`.
120    pub fn is<K: Any>(&self) -> bool {
121        self.as_any().is::<K>()
122    }
123
124    /// Returns whether the key corresponds to an anonymous await-tree.
125    pub fn is_anonymous(&self) -> bool {
126        self.as_any().is::<AnonymousKey>()
127    }
128
129    /// Returns the key as a reference to type `K`, if it is of type `K`.
130    ///
131    /// Equivalent to `self.as_any().downcast_ref::<K>()`.
132    pub fn downcast_ref<K: Any>(&self) -> Option<&K> {
133        self.as_any().downcast_ref()
134    }
135}
136
137type Contexts = RwLock<WeakValueHashMap<AnyKey, Weak<TreeContext>>>;
138
139struct RegistryCore {
140    contexts: Contexts,
141    config: Config,
142}
143
144/// The registry of multiple await-trees.
145///
146/// Can be cheaply cloned to share the same registry.
147pub struct Registry(Arc<RegistryCore>);
148
149impl Debug for Registry {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("Registry")
152            .field("config", self.config())
153            .finish_non_exhaustive()
154    }
155}
156
157impl Clone for Registry {
158    fn clone(&self) -> Self {
159        Self(Arc::clone(&self.0))
160    }
161}
162
163impl Registry {
164    fn contexts(&self) -> &Contexts {
165        &self.0.contexts
166    }
167
168    fn config(&self) -> &Config {
169        &self.0.config
170    }
171}
172
173impl Registry {
174    /// Create a new registry with given `config`.
175    pub fn new(config: Config) -> Self {
176        Self(
177            RegistryCore {
178                contexts: Default::default(),
179                config,
180            }
181            .into(),
182        )
183    }
184
185    /// Returns the current registry, if exists.
186    ///
187    /// 1. If the current task is registered with a registry, returns the registry.
188    /// 2. If the global registry is initialized with
189    ///    [`init_global_registry`](crate::global::init_global_registry), returns the global
190    ///    registry.
191    /// 3. Otherwise, returns `None`.
192    pub fn try_current() -> Option<Self> {
193        crate::root::current_registry()
194    }
195
196    /// Returns the current registry, panics if not exists.
197    ///
198    /// See [`Registry::try_current`] for more information.
199    pub fn current() -> Self {
200        Self::try_current().expect("no current registry")
201    }
202
203    fn register_inner(&self, key: impl Key, context: Arc<TreeContext>) -> TreeRoot {
204        self.contexts()
205            .write()
206            .insert(AnyKey::new(key), Arc::clone(&context));
207
208        TreeRoot {
209            context,
210            registry: WeakRegistry(Arc::downgrade(&self.0)),
211        }
212    }
213
214    /// Register with given key. Returns a [`TreeRoot`] that can be used to instrument a future.
215    ///
216    /// If the key already exists, a new [`TreeRoot`] is returned and the reference to the old
217    /// [`TreeRoot`] is dropped.
218    pub fn register(&self, key: impl Key, root_span: impl Into<Span>) -> TreeRoot {
219        let context = Arc::new(TreeContext::new(root_span.into(), self.config().verbose));
220        self.register_inner(key, context)
221    }
222
223    /// Register an anonymous await-tree without specifying a key. Returns a [`TreeRoot`] that can
224    /// be used to instrument a future.
225    ///
226    /// Anonymous await-trees are not able to be retrieved through the [`Registry::get`] method. Use
227    /// [`Registry::collect_anonymous`] or [`Registry::collect_all`] to collect them.
228    // TODO: we have keyed and anonymous, should we also have a typed-anonymous (for classification
229    // only)?
230    pub fn register_anonymous(&self, root_span: impl Into<Span>) -> TreeRoot {
231        let context = Arc::new(TreeContext::new(root_span.into(), self.config().verbose));
232        self.register_inner(AnonymousKey(context.id()), context) // use the private id as the key
233    }
234
235    /// Get a clone of the await-tree with given key.
236    ///
237    /// Returns `None` if the key does not exist or the tree root has been dropped.
238    pub fn get(&self, key: impl Key) -> Option<Tree> {
239        self.contexts()
240            .read()
241            .get(&AnyKey::new(key)) // TODO: accept ref can?
242            .map(|v| v.tree().clone())
243    }
244
245    /// Remove all the registered await-trees.
246    pub fn clear(&self) {
247        self.contexts().write().clear();
248    }
249
250    /// Collect the snapshots of all await-trees with the key of type `K`.
251    pub fn collect<K: Key + Clone>(&self) -> Vec<(K, Tree)> {
252        self.contexts()
253            .read()
254            .iter()
255            .filter_map(|(k, v)| {
256                k.0.as_ref()
257                    .as_any()
258                    .downcast_ref::<K>()
259                    .map(|k| (k.clone(), v.tree().clone()))
260            })
261            .collect()
262    }
263
264    /// Collect the snapshots of all await-trees registered with [`Registry::register_anonymous`].
265    pub fn collect_anonymous(&self) -> Vec<Tree> {
266        self.contexts()
267            .read()
268            .iter()
269            .filter_map(|(k, v)| {
270                if k.is_anonymous() {
271                    Some(v.tree().clone())
272                } else {
273                    None
274                }
275            })
276            .collect()
277    }
278
279    /// Collect the snapshots of all await-trees regardless of the key type.
280    pub fn collect_all(&self) -> Vec<(AnyKey, Tree)> {
281        self.contexts()
282            .read()
283            .iter()
284            .map(|(k, v)| (k.clone(), v.tree().clone()))
285            .collect()
286    }
287}
288
289pub(crate) struct WeakRegistry(Weak<RegistryCore>);
290
291impl WeakRegistry {
292    pub fn upgrade(&self) -> Option<Registry> {
293        self.0.upgrade().map(Registry)
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_registry() {
303        let registry = Registry::new(Config::default());
304
305        let _0_i32 = registry.register(0_i32, "0");
306        let _1_i32 = registry.register(1_i32, "1");
307        let _2_i32 = registry.register(2_i32, "2");
308
309        let _0_str = registry.register("0", "0");
310        let _1_str = registry.register("1", "1");
311
312        let _unit = registry.register((), "()");
313        let _unit_replaced = registry.register((), "[]");
314
315        let _anon = registry.register_anonymous("anon");
316        let _anon = registry.register_anonymous("anon");
317
318        let i32s = registry.collect::<i32>();
319        assert_eq!(i32s.len(), 3);
320
321        let strs = registry.collect::<&'static str>();
322        assert_eq!(strs.len(), 2);
323
324        let units = registry.collect::<()>();
325        assert_eq!(units.len(), 1);
326
327        let anons = registry.collect_anonymous();
328        assert_eq!(anons.len(), 2);
329
330        let all = registry.collect_all();
331        assert_eq!(all.len(), 8);
332    }
333}