fedimint_client/sm/
state.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::future::Future;
4use std::hash;
5use std::io::{Error, Read, Write};
6use std::pin::Pin;
7use std::sync::Arc;
8
9use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, ModuleKind, OperationId};
10use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
11use fedimint_core::module::registry::ModuleDecoderRegistry;
12use fedimint_core::task::{MaybeSend, MaybeSync};
13use fedimint_core::util::BoxFuture;
14use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
15
16use crate::sm::ClientSMDatabaseTransaction;
17use crate::DynGlobalClientContext;
18
19/// Implementors act as state machines that can be executed
20pub trait State:
21    Debug
22    + Clone
23    + Eq
24    + PartialEq
25    + std::hash::Hash
26    + Encodable
27    + Decodable
28    + MaybeSend
29    + MaybeSync
30    + 'static
31{
32    /// Additional resources made available in this module's state transitions
33    type ModuleContext: Context;
34
35    /// All possible transitions from the current state to other states. See
36    /// [`StateTransition`] for details.
37    fn transitions(
38        &self,
39        context: &Self::ModuleContext,
40        global_context: &DynGlobalClientContext,
41    ) -> Vec<StateTransition<Self>>;
42
43    // TODO: move out of this interface into wrapper struct (see OperationState)
44    /// Operation this state machine belongs to. See [`OperationId`] for
45    /// details.
46    fn operation_id(&self) -> OperationId;
47}
48
49/// Object-safe version of [`State`]
50pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
51    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
52
53    /// All possible transitions from the state
54    fn transitions(
55        &self,
56        context: &DynContext,
57        global_context: &DynGlobalClientContext,
58    ) -> Vec<StateTransition<DynState>>;
59
60    /// Operation this state machine belongs to. See [`OperationId`] for
61    /// details.
62    fn operation_id(&self) -> OperationId;
63
64    /// Clone state
65    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
66
67    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
68
69    fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
70}
71
72/// Something that can be a [`DynContext`] for a state machine
73///
74/// General purpose code should use [`DynContext`] instead
75pub trait IContext: Debug {
76    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
77    fn module_kind(&self) -> Option<ModuleKind>;
78}
79
80module_plugin_dyn_newtype_define! {
81    /// A shared context for a module client state machine
82    #[derive(Clone)]
83    pub DynContext(Arc<IContext>)
84}
85
86/// Additional data made available to state machines of a module (e.g. API
87/// clients)
88pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {
89    const KIND: Option<ModuleKind>;
90}
91
92/// Type-erased version of [`Context`]
93impl<T> IContext for T
94where
95    T: Context + 'static + MaybeSend + MaybeSync,
96{
97    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
98        self
99    }
100
101    fn module_kind(&self) -> Option<ModuleKind> {
102        T::KIND
103    }
104}
105
106type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
107
108// TODO: remove Arc, maybe make it a fn pointer?
109pub(super) type StateTransitionFunction<S> = Arc<
110    maybe_add_send_sync!(
111        dyn for<'a> Fn(
112            &'a mut ClientSMDatabaseTransaction<'_, '_>,
113            serde_json::Value,
114            S,
115        ) -> BoxFuture<'a, S>
116    ),
117>;
118
119/// Represents one or multiple possible state transitions triggered in a common
120/// way
121pub struct StateTransition<S> {
122    /// Future that will block until a state transition is possible.
123    ///
124    /// **The trigger future must be idempotent since it might be re-run if the
125    /// client is restarted.**
126    ///
127    /// To wait for a possible state transition it can query external APIs,
128    /// subscribe to events emitted by other state machines, etc.
129    /// Optionally, it can also return some data that will be given to the
130    /// state transition function, see the `transition` docs for details.
131    pub trigger: TriggerFuture,
132    /// State transition function that, using the output of the `trigger`,
133    /// performs the appropriate state transition.
134    ///
135    /// **This function shall not block on network IO or similar things as all
136    /// actual state transitions are run serially.**
137    ///
138    /// Since the this function can return different output states depending on
139    /// the `Value` returned by the `trigger` future it can be used to model
140    /// multiple possible state transition at once. E.g. instead of having
141    /// two state transitions querying the same API endpoint and each waiting
142    /// for a specific value to be returned to trigger their respective state
143    /// transition we can have one `trigger` future querying the API and
144    /// depending on the return value run different state transitions,
145    /// saving network requests.
146    pub transition: StateTransitionFunction<S>,
147}
148
149impl<S> StateTransition<S> {
150    /// Creates a new `StateTransition` where the `trigger` future returns a
151    /// value of type `V` that is then given to the `transition` function.
152    pub fn new<V, Trigger, TransitionFn>(
153        trigger: Trigger,
154        transition: TransitionFn,
155    ) -> StateTransition<S>
156    where
157        S: MaybeSend + MaybeSync + Clone + 'static,
158        V: serde::Serialize + serde::de::DeserializeOwned + Send,
159        Trigger: Future<Output = V> + MaybeSend + 'static,
160        TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
161            + MaybeSend
162            + MaybeSync
163            + Clone
164            + 'static,
165    {
166        StateTransition {
167            trigger: Box::pin(async {
168                let val = trigger.await;
169                serde_json::to_value(val).expect("Value could not be serialized")
170            }),
171            transition: Arc::new(move |dbtx, val, state| {
172                let transition = transition.clone();
173                Box::pin(async move {
174                    let typed_val: V = serde_json::from_value(val)
175                        .expect("Deserialize trigger return value failed");
176                    transition(dbtx, typed_val, state.clone()).await
177                })
178            }),
179        }
180    }
181}
182
183impl<T> IState for T
184where
185    T: State,
186{
187    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
188        self
189    }
190
191    fn transitions(
192        &self,
193        context: &DynContext,
194        global_context: &DynGlobalClientContext,
195    ) -> Vec<StateTransition<DynState>> {
196        <T as State>::transitions(
197            self,
198            context.as_any().downcast_ref().expect("Wrong module"),
199            global_context,
200        )
201        .into_iter()
202        .map(|st| StateTransition {
203            trigger: st.trigger,
204            transition: Arc::new(
205                move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
206                    let transition = st.transition.clone();
207                    Box::pin(async move {
208                        let new_state = transition(
209                            dbtx,
210                            val,
211                            state
212                                .as_any()
213                                .downcast_ref::<T>()
214                                .expect("Wrong module")
215                                .clone(),
216                        )
217                        .await;
218                        DynState::from_typed(state.module_instance_id(), new_state)
219                    })
220                },
221            ),
222        })
223        .collect()
224    }
225
226    fn operation_id(&self) -> OperationId {
227        <T as State>::operation_id(self)
228    }
229
230    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
231        DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
232    }
233
234    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
235        let other: &T = other
236            .as_any()
237            .downcast_ref()
238            .expect("Type is ensured in previous step");
239
240        self == other
241    }
242
243    fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
244        self.hash(&mut hasher);
245    }
246}
247
248/// A type-erased state of a state machine belonging to a module instance, see
249/// [`State`]
250pub struct DynState(
251    Box<maybe_add_send_sync!(dyn IState + 'static)>,
252    ModuleInstanceId,
253);
254
255impl IState for DynState {
256    fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
257        (**self).as_any()
258    }
259
260    fn transitions(
261        &self,
262        context: &DynContext,
263        global_context: &DynGlobalClientContext,
264    ) -> Vec<StateTransition<DynState>> {
265        (**self).transitions(context, global_context)
266    }
267
268    fn operation_id(&self) -> OperationId {
269        (**self).operation_id()
270    }
271
272    fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
273        (**self).clone(module_instance_id)
274    }
275
276    fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
277        (**self).erased_eq_no_instance_id(other)
278    }
279
280    fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher) {
281        (**self).erased_hash_no_instance_id(hasher);
282    }
283}
284
285impl IntoDynInstance for DynState {
286    type DynType = DynState;
287
288    fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
289        assert_eq!(instance_id, self.1);
290        self
291    }
292}
293
294impl std::ops::Deref for DynState {
295    type Target = maybe_add_send_sync!(dyn IState + 'static);
296
297    fn deref(&self) -> &<Self as std::ops::Deref>::Target {
298        &*self.0
299    }
300}
301
302impl hash::Hash for DynState {
303    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
304        self.1.hash(hasher);
305        self.0.erased_hash_no_instance_id(hasher);
306    }
307}
308
309impl DynState {
310    pub fn module_instance_id(&self) -> ModuleInstanceId {
311        self.1
312    }
313
314    pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
315    where
316        I: IState + 'static,
317    {
318        Self(Box::new(typed), module_instance_id)
319    }
320
321    pub fn from_parts(
322        module_instance_id: ::fedimint_core::core::ModuleInstanceId,
323        dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
324    ) -> Self {
325        Self(dynbox, module_instance_id)
326    }
327}
328
329impl std::fmt::Debug for DynState {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        std::fmt::Debug::fmt(&self.0, f)
332    }
333}
334
335impl std::ops::DerefMut for DynState {
336    fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
337        &mut *self.0
338    }
339}
340
341impl Clone for DynState {
342    fn clone(&self) -> Self {
343        self.0.clone(self.1)
344    }
345}
346
347impl PartialEq for DynState {
348    fn eq(&self, other: &Self) -> bool {
349        if self.1 != other.1 {
350            return false;
351        }
352        self.erased_eq_no_instance_id(other)
353    }
354}
355impl Eq for DynState {}
356
357impl Encodable for DynState {
358    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
359        self.1.consensus_encode(writer)?;
360        self.0.consensus_encode_dyn(writer)
361    }
362}
363impl Decodable for DynState {
364    fn consensus_decode_partial<R: std::io::Read>(
365        reader: &mut R,
366        decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
367    ) -> Result<Self, fedimint_core::encoding::DecodeError> {
368        let module_id =
369            fedimint_core::core::ModuleInstanceId::consensus_decode_partial(reader, decoders)?;
370        decoders
371            .get_expect(module_id)
372            .decode_partial(reader, module_id, decoders)
373    }
374}
375
376impl DynState {
377    /// `true` if this state allows no further transitions
378    pub fn is_terminal(
379        &self,
380        context: &DynContext,
381        global_context: &DynGlobalClientContext,
382    ) -> bool {
383        self.transitions(context, global_context).is_empty()
384    }
385}
386
387#[derive(Debug)]
388pub struct OperationState<S> {
389    pub operation_id: OperationId,
390    pub state: S,
391}
392
393/// Wrapper for states that don't want to carry around their operation id. `S`
394/// is allowed to panic when `operation_id` is called.
395impl<S> State for OperationState<S>
396where
397    S: State,
398{
399    type ModuleContext = S::ModuleContext;
400
401    fn transitions(
402        &self,
403        context: &Self::ModuleContext,
404        global_context: &DynGlobalClientContext,
405    ) -> Vec<StateTransition<Self>> {
406        let transitions: Vec<StateTransition<OperationState<S>>> = self
407            .state
408            .transitions(context, global_context)
409            .into_iter()
410            .map(
411                |StateTransition {
412                     trigger,
413                     transition,
414                 }| {
415                    let op_transition: StateTransitionFunction<Self> =
416                        Arc::new(move |dbtx, value, op_state| {
417                            let transition = transition.clone();
418                            Box::pin(async move {
419                                let state = transition(dbtx, value, op_state.state).await;
420                                OperationState {
421                                    operation_id: op_state.operation_id,
422                                    state,
423                                }
424                            })
425                        });
426
427                    StateTransition {
428                        trigger,
429                        transition: op_transition,
430                    }
431                },
432            )
433            .collect();
434        transitions
435    }
436
437    fn operation_id(&self) -> OperationId {
438        self.operation_id
439    }
440}
441
442// TODO: can we get rid of `GC`? Maybe make it an associated type of `State`
443// instead?
444impl<S> IntoDynInstance for OperationState<S>
445where
446    S: State,
447{
448    type DynType = DynState;
449
450    fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
451        DynState::from_typed(instance_id, self)
452    }
453}
454
455impl<S> Encodable for OperationState<S>
456where
457    S: State,
458{
459    fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
460        let mut len = 0;
461        len += self.operation_id.consensus_encode(writer)?;
462        len += self.state.consensus_encode(writer)?;
463        Ok(len)
464    }
465}
466
467impl<S> Decodable for OperationState<S>
468where
469    S: State,
470{
471    fn consensus_decode_partial<R: Read>(
472        read: &mut R,
473        modules: &ModuleDecoderRegistry,
474    ) -> Result<Self, DecodeError> {
475        let operation_id = OperationId::consensus_decode_partial(read, modules)?;
476        let state = S::consensus_decode_partial(read, modules)?;
477
478        Ok(OperationState {
479            operation_id,
480            state,
481        })
482    }
483}
484
485// TODO: derive after getting rid of `GC` type arg
486impl<S> PartialEq for OperationState<S>
487where
488    S: State,
489{
490    fn eq(&self, other: &Self) -> bool {
491        self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
492    }
493}
494
495impl<S> Eq for OperationState<S> where S: State {}
496
497impl<S> hash::Hash for OperationState<S>
498where
499    S: hash::Hash,
500{
501    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
502        self.operation_id.hash(hasher);
503        self.state.hash(hasher);
504    }
505}
506
507impl<S> Clone for OperationState<S>
508where
509    S: State,
510{
511    fn clone(&self) -> Self {
512        OperationState {
513            operation_id: self.operation_id,
514            state: self.state.clone(),
515        }
516    }
517}