starknet_crypto/
ecdsa.rs

1use starknet_curve::curve_params::{ALPHA, BETA, EC_ORDER, GENERATOR};
2
3use crate::{
4    fe_utils::{add_unbounded, bigint_mul_mod_floor, mod_inverse, mul_mod_floor},
5    RecoverError, SignError, VerifyError,
6};
7use starknet_types_core::curve::{AffinePoint, ProjectivePoint};
8use starknet_types_core::felt::Felt;
9
10/// The (exclusive) upper bound on many ECDSA-related elements based on the original C++
11/// implementation from [`crypto-cpp`](https://github.com/starkware-libs/crypto-cpp).
12///
13/// The C++ implementation [imposes](https://github.com/starkware-libs/crypto-cpp/blob/78e3ed8dc7a0901fe6d62f4e99becc6e7936adfd/src/starkware/crypto/ecdsa.cc#L23)
14/// an upper bound of `0x0800000000000000000000000000000000000000000000000000000000000000`.
15///
16/// When a compuated value is greater than or equal to this bound, the modulus is taken to ensure
17/// the resulting value falls under the bound.
18const ELEMENT_UPPER_BOUND: Felt = Felt::from_raw([
19    576459263475450960,
20    18446744073709255680,
21    160989183,
22    18446743986131435553,
23]);
24
25/// Stark ECDSA signature.
26#[derive(Debug)]
27pub struct Signature {
28    /// The `r` value of a signature
29    pub r: Felt,
30    /// The `s` value of a signature
31    pub s: Felt,
32}
33
34/// Stark ECDSA signature with `v`, useful for recovering the public key.
35#[derive(Debug)]
36pub struct ExtendedSignature {
37    /// The `r` value of a signature
38    pub r: Felt,
39    /// The `s` value of a signature
40    pub s: Felt,
41    /// The `v` value of a signature
42    pub v: Felt,
43}
44
45impl From<ExtendedSignature> for Signature {
46    fn from(value: ExtendedSignature) -> Self {
47        Self {
48            r: value.r,
49            s: value.s,
50        }
51    }
52}
53
54#[cfg(feature = "signature-display")]
55impl core::fmt::Display for Signature {
56    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57        write!(
58            f,
59            "{}{}",
60            hex::encode(self.r.to_bytes_be()),
61            hex::encode(self.s.to_bytes_be()),
62        )
63    }
64}
65
66#[cfg(feature = "signature-display")]
67impl core::fmt::Display for ExtendedSignature {
68    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
69        write!(
70            f,
71            "{}{}{:02x}",
72            hex::encode(self.r.to_bytes_be()),
73            hex::encode(self.s.to_bytes_be()),
74            self.v
75        )
76    }
77}
78
79/// Computes the public key given a Stark private key.
80///
81/// ### Parameters
82///
83/// - `private_key`: The private key.
84pub fn get_public_key(private_key: &Felt) -> Felt {
85    mul_by_bits(&GENERATOR, private_key)
86        .to_affine()
87        .unwrap()
88        .x()
89}
90
91/// Computes ECDSA signature given a Stark private key and message hash.
92///
93/// ### Parameters
94///
95/// - `private_key`: The private key.
96/// - `message`: The message hash.
97/// - `k`: A random `k` value. You **MUST NOT** use the same `k` on different signatures.
98pub fn sign(private_key: &Felt, message: &Felt, k: &Felt) -> Result<ExtendedSignature, SignError> {
99    if message >= &ELEMENT_UPPER_BOUND {
100        return Err(SignError::InvalidMessageHash);
101    }
102    if k == &Felt::ZERO {
103        return Err(SignError::InvalidK);
104    }
105
106    let full_r = mul_by_bits(&GENERATOR, k).to_affine().unwrap();
107    let r = full_r.x();
108    if r == Felt::ZERO || r >= ELEMENT_UPPER_BOUND {
109        return Err(SignError::InvalidK);
110    }
111
112    let k_inv = mod_inverse(k, &EC_ORDER);
113
114    let s = mul_mod_floor(&r, private_key, &EC_ORDER);
115    let s = add_unbounded(&s, message);
116    let s = bigint_mul_mod_floor(s, &k_inv, &EC_ORDER);
117    if s == Felt::ZERO || s >= ELEMENT_UPPER_BOUND {
118        return Err(SignError::InvalidK);
119    }
120
121    Ok(ExtendedSignature {
122        r,
123        s,
124        v: (full_r.y().to_bigint() & Felt::ONE.to_bigint()).into(),
125    })
126}
127
128/// Verifies if a signature is valid over a message hash given a public key. Returns an error
129/// instead of `false` if the public key is invalid.
130///
131/// ### Parameters
132///
133/// - `public_key`: The public key.
134/// - `message`: The message hash.
135/// - `r`: The `r` value of the signature.
136/// - `s`: The `s` value of the signature.
137pub fn verify(public_key: &Felt, message: &Felt, r: &Felt, s: &Felt) -> Result<bool, VerifyError> {
138    if message >= &ELEMENT_UPPER_BOUND {
139        return Err(VerifyError::InvalidMessageHash);
140    }
141    if r == &Felt::ZERO || r >= &ELEMENT_UPPER_BOUND {
142        return Err(VerifyError::InvalidR);
143    }
144    if s == &Felt::ZERO || s >= &ELEMENT_UPPER_BOUND {
145        return Err(VerifyError::InvalidS);
146    }
147
148    let full_public_key = AffinePoint::new(
149        *public_key,
150        (public_key.square() * public_key + ALPHA * public_key + BETA)
151            .sqrt()
152            .ok_or(VerifyError::InvalidPublicKey)?,
153    )
154    .unwrap();
155
156    let w = mod_inverse(s, &EC_ORDER);
157    if w == Felt::ZERO || w >= ELEMENT_UPPER_BOUND {
158        return Err(VerifyError::InvalidS);
159    }
160
161    let zw = mul_mod_floor(message, &w, &EC_ORDER);
162    let zw_g = mul_by_bits(&GENERATOR, &zw);
163
164    let rw = mul_mod_floor(r, &w, &EC_ORDER);
165    let rw_q = mul_by_bits(&full_public_key, &rw);
166
167    Ok((&zw_g + &rw_q).to_affine().unwrap().x() == *r
168        || (&zw_g - &rw_q).to_affine().unwrap().x() == *r)
169}
170
171/// Recovers the public key from a message and (r, s, v) signature parameters
172///
173/// ### Parameters
174///
175/// - `msg_hash`: The message hash.
176/// - `r_bytes`: The `r` value of the signature.
177/// - `s_bytes`: The `s` value of the signature.
178/// - `v_bytes`: The `v` value of the signature.
179pub fn recover(message: &Felt, r: &Felt, s: &Felt, v: &Felt) -> Result<Felt, RecoverError> {
180    if message >= &ELEMENT_UPPER_BOUND {
181        return Err(RecoverError::InvalidMessageHash);
182    }
183    if r == &Felt::ZERO || r >= &ELEMENT_UPPER_BOUND {
184        return Err(RecoverError::InvalidR);
185    }
186    if s == &Felt::ZERO || s >= &EC_ORDER {
187        return Err(RecoverError::InvalidS);
188    }
189    if v > &Felt::ONE {
190        return Err(RecoverError::InvalidV);
191    }
192
193    let full_r = AffinePoint::new(
194        *r,
195        (r * r * r + ALPHA * r + BETA)
196            .sqrt()
197            .ok_or(RecoverError::InvalidR)?,
198    )
199    .unwrap();
200
201    let mut full_r_y = full_r.y();
202
203    let mut bits = [false; 256];
204
205    for (i, (&a, &b)) in full_r
206        .y()
207        .to_bits_le()
208        .iter()
209        .zip(Felt::ONE.to_bits_le().iter())
210        .enumerate()
211    {
212        bits[i] = a && b;
213    }
214
215    if bits != v.to_bits_le() {
216        full_r_y = -full_r.y();
217    }
218
219    let full_rs = mul_by_bits(&AffinePoint::new(full_r.x(), full_r_y).unwrap(), s);
220    let zg = mul_by_bits(&GENERATOR, message);
221
222    let r_inv = mod_inverse(r, &EC_ORDER);
223
224    let rs_zg = &full_rs - &zg;
225
226    let k = mul_by_bits(&rs_zg.to_affine().unwrap(), &r_inv);
227
228    Ok(k.to_affine().unwrap().x())
229}
230
231#[inline(always)]
232fn mul_by_bits(x: &AffinePoint, y: &Felt) -> ProjectivePoint {
233    &ProjectivePoint::from_affine(x.x(), x.y()).unwrap() * *y
234}
235
236#[cfg(test)]
237mod tests {
238    #[cfg(not(feature = "std"))]
239    use alloc::collections::BTreeMap;
240    #[cfg(feature = "std")]
241    use std::collections::BTreeMap;
242
243    use super::*;
244    use crate::test_utils::field_element_from_be_hex;
245
246    // Test cases ported from:
247    //   https://github.com/starkware-libs/crypto-cpp/blob/95864fbe11d5287e345432dbe1e80dea3c35fc58/src/starkware/crypto/ffi/crypto_lib_test.go
248
249    #[test]
250    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
251    fn test_get_public_key_1() {
252        let private_key = field_element_from_be_hex(
253            "03c1e9550e66958296d11b60f8e8e7a7ad990d07fa65d5f7652c4a6c87d4e3cc",
254        );
255        let expected_public_key = field_element_from_be_hex(
256            "077a3b314db07c45076d11f62b6f9e748a39790441823307743cf00d6597ea43",
257        );
258
259        assert_eq!(get_public_key(&private_key), expected_public_key);
260    }
261
262    #[test]
263    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
264    fn test_get_public_key_2() {
265        let private_key = field_element_from_be_hex(
266            "0000000000000000000000000000000000000000000000000000000000000012",
267        );
268        let expected_public_key = field_element_from_be_hex(
269            "019661066e96a8b9f06a1d136881ee924dfb6a885239caa5fd3f87a54c6b25c4",
270        );
271
272        assert_eq!(get_public_key(&private_key), expected_public_key);
273    }
274
275    #[test]
276    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
277    fn test_get_public_keys_from_json() {
278        // Precomputed keys can be found here:
279        // https://github.com/starkware-libs/starkex-for-spot-trading/blob/607f0b4ce507e1d95cd018d206a2797f6ba4aab4/src/starkware/crypto/starkware/crypto/signature/src/config/keys_precomputed.json
280
281        // Reading the JSON file
282        let json_data = include_str!("../test-data/keys_precomputed.json");
283
284        // Parsing the JSON
285        let key_map: BTreeMap<String, String> =
286            serde_json::from_str(json_data).expect("Unable to parse the JSON");
287
288        // Iterating over each element in the JSON
289        for (private_key, expected_public_key) in key_map {
290            let private_key = if private_key.len() % 2 != 0 {
291                format!("0{}", private_key.trim_start_matches("0x"))
292            } else {
293                private_key.trim_start_matches("0x").to_owned()
294            };
295
296            let expected_public_key = if expected_public_key.len() % 2 != 0 {
297                format!("0{}", expected_public_key.trim_start_matches("0x"))
298            } else {
299                expected_public_key.trim_start_matches("0x").to_owned()
300            };
301
302            // Assertion
303            assert_eq!(
304                get_public_key(&field_element_from_be_hex(
305                    private_key.trim_start_matches("0x")
306                )),
307                field_element_from_be_hex(expected_public_key.trim_start_matches("0x"))
308            );
309        }
310    }
311
312    #[test]
313    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
314    fn test_verify_valid_message() {
315        let stark_key = field_element_from_be_hex(
316            "01ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca",
317        );
318        let msg_hash = field_element_from_be_hex(
319            "0000000000000000000000000000000000000000000000000000000000000002",
320        );
321        let r_bytes = field_element_from_be_hex(
322            "0411494b501a98abd8262b0da1351e17899a0c4ef23dd2f96fec5ba847310b20",
323        );
324        let s_bytes = field_element_from_be_hex(
325            "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
326        );
327
328        assert!(verify(&stark_key, &msg_hash, &r_bytes, &s_bytes).unwrap());
329    }
330
331    #[test]
332    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
333    fn test_verify_invalid_message() {
334        let stark_key = field_element_from_be_hex(
335            "077a4b314db07c45076d11f62b6f9e748a39790441823307743cf00d6597ea43",
336        );
337        let msg_hash = field_element_from_be_hex(
338            "0397e76d1667c4454bfb83514e120583af836f8e32a516765497823eabe16a3f",
339        );
340        let r_bytes = field_element_from_be_hex(
341            "0173fd03d8b008ee7432977ac27d1e9d1a1f6c98b1a2f05fa84a21c84c44e882",
342        );
343        let s_bytes = field_element_from_be_hex(
344            "01f2c44a7798f55192f153b4c48ea5c1241fbb69e6132cc8a0da9c5b62a4286e",
345        );
346
347        assert!(!verify(&stark_key, &msg_hash, &r_bytes, &s_bytes).unwrap());
348    }
349
350    #[test]
351    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
352    fn test_verify_invalid_public_key() {
353        let stark_key = field_element_from_be_hex(
354            "03ee9bffffffffff26ffffffff60ffffffffffffffffffffffffffff004accff",
355        );
356        let msg_hash = field_element_from_be_hex(
357            "0000000000000000000000000000000000000000000000000000000000000002",
358        );
359        let r_bytes = field_element_from_be_hex(
360            "0411494b501a98abd8262b0da1351e17899a0c4ef23dd2f96fec5ba847310b20",
361        );
362        let s_bytes = field_element_from_be_hex(
363            "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
364        );
365
366        match verify(&stark_key, &msg_hash, &r_bytes, &s_bytes) {
367            Err(VerifyError::InvalidPublicKey) => {}
368            _ => panic!("unexpected result"),
369        }
370    }
371
372    #[test]
373    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
374    fn test_sign() {
375        let private_key = field_element_from_be_hex(
376            "0000000000000000000000000000000000000000000000000000000000000001",
377        );
378        let message = field_element_from_be_hex(
379            "0000000000000000000000000000000000000000000000000000000000000002",
380        );
381        let k = field_element_from_be_hex(
382            "0000000000000000000000000000000000000000000000000000000000000003",
383        );
384
385        let signature = sign(&private_key, &message, &k).unwrap();
386        let public_key = get_public_key(&private_key);
387
388        assert!(verify(&public_key, &message, &signature.r, &signature.s).unwrap());
389    }
390
391    #[test]
392    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
393    fn test_recover() {
394        let private_key = field_element_from_be_hex(
395            "0000000000000000000000000000000000000000000000000000000000000001",
396        );
397        let message = field_element_from_be_hex(
398            "0000000000000000000000000000000000000000000000000000000000000002",
399        );
400        let k = field_element_from_be_hex(
401            "0000000000000000000000000000000000000000000000000000000000000003",
402        );
403
404        let signature = sign(&private_key, &message, &k).unwrap();
405        let public_key = recover(&message, &signature.r, &signature.s, &signature.v).unwrap();
406
407        assert_eq!(get_public_key(&private_key), public_key);
408    }
409
410    #[test]
411    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
412    fn test_recover_invalid_r() {
413        let message = field_element_from_be_hex(
414            "0000000000000000000000000000000000000000000000000000000000000002",
415        );
416        let r = field_element_from_be_hex(
417            "03ee9bffffffffff26ffffffff60ffffffffffffffffffffffffffff004accff",
418        );
419        let s = field_element_from_be_hex(
420            "0405c3191ab3883ef2b763af35bc5f5d15b3b4e99461d70e84c654a351a7c81b",
421        );
422        let v = field_element_from_be_hex(
423            "0000000000000000000000000000000000000000000000000000000000000000",
424        );
425
426        match recover(&message, &r, &s, &v) {
427            Err(RecoverError::InvalidR) => {}
428            _ => panic!("unexpected result"),
429        }
430    }
431}