fedimint_client/transaction/
builder.rs

1use std::fmt;
2use std::ops::RangeInclusive;
3use std::sync::Arc;
4
5use bitcoin::key::Keypair;
6use bitcoin::secp256k1;
7use fedimint_core::core::{
8    DynInput, DynOutput, IInput, IOutput, IntoDynInstance, ModuleInstanceId,
9};
10use fedimint_core::encoding::{Decodable, Encodable};
11use fedimint_core::task::{MaybeSend, MaybeSync};
12use fedimint_core::transaction::{Transaction, TransactionSignature};
13use fedimint_core::Amount;
14use itertools::multiunzip;
15use rand::{CryptoRng, Rng, RngCore};
16use secp256k1::Secp256k1;
17
18use crate::module::{IdxRange, OutPointRange, StateGenerator};
19use crate::sm::{self, DynState};
20use crate::{
21    states_add_instance, states_to_instanceless_dyn, InstancelessDynClientInput,
22    InstancelessDynClientInputBundle, InstancelessDynClientInputSM, InstancelessDynClientOutput,
23    InstancelessDynClientOutputBundle, InstancelessDynClientOutputSM,
24};
25
26#[derive(Clone, Debug)]
27pub struct ClientInput<I = DynInput> {
28    pub input: I,
29    pub keys: Vec<Keypair>,
30    pub amount: Amount,
31}
32
33#[derive(Clone)]
34pub struct ClientInputSM<S = DynState> {
35    pub state_machines: StateGenerator<S>,
36}
37
38impl<S> fmt::Debug for ClientInputSM<S> {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.write_str("ClientInputSM")
41    }
42}
43
44/// A fake [`sm::Context`] for [`NeverClientStateMachine`]
45#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
46pub enum NeverClientContext {}
47
48impl sm::Context for NeverClientContext {
49    const KIND: Option<fedimint_core::core::ModuleKind> = None;
50}
51
52/// A fake [`sm::State`] that can actually never happen.
53///
54/// Useful as a default for type inference in cases where there are no
55/// state machines involved in [`ClientInputBundle`].
56#[derive(Debug, Clone, Eq, PartialEq, Hash, Decodable, Encodable)]
57pub enum NeverClientStateMachine {}
58
59impl IntoDynInstance for NeverClientStateMachine {
60    type DynType = DynState;
61
62    fn into_dyn(self, _instance_id: ModuleInstanceId) -> Self::DynType {
63        unreachable!()
64    }
65}
66impl sm::State for NeverClientStateMachine {
67    type ModuleContext = NeverClientContext;
68
69    fn transitions(
70        &self,
71        _context: &Self::ModuleContext,
72        _global_context: &crate::DynGlobalClientContext,
73    ) -> Vec<sm::StateTransition<Self>> {
74        unreachable!()
75    }
76
77    fn operation_id(&self) -> fedimint_core::core::OperationId {
78        unreachable!()
79    }
80}
81
82/// A group of inputs and state machines responsible for driving their state
83///
84/// These must be kept together as a whole when including in a transaction.
85#[derive(Clone, Debug)]
86pub struct ClientInputBundle<I = DynInput, S = DynState> {
87    pub(crate) inputs: Vec<ClientInput<I>>,
88    pub(crate) sms: Vec<ClientInputSM<S>>,
89}
90
91impl<I> ClientInputBundle<I, NeverClientStateMachine> {
92    /// A version of [`Self::new`] for times where input does not require any
93    /// state machines
94    ///
95    /// This avoids type inference issues of `S`, and saves some typing.
96    pub fn new_no_sm(inputs: Vec<ClientInput<I>>) -> Self {
97        Self {
98            inputs,
99            sms: vec![],
100        }
101    }
102}
103
104impl<I, S> ClientInputBundle<I, S>
105where
106    I: IInput + MaybeSend + MaybeSync + 'static,
107    S: sm::IState + MaybeSend + MaybeSync + 'static,
108{
109    pub fn new(inputs: Vec<ClientInput<I>>, sms: Vec<ClientInputSM<S>>) -> Self {
110        Self { inputs, sms }
111    }
112
113    pub fn inputs(&self) -> &[ClientInput<I>] {
114        &self.inputs
115    }
116
117    pub fn sms(&self) -> &[ClientInputSM<S>] {
118        &self.sms
119    }
120
121    pub fn into_instanceless(self) -> InstancelessDynClientInputBundle {
122        InstancelessDynClientInputBundle {
123            inputs: self
124                .inputs
125                .into_iter()
126                .map(|input| InstancelessDynClientInput {
127                    input: Box::new(input.input),
128                    keys: input.keys,
129                    amount: input.amount,
130                })
131                .collect(),
132            sms: self
133                .sms
134                .into_iter()
135                .map(|input_sm| InstancelessDynClientInputSM {
136                    state_machines: states_to_instanceless_dyn(input_sm.state_machines),
137                })
138                .collect(),
139        }
140    }
141}
142
143impl<I> IntoDynInstance for ClientInput<I>
144where
145    I: IntoDynInstance<DynType = DynInput> + 'static,
146{
147    type DynType = ClientInput;
148
149    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInput {
150        ClientInput {
151            input: self.input.into_dyn(module_instance_id),
152            keys: self.keys,
153            amount: self.amount,
154        }
155    }
156}
157
158impl<S> IntoDynInstance for ClientInputSM<S>
159where
160    S: IntoDynInstance<DynType = DynState> + 'static,
161{
162    type DynType = ClientInputSM;
163
164    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputSM {
165        ClientInputSM {
166            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
167        }
168    }
169}
170
171impl<I, S> IntoDynInstance for ClientInputBundle<I, S>
172where
173    I: IntoDynInstance<DynType = DynInput> + 'static,
174    S: IntoDynInstance<DynType = DynState> + 'static,
175{
176    type DynType = ClientInputBundle;
177
178    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
179        ClientInputBundle {
180            inputs: self
181                .inputs
182                .into_iter()
183                .map(|input| input.into_dyn(module_instance_id))
184                .collect::<Vec<ClientInput>>(),
185
186            sms: self
187                .sms
188                .into_iter()
189                .map(|input_sm| input_sm.into_dyn(module_instance_id))
190                .collect::<Vec<ClientInputSM>>(),
191        }
192    }
193}
194
195impl IntoDynInstance for InstancelessDynClientInputBundle {
196    type DynType = ClientInputBundle;
197
198    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientInputBundle {
199        ClientInputBundle {
200            inputs: self
201                .inputs
202                .into_iter()
203                .map(|input| ClientInput {
204                    input: DynInput::from_parts(module_instance_id, input.input),
205                    keys: input.keys,
206                    amount: input.amount,
207                })
208                .collect::<Vec<ClientInput>>(),
209
210            sms: self
211                .sms
212                .into_iter()
213                .map(|input_sm| ClientInputSM {
214                    state_machines: states_add_instance(
215                        module_instance_id,
216                        input_sm.state_machines,
217                    ),
218                })
219                .collect::<Vec<ClientInputSM>>(),
220        }
221    }
222}
223
224#[derive(Clone, Debug)]
225pub struct ClientOutputBundle<O = DynOutput, S = DynState> {
226    pub(crate) outputs: Vec<ClientOutput<O>>,
227    pub(crate) sms: Vec<ClientOutputSM<S>>,
228}
229
230#[derive(Clone, Debug)]
231pub struct ClientOutput<O = DynOutput> {
232    pub output: O,
233    pub amount: Amount,
234}
235
236#[derive(Clone)]
237pub struct ClientOutputSM<S = DynState> {
238    pub state_machines: StateGenerator<S>,
239}
240
241impl<S> fmt::Debug for ClientOutputSM<S> {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        f.write_str("ClientOutputSM")
244    }
245}
246impl<O> ClientOutputBundle<O, NeverClientStateMachine> {
247    /// A version of [`Self::new`] for times where output does not require any
248    /// state machines
249    ///
250    /// This avoids type inference issues of `S`, and saves some typing.
251    pub fn new_no_sm(outputs: Vec<ClientOutput<O>>) -> Self {
252        Self {
253            outputs,
254            sms: vec![],
255        }
256    }
257}
258
259impl<O, S> ClientOutputBundle<O, S>
260where
261    O: IOutput + MaybeSend + MaybeSync + 'static,
262    S: sm::IState + MaybeSend + MaybeSync + 'static,
263{
264    pub fn new(outputs: Vec<ClientOutput<O>>, sms: Vec<ClientOutputSM<S>>) -> Self {
265        Self { outputs, sms }
266    }
267
268    pub fn outputs(&self) -> &[ClientOutput<O>] {
269        &self.outputs
270    }
271
272    pub fn sms(&self) -> &[ClientOutputSM<S>] {
273        &self.sms
274    }
275
276    pub fn with(mut self, other: Self) -> Self {
277        self.outputs.extend(other.outputs);
278        self.sms.extend(other.sms);
279        self
280    }
281
282    pub fn into_instanceless(self) -> InstancelessDynClientOutputBundle {
283        InstancelessDynClientOutputBundle {
284            outputs: self
285                .outputs
286                .into_iter()
287                .map(|output| InstancelessDynClientOutput {
288                    output: Box::new(output.output),
289                    amount: output.amount,
290                })
291                .collect(),
292            sms: self
293                .sms
294                .into_iter()
295                .map(|output_sm| InstancelessDynClientOutputSM {
296                    state_machines: states_to_instanceless_dyn(output_sm.state_machines),
297                })
298                .collect(),
299        }
300    }
301}
302
303impl<I, S> IntoDynInstance for ClientOutputBundle<I, S>
304where
305    I: IntoDynInstance<DynType = DynOutput> + 'static,
306    S: IntoDynInstance<DynType = DynState> + 'static,
307{
308    type DynType = ClientOutputBundle;
309
310    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
311        ClientOutputBundle {
312            outputs: self
313                .outputs
314                .into_iter()
315                .map(|output| output.into_dyn(module_instance_id))
316                .collect::<Vec<ClientOutput>>(),
317
318            sms: self
319                .sms
320                .into_iter()
321                .map(|output_sm| output_sm.into_dyn(module_instance_id))
322                .collect::<Vec<ClientOutputSM>>(),
323        }
324    }
325}
326
327impl IntoDynInstance for InstancelessDynClientOutputBundle {
328    type DynType = ClientOutputBundle;
329
330    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputBundle {
331        ClientOutputBundle {
332            outputs: self
333                .outputs
334                .into_iter()
335                .map(|output| ClientOutput {
336                    output: DynOutput::from_parts(module_instance_id, output.output),
337                    amount: output.amount,
338                })
339                .collect::<Vec<ClientOutput>>(),
340
341            sms: self
342                .sms
343                .into_iter()
344                .map(|output_sm| ClientOutputSM {
345                    state_machines: states_add_instance(
346                        module_instance_id,
347                        output_sm.state_machines,
348                    ),
349                })
350                .collect::<Vec<ClientOutputSM>>(),
351        }
352    }
353}
354
355impl<I> IntoDynInstance for ClientOutput<I>
356where
357    I: IntoDynInstance<DynType = DynOutput> + 'static,
358{
359    type DynType = ClientOutput;
360
361    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutput {
362        ClientOutput {
363            output: self.output.into_dyn(module_instance_id),
364            amount: self.amount,
365        }
366    }
367}
368
369impl<S> IntoDynInstance for ClientOutputSM<S>
370where
371    S: IntoDynInstance<DynType = DynState> + 'static,
372{
373    type DynType = ClientOutputSM;
374
375    fn into_dyn(self, module_instance_id: ModuleInstanceId) -> ClientOutputSM {
376        ClientOutputSM {
377            state_machines: state_gen_to_dyn(self.state_machines, module_instance_id),
378        }
379    }
380}
381
382#[derive(Default, Clone, Debug)]
383pub struct TransactionBuilder {
384    inputs: Vec<ClientInputBundle>,
385    outputs: Vec<ClientOutputBundle>,
386}
387
388impl TransactionBuilder {
389    pub fn new() -> Self {
390        Self::default()
391    }
392
393    pub fn with_inputs(mut self, inputs: ClientInputBundle) -> Self {
394        self.inputs.push(inputs);
395        self
396    }
397
398    pub fn with_outputs(mut self, outputs: ClientOutputBundle) -> Self {
399        self.outputs.push(outputs);
400        self
401    }
402
403    pub fn build<C, R: RngCore + CryptoRng>(
404        self,
405        secp_ctx: &Secp256k1<C>,
406        mut rng: R,
407    ) -> (Transaction, Vec<DynState>)
408    where
409        C: secp256k1::Signing + secp256k1::Verification,
410    {
411        // `input_idx_to_bundle_idx[input_idx]` stores the index of a bundle the input
412        // at `input_idx` comes from, so we can call state machines of the
413        // corresponding bundle for every input bundle. It is always
414        // monotonically increasing, e.g. `[0, 0, 1, 2, 2, 2, 4]`
415        let (input_idx_to_bundle_idx, inputs, input_keys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
416            self.inputs
417                .iter()
418                .enumerate()
419                .flat_map(|(bundle_idx, bundle)| {
420                    bundle
421                        .inputs
422                        .iter()
423                        .map(move |input| (bundle_idx, input.input.clone(), input.keys.clone()))
424                }),
425        );
426        // `output_idx_to_bundle` works exactly like `input_idx_to_bundle_idx` above,
427        // but for outputs.
428        let (output_idx_to_bundle_idx, outputs): (Vec<_>, Vec<_>) = multiunzip(
429            self.outputs
430                .iter()
431                .enumerate()
432                .flat_map(|(bundle_idx, bundle)| {
433                    bundle
434                        .outputs
435                        .iter()
436                        .map(move |output| (bundle_idx, output.output.clone()))
437                }),
438        );
439        let nonce: [u8; 8] = rng.gen();
440
441        let txid = Transaction::tx_hash_from_parts(&inputs, &outputs, nonce);
442        let msg = secp256k1::Message::from_digest_slice(&txid[..]).expect("txid has right length");
443
444        let signatures = input_keys
445            .iter()
446            .flatten()
447            .map(|keypair| secp_ctx.sign_schnorr(&msg, keypair))
448            .collect();
449
450        let transaction = Transaction {
451            inputs,
452            outputs,
453            nonce,
454            signatures: TransactionSignature::NaiveMultisig(signatures),
455        };
456
457        let input_states = self
458            .inputs
459            .into_iter()
460            .enumerate()
461            .flat_map(|(bundle_idx, bundle)| {
462                let input_idxs = find_range_of_matching_items(&input_idx_to_bundle_idx, bundle_idx);
463                bundle.sms.into_iter().flat_map(move |sm| {
464                    if let Some(input_idxs) = input_idxs.as_ref() {
465                        (sm.state_machines)(OutPointRange::new(
466                            txid,
467                            IdxRange::from(input_idxs.clone()),
468                        ))
469                    } else {
470                        vec![]
471                    }
472                })
473            });
474
475        let output_states =
476            self.outputs
477                .into_iter()
478                .enumerate()
479                .flat_map(|(bundle_idx, bundle)| {
480                    let output_idxs =
481                        find_range_of_matching_items(&output_idx_to_bundle_idx, bundle_idx);
482                    bundle.sms.into_iter().flat_map(move |sm| {
483                        if let Some(output_idxs) = output_idxs.as_ref() {
484                            (sm.state_machines)(OutPointRange::new(
485                                txid,
486                                IdxRange::from(output_idxs.clone()),
487                            ))
488                        } else {
489                            vec![]
490                        }
491                    })
492                });
493        (transaction, input_states.chain(output_states).collect())
494    }
495
496    pub(crate) fn inputs(&self) -> impl Iterator<Item = &ClientInput> {
497        self.inputs.iter().flat_map(|i| i.inputs.iter())
498    }
499
500    pub(crate) fn outputs(&self) -> impl Iterator<Item = &ClientOutput> {
501        self.outputs.iter().flat_map(|i| i.outputs.iter())
502    }
503}
504
505/// Find the range of indexes in an monotonically increasing `arr`, that is
506/// equal to `item`
507fn find_range_of_matching_items(arr: &[usize], item: usize) -> Option<RangeInclusive<u64>> {
508    // `arr` must be monotonically increasing
509    debug_assert!(arr.windows(2).all(|w| w[0] <= w[1]));
510
511    arr.iter()
512        .enumerate()
513        .filter_map(|(arr_idx, arr_item)| (*arr_item == item).then_some(arr_idx as u64))
514        .fold(None, |cur: Option<(u64, u64)>, idx| {
515            Some(cur.map_or((idx, idx), |cur| (cur.0.min(idx), cur.1.max(idx))))
516        })
517        .map(|(start, end)| start..=end)
518}
519
520#[test]
521fn find_range_of_matching_items_sanity() {
522    assert_eq!(find_range_of_matching_items(&[0, 0], 0), Some(0..=1));
523    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 0), Some(0..=1));
524    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 1), Some(2..=2));
525    assert_eq!(find_range_of_matching_items(&[0, 0, 1], 2), None);
526    assert_eq!(find_range_of_matching_items(&[], 0), None);
527}
528
529fn state_gen_to_dyn<S>(
530    state_gen: StateGenerator<S>,
531    module_instance: ModuleInstanceId,
532) -> StateGenerator<DynState>
533where
534    S: IntoDynInstance<DynType = DynState> + 'static,
535{
536    Arc::new(move |out_point_range| {
537        let states = state_gen(out_point_range);
538        states
539            .into_iter()
540            .map(|state| state.into_dyn(module_instance))
541            .collect()
542    })
543}