ih_muse_core/
state.rs

1// crates/ih-muse/src/state.rs
2
3use std::cmp;
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicU8, Ordering};
6use std::sync::Arc;
7use std::sync::OnceLock;
8
9use arc_swap::{ArcSwap, ArcSwapOption};
10use imbl::{HashMap, HashSet, OrdMap, Vector};
11use uuid::Uuid;
12
13use ih_muse_proto::*;
14
15pub struct State {
16    nodes: Arc<ArcSwap<HashMap<Uuid, NodeInfo>>>,
17    element_kinds: OnceLock<Arc<HashSet<String>>>,
18    registered_metrics: OnceLock<Arc<HashMap<String, Arc<MetricDefinition>>>>,
19    metric_order: Arc<ArcSwap<Vector<Arc<MetricDefinition>>>>,
20    min_element_id: Arc<ArcSwapOption<ElementId>>,
21    max_element_id: Arc<ArcSwapOption<ElementId>>,
22    range_to_node: Arc<ArcSwap<OrdMap<OrdRangeInc, Uuid>>>,
23    finest_resolution: AtomicU8,
24    element_id_map: Arc<ArcSwap<HashMap<LocalElementId, ElementId>>>,
25}
26
27impl Default for State {
28    fn default() -> Self {
29        Self::new(TimestampResolution::default())
30    }
31}
32
33impl State {
34    pub fn new(default_resolution: TimestampResolution) -> Self {
35        Self {
36            nodes: Arc::new(ArcSwap::from_pointee(HashMap::new())),
37            element_kinds: OnceLock::new(),
38            registered_metrics: OnceLock::new(),
39            metric_order: Arc::new(ArcSwap::from_pointee(Vector::new())),
40            min_element_id: Arc::new(ArcSwapOption::empty()),
41            max_element_id: Arc::new(ArcSwapOption::empty()),
42            range_to_node: Arc::new(ArcSwap::from_pointee(OrdMap::new())),
43            finest_resolution: default_resolution.as_u8().into(),
44            element_id_map: Arc::new(ArcSwap::from_pointee(HashMap::new())),
45        }
46    }
47
48    /// Update the nodes
49    pub async fn update_nodes(&self, new_nodes: HashMap<Uuid, NodeInfo>) {
50        self.nodes.store(Arc::new(new_nodes));
51    }
52
53    pub async fn get_nodes(&self) -> HashMap<Uuid, NodeInfo> {
54        let nodes = self.nodes.load();
55        (**nodes).clone()
56    }
57
58    /// Inits `element_kinds` only once. Subsequent calls will return an error.
59    pub async fn init_element_kinds(&self, element_kinds: &[ElementKindRegistration]) {
60        let codes_set = element_kinds
61            .iter()
62            .map(|kind| kind.code.clone())
63            .collect::<HashSet<String>>();
64        let _ = self.element_kinds.set(Arc::new(codes_set));
65    }
66
67    /// Check if an element kind code is valid.
68    pub fn is_valid_element_kind_code(&self, element_kind_code: &str) -> bool {
69        if let Some(kinds) = self.element_kinds.get() {
70            kinds.contains(element_kind_code)
71        } else {
72            false
73        }
74    }
75
76    /// Inits `registered_metrics` only once. Subsequent calls will return an error.
77    pub async fn init_metrics(&self, metric_definitions: &[MetricDefinition]) {
78        let metrics = metric_definitions
79            .iter()
80            .map(|m| (m.code.clone(), Arc::new(m.clone()))) // Map code to an Arc<MetricDefinition>
81            .collect::<HashMap<String, Arc<MetricDefinition>>>();
82
83        // Attempt to set `registered_metrics` only once
84        let _ = self.registered_metrics.set(Arc::new(metrics));
85    }
86
87    /// Check if a metric code is valid.
88    pub fn is_valid_metric_code(&self, metric_code: &str) -> bool {
89        // Load the Arc<HashMap> from `registered_metrics` and check for the presence of the code
90        self.registered_metrics
91            .get()
92            .map_or(false, |metrics| metrics.contains_key(metric_code))
93    }
94
95    /// Update `metric_order` atomically with a new order.
96    /// ! This is only updated in one tasks, no concurrency issues
97    /// * It can safely being read from multiple threads
98    pub async fn update_metric_order(&self, metric_order: Vec<MetricDefinition>) {
99        let ordered_metrics = metric_order
100            .into_iter()
101            .map(Arc::new) // Wrap each MetricDefinition in an Arc
102            .collect::<Vector<_>>(); // Collect into an imbl Vector
103        self.metric_order.store(Arc::new(ordered_metrics));
104    }
105
106    // Retrieve a Vec<&MetricDefinition> slice for cases where that's required.
107    /// Returns a reference to the ordered metric definitions as a slice.
108    pub fn get_metric_order(&self) -> Arc<Vector<Arc<MetricDefinition>>> {
109        self.metric_order.load_full()
110    }
111
112    /// Updates `min_element_id` and `max_element_id` based on the provided `element_id`.
113    /// If either is `None`, it will set both to the `element_id`.
114    /// Otherwise, it updates `min_element_id` if `element_id` is smaller,
115    /// and `max_element_id` if `element_id` is larger.
116    pub async fn update_min_max_element_id(&self, element_id: ElementId) {
117        self.min_element_id.rcu(|current_min| {
118            Some(Arc::new(match current_min.as_deref() {
119                Some(&min_id) => cmp::min(min_id, element_id),
120                None => element_id,
121            }))
122        });
123
124        self.max_element_id.rcu(|current_max| {
125            Some(Arc::new(match current_max.as_deref() {
126                Some(&max_id) => cmp::max(max_id, element_id),
127                None => element_id,
128            }))
129        });
130    }
131
132    pub async fn get_element_id_range(&self) -> (Option<ElementId>, Option<ElementId>) {
133        let min_id = self.min_element_id.load_full().as_deref().cloned();
134        let max_id = self.max_element_id.load_full().as_deref().cloned();
135        (min_id, max_id)
136    }
137
138    pub async fn update_node_elem_ranges(&self, ranges: &[NodeElementRange]) {
139        self.range_to_node.rcu(|current| {
140            let mut new_map = (**current).clone();
141            for node_range in ranges {
142                new_map.insert(node_range.range.clone(), node_range.node_id);
143            }
144            Arc::new(new_map)
145        });
146    }
147
148    pub async fn get_node_elem_ranges(&self) -> OrdMap<OrdRangeInc, Uuid> {
149        let ranges = self.range_to_node.load();
150        (**ranges).clone()
151    }
152
153    /// Find the node ID corresponding to a given element ID
154    pub fn find_node(&self, element_id: u64) -> Option<Uuid> {
155        let map = self.range_to_node.load();
156        let bound = OrdRangeInc::new_bound(element_id);
157
158        // Find the last range that starts before or at `element_id`
159        map.range(..=bound).last().and_then(|(range, node_id)| {
160            if range.contains(&element_id) {
161                Some(*node_id)
162            } else {
163                None
164            }
165        })
166    }
167
168    /// Find the node address corresponding to a given element ID
169    pub fn find_element_node_addr(&self, element_id: u64) -> Option<SocketAddr> {
170        let node_id = self.find_node(element_id)?;
171        let nodes = self.nodes.load();
172        nodes.get(&node_id).map(|node_info| node_info.node_addr)
173    }
174
175    /// Update `finest_resolution` atomically.
176    pub async fn update_finest_resolution(&self, finest_resolution: TimestampResolution) {
177        self.finest_resolution
178            .store(finest_resolution.as_u8(), Ordering::SeqCst);
179    }
180
181    /// Retrieve the current `finest_resolution` as `TimestampResolution`.
182    pub fn get_finest_resolution(&self) -> TimestampResolution {
183        TimestampResolution::from_u8(self.finest_resolution.load(Ordering::SeqCst))
184    }
185
186    pub async fn update_element_id(&self, local_id: LocalElementId, element_id: ElementId) {
187        self.element_id_map.rcu(|current| {
188            let mut new_map = (**current).clone();
189            new_map.insert(local_id, element_id);
190            Arc::new(new_map)
191        });
192    }
193
194    // Retrieve an ElementId from a LocalElementId
195    pub fn get_element_id(&self, local_id: &LocalElementId) -> Option<ElementId> {
196        let element_map = self.element_id_map.load();
197        element_map.get(local_id).cloned()
198    }
199}