1use std::any::Any;
2use std::fmt::Debug;
3use std::future::Future;
4use std::hash;
5use std::io::{Error, Read, Write};
6use std::pin::Pin;
7use std::sync::Arc;
8
9use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, ModuleKind, OperationId};
10use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
11use fedimint_core::module::registry::ModuleDecoderRegistry;
12use fedimint_core::task::{MaybeSend, MaybeSync};
13use fedimint_core::util::BoxFuture;
14use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
15
16use crate::sm::ClientSMDatabaseTransaction;
17use crate::DynGlobalClientContext;
18
19pub trait State:
21 Debug
22 + Clone
23 + Eq
24 + PartialEq
25 + std::hash::Hash
26 + Encodable
27 + Decodable
28 + MaybeSend
29 + MaybeSync
30 + 'static
31{
32 type ModuleContext: Context;
34
35 fn transitions(
38 &self,
39 context: &Self::ModuleContext,
40 global_context: &DynGlobalClientContext,
41 ) -> Vec<StateTransition<Self>>;
42
43 fn operation_id(&self) -> OperationId;
47}
48
49pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
51 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
52
53 fn transitions(
55 &self,
56 context: &DynContext,
57 global_context: &DynGlobalClientContext,
58 ) -> Vec<StateTransition<DynState>>;
59
60 fn operation_id(&self) -> OperationId;
63
64 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
66
67 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
68
69 fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
70}
71
72pub trait IContext: Debug {
76 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
77 fn module_kind(&self) -> Option<ModuleKind>;
78}
79
80module_plugin_dyn_newtype_define! {
81 #[derive(Clone)]
83 pub DynContext(Arc<IContext>)
84}
85
86pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {
89 const KIND: Option<ModuleKind>;
90}
91
92impl<T> IContext for T
94where
95 T: Context + 'static + MaybeSend + MaybeSync,
96{
97 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
98 self
99 }
100
101 fn module_kind(&self) -> Option<ModuleKind> {
102 T::KIND
103 }
104}
105
106type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
107
108pub(super) type StateTransitionFunction<S> = Arc<
110 maybe_add_send_sync!(
111 dyn for<'a> Fn(
112 &'a mut ClientSMDatabaseTransaction<'_, '_>,
113 serde_json::Value,
114 S,
115 ) -> BoxFuture<'a, S>
116 ),
117>;
118
119pub struct StateTransition<S> {
122 pub trigger: TriggerFuture,
132 pub transition: StateTransitionFunction<S>,
147}
148
149impl<S> StateTransition<S> {
150 pub fn new<V, Trigger, TransitionFn>(
153 trigger: Trigger,
154 transition: TransitionFn,
155 ) -> StateTransition<S>
156 where
157 S: MaybeSend + MaybeSync + Clone + 'static,
158 V: serde::Serialize + serde::de::DeserializeOwned + Send,
159 Trigger: Future<Output = V> + MaybeSend + 'static,
160 TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
161 + MaybeSend
162 + MaybeSync
163 + Clone
164 + 'static,
165 {
166 StateTransition {
167 trigger: Box::pin(async {
168 let val = trigger.await;
169 serde_json::to_value(val).expect("Value could not be serialized")
170 }),
171 transition: Arc::new(move |dbtx, val, state| {
172 let transition = transition.clone();
173 Box::pin(async move {
174 let typed_val: V = serde_json::from_value(val)
175 .expect("Deserialize trigger return value failed");
176 transition(dbtx, typed_val, state.clone()).await
177 })
178 }),
179 }
180 }
181}
182
183impl<T> IState for T
184where
185 T: State,
186{
187 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
188 self
189 }
190
191 fn transitions(
192 &self,
193 context: &DynContext,
194 global_context: &DynGlobalClientContext,
195 ) -> Vec<StateTransition<DynState>> {
196 <T as State>::transitions(
197 self,
198 context.as_any().downcast_ref().expect("Wrong module"),
199 global_context,
200 )
201 .into_iter()
202 .map(|st| StateTransition {
203 trigger: st.trigger,
204 transition: Arc::new(
205 move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
206 let transition = st.transition.clone();
207 Box::pin(async move {
208 let new_state = transition(
209 dbtx,
210 val,
211 state
212 .as_any()
213 .downcast_ref::<T>()
214 .expect("Wrong module")
215 .clone(),
216 )
217 .await;
218 DynState::from_typed(state.module_instance_id(), new_state)
219 })
220 },
221 ),
222 })
223 .collect()
224 }
225
226 fn operation_id(&self) -> OperationId {
227 <T as State>::operation_id(self)
228 }
229
230 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
231 DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
232 }
233
234 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
235 let other: &T = other
236 .as_any()
237 .downcast_ref()
238 .expect("Type is ensured in previous step");
239
240 self == other
241 }
242
243 fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
244 self.hash(&mut hasher);
245 }
246}
247
248pub struct DynState(
251 Box<maybe_add_send_sync!(dyn IState + 'static)>,
252 ModuleInstanceId,
253);
254
255impl IState for DynState {
256 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
257 (**self).as_any()
258 }
259
260 fn transitions(
261 &self,
262 context: &DynContext,
263 global_context: &DynGlobalClientContext,
264 ) -> Vec<StateTransition<DynState>> {
265 (**self).transitions(context, global_context)
266 }
267
268 fn operation_id(&self) -> OperationId {
269 (**self).operation_id()
270 }
271
272 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
273 (**self).clone(module_instance_id)
274 }
275
276 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
277 (**self).erased_eq_no_instance_id(other)
278 }
279
280 fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher) {
281 (**self).erased_hash_no_instance_id(hasher);
282 }
283}
284
285impl IntoDynInstance for DynState {
286 type DynType = DynState;
287
288 fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
289 assert_eq!(instance_id, self.1);
290 self
291 }
292}
293
294impl std::ops::Deref for DynState {
295 type Target = maybe_add_send_sync!(dyn IState + 'static);
296
297 fn deref(&self) -> &<Self as std::ops::Deref>::Target {
298 &*self.0
299 }
300}
301
302impl hash::Hash for DynState {
303 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
304 self.1.hash(hasher);
305 self.0.erased_hash_no_instance_id(hasher);
306 }
307}
308
309impl DynState {
310 pub fn module_instance_id(&self) -> ModuleInstanceId {
311 self.1
312 }
313
314 pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
315 where
316 I: IState + 'static,
317 {
318 Self(Box::new(typed), module_instance_id)
319 }
320
321 pub fn from_parts(
322 module_instance_id: ::fedimint_core::core::ModuleInstanceId,
323 dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
324 ) -> Self {
325 Self(dynbox, module_instance_id)
326 }
327}
328
329impl std::fmt::Debug for DynState {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 std::fmt::Debug::fmt(&self.0, f)
332 }
333}
334
335impl std::ops::DerefMut for DynState {
336 fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
337 &mut *self.0
338 }
339}
340
341impl Clone for DynState {
342 fn clone(&self) -> Self {
343 self.0.clone(self.1)
344 }
345}
346
347impl PartialEq for DynState {
348 fn eq(&self, other: &Self) -> bool {
349 if self.1 != other.1 {
350 return false;
351 }
352 self.erased_eq_no_instance_id(other)
353 }
354}
355impl Eq for DynState {}
356
357impl Encodable for DynState {
358 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
359 self.1.consensus_encode(writer)?;
360 self.0.consensus_encode_dyn(writer)
361 }
362}
363impl Decodable for DynState {
364 fn consensus_decode_partial<R: std::io::Read>(
365 reader: &mut R,
366 decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
367 ) -> Result<Self, fedimint_core::encoding::DecodeError> {
368 let module_id =
369 fedimint_core::core::ModuleInstanceId::consensus_decode_partial(reader, decoders)?;
370 decoders
371 .get_expect(module_id)
372 .decode_partial(reader, module_id, decoders)
373 }
374}
375
376impl DynState {
377 pub fn is_terminal(
379 &self,
380 context: &DynContext,
381 global_context: &DynGlobalClientContext,
382 ) -> bool {
383 self.transitions(context, global_context).is_empty()
384 }
385}
386
387#[derive(Debug)]
388pub struct OperationState<S> {
389 pub operation_id: OperationId,
390 pub state: S,
391}
392
393impl<S> State for OperationState<S>
396where
397 S: State,
398{
399 type ModuleContext = S::ModuleContext;
400
401 fn transitions(
402 &self,
403 context: &Self::ModuleContext,
404 global_context: &DynGlobalClientContext,
405 ) -> Vec<StateTransition<Self>> {
406 let transitions: Vec<StateTransition<OperationState<S>>> = self
407 .state
408 .transitions(context, global_context)
409 .into_iter()
410 .map(
411 |StateTransition {
412 trigger,
413 transition,
414 }| {
415 let op_transition: StateTransitionFunction<Self> =
416 Arc::new(move |dbtx, value, op_state| {
417 let transition = transition.clone();
418 Box::pin(async move {
419 let state = transition(dbtx, value, op_state.state).await;
420 OperationState {
421 operation_id: op_state.operation_id,
422 state,
423 }
424 })
425 });
426
427 StateTransition {
428 trigger,
429 transition: op_transition,
430 }
431 },
432 )
433 .collect();
434 transitions
435 }
436
437 fn operation_id(&self) -> OperationId {
438 self.operation_id
439 }
440}
441
442impl<S> IntoDynInstance for OperationState<S>
445where
446 S: State,
447{
448 type DynType = DynState;
449
450 fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
451 DynState::from_typed(instance_id, self)
452 }
453}
454
455impl<S> Encodable for OperationState<S>
456where
457 S: State,
458{
459 fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
460 let mut len = 0;
461 len += self.operation_id.consensus_encode(writer)?;
462 len += self.state.consensus_encode(writer)?;
463 Ok(len)
464 }
465}
466
467impl<S> Decodable for OperationState<S>
468where
469 S: State,
470{
471 fn consensus_decode_partial<R: Read>(
472 read: &mut R,
473 modules: &ModuleDecoderRegistry,
474 ) -> Result<Self, DecodeError> {
475 let operation_id = OperationId::consensus_decode_partial(read, modules)?;
476 let state = S::consensus_decode_partial(read, modules)?;
477
478 Ok(OperationState {
479 operation_id,
480 state,
481 })
482 }
483}
484
485impl<S> PartialEq for OperationState<S>
487where
488 S: State,
489{
490 fn eq(&self, other: &Self) -> bool {
491 self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
492 }
493}
494
495impl<S> Eq for OperationState<S> where S: State {}
496
497impl<S> hash::Hash for OperationState<S>
498where
499 S: hash::Hash,
500{
501 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
502 self.operation_id.hash(hasher);
503 self.state.hash(hasher);
504 }
505}
506
507impl<S> Clone for OperationState<S>
508where
509 S: State,
510{
511 fn clone(&self) -> Self {
512 OperationState {
513 operation_id: self.operation_id,
514 state: self.state.clone(),
515 }
516 }
517}