1use std::collections::BTreeMap;
2use std::fmt::Debug;
3use std::io::Write;
4
5use anyhow::{ensure, Context};
6use async_trait::async_trait;
7use bitcoin::hashes::sha256::{Hash as Sha256, HashEngine};
8use bitcoin::hashes::Hash as BitcoinHash;
9use bls12_381::Scalar;
10use fedimint_core::config::{DkgGroup, DkgMessage, DkgPeerMessage, ISupportedDkgMessage};
11use fedimint_core::encoding::{Decodable, Encodable};
12use fedimint_core::module::registry::ModuleDecoderRegistry;
13use fedimint_core::module::PeerHandle;
14use fedimint_core::net::peers::{DynP2PConnections, Recipient};
15use fedimint_core::{NumPeers, PeerId};
16use rand::rngs::OsRng;
17use rand::SeedableRng;
18use rand_chacha::ChaChaRng;
19use threshold_crypto::ff::Field;
20use threshold_crypto::group::Curve;
21use threshold_crypto::{G1Affine, G1Projective, G2Affine, G2Projective};
22
23struct Dkg<G> {
24 num_peers: NumPeers,
25 identity: PeerId,
26 generator: G,
27 f1_poly: Vec<Scalar>,
28 f2_poly: Vec<Scalar>,
29 hashed_commits: BTreeMap<PeerId, Sha256>,
30 commitments: BTreeMap<PeerId, Vec<G>>,
31 sk_shares: BTreeMap<PeerId, Scalar>,
32 pk_shares: BTreeMap<PeerId, Vec<G>>,
33}
34
35impl<G: DkgGroup> Dkg<G> {
42 pub fn new(num_peers: NumPeers, identity: PeerId, generator: G) -> (Self, DkgMessage<G>) {
44 let f1_poly = random_coefficients(num_peers.threshold() - 1);
45 let f2_poly = random_coefficients(num_peers.threshold() - 1);
46
47 let mut dkg = Dkg {
48 num_peers,
49 identity,
50 generator,
51 f1_poly,
52 f2_poly,
53 hashed_commits: BTreeMap::new(),
54 commitments: BTreeMap::new(),
55 sk_shares: BTreeMap::new(),
56 pk_shares: BTreeMap::new(),
57 };
58
59 let commit: Vec<G> = dkg
61 .f1_poly
62 .iter()
63 .map(|c| dkg.generator * *c)
64 .zip(dkg.f2_poly.iter().map(|c| gen_h::<G>() * *c))
65 .map(|(g, h)| g + h)
66 .collect();
67
68 let hashed = Dkg::hash(&commit);
69
70 dkg.commitments.insert(identity, commit);
71 dkg.hashed_commits.insert(identity, hashed);
72
73 (dkg, DkgMessage::HashedCommit(hashed))
74 }
75
76 pub fn step(&mut self, peer: PeerId, msg: DkgMessage<G>) -> anyhow::Result<DkgStep<G>> {
78 match msg {
79 DkgMessage::HashedCommit(hashed) => {
80 ensure!(
81 self.hashed_commits.insert(peer, hashed).is_none(),
82 "DKG: peer {peer} sent us two hash commitments."
83 );
84
85 if self.hashed_commits.len() == self.num_peers.total() {
86 let commit = self
87 .commitments
88 .get(&self.identity)
89 .expect("DKG hash commitment not found for identity.")
90 .clone();
91
92 return Ok(DkgStep::Broadcast(DkgMessage::Commit(commit)));
93 }
94 }
95 DkgMessage::Commit(commit) => {
96 ensure!(
97 self.num_peers.threshold() == commit.len(),
98 "DKG: polynomial commitment from peer {peer} is of wrong degree."
99 );
100
101 let hash_commitment = *self
102 .hashed_commits
103 .get(&peer)
104 .context("DKG: hash commitment not found for peer {peer}")?;
105
106 ensure!(
107 Self::hash(&commit) == hash_commitment,
108 "DKG: polynomial commitment from peer {peer} has invalid hash."
109 );
110
111 ensure!(
112 self.commitments.insert(peer, commit).is_none(),
113 "DKG: peer {peer} sent us two commitments."
114 );
115
116 if self.commitments.len() == self.num_peers.total() {
118 let mut messages = vec![];
119
120 for peer in self.num_peers.peer_ids() {
121 let s1 = eval_poly_scalar(&self.f1_poly, &scalar(&peer));
122 let s2 = eval_poly_scalar(&self.f2_poly, &scalar(&peer));
123
124 if peer == self.identity {
125 self.sk_shares.insert(self.identity, s1);
126 } else {
127 messages.push((peer, DkgMessage::Share(s1, s2)));
128 }
129 }
130
131 return Ok(DkgStep::Messages(messages));
132 }
133 }
134 DkgMessage::Share(s1, s2) => {
136 let share_product: G = (self.generator * s1) + (gen_h::<G>() * s2);
137
138 let commitment = self
139 .commitments
140 .get(&peer)
141 .context("DKG: polynomial commitment not found for peer {peer}.")?;
142
143 let commit_product: G = commitment
144 .iter()
145 .enumerate()
146 .map(|(idx, commit)| {
147 *commit * scalar(&self.identity).pow(&[idx as u64, 0, 0, 0])
148 })
149 .reduce(|a, b| a + b)
150 .expect("DKG: polynomial commitment from peer {peer} is empty.");
151
152 ensure!(
153 share_product == commit_product,
154 "DKG: share from {peer} is invalid."
155 );
156
157 ensure!(
158 self.sk_shares.insert(peer, s1).is_none(),
159 "Peer {peer} sent us two shares."
160 );
161
162 if self.sk_shares.len() == self.num_peers.total() {
163 let extract = self
164 .f1_poly
165 .iter()
166 .map(|c| self.generator * *c)
167 .collect::<Vec<G>>();
168
169 self.pk_shares.insert(self.identity, extract.clone());
170
171 return Ok(DkgStep::Broadcast(DkgMessage::Extract(extract)));
172 }
173 }
174 DkgMessage::Extract(extract) => {
176 let share = self
177 .sk_shares
178 .get(&peer)
179 .context("DKG share not found for peer {peer}.")?;
180
181 let extract_product: G = extract
182 .iter()
183 .enumerate()
184 .map(|(idx, commit)| {
185 *commit * scalar(&self.identity).pow(&[idx as u64, 0, 0, 0])
186 })
187 .reduce(|a, b| a + b)
188 .expect("sums");
189
190 ensure!(
191 self.generator * *share == extract_product,
192 "DKG: extract from {peer} is invalid."
193 );
194
195 ensure!(
196 self.num_peers.threshold() == extract.len(),
197 "wrong degree from {peer}."
198 );
199
200 ensure!(
201 self.pk_shares.insert(peer, extract).is_none(),
202 "DKG: peer {peer} sent us two extracts."
203 );
204
205 if self.pk_shares.len() == self.num_peers.total() {
206 let sks = self.sk_shares.values().sum();
207
208 let pks: Vec<G> = (0..self.num_peers.threshold())
209 .map(|i| {
210 self.pk_shares
211 .values()
212 .map(|shares| shares[i])
213 .reduce(|a, b| a + b)
214 .expect("DKG: pk shares are empty.")
215 })
216 .collect();
217
218 return Ok(DkgStep::Result((pks, sks)));
219 }
220 }
221 }
222
223 Ok(DkgStep::Messages(vec![]))
224 }
225
226 fn hash(poly: &[G]) -> Sha256 {
227 let mut engine = HashEngine::default();
228
229 for element in poly {
230 engine
231 .write_all(element.to_bytes().as_ref())
232 .expect("Writing to a hash engine cannot fail.");
233 }
234
235 Sha256::from_engine(engine)
236 }
237}
238
239fn gen_h<G: DkgGroup>() -> G {
240 G::random(&mut ChaChaRng::from_seed([42; 32]))
241}
242
243fn scalar(peer: &PeerId) -> Scalar {
245 Scalar::from(peer.to_usize() as u64 + 1)
246}
247
248pub async fn run_dkg<G: DkgGroup>(
251 num_peers: NumPeers,
252 identity: PeerId,
253 generator: G,
254 connections: &DynP2PConnections<DkgPeerMessage>,
255) -> anyhow::Result<(Vec<G>, Scalar)>
256where
257 DkgMessage<G>: ISupportedDkgMessage,
258{
259 let (mut dkg, initial_message) = Dkg::new(num_peers, identity, generator);
260
261 connections
262 .send(
263 Recipient::Everyone,
264 DkgPeerMessage::DistributedGen(initial_message.to_msg()),
265 )
266 .await;
267
268 loop {
269 for peer in num_peers.peer_ids().filter(|p| *p != identity) {
270 let message = connections
271 .receive_from_peer(peer)
272 .await
273 .context("Unexpected shutdown of p2p connections")?;
274
275 let message = match message {
276 DkgPeerMessage::DistributedGen(message) => message,
277 _ => anyhow::bail!("Wrong message received: {message:?}"),
278 };
279
280 match dkg.step(peer, ISupportedDkgMessage::from_msg(message)?)? {
281 DkgStep::Broadcast(message) => {
282 connections
283 .send(
284 Recipient::Everyone,
285 DkgPeerMessage::DistributedGen(message.to_msg()),
286 )
287 .await;
288 }
289 DkgStep::Messages(messages) => {
290 for (peer, message) in messages {
291 connections
292 .send(
293 Recipient::Peer(peer),
294 DkgPeerMessage::DistributedGen(message.to_msg()),
295 )
296 .await;
297 }
298 }
299 DkgStep::Result(result) => {
300 return Ok(result);
301 }
302 }
303 }
304 }
305}
306
307fn random_coefficients(degree: usize) -> Vec<Scalar> {
308 (0..=degree).map(|_| Scalar::random(&mut OsRng)).collect()
309}
310
311fn eval_poly_scalar(coefficients: &[Scalar], x: &Scalar) -> Scalar {
312 coefficients
313 .iter()
314 .copied()
315 .rev()
316 .reduce(|acc, coefficient| acc * x + coefficient)
317 .expect("We have at least one coefficient")
318}
319
320#[derive(Debug, Clone)]
321pub enum DkgStep<G: DkgGroup> {
322 Broadcast(DkgMessage<G>),
323 Messages(Vec<(PeerId, DkgMessage<G>)>),
324 Result((Vec<G>, Scalar)),
325}
326
327pub fn eval_poly_g1(coefficients: &[G1Projective], peer: &PeerId) -> G1Affine {
328 coefficients
329 .iter()
330 .copied()
331 .rev()
332 .reduce(|acc, coefficient| acc * scalar(peer) + coefficient)
333 .expect("We have at least one coefficient")
334 .to_affine()
335}
336
337pub fn eval_poly_g2(coefficients: &[G2Projective], peer: &PeerId) -> G2Affine {
338 coefficients
339 .iter()
340 .copied()
341 .rev()
342 .reduce(|acc, coefficient| acc * scalar(peer) + coefficient)
343 .expect("We have at least one coefficient")
344 .to_affine()
345}
346
347#[async_trait]
350pub trait PeerHandleOps {
351 async fn run_dkg_g1(&self) -> anyhow::Result<(Vec<G1Projective>, Scalar)>;
352
353 async fn run_dkg_g2(&self) -> anyhow::Result<(Vec<G2Projective>, Scalar)>;
354
355 async fn exchange_encodable<T: Encodable + Decodable + Send + Sync>(
361 &self,
362 data: T,
363 ) -> anyhow::Result<BTreeMap<PeerId, T>>;
364}
365
366#[async_trait]
367impl<'a> PeerHandleOps for PeerHandle<'a> {
368 async fn run_dkg_g1(&self) -> anyhow::Result<(Vec<G1Projective>, Scalar)> {
369 run_dkg(
370 self.num_peers,
371 self.identity,
372 G1Projective::generator(),
373 self.connections,
374 )
375 .await
376 }
377
378 async fn run_dkg_g2(&self) -> anyhow::Result<(Vec<G2Projective>, Scalar)> {
379 run_dkg(
380 self.num_peers,
381 self.identity,
382 G2Projective::generator(),
383 self.connections,
384 )
385 .await
386 }
387
388 async fn exchange_encodable<T: Encodable + Decodable + Send + Sync>(
389 &self,
390 data: T,
391 ) -> anyhow::Result<BTreeMap<PeerId, T>> {
392 let mut peer_data: BTreeMap<PeerId, T> = BTreeMap::new();
393
394 self.connections
395 .send(
396 Recipient::Everyone,
397 DkgPeerMessage::Encodable(data.consensus_encode_to_vec()),
398 )
399 .await;
400
401 peer_data.insert(self.identity, data);
402
403 for peer in self.num_peers.peer_ids().filter(|p| *p != self.identity) {
404 let message = self
405 .connections
406 .receive_from_peer(peer)
407 .await
408 .context("Unexpected shutdown of p2p connections")?;
409
410 match message {
411 DkgPeerMessage::Encodable(bytes) => {
412 peer_data.insert(
413 peer,
414 T::consensus_decode_whole(&bytes, &ModuleDecoderRegistry::default())?,
415 );
416 }
417 message => {
418 anyhow::bail!("Invalid message from {peer}: {message:?}");
419 }
420 }
421 }
422
423 Ok(peer_data)
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use std::collections::{HashMap, VecDeque};
430
431 use bls12_381::Scalar;
432 use fedimint_core::{NumPeersExt, PeerId};
433 use tbs::derive_pk_share;
434 use threshold_crypto::poly::Commitment;
435 use threshold_crypto::serde_impl::SerdeSecret;
436 use threshold_crypto::{G1Projective, G2Projective, PublicKeySet, SecretKeyShare};
437
438 use crate::config::distributedgen::{eval_poly_g2, Dkg, DkgGroup, DkgStep};
439
440 #[test_log::test]
441 fn test_dkg() {
442 for (peer, (polynomial, mut sks)) in run(G1Projective::generator()) {
443 let public_key_set = PublicKeySet::from(Commitment::from(polynomial));
444 let secret_key_share = SerdeSecret(SecretKeyShare::from_mut(&mut sks));
445
446 assert_eq!(public_key_set.threshold(), 2);
447 assert_eq!(
448 public_key_set.public_key_share(peer.to_usize()),
449 secret_key_share.public_key_share()
450 );
451 }
452
453 for (peer, (polynomial, sks)) in run(G2Projective::generator()) {
454 assert_eq!(polynomial.len(), 3);
455 assert_eq!(
456 eval_poly_g2(&polynomial, &peer),
457 derive_pk_share(&tbs::SecretKeyShare(sks)).0
458 );
459 }
460 }
461
462 fn run<G: DkgGroup>(group: G) -> HashMap<PeerId, (Vec<G>, Scalar)> {
463 let peers = (0..4_u16).map(PeerId::from).collect::<Vec<_>>();
464
465 let mut steps: VecDeque<(PeerId, DkgStep<G>)> = VecDeque::new();
466 let mut dkgs: HashMap<PeerId, Dkg<G>> = HashMap::new();
467 let mut keys: HashMap<PeerId, (Vec<G>, Scalar)> = HashMap::new();
468
469 for peer in &peers {
470 let (dkg, initial_message) = Dkg::new(peers.to_num_peers(), *peer, group);
471 dkgs.insert(*peer, dkg);
472 steps.push_back((*peer, DkgStep::Broadcast(initial_message)));
473 }
474
475 while keys.len() < peers.len() {
476 match steps.pop_front() {
477 Some((peer, DkgStep::Broadcast(message))) => {
478 for receive_peer in peers.iter().filter(|p| **p != peer) {
479 let receive_dkg = dkgs.get_mut(receive_peer).unwrap();
480 let step = receive_dkg.step(peer, message.clone());
481 steps.push_back((*receive_peer, step.unwrap()));
482 }
483 }
484 Some((peer, DkgStep::Messages(messages))) => {
485 for (receive_peer, messages) in messages {
486 let receive_dkg = dkgs.get_mut(&receive_peer).unwrap();
487 let step = receive_dkg.step(peer, messages);
488 steps.push_back((receive_peer, step.unwrap()));
489 }
490 }
491 Some((peer, DkgStep::Result(step_keys))) => {
492 keys.insert(peer, step_keys);
493 }
494 _ => {}
495 }
496 }
497
498 assert!(steps.is_empty());
499
500 keys
501 }
502}