await_tree/
registry.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::fmt::{Debug, Display};
17use std::hash::Hash;
18use std::sync::{Arc, Weak};
19
20use derive_builder::Builder;
21use parking_lot::RwLock;
22use weak_table::WeakValueHashMap;
23
24use crate::context::{ContextId, Tree, TreeContext};
25use crate::obj_utils::{DynEq, DynHash};
26use crate::{Span, TreeRoot};
27
28/// Configuration for an await-tree registry, which affects the behavior of all await-trees in the
29/// registry.
30#[derive(Debug, Clone, Builder)]
31#[builder(default)]
32pub struct Config {
33    /// Whether to include the **verbose** span in the await-tree.
34    verbose: bool,
35}
36
37#[allow(clippy::derivable_impls)]
38impl Default for Config {
39    fn default() -> Self {
40        Self { verbose: false }
41    }
42}
43
44/// A key that can be used to identify a task and its await-tree in the [`Registry`].
45///
46/// All thread-safe types that can be used as a key of a hash map are automatically implemented with
47/// this trait.
48pub trait Key: Hash + Eq + Debug + Send + Sync + 'static {}
49impl<T> Key for T where T: Hash + Eq + Debug + Send + Sync + 'static {}
50
51/// The object-safe version of [`Key`], automatically implemented.
52trait ObjKey: DynHash + DynEq + Debug + Send + Sync + 'static {}
53impl<T> ObjKey for T where T: DynHash + DynEq + Debug + Send + Sync + 'static {}
54
55/// Type-erased key for the [`Registry`].
56#[derive(Clone)]
57pub struct AnyKey(Arc<dyn ObjKey>);
58
59impl PartialEq for AnyKey {
60    fn eq(&self, other: &Self) -> bool {
61        self.0.dyn_eq(other.0.as_dyn_eq())
62    }
63}
64
65impl Eq for AnyKey {}
66
67impl Hash for AnyKey {
68    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
69        self.0.dyn_hash(state);
70    }
71}
72
73impl Debug for AnyKey {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        self.0.fmt(f)
76    }
77}
78
79impl Display for AnyKey {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        // TODO: for all `impl Display`?
82        if let Some(s) = self.as_any().downcast_ref::<String>() {
83            write!(f, "{}", s)
84        } else if let Some(s) = self.as_any().downcast_ref::<&str>() {
85            write!(f, "{}", s)
86        } else {
87            write!(f, "{:?}", self)
88        }
89    }
90}
91
92impl AnyKey {
93    fn new(key: impl ObjKey) -> Self {
94        Self(Arc::new(key))
95    }
96
97    /// Cast the key to `dyn Any`.
98    pub fn as_any(&self) -> &dyn Any {
99        self.0.as_ref().as_any()
100    }
101
102    /// Returns whether the key is of type `K`.
103    ///
104    /// Equivalent to `self.as_any().is::<K>()`.
105    pub fn is<K: Any>(&self) -> bool {
106        self.as_any().is::<K>()
107    }
108
109    /// Returns whether the key corresponds to an anonymous await-tree.
110    pub fn is_anonymous(&self) -> bool {
111        self.as_any().is::<ContextId>()
112    }
113
114    /// Returns the key as a reference to type `K`, if it is of type `K`.
115    ///
116    /// Equivalent to `self.as_any().downcast_ref::<K>()`.
117    pub fn downcast_ref<K: Any>(&self) -> Option<&K> {
118        self.as_any().downcast_ref()
119    }
120}
121
122type Contexts = RwLock<WeakValueHashMap<AnyKey, Weak<TreeContext>>>;
123
124struct RegistryCore {
125    contexts: Contexts,
126    config: Config,
127}
128
129/// The registry of multiple await-trees.
130///
131/// Can be cheaply cloned to share the same registry.
132pub struct Registry(Arc<RegistryCore>);
133
134impl Debug for Registry {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("Registry")
137            .field("config", self.config())
138            .finish_non_exhaustive()
139    }
140}
141
142impl Clone for Registry {
143    fn clone(&self) -> Self {
144        Self(Arc::clone(&self.0))
145    }
146}
147
148impl Registry {
149    fn contexts(&self) -> &Contexts {
150        &self.0.contexts
151    }
152
153    fn config(&self) -> &Config {
154        &self.0.config
155    }
156}
157
158impl Registry {
159    /// Create a new registry with given `config`.
160    pub fn new(config: Config) -> Self {
161        Self(
162            RegistryCore {
163                contexts: Default::default(),
164                config,
165            }
166            .into(),
167        )
168    }
169
170    fn register_inner(&self, key: impl Key, context: Arc<TreeContext>) -> TreeRoot {
171        self.contexts()
172            .write()
173            .insert(AnyKey::new(key), Arc::clone(&context));
174
175        TreeRoot {
176            context,
177            registry: WeakRegistry(Arc::downgrade(&self.0)),
178        }
179    }
180
181    /// Register with given key. Returns a [`TreeRoot`] that can be used to instrument a future.
182    ///
183    /// If the key already exists, a new [`TreeRoot`] is returned and the reference to the old
184    /// [`TreeRoot`] is dropped.
185    pub fn register(&self, key: impl Key, root_span: impl Into<Span>) -> TreeRoot {
186        let context = Arc::new(TreeContext::new(root_span.into(), self.config().verbose));
187        self.register_inner(key, context)
188    }
189
190    /// Register an anonymous await-tree without specifying a key. Returns a [`TreeRoot`] that can
191    /// be used to instrument a future.
192    ///
193    /// Anonymous await-trees are not able to be retrieved through the [`Registry::get`] method. Use
194    /// [`Registry::collect_anonymous`] or [`Registry::collect_all`] to collect them.
195    pub fn register_anonymous(&self, root_span: impl Into<Span>) -> TreeRoot {
196        let context = Arc::new(TreeContext::new(root_span.into(), self.config().verbose));
197        self.register_inner(context.id(), context) // use the private id as the key
198    }
199
200    /// Get a clone of the await-tree with given key.
201    ///
202    /// Returns `None` if the key does not exist or the tree root has been dropped.
203    pub fn get(&self, key: impl Key) -> Option<Tree> {
204        self.contexts()
205            .read()
206            .get(&AnyKey::new(key)) // TODO: accept ref can?
207            .map(|v| v.tree().clone())
208    }
209
210    /// Remove all the registered await-trees.
211    pub fn clear(&self) {
212        self.contexts().write().clear();
213    }
214
215    /// Collect the snapshots of all await-trees with the key of type `K`.
216    pub fn collect<K: Key + Clone>(&self) -> Vec<(K, Tree)> {
217        self.contexts()
218            .read()
219            .iter()
220            .filter_map(|(k, v)| {
221                k.0.as_ref()
222                    .as_any()
223                    .downcast_ref::<K>()
224                    .map(|k| (k.clone(), v.tree().clone()))
225            })
226            .collect()
227    }
228
229    /// Collect the snapshots of all await-trees registered with [`Registry::register_anonymous`].
230    pub fn collect_anonymous(&self) -> Vec<Tree> {
231        self.contexts()
232            .read()
233            .iter()
234            .filter_map(|(k, v)| {
235                if k.is_anonymous() {
236                    Some(v.tree().clone())
237                } else {
238                    None
239                }
240            })
241            .collect()
242    }
243
244    /// Collect the snapshots of all await-trees regardless of the key type.
245    pub fn collect_all(&self) -> Vec<(AnyKey, Tree)> {
246        self.contexts()
247            .read()
248            .iter()
249            .map(|(k, v)| (k.clone(), v.tree().clone()))
250            .collect()
251    }
252}
253
254pub(crate) struct WeakRegistry(Weak<RegistryCore>);
255
256impl WeakRegistry {
257    pub fn upgrade(&self) -> Option<Registry> {
258        self.0.upgrade().map(Registry)
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_registry() {
268        let registry = Registry::new(Config::default());
269
270        let _0_i32 = registry.register(0_i32, "0");
271        let _1_i32 = registry.register(1_i32, "1");
272        let _2_i32 = registry.register(2_i32, "2");
273
274        let _0_str = registry.register("0", "0");
275        let _1_str = registry.register("1", "1");
276
277        let _unit = registry.register((), "()");
278        let _unit_replaced = registry.register((), "[]");
279
280        let _anon = registry.register_anonymous("anon");
281        let _anon = registry.register_anonymous("anon");
282
283        let i32s = registry.collect::<i32>();
284        assert_eq!(i32s.len(), 3);
285
286        let strs = registry.collect::<&'static str>();
287        assert_eq!(strs.len(), 2);
288
289        let units = registry.collect::<()>();
290        assert_eq!(units.len(), 1);
291
292        let anons = registry.collect_anonymous();
293        assert_eq!(anons.len(), 2);
294
295        let all = registry.collect_all();
296        assert_eq!(all.len(), 8);
297    }
298}