fedimint_core/module/
registry.rs

1use std::collections::BTreeMap;
2
3use anyhow::anyhow;
4
5pub use crate::core::ModuleInstanceId;
6use crate::core::{Decoder, ModuleKind};
7use crate::server::DynServerModule;
8
9/// Module Registry hold module-specific data `M` by the `ModuleInstanceId`
10#[derive(Debug)]
11pub struct ModuleRegistry<M, State = ()> {
12    inner: BTreeMap<ModuleInstanceId, (ModuleKind, M)>,
13    // It is sometimes useful for registries to have some state to modify
14    // their behavior.
15    state: State,
16}
17
18impl<M, State> Clone for ModuleRegistry<M, State>
19where
20    State: Clone,
21    M: Clone,
22{
23    fn clone(&self) -> Self {
24        Self {
25            inner: self.inner.clone(),
26            state: self.state.clone(),
27        }
28    }
29}
30
31impl<M, State> Default for ModuleRegistry<M, State>
32where
33    State: Default,
34{
35    fn default() -> Self {
36        Self {
37            inner: BTreeMap::new(),
38            state: State::default(),
39        }
40    }
41}
42
43impl<M, State> From<BTreeMap<ModuleInstanceId, (ModuleKind, M)>> for ModuleRegistry<M, State>
44where
45    State: Default,
46{
47    fn from(value: BTreeMap<ModuleInstanceId, (ModuleKind, M)>) -> Self {
48        Self {
49            inner: value,
50            state: Default::default(),
51        }
52    }
53}
54
55impl<M, State> FromIterator<(ModuleInstanceId, ModuleKind, M)> for ModuleRegistry<M, State>
56where
57    State: Default,
58{
59    fn from_iter<T: IntoIterator<Item = (ModuleInstanceId, ModuleKind, M)>>(iter: T) -> Self {
60        Self::new(iter)
61    }
62}
63
64impl<M, State> ModuleRegistry<M, State> {
65    /// Create [`Self`] from an iterator of pairs
66    pub fn new(iter: impl IntoIterator<Item = (ModuleInstanceId, ModuleKind, M)>) -> Self
67    where
68        State: Default,
69    {
70        Self {
71            inner: iter
72                .into_iter()
73                .map(|(id, kind, module)| (id, (kind, module)))
74                .collect(),
75            state: Default::default(),
76        }
77    }
78
79    /// Is registry empty?
80    pub fn is_empty(&self) -> bool {
81        self.inner.is_empty()
82    }
83
84    /// Return an iterator over all module data
85    pub fn iter_modules(&self) -> impl Iterator<Item = (ModuleInstanceId, &ModuleKind, &M)> {
86        self.inner.iter().map(|(id, (kind, m))| (*id, kind, m))
87    }
88
89    /// Return an iterator over module ids an kinds
90    pub fn iter_modules_id_kind(&self) -> impl Iterator<Item = (ModuleInstanceId, &ModuleKind)> {
91        self.inner.iter().map(|(id, (kind, _))| (*id, kind))
92    }
93
94    /// Return an iterator over all module data
95    pub fn iter_modules_mut(
96        &mut self,
97    ) -> impl Iterator<Item = (ModuleInstanceId, &ModuleKind, &mut M)> {
98        self.inner
99            .iter_mut()
100            .map(|(id, (kind, m))| (*id, &*kind, m))
101    }
102
103    /// Return an iterator over all module data
104    pub fn into_iter_modules(self) -> impl Iterator<Item = (ModuleInstanceId, ModuleKind, M)> {
105        self.inner.into_iter().map(|(id, (kind, m))| (id, kind, m))
106    }
107
108    /// Get module data by instance id
109    pub fn get(&self, id: ModuleInstanceId) -> Option<&M> {
110        self.inner.get(&id).map(|m| &m.1)
111    }
112
113    /// Get module data by instance id, including [`ModuleKind`]
114    pub fn get_with_kind(&self, id: ModuleInstanceId) -> Option<&(ModuleKind, M)> {
115        self.inner.get(&id)
116    }
117}
118
119impl<M: std::fmt::Debug, State> ModuleRegistry<M, State> {
120    /// Return the module data belonging to the module identified by the
121    /// supplied `module_id`
122    ///
123    /// # Panics
124    /// If the module isn't in the registry
125    pub fn get_expect(&self, id: ModuleInstanceId) -> &M {
126        &self
127            .inner
128            .get(&id)
129            .ok_or_else(|| {
130                anyhow!(
131                    "Instance ID not found: got {}, expected one of {:?}",
132                    id,
133                    self.inner.keys().collect::<Vec<_>>()
134                )
135            })
136            .expect("Only existing instance should be fetched")
137            .1
138    }
139
140    /// Add a module to the registry
141    pub fn register_module(&mut self, id: ModuleInstanceId, kind: ModuleKind, module: M) {
142        // FIXME: return result
143        assert!(
144            self.inner.insert(id, (kind, module)).is_none(),
145            "Module was already registered!"
146        );
147    }
148
149    pub fn append_module(&mut self, kind: ModuleKind, module: M) {
150        let last_id = self
151            .inner
152            .last_key_value()
153            .map(|id| id.0.checked_add(1).expect("Module id overflow"))
154            .unwrap_or_default();
155        assert!(
156            self.inner.insert(last_id, (kind, module)).is_none(),
157            "Module was already registered?!"
158        );
159    }
160}
161
162/// Collection of server modules
163pub type ServerModuleRegistry = ModuleRegistry<DynServerModule>;
164
165impl ServerModuleRegistry {
166    /// Generate a `ModuleDecoderRegistry` from this `ModuleRegistry`
167    pub fn decoder_registry(&self) -> ModuleDecoderRegistry {
168        // TODO: cache decoders
169        self.inner
170            .iter()
171            .map(|(&id, (kind, module))| (id, kind.clone(), module.decoder()))
172            .collect::<ModuleDecoderRegistry>()
173    }
174}
175
176#[derive(Default, Clone, Copy, PartialEq, Eq, Hash, Debug)]
177pub enum DecodingMode {
178    /// Reject unknown module instance ids
179    #[default]
180    Reject,
181    /// Fallback to decoding unknown module instance ids as
182    /// [`crate::core::DynUnknown`]
183    Fallback,
184}
185
186/// Collection of decoders belonging to modules, typically obtained from a
187/// `ModuleRegistry`
188pub type ModuleDecoderRegistry = ModuleRegistry<Decoder, DecodingMode>;
189
190impl ModuleDecoderRegistry {
191    pub fn with_fallback(self) -> Self {
192        Self {
193            state: DecodingMode::Fallback,
194            ..self
195        }
196    }
197
198    pub fn decoding_mode(&self) -> DecodingMode {
199        self.state
200    }
201
202    /// Panic if the [`Self::decoding_mode`] is not `Reject`
203    pub fn assert_reject_mode(&self) {
204        assert_eq!(self.state, DecodingMode::Reject);
205    }
206}