1use std::collections::BTreeMap;
7
8use bls12_381::{pairing, G1Affine, G1Projective, G2Affine, G2Projective, Scalar};
9use fedimint_core::bls12_381_serde;
10use fedimint_core::encoding::{Decodable, Encodable};
11use group::ff::Field;
12use group::{Curve, Group};
13use hex::encode;
14use rand::rngs::OsRng;
15use rand::SeedableRng;
16use rand_chacha::ChaChaRng;
17use serde::{Deserialize, Serialize};
18use sha3::Digest;
19
20const HASH_TAG: &[u8] = b"TBS_BLS12-381_";
21const FINGERPRINT_TAG: &[u8] = b"TBS_KFP24_";
22
23fn hash_bytes_to_g1(data: &[u8]) -> G1Projective {
24 let mut hash_engine = sha3::Sha3_256::new();
25
26 hash_engine.update(HASH_TAG);
27 hash_engine.update(data);
28
29 let mut prng = ChaChaRng::from_seed(hash_engine.finalize().into());
30
31 G1Projective::random(&mut prng)
32}
33
34#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
35pub struct SecretKeyShare(#[serde(with = "bls12_381_serde::scalar")] pub Scalar);
36
37#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
38pub struct PublicKeyShare(#[serde(with = "bls12_381_serde::g2")] pub G2Affine);
39
40#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
41pub struct AggregatePublicKey(#[serde(with = "bls12_381_serde::g2")] pub G2Affine);
42
43#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
44pub struct Message(#[serde(with = "bls12_381_serde::g1")] pub G1Affine);
45
46#[derive(Copy, Clone, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
47pub struct BlindingKey(#[serde(with = "bls12_381_serde::scalar")] pub Scalar);
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
50pub struct BlindedMessage(#[serde(with = "bls12_381_serde::g1")] pub G1Affine);
51
52#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
53pub struct BlindedSignatureShare(#[serde(with = "bls12_381_serde::g1")] pub G1Affine);
54
55#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
56pub struct BlindedSignature(#[serde(with = "bls12_381_serde::g1")] pub G1Affine);
57
58#[derive(Copy, Clone, Debug, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
59pub struct Signature(#[serde(with = "bls12_381_serde::g1")] pub G1Affine);
60
61macro_rules! point_hash_impl {
62 ($type:ty) => {
63 impl std::hash::Hash for $type {
64 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65 self.0.to_compressed().hash(state);
66 }
67 }
68 };
69}
70
71point_hash_impl!(PublicKeyShare);
72point_hash_impl!(AggregatePublicKey);
73point_hash_impl!(Message);
74point_hash_impl!(BlindedMessage);
75point_hash_impl!(BlindedSignatureShare);
76point_hash_impl!(BlindedSignature);
77point_hash_impl!(Signature);
78
79impl SecretKeyShare {
80 pub fn to_pub_key_share(self) -> PublicKeyShare {
81 PublicKeyShare((G2Projective::generator() * self.0).to_affine())
82 }
83}
84
85impl BlindingKey {
86 pub fn random() -> BlindingKey {
87 BlindingKey(Scalar::random(OsRng))
89 }
90
91 fn fingerprint(&self) -> [u8; 32] {
92 let mut hash_engine = sha3::Sha3_256::new();
93 hash_engine.update(FINGERPRINT_TAG);
94 hash_engine.update(self.0.to_bytes());
95 let result = hash_engine.finalize();
96 result.into()
97 }
98}
99
100impl ::core::fmt::Debug for BlindingKey {
101 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
102 let fingerprint = self.fingerprint();
103 let fingerprint_hex = encode(&fingerprint[..]);
104 write!(f, "BlindingKey({fingerprint_hex})")
105 }
106}
107
108impl ::core::fmt::Display for BlindingKey {
109 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
110 let fingerprint = self.fingerprint();
111 let fingerprint_hex = encode(&fingerprint[..]);
112 write!(f, "{fingerprint_hex}")
113 }
114}
115
116impl Message {
117 pub fn from_bytes(msg: &[u8]) -> Message {
118 Message(hash_bytes_to_g1(msg).to_affine())
119 }
120}
121
122pub fn blind_message(msg: Message, blinding_key: BlindingKey) -> BlindedMessage {
123 let blinded_msg = msg.0 * blinding_key.0;
124
125 BlindedMessage(blinded_msg.to_affine())
126}
127
128pub fn sign_blinded_msg(msg: BlindedMessage, sks: SecretKeyShare) -> BlindedSignatureShare {
129 let sig = msg.0 * sks.0;
130 BlindedSignatureShare(sig.to_affine())
131}
132
133pub fn verify_blind_share(
134 msg: BlindedMessage,
135 sig: BlindedSignatureShare,
136 pk: PublicKeyShare,
137) -> bool {
138 pairing(&msg.0, &pk.0) == pairing(&sig.0, &G2Affine::generator())
139}
140
141pub fn aggregate_signature_shares(
147 shares: &BTreeMap<u64, BlindedSignatureShare>,
148) -> BlindedSignature {
149 if shares.len() == 1 {
151 return BlindedSignature(
152 shares
153 .values()
154 .next()
155 .expect("We have at least one value")
156 .0,
157 );
158 }
159
160 BlindedSignature(
161 lagrange_multipliers(shares.keys().cloned().map(Scalar::from).collect())
162 .into_iter()
163 .zip(shares.values())
164 .map(|(lagrange_multiplier, share)| lagrange_multiplier * share.0)
165 .reduce(|a, b| a + b)
166 .expect("We have at least one share")
167 .to_affine(),
168 )
169}
170
171pub fn aggregate_public_key_shares(shares: &BTreeMap<u64, PublicKeyShare>) -> AggregatePublicKey {
176 if shares.len() == 1 {
178 return AggregatePublicKey(
179 shares
180 .values()
181 .next()
182 .expect("We have at least one value")
183 .0,
184 );
185 }
186
187 AggregatePublicKey(
188 lagrange_multipliers(shares.keys().cloned().map(Scalar::from).collect())
189 .into_iter()
190 .zip(shares.values())
191 .map(|(lagrange_multiplier, share)| lagrange_multiplier * share.0)
192 .reduce(|a, b| a + b)
193 .expect("We have at least one share")
194 .to_affine(),
195 )
196}
197
198fn lagrange_multipliers(scalars: Vec<Scalar>) -> Vec<Scalar> {
199 scalars
200 .iter()
201 .map(|i| {
202 scalars
203 .iter()
204 .filter(|j| *j != i)
205 .map(|j| j * (j - i).invert().expect("We filtered the case j == i"))
206 .reduce(|a, b| a * b)
207 .expect("We have at least one share")
208 })
209 .collect()
210}
211
212pub fn verify_blinded_signature(
213 msg: BlindedMessage,
214 sig: BlindedSignature,
215 pk: AggregatePublicKey,
216) -> bool {
217 pairing(&msg.0, &pk.0) == pairing(&sig.0, &G2Affine::generator())
218}
219
220pub fn unblind_signature(blinding_key: BlindingKey, blinded_sig: BlindedSignature) -> Signature {
221 let sig = blinded_sig.0 * blinding_key.0.invert().unwrap();
222 Signature(sig.to_affine())
223}
224
225pub fn verify(msg: Message, sig: Signature, pk: AggregatePublicKey) -> bool {
226 pairing(&msg.0, &pk.0) == pairing(&sig.0, &G2Affine::generator())
227}
228
229#[cfg(test)]
230mod tests {
231 use std::collections::BTreeMap;
232
233 use bls12_381::{G2Projective, Scalar};
234 use group::ff::Field;
235 use group::Curve;
236 use rand::rngs::OsRng;
237
238 use crate::{
239 aggregate_signature_shares, blind_message, sign_blinded_msg, unblind_signature, verify,
240 verify_blind_share, AggregatePublicKey, BlindedSignatureShare, BlindingKey, Message,
241 PublicKeyShare, SecretKeyShare,
242 };
243
244 fn dealer_keygen(
245 threshold: usize,
246 keys: usize,
247 ) -> (AggregatePublicKey, Vec<PublicKeyShare>, Vec<SecretKeyShare>) {
248 let mut rng = OsRng;
249 let poly: Vec<Scalar> = (0..threshold).map(|_| Scalar::random(&mut rng)).collect();
250
251 let apk = (G2Projective::generator() * eval_polynomial(&poly, &Scalar::zero())).to_affine();
252
253 let sks: Vec<SecretKeyShare> = (0..keys)
254 .map(|idx| SecretKeyShare(eval_polynomial(&poly, &Scalar::from(idx as u64 + 1))))
255 .collect();
256
257 let pks = sks
258 .iter()
259 .map(|sk| PublicKeyShare((G2Projective::generator() * sk.0).to_affine()))
260 .collect();
261
262 (AggregatePublicKey(apk), pks, sks)
263 }
264
265 fn eval_polynomial(coefficients: &[Scalar], x: &Scalar) -> Scalar {
266 coefficients
267 .iter()
268 .cloned()
269 .rev()
270 .reduce(|acc, coefficient| acc * x + coefficient)
271 .expect("We have at least one coefficient")
272 }
273
274 #[test]
275 fn test_roundtrip() {
276 let (pk, pks, sks) = dealer_keygen(5, 15);
277
278 let msg = Message::from_bytes(b"Hello World!");
279 let bkey = BlindingKey::random();
280 let bmsg = blind_message(msg, bkey);
281
282 let bsig_shares = sks
283 .iter()
284 .map(|sk| sign_blinded_msg(bmsg, *sk))
285 .collect::<Vec<BlindedSignatureShare>>();
286
287 for (share, pk) in bsig_shares.iter().zip(pks) {
288 assert!(verify_blind_share(bmsg, *share, pk));
289 }
290
291 let bsig_shares = (1_u64..)
292 .zip(bsig_shares)
293 .take(5)
294 .collect::<BTreeMap<u64, BlindedSignatureShare>>();
295
296 let bsig = aggregate_signature_shares(&bsig_shares);
297 let sig = unblind_signature(bkey, bsig);
298
299 assert!(verify(msg, sig, pk));
300 }
301
302 #[test]
303 fn test_blindingkey_fingerprint_multiple_calls_same_result() {
304 let bkey = BlindingKey::random();
305 assert_eq!(bkey.fingerprint(), bkey.fingerprint());
306 }
307
308 #[test]
309 fn test_blindingkey_fingerprint_ne_scalar() {
310 let bkey = BlindingKey::random();
311 assert_ne!(bkey.fingerprint(), bkey.0.to_bytes());
312 }
313}