fedimint_client/transaction/
builder.rs

1use std::fmt;
2use std::ops::RangeInclusive;
3use std::sync::Arc;
4
5use bitcoin::key::Keypair;
6use bitcoin::secp256k1;
7use fedimint_core::core::{
8    DynInput, DynOutput, IInput, IOutput, IntoDynInstance, ModuleInstanceId,
9};
10use fedimint_core::encoding::{Decodable, Encodable};
11use fedimint_core::task::{MaybeSend, MaybeSync};
12use fedimint_core::transaction::{Transaction, TransactionSignature};
13use fedimint_core::Amount;
14use fedimint_logging::LOG_CLIENT;
15use itertools::multiunzip;
16use rand::{CryptoRng, Rng, RngCore};
17use secp256k1::Secp256k1;
18use tracing::warn;
19
20use crate::module::{IdxRange, OutPointRange, StateGenerator};
21use crate::sm::{self, DynState};
22use crate::{
23    states_add_instance, states_to_instanceless_dyn, InstancelessDynClientInput,
24    InstancelessDynClientInputBundle, InstancelessDynClientInputSM, InstancelessDynClientOutput,
25    InstancelessDynClientOutputBundle, InstancelessDynClientOutputSM,
26};
27
28#[derive(Clone, Debug)]
29pub struct ClientInput<I = DynInput> {
30    pub input: I,
31    pub keys: Vec<Keypair>,
32    pub amount: Amount,
33}
34
35#[derive(Clone)]
36pub struct ClientInputSM<S = DynState> {
37    pub state_machines: StateGenerator<S>,
38}
39
40impl<S> fmt::Debug for ClientInputSM<S> {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_str("ClientInputSM")
43    }
44}
45
46/// A fake [`sm::Context`] for [`NeverClientStateMachine`]
47#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
48pub enum NeverClientContext {}
49
50impl sm::Context for NeverClientContext {
51    const KIND: Option<fedimint_core::core::ModuleKind> = None;
52}
53
54/// A fake [`sm::State`] that can actually never happen.
55///
56/// Useful as a default for type inference in cases where there are no
57/// state machines involved in [`ClientInputBundle`].
58#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
59pub enum NeverClientStateMachine {}
60
61impl IntoDynInstance for NeverClientStateMachine {
62    type DynType = DynState;
63
64    fn into_dyn(self, _instance_id: ModuleInstanceId) -> Self::DynType {
65        unreachable!()
66    }
67}
68impl sm::State for NeverClientStateMachine {
69    type ModuleContext = NeverClientContext;
70
71    fn transitions(
72        &self,
73        _context: &Self::ModuleContext,
74        _global_context: &crate::DynGlobalClientContext,
75    ) -> Vec<sm::StateTransition<Self>> {
76        unreachable!()
77    }
78
79    fn operation_id(&self) -> fedimint_core::core::OperationId {
80        unreachable!()
81    }
82}
83
84/// A group of inputs and state machines responsible for driving their state
85///
86/// These must be kept together as a whole when including in a transaction.
87#[derive(Clone, Debug)]
88pub struct ClientInputBundle<I = DynInput, S = DynState> {
89    pub(crate) inputs: Vec<ClientInput<I>>,
90    pub(crate) sm_gens: Vec<ClientInputSM<S>>,
91}
92
93impl<I> ClientInputBundle<I, NeverClientStateMachine> {
94    /// A version of [`Self::new`] for times where input does not require any
95    /// state machines
96    ///
97    /// This avoids type inference issues of `S`, and saves some typing.
98    pub fn new_no_sm(inputs: Vec<ClientInput<I>>) -> Self {
99        if inputs.is_empty() {
100            // TODO: Make it return Result or assert?
101            warn!(target: LOG_CLIENT, "Empty input bundle will be illegal in the future");
102        }
103        Self {
104            inputs,
105            sm_gens: vec![],
106        }
107    }
108}
109
110impl<I, S> ClientInputBundle<I, S>
111where
112    I: IInput + MaybeSend + MaybeSync + 'static,
113    S: sm::IState + MaybeSend + MaybeSync + 'static,
114{
115    pub fn new(inputs: Vec<ClientInput<I>>, sm_gens: Vec<ClientInputSM<S>>) -> Self {
116        Self { inputs, sm_gens }
117    }
118
119    pub fn inputs(&self) -> &[ClientInput<I>] {
120        &self.inputs
121    }
122
123    pub fn sms(&self) -> &[ClientInputSM<S>] {
124        &self.sm_gens
125    }
126
127    pub fn into_instanceless(self) -> InstancelessDynClientInputBundle {
128        InstancelessDynClientInputBundle {
129            inputs: self
130                .inputs
131                .into_iter()
132                .map(|input| InstancelessDynClientInput {
133                    input: Box::new(input.input),
134                    keys: input.keys,
135                    amount: input.amount,
136                })
137                .collect(),
138            sm_gens: self
139                .sm_gens
140                .into_iter()
141                .map(|input_sm| InstancelessDynClientInputSM {
142                    state_machines: states_to_instanceless_dyn(input_sm.state_machines),
143                })
144                .collect(),
145        }
146    }
147}
148
149impl<I, S> ClientInputBundle<I, S> {
150    pub fn is_empty(&self) -> bool {
151        // Notably, sm_gen will not be called when inputs are empty anyway
152        self.inputs.is_empty()
153    }
154}
155
156impl<I> IntoDynInstance for ClientInput<I>
157where
158    I: IntoDynInstance<DynType = DynInput> + 'static,
159{
160    type DynType = ClientInput;
161
162    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInput {
163        ClientInput {
164            input: self.input.into_dyn(module_instance_id),
165            keys: self.keys,
166            amount: self.amount,
167        }
168    }
169}
170
171impl<S> IntoDynInstance for ClientInputSM<S>
172where
173    S: IntoDynInstance<DynType = DynState> + 'static,
174{
175    type DynType = ClientInputSM;
176
177    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputSM {
178        ClientInputSM {
179            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
180        }
181    }
182}
183
184impl<I, S> IntoDynInstance for ClientInputBundle<I, S>
185where
186    I: IntoDynInstance<DynType = DynInput> + 'static,
187    S: IntoDynInstance<DynType = DynState> + 'static,
188{
189    type DynType = ClientInputBundle;
190
191    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
192        ClientInputBundle {
193            inputs: self
194                .inputs
195                .into_iter()
196                .map(|input| input.into_dyn(module_instance_id))
197                .collect::<Vec<ClientInput>>(),
198
199            sm_gens: self
200                .sm_gens
201                .into_iter()
202                .map(|input_sm| input_sm.into_dyn(module_instance_id))
203                .collect::<Vec<ClientInputSM>>(),
204        }
205    }
206}
207
208impl IntoDynInstance for InstancelessDynClientInputBundle {
209    type DynType = ClientInputBundle;
210
211    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
212        ClientInputBundle {
213            inputs: self
214                .inputs
215                .into_iter()
216                .map(|input| ClientInput {
217                    input: DynInput::from_parts(module_instance_id, input.input),
218                    keys: input.keys,
219                    amount: input.amount,
220                })
221                .collect::<Vec<ClientInput>>(),
222
223            sm_gens: self
224                .sm_gens
225                .into_iter()
226                .map(|input_sm| ClientInputSM {
227                    state_machines: states_add_instance(
228                        module_instance_id,
229                        input_sm.state_machines,
230                    ),
231                })
232                .collect::<Vec<ClientInputSM>>(),
233        }
234    }
235}
236
237#[derive(Clone, Debug)]
238pub struct ClientOutputBundle<O = DynOutput, S = DynState> {
239    pub(crate) outputs: Vec<ClientOutput<O>>,
240    pub(crate) sm_gens: Vec<ClientOutputSM<S>>,
241}
242
243#[derive(Clone, Debug)]
244pub struct ClientOutput<O = DynOutput> {
245    pub output: O,
246    pub amount: Amount,
247}
248
249#[derive(Clone)]
250pub struct ClientOutputSM<S = DynState> {
251    pub state_machines: StateGenerator<S>,
252}
253
254impl<S> fmt::Debug for ClientOutputSM<S> {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        f.write_str("ClientOutputSM")
257    }
258}
259impl<O> ClientOutputBundle<O, NeverClientStateMachine> {
260    /// A version of [`Self::new`] for times where output does not require any
261    /// state machines
262    ///
263    /// This avoids type inference issues of `S`, and saves some typing.
264    pub fn new_no_sm(outputs: Vec<ClientOutput<O>>) -> Self {
265        if outputs.is_empty() {
266            // TODO: Make it return Result or assert?
267            warn!(target: LOG_CLIENT, "Empty output bundle will be illegal in the future");
268        }
269        Self {
270            outputs,
271            sm_gens: vec![],
272        }
273    }
274}
275
276impl<O, S> ClientOutputBundle<O, S>
277where
278    O: IOutput + MaybeSend + MaybeSync + 'static,
279    S: sm::IState + MaybeSend + MaybeSync + 'static,
280{
281    pub fn new(outputs: Vec<ClientOutput<O>>, sm_gens: Vec<ClientOutputSM<S>>) -> Self {
282        Self { outputs, sm_gens }
283    }
284
285    pub fn outputs(&self) -> &[ClientOutput<O>] {
286        &self.outputs
287    }
288
289    pub fn sms(&self) -> &[ClientOutputSM<S>] {
290        &self.sm_gens
291    }
292
293    pub fn with(mut self, other: Self) -> Self {
294        self.outputs.extend(other.outputs);
295        self.sm_gens.extend(other.sm_gens);
296        self
297    }
298
299    pub fn into_instanceless(self) -> InstancelessDynClientOutputBundle {
300        InstancelessDynClientOutputBundle {
301            outputs: self
302                .outputs
303                .into_iter()
304                .map(|output| InstancelessDynClientOutput {
305                    output: Box::new(output.output),
306                    amount: output.amount,
307                })
308                .collect(),
309            sm_gens: self
310                .sm_gens
311                .into_iter()
312                .map(|output_sm| InstancelessDynClientOutputSM {
313                    state_machines: states_to_instanceless_dyn(output_sm.state_machines),
314                })
315                .collect(),
316        }
317    }
318}
319
320impl<O, S> ClientOutputBundle<O, S> {
321    pub fn is_empty(&self) -> bool {
322        // Notably, sm_gen will not be called when outputs are empty anyway
323        self.outputs.is_empty()
324    }
325}
326
327impl<I, S> IntoDynInstance for ClientOutputBundle<I, S>
328where
329    I: IntoDynInstance<DynType = DynOutput> + 'static,
330    S: IntoDynInstance<DynType = DynState> + 'static,
331{
332    type DynType = ClientOutputBundle;
333
334    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
335        ClientOutputBundle {
336            outputs: self
337                .outputs
338                .into_iter()
339                .map(|output| output.into_dyn(module_instance_id))
340                .collect::<Vec<ClientOutput>>(),
341
342            sm_gens: self
343                .sm_gens
344                .into_iter()
345                .map(|output_sm| output_sm.into_dyn(module_instance_id))
346                .collect::<Vec<ClientOutputSM>>(),
347        }
348    }
349}
350
351impl IntoDynInstance for InstancelessDynClientOutputBundle {
352    type DynType = ClientOutputBundle;
353
354    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
355        ClientOutputBundle {
356            outputs: self
357                .outputs
358                .into_iter()
359                .map(|output| ClientOutput {
360                    output: DynOutput::from_parts(module_instance_id, output.output),
361                    amount: output.amount,
362                })
363                .collect::<Vec<ClientOutput>>(),
364
365            sm_gens: self
366                .sm_gens
367                .into_iter()
368                .map(|output_sm| ClientOutputSM {
369                    state_machines: states_add_instance(
370                        module_instance_id,
371                        output_sm.state_machines,
372                    ),
373                })
374                .collect::<Vec<ClientOutputSM>>(),
375        }
376    }
377}
378
379impl<I> IntoDynInstance for ClientOutput<I>
380where
381    I: IntoDynInstance<DynType = DynOutput> + 'static,
382{
383    type DynType = ClientOutput;
384
385    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutput {
386        ClientOutput {
387            output: self.output.into_dyn(module_instance_id),
388            amount: self.amount,
389        }
390    }
391}
392
393impl<S> IntoDynInstance for ClientOutputSM<S>
394where
395    S: IntoDynInstance<DynType = DynState> + 'static,
396{
397    type DynType = ClientOutputSM;
398
399    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputSM {
400        ClientOutputSM {
401            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
402        }
403    }
404}
405
406#[derive(Default, Clone, Debug)]
407pub struct TransactionBuilder {
408    inputs: Vec<ClientInputBundle>,
409    outputs: Vec<ClientOutputBundle>,
410}
411
412impl TransactionBuilder {
413    pub fn new() -> Self {
414        Self::default()
415    }
416
417    pub fn with_inputs(mut self, inputs: ClientInputBundle) -> Self {
418        self.inputs.push(inputs);
419        self
420    }
421
422    pub fn with_outputs(mut self, outputs: ClientOutputBundle) -> Self {
423        self.outputs.push(outputs);
424        self
425    }
426
427    pub fn build<C, R: RngCore + CryptoRng>(
428        self,
429        secp_ctx: &Secp256k1<C>,
430        mut rng: R,
431    ) -> (Transaction, Vec<DynState>)
432    where
433        C: secp256k1::Signing + secp256k1::Verification,
434    {
435        // `input_idx_to_bundle_idx[input_idx]` stores the index of a bundle the input
436        // at `input_idx` comes from, so we can call state machines of the
437        // corresponding bundle for every input bundle. It is always
438        // monotonically increasing, e.g. `[0, 0, 1, 2, 2, 2, 4]`
439        let (input_idx_to_bundle_idx, inputs, input_keys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
440            self.inputs
441                .iter()
442                .enumerate()
443                .flat_map(|(bundle_idx, bundle)| {
444                    bundle
445                        .inputs
446                        .iter()
447                        .map(move |input| (bundle_idx, input.input.clone(), input.keys.clone()))
448                }),
449        );
450        // `output_idx_to_bundle` works exactly like `input_idx_to_bundle_idx` above,
451        // but for outputs.
452        let (output_idx_to_bundle_idx, outputs): (Vec<_>, Vec<_>) = multiunzip(
453            self.outputs
454                .iter()
455                .enumerate()
456                .flat_map(|(bundle_idx, bundle)| {
457                    bundle
458                        .outputs
459                        .iter()
460                        .map(move |output| (bundle_idx, output.output.clone()))
461                }),
462        );
463        let nonce: [u8; 8] = rng.gen();
464
465        let txid = Transaction::tx_hash_from_parts(&inputs, &outputs, nonce);
466        let msg = secp256k1::Message::from_digest_slice(&txid[..]).expect("txid has right length");
467
468        let signatures = input_keys
469            .iter()
470            .flatten()
471            .map(|keypair| secp_ctx.sign_schnorr(&msg, keypair))
472            .collect();
473
474        let transaction = Transaction {
475            inputs,
476            outputs,
477            nonce,
478            signatures: TransactionSignature::NaiveMultisig(signatures),
479        };
480
481        let input_states = self
482            .inputs
483            .into_iter()
484            .enumerate()
485            .filter(|(_, bundle)| !bundle.is_empty())
486            .flat_map(|(bundle_idx, bundle)| {
487                let input_idxs = find_range_of_matching_items(&input_idx_to_bundle_idx, bundle_idx)
488                    .expect("Non empty bundles must always have a match");
489                bundle.sm_gens.into_iter().flat_map(move |sm| {
490                    (sm.state_machines)(OutPointRange::new(
491                        txid,
492                        IdxRange::from_inclusive(input_idxs.clone()).expect("can't overflow"),
493                    ))
494                })
495            });
496
497        let output_states = self
498            .outputs
499            .into_iter()
500            .enumerate()
501            .filter(|(_, bundle)| !bundle.is_empty())
502            .flat_map(|(bundle_idx, bundle)| {
503                let output_idxs =
504                    find_range_of_matching_items(&output_idx_to_bundle_idx, bundle_idx)
505                        .expect("Non empty bundles must always have a match");
506                bundle.sm_gens.into_iter().flat_map(move |sm| {
507                    (sm.state_machines)(OutPointRange::new(
508                        txid,
509                        IdxRange::from_inclusive(output_idxs.clone())
510                            .expect("can't possibly overflow"),
511                    ))
512                })
513            });
514        (transaction, input_states.chain(output_states).collect())
515    }
516
517    pub(crate) fn inputs(&self) -> impl Iterator<Item = &ClientInput> {
518        self.inputs.iter().flat_map(|i| i.inputs.iter())
519    }
520
521    pub(crate) fn outputs(&self) -> impl Iterator<Item = &ClientOutput> {
522        self.outputs.iter().flat_map(|i| i.outputs.iter())
523    }
524}
525
526/// Find the range of indexes in an monotonically increasing `arr`, that is
527/// equal to `item`
528fn find_range_of_matching_items(arr: &[usize], item: usize) -> Option<RangeInclusive<u64>> {
529    // `arr` must be monotonically increasing
530    debug_assert!(arr.windows(2).all(|w| w[0] <= w[1]));
531
532    arr.iter()
533        .enumerate()
534        .filter_map(|(arr_idx, arr_item)| (*arr_item == item).then_some(arr_idx as u64))
535        .fold(None, |cur: Option<(u64, u64)>, idx| {
536            Some(cur.map_or((idx, idx), |cur| (cur.0.min(idx), cur.1.max(idx))))
537        })
538        .map(|(start, end)| start..=end)
539}
540
541#[test]
542fn find_range_of_matching_items_sanity() {
543    assert_eq!(find_range_of_matching_items(&[0, 0], 0), Some(0..=1));
544    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 0), Some(0..=1));
545    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 1), Some(2..=2));
546    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 2), None);
547    assert_eq!(find_range_of_matching_items(&[], 0), None);
548}
549
550fn state_gen_to_dyn<S>(
551    state_gen: StateGenerator<S>,
552    module_instance: ModuleInstanceId,
553) -> StateGenerator<DynState>
554where
555    S: IntoDynInstance<DynType = DynState> + 'static,
556{
557    Arc::new(move |out_point_range| {
558        let states = state_gen(out_point_range);
559        states
560            .into_iter()
561            .map(|state| state.into_dyn(module_instance))
562            .collect()
563    })
564}
565
566#[cfg(test)]
567mod tests;