1use starknet_curve::curve_params::{ALPHA, BETA, EC_ORDER, GENERATOR};
2
3use crate::{
4 fe_utils::{add_unbounded, bigint_mul_mod_floor, mod_inverse, mul_mod_floor},
5 RecoverError, SignError, VerifyError,
6};
7use starknet_types_core::curve::{AffinePoint, ProjectivePoint};
8use starknet_types_core::felt::Felt;
9
10const ELEMENT_UPPER_BOUND: Felt = Felt::from_raw([
19 576459263475450960,
20 18446744073709255680,
21 160989183,
22 18446743986131435553,
23]);
24
25#[derive(Debug)]
27pub struct Signature {
28 pub r: Felt,
30 pub s: Felt,
32}
33
34#[derive(Debug)]
36pub struct ExtendedSignature {
37 pub r: Felt,
39 pub s: Felt,
41 pub v: Felt,
43}
44
45impl From<ExtendedSignature> for Signature {
46 fn from(value: ExtendedSignature) -> Self {
47 Self {
48 r: value.r,
49 s: value.s,
50 }
51 }
52}
53
54#[cfg(feature = "signature-display")]
55impl core::fmt::Display for Signature {
56 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57 write!(
58 f,
59 "{}{}",
60 hex::encode(self.r.to_bytes_be()),
61 hex::encode(self.s.to_bytes_be()),
62 )
63 }
64}
65
66#[cfg(feature = "signature-display")]
67impl core::fmt::Display for ExtendedSignature {
68 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69 write!(
70 f,
71 "{}{}{:02x}",
72 hex::encode(self.r.to_bytes_be()),
73 hex::encode(self.s.to_bytes_be()),
74 self.v
75 )
76 }
77}
78
79pub fn get_public_key(private_key: &Felt) -> Felt {
85 mul_by_bits(&GENERATOR, private_key)
86 .to_affine()
87 .unwrap()
88 .x()
89}
90
91pub fn sign(private_key: &Felt, message: &Felt, k: &Felt) -> Result<ExtendedSignature, SignError> {
99 if message >= &ELEMENT_UPPER_BOUND {
100 return Err(SignError::InvalidMessageHash);
101 }
102 if k == &Felt::ZERO {
103 return Err(SignError::InvalidK);
104 }
105
106 let full_r = mul_by_bits(&GENERATOR, k).to_affine().unwrap();
107 let r = full_r.x();
108 if r == Felt::ZERO || r >= ELEMENT_UPPER_BOUND {
109 return Err(SignError::InvalidK);
110 }
111
112 let k_inv = mod_inverse(k, &EC_ORDER);
113
114 let s = mul_mod_floor(&r, private_key, &EC_ORDER);
115 let s = add_unbounded(&s, message);
116 let s = bigint_mul_mod_floor(s, &k_inv, &EC_ORDER);
117 if s == Felt::ZERO || s >= ELEMENT_UPPER_BOUND {
118 return Err(SignError::InvalidK);
119 }
120
121 Ok(ExtendedSignature {
122 r,
123 s,
124 v: (full_r.y().to_bigint() & Felt::ONE.to_bigint()).into(),
125 })
126}
127
128pub fn verify(public_key: &Felt, message: &Felt, r: &Felt, s: &Felt) -> Result<bool, VerifyError> {
138 if message >= &ELEMENT_UPPER_BOUND {
139 return Err(VerifyError::InvalidMessageHash);
140 }
141 if r == &Felt::ZERO || r >= &ELEMENT_UPPER_BOUND {
142 return Err(VerifyError::InvalidR);
143 }
144 if s == &Felt::ZERO || s >= &ELEMENT_UPPER_BOUND {
145 return Err(VerifyError::InvalidS);
146 }
147
148 let full_public_key = AffinePoint::new(
149 *public_key,
150 (public_key.square() * public_key + ALPHA * public_key + BETA)
151 .sqrt()
152 .ok_or(VerifyError::InvalidPublicKey)?,
153 )
154 .unwrap();
155
156 let w = mod_inverse(s, &EC_ORDER);
157 if w == Felt::ZERO || w >= ELEMENT_UPPER_BOUND {
158 return Err(VerifyError::InvalidS);
159 }
160
161 let zw = mul_mod_floor(message, &w, &EC_ORDER);
162 let zw_g = mul_by_bits(&GENERATOR, &zw);
163
164 let rw = mul_mod_floor(r, &w, &EC_ORDER);
165 let rw_q = mul_by_bits(&full_public_key, &rw);
166
167 Ok((&zw_g + &rw_q).to_affine().unwrap().x() == *r
168 || (&zw_g - &rw_q).to_affine().unwrap().x() == *r)
169}
170
171pub fn recover(message: &Felt, r: &Felt, s: &Felt, v: &Felt) -> Result<Felt, RecoverError> {
180 if message >= &ELEMENT_UPPER_BOUND {
181 return Err(RecoverError::InvalidMessageHash);
182 }
183 if r == &Felt::ZERO || r >= &ELEMENT_UPPER_BOUND {
184 return Err(RecoverError::InvalidR);
185 }
186 if s == &Felt::ZERO || s >= &EC_ORDER {
187 return Err(RecoverError::InvalidS);
188 }
189 if v > &Felt::ONE {
190 return Err(RecoverError::InvalidV);
191 }
192
193 let full_r = AffinePoint::new(
194 *r,
195 (r * r * r + ALPHA * r + BETA)
196 .sqrt()
197 .ok_or(RecoverError::InvalidR)?,
198 )
199 .unwrap();
200
201 let mut full_r_y = full_r.y();
202
203 let mut bits = [false; 256];
204
205 for (i, (&a, &b)) in full_r
206 .y()
207 .to_bits_le()
208 .iter()
209 .zip(Felt::ONE.to_bits_le().iter())
210 .enumerate()
211 {
212 bits[i] = a && b;
213 }
214
215 if bits != v.to_bits_le() {
216 full_r_y = -full_r.y();
217 }
218
219 let full_rs = mul_by_bits(&AffinePoint::new(full_r.x(), full_r_y).unwrap(), s);
220 let zg = mul_by_bits(&GENERATOR, message);
221
222 let r_inv = mod_inverse(r, &EC_ORDER);
223
224 let rs_zg = &full_rs - &zg;
225
226 let k = mul_by_bits(&rs_zg.to_affine().unwrap(), &r_inv);
227
228 Ok(k.to_affine().unwrap().x())
229}
230
231#[inline(always)]
232fn mul_by_bits(x: &AffinePoint, y: &Felt) -> ProjectivePoint {
233 &ProjectivePoint::from_affine(x.x(), x.y()).unwrap() * *y
234}
235
236#[cfg(test)]
237mod tests {
238 #[cfg(not(feature = "std"))]
239 use alloc::collections::BTreeMap;
240 #[cfg(feature = "std")]
241 use std::collections::BTreeMap;
242
243 use super::*;
244 use crate::test_utils::field_element_from_be_hex;
245
246 #[test]
250 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
251 fn test_get_public_key_1() {
252 let private_key = field_element_from_be_hex(
253 "03c1e9550e66958296d11b60f8e8e7a7ad990d07fa65d5f7652c4a6c87d4e3cc",
254 );
255 let expected_public_key = field_element_from_be_hex(
256 "077a3b314db07c45076d11f62b6f9e748a39790441823307743cf00d6597ea43",
257 );
258
259 assert_eq!(get_public_key(&private_key), expected_public_key);
260 }
261
262 #[test]
263 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
264 fn test_get_public_key_2() {
265 let private_key = field_element_from_be_hex(
266 "0000000000000000000000000000000000000000000000000000000000000012",
267 );
268 let expected_public_key = field_element_from_be_hex(
269 "019661066e96a8b9f06a1d136881ee924dfb6a885239caa5fd3f87a54c6b25c4",
270 );
271
272 assert_eq!(get_public_key(&private_key), expected_public_key);
273 }
274
275 #[test]
276 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
277 fn test_get_public_keys_from_json() {
278 let json_data = include_str!("../test-data/keys_precomputed.json");
283
284 let key_map: BTreeMap<String, String> =
286 serde_json::from_str(json_data).expect("Unable to parse the JSON");
287
288 for (private_key, expected_public_key) in key_map {
290 let private_key = if private_key.len() % 2 != 0 {
291 format!("0{}", private_key.trim_start_matches("0x"))
292 } else {
293 private_key.trim_start_matches("0x").to_owned()
294 };
295
296 let expected_public_key = if expected_public_key.len() % 2 != 0 {
297 format!("0{}", expected_public_key.trim_start_matches("0x"))
298 } else {
299 expected_public_key.trim_start_matches("0x").to_owned()
300 };
301
302 assert_eq!(
304 get_public_key(&field_element_from_be_hex(
305 private_key.trim_start_matches("0x")
306 )),
307 field_element_from_be_hex(expected_public_key.trim_start_matches("0x"))
308 );
309 }
310 }
311
312 #[test]
313 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
314 fn test_verify_valid_message() {
315 let stark_key = field_element_from_be_hex(
316 "01ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca",
317 );
318 let msg_hash = field_element_from_be_hex(
319 "0000000000000000000000000000000000000000000000000000000000000002",
320 );
321 let r_bytes = field_element_from_be_hex(
322 "0411494b501a98abd8262b0da1351e17899a0c4ef23dd2f96fec5ba847310b20",
323 );
324 let s_bytes = field_element_from_be_hex(
325 "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
326 );
327
328 assert!(verify(&stark_key, &msg_hash, &r_bytes, &s_bytes).unwrap());
329 }
330
331 #[test]
332 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
333 fn test_verify_invalid_message() {
334 let stark_key = field_element_from_be_hex(
335 "077a4b314db07c45076d11f62b6f9e748a39790441823307743cf00d6597ea43",
336 );
337 let msg_hash = field_element_from_be_hex(
338 "0397e76d1667c4454bfb83514e120583af836f8e32a516765497823eabe16a3f",
339 );
340 let r_bytes = field_element_from_be_hex(
341 "0173fd03d8b008ee7432977ac27d1e9d1a1f6c98b1a2f05fa84a21c84c44e882",
342 );
343 let s_bytes = field_element_from_be_hex(
344 "01f2c44a7798f55192f153b4c48ea5c1241fbb69e6132cc8a0da9c5b62a4286e",
345 );
346
347 assert!(!verify(&stark_key, &msg_hash, &r_bytes, &s_bytes).unwrap());
348 }
349
350 #[test]
351 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
352 fn test_verify_invalid_public_key() {
353 let stark_key = field_element_from_be_hex(
354 "03ee9bffffffffff26ffffffff60ffffffffffffffffffffffffffff004accff",
355 );
356 let msg_hash = field_element_from_be_hex(
357 "0000000000000000000000000000000000000000000000000000000000000002",
358 );
359 let r_bytes = field_element_from_be_hex(
360 "0411494b501a98abd8262b0da1351e17899a0c4ef23dd2f96fec5ba847310b20",
361 );
362 let s_bytes = field_element_from_be_hex(
363 "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
364 );
365
366 match verify(&stark_key, &msg_hash, &r_bytes, &s_bytes) {
367 Err(VerifyError::InvalidPublicKey) => {}
368 _ => panic!("unexpected result"),
369 }
370 }
371
372 #[test]
373 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374 fn test_sign() {
375 let private_key = field_element_from_be_hex(
376 "0000000000000000000000000000000000000000000000000000000000000001",
377 );
378 let message = field_element_from_be_hex(
379 "0000000000000000000000000000000000000000000000000000000000000002",
380 );
381 let k = field_element_from_be_hex(
382 "0000000000000000000000000000000000000000000000000000000000000003",
383 );
384
385 let signature = sign(&private_key, &message, &k).unwrap();
386 let public_key = get_public_key(&private_key);
387
388 assert!(verify(&public_key, &message, &signature.r, &signature.s).unwrap());
389 }
390
391 #[test]
392 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
393 fn test_recover() {
394 let private_key = field_element_from_be_hex(
395 "0000000000000000000000000000000000000000000000000000000000000001",
396 );
397 let message = field_element_from_be_hex(
398 "0000000000000000000000000000000000000000000000000000000000000002",
399 );
400 let k = field_element_from_be_hex(
401 "0000000000000000000000000000000000000000000000000000000000000003",
402 );
403
404 let signature = sign(&private_key, &message, &k).unwrap();
405 let public_key = recover(&message, &signature.r, &signature.s, &signature.v).unwrap();
406
407 assert_eq!(get_public_key(&private_key), public_key);
408 }
409
410 #[test]
411 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412 fn test_recover_invalid_r() {
413 let message = field_element_from_be_hex(
414 "0000000000000000000000000000000000000000000000000000000000000002",
415 );
416 let r = field_element_from_be_hex(
417 "03ee9bffffffffff26ffffffff60ffffffffffffffffffffffffffff004accff",
418 );
419 let s = field_element_from_be_hex(
420 "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
421 );
422 let v = field_element_from_be_hex(
423 "0000000000000000000000000000000000000000000000000000000000000000",
424 );
425
426 match recover(&message, &r, &s, &v) {
427 Err(RecoverError::InvalidR) => {}
428 _ => panic!("unexpected result"),
429 }
430 }
431}