1pub mod recovery;
2
3use std::collections::BTreeMap;
4use std::fmt::Debug;
5use std::marker;
6use std::sync::Arc;
7
8use fedimint_api_client::api::{DynGlobalApi, DynModuleApi};
9use fedimint_core::config::{ClientModuleConfig, FederationId, ModuleInitRegistry};
10use fedimint_core::core::{Decoder, ModuleInstanceId, ModuleKind};
11use fedimint_core::db::{Database, DatabaseVersion};
12use fedimint_core::module::{
13 ApiAuth, ApiVersion, CommonModuleInit, IDynCommonModuleInit, ModuleInit, MultiApiVersion,
14};
15use fedimint_core::task::{MaybeSend, MaybeSync, TaskGroup};
16use fedimint_core::{apply, async_trait_maybe_send, dyn_newtype_define, NumPeers};
17use fedimint_derive_secret::DerivableSecret;
18use tokio::sync::watch;
19use tracing::warn;
20
21use super::recovery::{DynModuleBackup, RecoveryProgress};
22use super::{ClientContext, FinalClient};
23use crate::db::ClientMigrationFn;
24use crate::module::{ClientModule, DynClientModule};
25use crate::sm::{ModuleNotifier, Notifier};
26
27pub type ClientModuleInitRegistry = ModuleInitRegistry<DynClientModuleInit>;
28
29pub struct ClientModuleInitArgs<C>
30where
31 C: ClientModuleInit,
32{
33 federation_id: FederationId,
34 peer_num: usize,
35 cfg: <<C as ModuleInit>::Common as CommonModuleInit>::ClientConfig,
36 db: Database,
37 core_api_version: ApiVersion,
38 module_api_version: ApiVersion,
39 module_root_secret: DerivableSecret,
40 notifier: ModuleNotifier<<<C as ClientModuleInit>::Module as ClientModule>::States>,
41 api: DynGlobalApi,
42 admin_auth: Option<ApiAuth>,
43 module_api: DynModuleApi,
44 context: ClientContext<<C as ClientModuleInit>::Module>,
45 task_group: TaskGroup,
46}
47
48impl<C> ClientModuleInitArgs<C>
49where
50 C: ClientModuleInit,
51{
52 pub fn federation_id(&self) -> &FederationId {
53 &self.federation_id
54 }
55
56 pub fn peer_num(&self) -> usize {
57 self.peer_num
58 }
59
60 pub fn cfg(&self) -> &<<C as ModuleInit>::Common as CommonModuleInit>::ClientConfig {
61 &self.cfg
62 }
63
64 pub fn db(&self) -> &Database {
65 &self.db
66 }
67
68 pub fn core_api_version(&self) -> &ApiVersion {
69 &self.core_api_version
70 }
71
72 pub fn module_api_version(&self) -> &ApiVersion {
73 &self.module_api_version
74 }
75
76 pub fn module_root_secret(&self) -> &DerivableSecret {
77 &self.module_root_secret
78 }
79
80 pub fn notifier(
81 &self,
82 ) -> &ModuleNotifier<<<C as ClientModuleInit>::Module as ClientModule>::States> {
83 &self.notifier
84 }
85
86 pub fn api(&self) -> &DynGlobalApi {
87 &self.api
88 }
89
90 pub fn admin_auth(&self) -> Option<&ApiAuth> {
91 self.admin_auth.as_ref()
92 }
93
94 pub fn module_api(&self) -> &DynModuleApi {
95 &self.module_api
96 }
97
98 pub fn context(&self) -> ClientContext<<C as ClientModuleInit>::Module> {
105 self.context.clone()
106 }
107
108 pub fn task_group(&self) -> &TaskGroup {
109 &self.task_group
110 }
111}
112
113pub struct ClientModuleRecoverArgs<C>
114where
115 C: ClientModuleInit,
116{
117 federation_id: FederationId,
118 num_peers: NumPeers,
119 cfg: <<C as ModuleInit>::Common as CommonModuleInit>::ClientConfig,
120 db: Database,
121 core_api_version: ApiVersion,
122 module_api_version: ApiVersion,
123 module_root_secret: DerivableSecret,
124 notifier: ModuleNotifier<<<C as ClientModuleInit>::Module as ClientModule>::States>,
125 api: DynGlobalApi,
126 admin_auth: Option<ApiAuth>,
127 module_api: DynModuleApi,
128 context: ClientContext<<C as ClientModuleInit>::Module>,
129 progress_tx: tokio::sync::watch::Sender<RecoveryProgress>,
130 task_group: TaskGroup,
131}
132
133impl<C> ClientModuleRecoverArgs<C>
134where
135 C: ClientModuleInit,
136{
137 pub fn federation_id(&self) -> &FederationId {
138 &self.federation_id
139 }
140
141 pub fn num_peers(&self) -> NumPeers {
142 self.num_peers
143 }
144
145 pub fn cfg(&self) -> &<<C as ModuleInit>::Common as CommonModuleInit>::ClientConfig {
146 &self.cfg
147 }
148
149 pub fn db(&self) -> &Database {
150 &self.db
151 }
152
153 pub fn task_group(&self) -> &TaskGroup {
154 &self.task_group
155 }
156
157 pub fn core_api_version(&self) -> &ApiVersion {
158 &self.core_api_version
159 }
160
161 pub fn module_api_version(&self) -> &ApiVersion {
162 &self.module_api_version
163 }
164
165 pub fn module_root_secret(&self) -> &DerivableSecret {
166 &self.module_root_secret
167 }
168
169 pub fn notifier(
170 &self,
171 ) -> &ModuleNotifier<<<C as ClientModuleInit>::Module as ClientModule>::States> {
172 &self.notifier
173 }
174
175 pub fn api(&self) -> &DynGlobalApi {
176 &self.api
177 }
178
179 pub fn admin_auth(&self) -> Option<&ApiAuth> {
180 self.admin_auth.as_ref()
181 }
182
183 pub fn module_api(&self) -> &DynModuleApi {
184 &self.module_api
185 }
186
187 pub fn context(&self) -> ClientContext<<C as ClientModuleInit>::Module> {
192 self.context.clone()
193 }
194
195 pub fn update_recovery_progress(&self, progress: RecoveryProgress) {
196 if progress.is_done() {
197 warn!("Module trying to send a completed recovery progress. Ignoring");
200 } else if progress.is_none() {
201 warn!("Module trying to send a none recovery progress. Ignoring");
204 } else if self.progress_tx.send(progress).is_err() {
205 warn!("Module trying to send a recovery progress but nothing is listening");
206 }
207 }
208}
209
210#[apply(async_trait_maybe_send!)]
211pub trait ClientModuleInit: ModuleInit + Sized {
212 type Module: ClientModule;
213
214 fn supported_api_versions(&self) -> MultiApiVersion;
217
218 async fn recover(
224 &self,
225 _args: &ClientModuleRecoverArgs<Self>,
226 _snapshot: Option<&<Self::Module as ClientModule>::Backup>,
227 ) -> anyhow::Result<()> {
228 warn!(
229 kind = %<Self::Module as ClientModule>::kind(),
230 "Module does not support recovery, completing without doing anything"
231 );
232 Ok(())
233 }
234
235 async fn init(&self, args: &ClientModuleInitArgs<Self>) -> anyhow::Result<Self::Module>;
237
238 fn get_database_migrations(&self) -> BTreeMap<DatabaseVersion, ClientMigrationFn> {
242 BTreeMap::new()
243 }
244}
245
246#[apply(async_trait_maybe_send!)]
247pub trait IClientModuleInit: IDynCommonModuleInit + Debug + MaybeSend + MaybeSync {
248 fn decoder(&self) -> Decoder;
249
250 fn module_kind(&self) -> ModuleKind;
251
252 fn as_common(&self) -> &(dyn IDynCommonModuleInit + Send + Sync + 'static);
253
254 fn supported_api_versions(&self) -> MultiApiVersion;
256
257 #[allow(clippy::too_many_arguments)]
258 async fn recover(
259 &self,
260 final_client: FinalClient,
261 federation_id: FederationId,
262 num_peers: NumPeers,
263 cfg: ClientModuleConfig,
264 db: Database,
265 instance_id: ModuleInstanceId,
266 core_api_version: ApiVersion,
267 module_api_version: ApiVersion,
268 module_root_secret: DerivableSecret,
269 notifier: Notifier,
270 api: DynGlobalApi,
271 admin_auth: Option<ApiAuth>,
272 snapshot: Option<&DynModuleBackup>,
273 progress_tx: watch::Sender<RecoveryProgress>,
274 task_group: TaskGroup,
275 ) -> anyhow::Result<()>;
276
277 #[allow(clippy::too_many_arguments)]
278 async fn init(
279 &self,
280 final_client: FinalClient,
281 federation_id: FederationId,
282 peer_num: usize,
283 cfg: ClientModuleConfig,
284 db: Database,
285 instance_id: ModuleInstanceId,
286 core_api_version: ApiVersion,
287 module_api_version: ApiVersion,
288 module_root_secret: DerivableSecret,
289 notifier: Notifier,
290 api: DynGlobalApi,
291 admin_auth: Option<ApiAuth>,
292 task_group: TaskGroup,
293 ) -> anyhow::Result<DynClientModule>;
294
295 fn get_database_migrations(&self) -> BTreeMap<DatabaseVersion, ClientMigrationFn>;
296}
297
298#[apply(async_trait_maybe_send!)]
299impl<T> IClientModuleInit for T
300where
301 T: ClientModuleInit + 'static + MaybeSend + Sync,
302{
303 fn decoder(&self) -> Decoder {
304 <<T as ClientModuleInit>::Module as ClientModule>::decoder()
305 }
306
307 fn module_kind(&self) -> ModuleKind {
308 <Self as ModuleInit>::Common::KIND
309 }
310
311 fn as_common(&self) -> &(dyn IDynCommonModuleInit + Send + Sync + 'static) {
312 self
313 }
314
315 fn supported_api_versions(&self) -> MultiApiVersion {
316 <Self as ClientModuleInit>::supported_api_versions(self)
317 }
318
319 async fn recover(
320 &self,
321 final_client: FinalClient,
322 federation_id: FederationId,
323 num_peers: NumPeers,
324 cfg: ClientModuleConfig,
325 db: Database,
326 instance_id: ModuleInstanceId,
327 core_api_version: ApiVersion,
328 module_api_version: ApiVersion,
329 module_root_secret: DerivableSecret,
330 notifier: Notifier,
332 api: DynGlobalApi,
333 admin_auth: Option<ApiAuth>,
334 snapshot: Option<&DynModuleBackup>,
335 progress_tx: watch::Sender<RecoveryProgress>,
336 task_group: TaskGroup,
337 ) -> anyhow::Result<()> {
338 let typed_cfg: &<<T as fedimint_core::module::ModuleInit>::Common as CommonModuleInit>::ClientConfig = cfg.cast()?;
339 let snapshot: Option<&<<Self as ClientModuleInit>::Module as ClientModule>::Backup> =
340 snapshot.map(|s| {
341 s.as_any()
342 .downcast_ref()
343 .expect("can't convert client module backup to desired type")
344 });
345
346 let (module_db, global_dbtx_access_token) = db.with_prefix_module_id(instance_id);
347 Ok(self
348 .recover(
349 &ClientModuleRecoverArgs {
350 federation_id,
351 num_peers,
352 cfg: typed_cfg.clone(),
353 db: module_db.clone(),
354 core_api_version,
355 module_api_version,
356 module_root_secret,
357 notifier: notifier.module_notifier(instance_id),
358 api: api.clone(),
359 admin_auth,
360 module_api: api.with_module(instance_id),
361 context: ClientContext {
362 client: final_client,
363 module_instance_id: instance_id,
364 global_dbtx_access_token,
365 module_db,
366 _marker: marker::PhantomData,
367 },
368 progress_tx,
369 task_group,
370 },
371 snapshot,
372 )
373 .await?)
374 }
375
376 async fn init(
377 &self,
378 final_client: FinalClient,
379 federation_id: FederationId,
380 peer_num: usize,
381 cfg: ClientModuleConfig,
382 db: Database,
383 instance_id: ModuleInstanceId,
384 core_api_version: ApiVersion,
385 module_api_version: ApiVersion,
386 module_root_secret: DerivableSecret,
387 notifier: Notifier,
389 api: DynGlobalApi,
390 admin_auth: Option<ApiAuth>,
391 task_group: TaskGroup,
392 ) -> anyhow::Result<DynClientModule> {
393 let typed_cfg: &<<T as fedimint_core::module::ModuleInit>::Common as CommonModuleInit>::ClientConfig = cfg.cast()?;
394 let (module_db, global_dbtx_access_token) = db.with_prefix_module_id(instance_id);
395 Ok(self
396 .init(&ClientModuleInitArgs {
397 federation_id,
398 peer_num,
399 cfg: typed_cfg.clone(),
400 db: module_db.clone(),
401 core_api_version,
402 module_api_version,
403 module_root_secret,
404 notifier: notifier.module_notifier(instance_id),
405 api: api.clone(),
406 admin_auth,
407 module_api: api.with_module(instance_id),
408 context: ClientContext {
409 client: final_client,
410 module_instance_id: instance_id,
411 module_db,
412 global_dbtx_access_token,
413 _marker: marker::PhantomData,
414 },
415 task_group,
416 })
417 .await?
418 .into())
419 }
420
421 fn get_database_migrations(&self) -> BTreeMap<DatabaseVersion, ClientMigrationFn> {
422 <Self as ClientModuleInit>::get_database_migrations(self)
423 }
424}
425
426dyn_newtype_define!(
427 #[derive(Clone)]
428 pub DynClientModuleInit(Arc<IClientModuleInit>)
429);
430
431impl AsRef<dyn IDynCommonModuleInit + Send + Sync + 'static> for DynClientModuleInit {
432 fn as_ref(&self) -> &(dyn IDynCommonModuleInit + Send + Sync + 'static) {
433 self.inner.as_common()
434 }
435}
436
437impl AsRef<dyn IClientModuleInit + 'static> for DynClientModuleInit {
438 fn as_ref(&self) -> &(dyn IClientModuleInit + 'static) {
439 self.inner.as_ref()
440 }
441}