1use std::collections::{BTreeMap, HashMap};
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::io::Write;
5
6use anyhow::{ensure, format_err};
7use async_trait::async_trait;
8use bitcoin::hashes::sha256::{Hash as Sha256, HashEngine};
9use bitcoin::hashes::Hash as BitcoinHash;
10use bitcoin::secp256k1;
11use bls12_381::Scalar;
12use fedimint_core::config::{
13 DkgError, DkgGroup, DkgMessage, DkgPeerMsg, DkgResult, ISupportedDkgMessage,
14};
15use fedimint_core::core::{Decoder, ModuleInstanceId, ModuleKind};
16use fedimint_core::encoding::{Decodable, Encodable};
17use fedimint_core::module::registry::ModuleDecoderRegistry;
18use fedimint_core::module::PeerHandle;
19use fedimint_core::net::peers::MuxPeerConnections;
20use fedimint_core::runtime::spawn;
21use fedimint_core::{NumPeersExt, PeerId};
22use rand::rngs::OsRng;
23use rand::{RngCore, SeedableRng};
24use rand_chacha::ChaChaRng;
25use serde::de::DeserializeOwned;
26use serde::Serialize;
27use sha3::Digest;
28use threshold_crypto::ff::Field;
29use threshold_crypto::group::Curve;
30use threshold_crypto::poly::Commitment;
31use threshold_crypto::serde_impl::SerdeSecret;
32use threshold_crypto::{
33 G1Affine, G1Projective, G2Affine, G2Projective, PublicKeySet, SecretKeyShare,
34};
35
36struct Dkg<G> {
37 gen_g: G,
38 peers: Vec<PeerId>,
39 our_id: PeerId,
40 threshold: usize,
41 f1_poly: Vec<Scalar>,
42 f2_poly: Vec<Scalar>,
43 hashed_commits: BTreeMap<PeerId, Sha256>,
44 commitments: BTreeMap<PeerId, Vec<G>>,
45 sk_shares: BTreeMap<PeerId, Scalar>,
46 pk_shares: BTreeMap<PeerId, Vec<G>>,
47}
48
49impl<G: DkgGroup> Dkg<G> {
56 pub fn new(
58 group: G,
59 our_id: PeerId,
60 peers: Vec<PeerId>,
61 threshold: usize,
62 rng: &mut impl rand::RngCore,
63 ) -> (Self, DkgStep<G>) {
64 let f1_poly = random_scalar_coefficients(threshold - 1, rng);
65 let f2_poly = random_scalar_coefficients(threshold - 1, rng);
66
67 let mut dkg = Dkg {
68 gen_g: group,
69 peers,
70 our_id,
71 threshold,
72 f1_poly,
73 f2_poly,
74 hashed_commits: BTreeMap::new(),
75 commitments: BTreeMap::new(),
76 sk_shares: BTreeMap::new(),
77 pk_shares: BTreeMap::new(),
78 };
79
80 let commit: Vec<G> = dkg
82 .f1_poly
83 .iter()
84 .map(|c| dkg.gen_g * *c)
85 .zip(dkg.f2_poly.iter().map(|c| dkg.gen_h() * *c))
86 .map(|(g, h)| g + h)
87 .collect();
88
89 let hashed = Dkg::hash(&commit);
90 dkg.commitments.insert(our_id, commit);
91 dkg.hashed_commits.insert(our_id, hashed);
92 let step = dkg.broadcast(&DkgMessage::HashedCommit(hashed));
93
94 (dkg, step)
95 }
96
97 pub fn step(&mut self, peer: PeerId, msg: DkgMessage<G>) -> anyhow::Result<DkgStep<G>> {
99 match msg {
100 DkgMessage::HashedCommit(hashed) => {
101 match self.hashed_commits.get(&peer) {
102 Some(old) if *old != hashed => {
103 return Err(format_err!("{peer} sent us two hashes!"))
104 }
105 _ => self.hashed_commits.insert(peer, hashed),
106 };
107
108 if self.hashed_commits.len() == self.peers.len() {
109 let our_commit = self.commitments[&self.our_id].clone();
110 return Ok(self.broadcast(&DkgMessage::Commit(our_commit)));
111 }
112 }
113 DkgMessage::Commit(commit) => {
114 let hash = Self::hash(&commit);
115 ensure!(self.threshold == commit.len(), "wrong degree from {peer}");
116 ensure!(hash == self.hashed_commits[&peer], "wrong hash from {peer}");
117
118 match self.commitments.get(&peer) {
119 Some(old) if *old != commit => {
120 return Err(format_err!("{peer} sent us two commitments!"))
121 }
122 _ => self.commitments.insert(peer, commit),
123 };
124
125 if self.commitments.len() == self.peers.len() {
127 let mut messages = vec![];
128 for peer in &self.peers {
129 let s1 = evaluate_polynomial_scalar(&self.f1_poly, &scalar(peer));
130 let s2 = evaluate_polynomial_scalar(&self.f2_poly, &scalar(peer));
131
132 if *peer == self.our_id {
133 self.sk_shares.insert(self.our_id, s1);
134 } else {
135 messages.push((*peer, DkgMessage::Share(s1, s2)));
136 }
137 }
138 return Ok(DkgStep::Messages(messages));
139 }
140 }
141 DkgMessage::Share(s1, s2) => {
143 let share_product = (self.gen_g * s1) + (self.gen_h() * s2);
144 let commitment = self
145 .commitments
146 .get(&peer)
147 .ok_or_else(|| format_err!("{peer} sent share before commit"))?;
148 let commit_product: G = commitment
149 .iter()
150 .enumerate()
151 .map(|(idx, commit)| *commit * scalar(&self.our_id).pow(&[idx as u64, 0, 0, 0]))
152 .reduce(|a, b| a + b)
153 .expect("sums");
154
155 ensure!(share_product == commit_product, "bad commit from {peer}");
156 match self.sk_shares.get(&peer) {
157 Some(old) if *old != s1 => {
158 return Err(format_err!("{peer} sent us two shares!"))
159 }
160 _ => self.sk_shares.insert(peer, s1),
161 };
162
163 if self.sk_shares.len() == self.peers.len() {
164 let extract: Vec<G> = self.f1_poly.iter().map(|c| self.gen_g * *c).collect();
165
166 self.pk_shares.insert(self.our_id, extract.clone());
167 return Ok(self.broadcast(&DkgMessage::Extract(extract)));
168 }
169 }
170 DkgMessage::Extract(extract) => {
172 let share = self
173 .sk_shares
174 .get(&peer)
175 .ok_or_else(|| format_err!("{peer} sent extract before share"))?;
176 let share_product = self.gen_g * *share;
177 let extract_product: G = extract
178 .iter()
179 .enumerate()
180 .map(|(idx, commit)| *commit * scalar(&self.our_id).pow(&[idx as u64, 0, 0, 0]))
181 .reduce(|a, b| a + b)
182 .expect("sums");
183
184 ensure!(share_product == extract_product, "bad extract from {peer}");
185 ensure!(self.threshold == extract.len(), "wrong degree from {peer}");
186 match self.pk_shares.get(&peer) {
187 Some(old) if *old != extract => {
188 return Err(format_err!("{peer} sent us two extracts!"))
189 }
190 _ => self.pk_shares.insert(peer, extract),
191 };
192
193 if self.pk_shares.len() == self.peers.len() {
194 let sks = self.sk_shares.values().sum();
195
196 let pks: Vec<G> = (0..self.threshold)
197 .map(|idx| {
198 self.pk_shares
199 .values()
200 .map(|shares| *shares.get(idx).unwrap())
201 .reduce(|a, b| a + b)
202 .expect("sums")
203 })
204 .collect();
205
206 return Ok(DkgStep::Result(DkgKeys {
207 public_key_set: pks,
208 secret_key_share: sks,
209 }));
210 }
211 }
212 }
213
214 Ok(DkgStep::Messages(vec![]))
215 }
216
217 fn hash(poly: &[G]) -> Sha256 {
218 let mut engine = HashEngine::default();
219 for element in poly {
220 engine
221 .write_all(element.to_bytes().as_ref())
222 .expect("hashes");
223 }
224 Sha256::from_engine(engine)
225 }
226
227 fn broadcast(&self, msg: &DkgMessage<G>) -> DkgStep<G> {
228 let others = self.peers.iter().filter(|p| **p != self.our_id);
229 DkgStep::Messages(others.map(|peer| (*peer, msg.clone())).collect())
230 }
231
232 fn gen_h(&self) -> G {
234 let mut hash_engine = sha3::Sha3_256::new();
235
236 hash_engine.update(self.gen_g.clone().to_bytes().as_ref());
237
238 G::random(&mut ChaChaRng::from_seed(hash_engine.finalize().into()))
239 }
240}
241
242pub fn scalar(peer: &PeerId) -> Scalar {
244 Scalar::from(peer.to_usize() as u64 + 1)
245}
246
247pub struct DkgRunner<T> {
248 peers: Vec<PeerId>,
249 our_id: PeerId,
250 dkg_config: HashMap<T, usize>,
251}
252
253impl<T> DkgRunner<T>
257where
258 T: Serialize + DeserializeOwned + Unpin + Send + Clone + Eq + Hash,
259{
260 pub fn multi(keys: Vec<T>, threshold: usize, our_id: &PeerId, peers: &[PeerId]) -> Self {
262 let dkg_config = keys.into_iter().map(|key| (key, threshold)).collect();
263
264 Self {
265 our_id: *our_id,
266 peers: peers.to_vec(),
267 dkg_config,
268 }
269 }
270
271 pub fn new(key: T, threshold: usize, our_id: &PeerId, peers: &[PeerId]) -> Self {
273 Self::multi(vec![key], threshold, our_id, peers)
274 }
275
276 pub fn add(&mut self, key: T, threshold: usize) {
278 self.dkg_config.insert(key, threshold);
279 }
280
281 pub async fn run_g2(
283 &mut self,
284 module_id: ModuleInstanceId,
285 connections: &MuxPeerConnections<(ModuleInstanceId, String), DkgPeerMsg>,
286 ) -> DkgResult<HashMap<T, DkgKeys<G2Projective>>> {
287 self.run(module_id, G2Projective::generator(), connections)
288 .await
289 }
290
291 pub async fn run_g1(
293 &mut self,
294 module_id: ModuleInstanceId,
295 connections: &MuxPeerConnections<(ModuleInstanceId, String), DkgPeerMsg>,
296 ) -> DkgResult<HashMap<T, DkgKeys<G1Projective>>> {
297 self.run(module_id, G1Projective::generator(), connections)
298 .await
299 }
300
301 pub async fn run<G: DkgGroup>(
306 &mut self,
307 module_id: ModuleInstanceId,
308 group: G,
309 connections: &MuxPeerConnections<(ModuleInstanceId, String), DkgPeerMsg>,
310 ) -> DkgResult<HashMap<T, DkgKeys<G>>>
311 where
312 DkgMessage<G>: ISupportedDkgMessage,
313 {
314 let (send, mut receive) = tokio::sync::mpsc::channel(10_000);
316
317 self.dkg_config
319 .clone()
320 .into_iter()
321 .for_each(|(key, threshold)| {
322 let our_id = self.our_id;
323 let peers = self.peers.clone();
324 let connections = connections.clone();
325 let key = serde_json::to_string(&key).expect("serialization can't fail");
326 let send = send.clone();
327
328 spawn("dkg runner", async move {
329 let (dkg, step) = Dkg::new(group, our_id, peers, threshold, &mut OsRng);
330 let result =
331 Self::run_dkg_key((module_id, key.clone()), connections, dkg, step).await;
332 send.send((key, result)).await.expect("channel open");
333 });
334 });
335
336 let mut results: HashMap<T, DkgKeys<G>> = HashMap::new();
338 while results.len() < self.dkg_config.len() {
339 let (key, result) = receive.recv().await.expect("channel open");
340 let key = serde_json::from_str(&key).expect("serialization can't fail");
341 results.insert(key, result?);
342 }
343 Ok(results)
344 }
345
346 async fn run_dkg_key<G: DkgGroup>(
348 key_id: (ModuleInstanceId, String),
349 connections: MuxPeerConnections<(ModuleInstanceId, String), DkgPeerMsg>,
350 mut dkg: Dkg<G>,
351 initial_step: DkgStep<G>,
352 ) -> DkgResult<DkgKeys<G>>
353 where
354 DkgMessage<G>: ISupportedDkgMessage,
355 {
356 if let DkgStep::Messages(messages) = initial_step {
357 for (peer, msg) in messages {
358 let send_msg = DkgPeerMsg::DistributedGen(msg.to_msg());
359 connections.send(&[peer], key_id.clone(), send_msg).await?;
360 }
361 }
362
363 loop {
365 let (peer, msg) = connections.receive(key_id.clone()).await?;
366
367 let message = match msg {
368 DkgPeerMsg::DistributedGen(v) => Ok(v),
369 _ => Err(format_err!(
370 "Key {key_id:?} wrong message received: {msg:?}"
371 )),
372 }?;
373
374 let message = ISupportedDkgMessage::from_msg(message)?;
375 let step = dkg.step(peer, message)?;
376
377 match step {
378 DkgStep::Messages(messages) => {
379 for (peer, msg) in messages {
380 let send_msg = DkgPeerMsg::DistributedGen(msg.to_msg());
381 connections.send(&[peer], key_id.clone(), send_msg).await?;
382 }
383 }
384 DkgStep::Result(result) => {
385 return Ok(result);
386 }
387 }
388 }
389 }
390}
391
392pub fn random_scalar_coefficients(degree: usize, rng: &mut impl RngCore) -> Vec<Scalar> {
393 (0..=degree).map(|_| random_scalar(rng)).collect()
394}
395
396fn random_scalar(rng: &mut impl RngCore) -> Scalar {
397 Scalar::random(rng)
398}
399
400pub fn evaluate_polynomial_scalar(coefficients: &[Scalar], x: &Scalar) -> Scalar {
401 coefficients
402 .iter()
403 .copied()
404 .rev()
405 .reduce(|acc, coefficient| acc * x + coefficient)
406 .expect("We have at least one coefficient")
407}
408
409#[derive(Debug, Clone)]
410pub enum DkgStep<G: DkgGroup> {
411 Messages(Vec<(PeerId, DkgMessage<G>)>),
412 Result(DkgKeys<G>),
413}
414
415#[derive(Debug, Clone)]
416pub struct DkgKeys<G> {
417 pub public_key_set: Vec<G>,
418 pub secret_key_share: Scalar,
419}
420
421#[derive(Debug, Clone)]
423pub struct ThresholdKeys {
424 pub public_key_set: PublicKeySet,
425 pub secret_key_share: SerdeSecret<SecretKeyShare>,
426}
427
428impl DkgKeys<G2Projective> {
429 pub fn tbs(self) -> (Vec<G2Projective>, tbs::SecretKeyShare) {
430 (
431 self.public_key_set,
432 tbs::SecretKeyShare(self.secret_key_share),
433 )
434 }
435}
436
437impl DkgKeys<G1Projective> {
438 pub fn threshold_crypto(&self) -> ThresholdKeys {
439 ThresholdKeys {
440 public_key_set: PublicKeySet::from(Commitment::from(self.public_key_set.clone())),
441 secret_key_share: SerdeSecret(SecretKeyShare::from_mut(
442 &mut self.secret_key_share.clone(),
443 )),
444 }
445 }
446
447 pub fn tpe(self) -> (Vec<G1Projective>, Scalar) {
448 (self.public_key_set, self.secret_key_share)
449 }
450}
451
452pub fn evaluate_polynomial_g1(coefficients: &[G1Projective], x: &Scalar) -> G1Affine {
453 coefficients
454 .iter()
455 .copied()
456 .rev()
457 .reduce(|acc, coefficient| acc * x + coefficient)
458 .expect("We have at least one coefficient")
459 .to_affine()
460}
461
462pub fn evaluate_polynomial_g2(coefficients: &[G2Projective], x: &Scalar) -> G2Affine {
463 coefficients
464 .iter()
465 .copied()
466 .rev()
467 .reduce(|acc, coefficient| acc * x + coefficient)
468 .expect("We have at least one coefficient")
469 .to_affine()
470}
471
472#[async_trait]
475pub trait PeerHandleOps {
476 async fn run_dkg_g1<T>(&self, v: T) -> DkgResult<HashMap<T, DkgKeys<G1Projective>>>
477 where
478 T: Serialize + DeserializeOwned + Unpin + Send + Clone + Eq + Hash + Sync;
479
480 async fn run_dkg_multi_g2<T>(&self, v: Vec<T>) -> DkgResult<HashMap<T, DkgKeys<G2Projective>>>
481 where
482 T: Serialize + DeserializeOwned + Unpin + Send + Clone + Eq + Hash + Sync;
483
484 async fn exchange_pubkeys(
487 &self,
488 dkg_key: String,
489 key: secp256k1::PublicKey,
490 ) -> DkgResult<BTreeMap<PeerId, secp256k1::PublicKey>>;
491
492 async fn exchange_with_peers<T: Encodable + Decodable + Send + Sync>(
498 &self,
499 dkg_key: String,
500 data: T,
501 kind: ModuleKind,
502 decoder: Decoder,
503 ) -> DkgResult<BTreeMap<PeerId, T>>;
504}
505
506#[async_trait]
507impl<'a> PeerHandleOps for PeerHandle<'a> {
508 async fn run_dkg_g1<T>(&self, v: T) -> DkgResult<HashMap<T, DkgKeys<G1Projective>>>
509 where
510 T: Serialize + DeserializeOwned + Unpin + Send + Clone + Eq + Hash + Sync,
511 {
512 let mut dkg = DkgRunner::new(
513 v,
514 self.peers.to_num_peers().threshold(),
515 &self.our_id,
516 &self.peers,
517 );
518 dkg.run_g1(self.module_instance_id, self.connections).await
519 }
520
521 async fn run_dkg_multi_g2<T>(&self, v: Vec<T>) -> DkgResult<HashMap<T, DkgKeys<G2Projective>>>
522 where
523 T: Serialize + DeserializeOwned + Unpin + Send + Clone + Eq + Hash + Sync,
524 {
525 let mut dkg = DkgRunner::multi(
526 v,
527 self.peers.to_num_peers().threshold(),
528 &self.our_id,
529 &self.peers,
530 );
531
532 dkg.run_g2(self.module_instance_id, self.connections).await
533 }
534
535 async fn exchange_pubkeys(
536 &self,
537 dkg_key: String,
538 key: secp256k1::PublicKey,
539 ) -> DkgResult<BTreeMap<PeerId, secp256k1::PublicKey>> {
540 let mut peer_peg_in_keys: BTreeMap<PeerId, secp256k1::PublicKey> = BTreeMap::new();
541
542 self.connections
543 .send(
544 &self.peers,
545 (self.module_instance_id, dkg_key.clone()),
546 DkgPeerMsg::PublicKey(key),
547 )
548 .await?;
549
550 peer_peg_in_keys.insert(self.our_id, key);
551 while peer_peg_in_keys.len() < self.peers.len() {
552 match self
553 .connections
554 .receive((self.module_instance_id, dkg_key.clone()))
555 .await?
556 {
557 (peer, DkgPeerMsg::PublicKey(key)) => {
558 peer_peg_in_keys.insert(peer, key);
559 }
560 (peer, msg) => {
561 return Err(
562 format_err!("Invalid message received from: {peer}: {msg:?}").into(),
563 );
564 }
565 }
566 }
567
568 Ok(peer_peg_in_keys)
569 }
570
571 async fn exchange_with_peers<T: Encodable + Decodable + Send + Sync>(
572 &self,
573 dkg_key: String,
574 data: T,
575 kind: ModuleKind,
576 decoder: Decoder,
577 ) -> DkgResult<BTreeMap<PeerId, T>> {
578 let mut peer_data: BTreeMap<PeerId, T> = BTreeMap::new();
579 let msg = DkgPeerMsg::Module(data.consensus_encode_to_vec());
580
581 self.connections
582 .send(&self.peers, (self.module_instance_id, dkg_key.clone()), msg)
583 .await?;
584 peer_data.insert(self.our_id, data);
585
586 let modules =
587 ModuleDecoderRegistry::new([(self.module_instance_id, kind.clone(), decoder)]);
588 while peer_data.len() < self.peers.len() {
589 match self
590 .connections
591 .receive((self.module_instance_id, dkg_key.clone()))
592 .await?
593 {
594 (peer, DkgPeerMsg::Module(bytes)) => {
595 let received_data: T = T::consensus_decode_vec(bytes, &modules)
596 .map_err(|_| DkgError::ModuleDecodeError(kind.clone()))?;
597 peer_data.insert(peer, received_data);
598 }
599 (peer, msg) => {
600 return Err(format_err!("Invalid message received from {peer}: {msg:?}").into());
601 }
602 }
603 }
604
605 Ok(peer_data)
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use std::collections::{HashMap, VecDeque};
612
613 use fedimint_core::PeerId;
614 use rand::rngs::OsRng;
615 use threshold_crypto::{G1Projective, G2Projective};
616
617 use crate::config::distributedgen::{
618 evaluate_polynomial_g2, scalar, Dkg, DkgGroup, DkgKeys, DkgStep, ThresholdKeys,
619 };
620
621 #[test_log::test]
622 fn test_dkg() {
623 for (peer, keys) in run(G1Projective::generator()) {
624 let ThresholdKeys {
625 public_key_set,
626 secret_key_share,
627 } = keys.threshold_crypto();
628 assert_eq!(public_key_set.threshold(), 2);
629 assert_eq!(
630 public_key_set.public_key_share(peer.to_usize()),
631 secret_key_share.public_key_share()
632 );
633 }
634
635 for (peer, keys) in run(G2Projective::generator()) {
636 let (pk, sk) = keys.tbs();
637 assert_eq!(pk.len(), 3);
638 assert_eq!(
639 evaluate_polynomial_g2(&pk, &scalar(&peer)),
640 sk.to_pub_key_share().0
641 );
642 }
643 }
644
645 fn run<G: DkgGroup>(group: G) -> HashMap<PeerId, DkgKeys<G>> {
646 let mut rng = OsRng;
647 let num_peers = 4;
648 let threshold = 3;
649 let peers = (0..num_peers as u16).map(PeerId::from).collect::<Vec<_>>();
650
651 let mut steps: VecDeque<(PeerId, DkgStep<G>)> = VecDeque::new();
652 let mut dkgs: HashMap<PeerId, Dkg<G>> = HashMap::new();
653 let mut keys: HashMap<PeerId, DkgKeys<G>> = HashMap::new();
654
655 for peer in &peers {
656 let (dkg, step) = Dkg::new(group, *peer, peers.clone(), threshold, &mut rng);
657 dkgs.insert(*peer, dkg);
658 steps.push_back((*peer, step));
659 }
660
661 while keys.len() < peers.len() {
662 match steps.pop_front() {
663 Some((peer, DkgStep::Messages(messages))) => {
664 for (receive_peer, msg) in messages {
665 let receive_dkg = dkgs.get_mut(&receive_peer).unwrap();
666 let step = receive_dkg.step(peer, msg);
667 steps.push_back((receive_peer, step.unwrap()));
668 }
669 }
670 Some((peer, DkgStep::Result(step_keys))) => {
671 keys.insert(peer, step_keys);
672 }
673 _ => {}
674 }
675 }
676
677 keys
678 }
679}