1use std::ops::{Add, Index, IndexMut, Mul, Sub};
2
3use rand_core::{CryptoRng, RngCore};
4use subtle::{Choice, ConstantTimeEq};
5
6use crate::constants;
7
8#[derive(Debug, Copy, Clone)]
12pub struct Scalar(pub(crate) [u32; 14]);
13
14pub(crate) const MODULUS: Scalar = constants::BASEPOINT_ORDER;
15
16const R2: Scalar = Scalar([
18 0x049b9b60, 0xe3539257, 0xc1b195d9, 0x7af32c4b, 0x88ea1859, 0x0d66de23, 0x5ee4d838, 0xae17cf72,
19 0xa3c47c44, 0x1a9cc14b, 0xe4d070af, 0x2052bcb7, 0xf823b729, 0x3402a939,
20]);
21const R: Scalar = Scalar([
22 0x529eec34, 0x721cf5b5, 0xc8e9c2ab, 0x7a4cf635, 0x44a725bf, 0xeec492d9, 0xcd77058, 0x2, 0, 0,
23 0, 0, 0, 0,
24]);
25
26impl ConstantTimeEq for Scalar {
27 fn ct_eq(&self, other: &Self) -> Choice {
28 self.to_bytes().ct_eq(&other.to_bytes())
29 }
30}
31
32impl PartialEq for Scalar {
33 fn eq(&self, other: &Scalar) -> bool {
34 self.ct_eq(&other).into()
35 }
36}
37impl Eq for Scalar {}
38
39impl From<u32> for Scalar {
40 fn from(a: u32) -> Scalar {
41 Scalar([a, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
42 }
43}
44
45impl Index<usize> for Scalar {
46 type Output = u32;
47 fn index(&self, index: usize) -> &Self::Output {
48 &self.0[index]
49 }
50}
51impl IndexMut<usize> for Scalar {
52 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
53 &mut self.0[index]
54 }
55}
56
57impl Add<Scalar> for Scalar {
60 type Output = Scalar;
61 fn add(self, rhs: Scalar) -> Self::Output {
62 add(&self, &rhs)
63 }
64}
65impl Mul<Scalar> for Scalar {
66 type Output = Scalar;
67 fn mul(self, rhs: Scalar) -> Self::Output {
68 let unreduced = montgomery_multiply(&self, &rhs);
69 montgomery_multiply(&unreduced, &R2)
70 }
71}
72impl Sub<Scalar> for Scalar {
73 type Output = Scalar;
74 fn sub(self, rhs: Scalar) -> Self::Output {
75 sub_extra(&self, &rhs, 0)
76 }
77}
78impl Default for Scalar {
79 fn default() -> Scalar {
80 Scalar::zero()
81 }
82}
83
84impl Scalar {
85 pub const fn one() -> Scalar {
86 Scalar([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
87 }
88 pub const fn zero() -> Scalar {
89 Scalar([0; 14])
90 }
91 pub(crate) fn div_by_four(&mut self) {
95 for i in 0..=12 {
96 self.0[i] = (self.0[i + 1] << 30) | (self.0[i] >> 2);
97 }
98 self.0[13] >>= 2
99 }
100 pub(crate) fn to_radix_16(&self) -> [i8; 113] {
105 let bytes = self.to_bytes();
106 let mut output = [0i8; 113];
107
108 #[inline(always)]
111 fn bot_half(x: u8) -> u8 {
112 (x >> 0) & 15
113 }
114 #[inline(always)]
115 fn top_half(x: u8) -> u8 {
116 (x >> 4) & 15
117 }
118
119 for i in 0..56 {
121 output[2 * i] = bot_half(bytes[i]) as i8;
122 output[2 * i + 1] = top_half(bytes[i]) as i8;
123 }
124 for i in 0..112 {
126 let carry = (output[i] + 8) >> 4;
127 output[i] -= carry << 4;
128 output[i + 1] += carry;
129 }
130
131 output
132 }
133 pub fn bits(&self) -> Vec<bool> {
135 let mut bits: Vec<bool> = Vec::with_capacity(14 * 32);
136 for limb in self.0.iter() {
139 for j in 0..32 {
141 bits.push(limb & (1 << j) != 0)
142 }
143 }
144
145 bits
147 }
148 pub fn from_bytes(bytes: [u8; 56]) -> Scalar {
149 let load7 = |input: &[u8]| -> u64 {
150 (input[0] as u64)
151 | ((input[1] as u64) << 8)
152 | ((input[2] as u64) << 16)
153 | ((input[3] as u64) << 24)
154 };
155
156 let mut res = Scalar::zero();
157 for i in 0..14 {
158 let out = load7(&bytes[i * 4..]);
160 res[i] = out as u32;
161 }
162
163 res
164 }
165 pub fn to_bytes(&self) -> [u8; 56] {
166 let mut res = [0u8; 56];
167
168 for i in 0..14 {
169 let mut l = self.0[i];
170 for j in 0..4 {
171 res[4 * i + j] = l as u8;
172 l >>= 8;
173 }
174 }
175 res
176 }
177 fn square(&self) -> Scalar {
178 montgomery_multiply(&self, &self)
179 }
180 pub fn invert(&self) -> Self {
181 let mut pre_comp: Vec<Scalar> = vec![Scalar::zero(); 8];
182 let mut result = Scalar::zero();
183
184 let scalar_window_bits = 3;
185 let last = (1 << scalar_window_bits) - 1;
186
187 pre_comp[0] = montgomery_multiply(self, &R2);
189
190 if last > 0 {
191 pre_comp[last] = montgomery_multiply(&pre_comp[0], &pre_comp[0]);
192 }
193
194 for i in 1..=last {
195 pre_comp[i] = montgomery_multiply(&pre_comp[i - 1], &pre_comp[last])
196 }
197
198 let mut residue: usize = 0;
200 let mut trailing: usize = 0;
201 let mut started: usize = 0;
202
203 let loop_start = -scalar_window_bits as isize;
205 let loop_end = 446 - 1;
206 for i in (loop_start..=loop_end).rev() {
207 if started != 0 {
208 result = result.square()
209 }
210
211 let mut w: u32;
212 if i >= 0 {
213 w = MODULUS[(i / 32) as usize];
214 } else {
215 w = 0;
216 }
217
218 if i >= 0 && i < 32 {
219 w -= 2
220 }
221
222 residue = (((residue as u32) << 1) | ((w >> ((i as u32) % 32)) & 1)) as usize;
223 if residue >> scalar_window_bits != 0 {
224 trailing = residue;
225 residue = 0
226 }
227
228 if trailing > 0 && (trailing & ((1 << scalar_window_bits) - 1)) == 0 {
229 if started != 0 {
230 result = montgomery_multiply(
231 &result,
232 &pre_comp[trailing >> (scalar_window_bits + 1)],
233 )
234 } else {
235 result = pre_comp[trailing >> (scalar_window_bits + 1)];
236 started = 1
237 }
238 trailing = 0
239 }
240 trailing <<= 1
241 }
242
243 montgomery_multiply(&result, &Scalar::one())
246 }
247
248 pub fn halve(&self) -> Self {
250 let mut result = Scalar::zero();
251
252 let mask = 0u32.wrapping_sub(self[0] & 1);
253 let mut chain = 0u64;
254
255 for i in 0..14 {
256 chain += (self[i] as u64) + ((MODULUS[i] & mask) as u64);
257 result[i] = chain as u32;
258 chain >>= 32
259 }
260
261 for i in 0..13 {
262 result[i] = (result[i] >> 1) | (result[i + 1] << 31);
263 }
264 result[13] = (result[13] >> 1) | ((chain << 31) as u32);
265
266 result
267 }
268
269 pub fn from_canonical_bytes(bytes: [u8; 57]) -> Option<Scalar> {
277 if bytes[56] != 0u8 || (bytes[55] >> 6) != 0u8 {
279 return None;
280 }
281 let bytes: [u8; 56] = std::array::from_fn(|i| bytes[i]);
282 let candidate = Scalar::from_bytes(bytes);
283
284 let reduced = sub_extra(&candidate, &MODULUS, 0);
285
286 if candidate == reduced {
287 Some(candidate)
288 } else {
289 None
290 }
291 }
292
293 pub fn to_bytes_rfc_8032(&self) -> [u8; 57] {
296 let bytes = self.to_bytes();
297 let res: [u8; 57] = std::array::from_fn(|i| if i < 56 { bytes[i] } else { 0 });
298 res
299 }
300
301 pub fn from_bytes_mod_order_wide(input: &[u8; 114]) -> Scalar {
304 let lo: [u8; 56] = std::array::from_fn(|i| input[i]);
305 let lo = Scalar::from_bytes(lo);
306 let lo = montgomery_multiply(&lo, &R);
309
310 let hi: [u8; 56] = std::array::from_fn(|i| input[i + 56]);
311 let hi = Scalar::from_bytes(hi);
312 let hi = montgomery_multiply(&hi, &R2);
314
315 let top: [u8; 56] = std::array::from_fn(|i| if i < 2 { input[i + 112] } else { 0 });
317 let top = Scalar::from_bytes(top);
318 let top = montgomery_multiply(&top, &R2);
320 let top = montgomery_multiply(&top, &R2);
322
323 add(&lo, &hi).add(top)
325 }
326
327 pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
337 let mut scalar_bytes = [0u8; 114];
338 rng.fill_bytes(&mut scalar_bytes);
339 Scalar::from_bytes_mod_order_wide(&scalar_bytes)
340 }
341}
342pub fn add(a: &Scalar, b: &Scalar) -> Scalar {
344 let mut result = Scalar::zero();
348
349 let mut chain = 0u64;
351 for i in 0..14 {
353 chain += (a[i] as u64) + (b[i] as u64);
354 result[i] = chain as u32;
356 chain >>= 32;
358 }
359
360 sub_extra(&result, &MODULUS, chain as u32)
362}
363
364fn sub_extra(a: &Scalar, b: &Scalar, carry: u32) -> Scalar {
367 let mut result = Scalar::zero();
368
369 let mut chain = 0i64;
371 for i in 0..14 {
372 chain += a[i] as i64 - b[i] as i64;
373 result[i] = chain as u32;
375 chain >>= 32
377 }
378
379 let borrow = chain + (carry as i64);
385 assert!(borrow == -1 || borrow == 0);
386
387 chain = 0i64;
388 for i in 0..14 {
389 chain += (result[i] as i64) + ((MODULUS[i] as i64) & borrow);
390 result[i] = chain as u32;
392 chain >>= 32;
394 }
395
396 result
397}
398
399fn montgomery_multiply(x: &Scalar, y: &Scalar) -> Scalar {
400 const MONTGOMERY_FACTOR: u32 = 0xae918bc5;
401
402 let mut result = Scalar::zero();
403 let mut carry = 0u32;
404
405 let mul_add = |a: u32, b: u32, c: u32| -> u64 { ((a as u64) * (b as u64)) + (c as u64) };
407
408 for i in 0..14 {
409 let mut chain = 0u64;
410 for j in 0..14 {
411 chain += mul_add(x[i], y[j], result[j]);
412 result[j] = chain as u32;
413 chain >>= 32;
414 }
415
416 let saved = chain as u32;
417 let multiplicand = result[0].wrapping_mul(MONTGOMERY_FACTOR);
418 chain = 0u64;
419
420 for j in 0..14 {
421 chain += mul_add(multiplicand, MODULUS[j], result[j]);
422 if j > 0 {
423 result[j - 1] = chain as u32;
424 }
425 chain >>= 32;
426 }
427 chain += (saved as u64) + (carry as u64);
428 result[14 - 1] = chain as u32;
429 carry = (chain >> 32) as u32;
430 }
431
432 sub_extra(&result, &MODULUS, carry)
433}
434#[cfg(test)]
435mod test {
436 use std::convert::TryInto;
437
438 use super::*;
439 #[test]
440 fn test_basic_add() {
441 let five = Scalar::from(5);
442 let six = Scalar::from(6);
443
444 assert_eq!(five + six, Scalar::from(11))
445 }
446
447 #[test]
448 fn test_basic_sub() {
449 let ten = Scalar::from(10);
450 let five = Scalar::from(5);
451 assert_eq!(ten - five, Scalar::from(5))
452 }
453
454 #[test]
455 fn test_basic_mul() {
456 let ten = Scalar::from(10);
457 let five = Scalar::from(5);
458
459 assert_eq!(ten * five, Scalar::from(50))
460 }
461
462 #[test]
463 fn test_mul() {
464 let a = Scalar([
465 0xffb823a3, 0xc96a3c35, 0x7f8ed27d, 0x087b8fb9, 0x1d9ac30a, 0x74d65764, 0xc0be082e,
466 0xa8cb0ae8, 0xa8fa552b, 0x2aae8688, 0x2c3dc273, 0x47cf8cac, 0x3b089f07, 0x1e63e807,
467 ]);
468
469 let b = Scalar([
470 0xd8bedc42, 0x686eb329, 0xe416b899, 0x17aa6d9b, 0x1e30b38b, 0x188c6b1a, 0xd099595b,
471 0xbc343bcb, 0x1adaa0e7, 0x24e8d499, 0x8e59b308, 0x0a92de2d, 0xcae1cb68, 0x16c5450a,
472 ]);
473
474 let exp = Scalar([
475 0xa18d010a, 0x1f5b3197, 0x994c9c2b, 0x6abd26f5, 0x08a3a0e4, 0x36a14920, 0x74e9335f,
476 0x07bcd931, 0xf2d89c1e, 0xb9036ff6, 0x203d424b, 0xfccd61b3, 0x4ca389ed, 0x31e055c1,
477 ]);
478
479 assert_eq!(a * b, exp)
480 }
481 #[test]
482 fn test_basic_square() {
483 let a = Scalar([
484 0xcf5fac3d, 0x7e56a34b, 0xf640922b, 0x3fa50692, 0x1370f8b8, 0x6f08f331, 0x8dccc486,
485 0x4bb395e0, 0xf22c6951, 0x21cc3078, 0xd2391f9d, 0x930392e5, 0x04b3273b, 0x31620816,
486 ]);
487 let expected_a_squared = Scalar([
488 0x15598f62, 0xb9b1ed71, 0x52fcd042, 0x862a9f10, 0x1e8a309f, 0x9988f8e0, 0xa22347d7,
489 0xe9ab2c22, 0x38363f74, 0xfd7c58aa, 0xc49a1433, 0xd9a6c4c3, 0x75d3395e, 0x0d79f6e3,
490 ]);
491
492 assert_eq!(a.square(), expected_a_squared)
493 }
494
495 #[test]
496 fn test_sanity_check_index_mut() {
497 let mut x = Scalar::one();
498 x[0] = 2u32;
499 assert_eq!(x, Scalar::from(2))
500 }
501 #[test]
502 fn test_basic_halving() {
503 let eight = Scalar::from(8);
504 let four = Scalar::from(4);
505 let two = Scalar::from(2);
506 assert_eq!(eight.halve(), four);
507 assert_eq!(four.halve(), two);
508 assert_eq!(two.halve(), Scalar::one());
509 }
510
511 #[test]
512 fn test_equals() {
513 let a = Scalar::from(5);
514 let b = Scalar::from(5);
515 let c = Scalar::from(10);
516 assert!(a == b);
517 assert!(!(a == c))
518 }
519
520 #[test]
521 fn test_basic_inversion() {
522 for i in 1..=100 {
524 let x = Scalar::from(i);
525 let x_inv = x.invert();
526 assert_eq!(x_inv * x, Scalar::one())
527 }
528
529 let zero = Scalar::zero();
531 let expected_zero = zero.invert();
532 assert_eq!(expected_zero, zero)
533 }
534 #[test]
535 fn test_serialise() {
536 let scalar = Scalar([
537 0x15598f62, 0xb9b1ed71, 0x52fcd042, 0x862a9f10, 0x1e8a309f, 0x9988f8e0, 0xa22347d7,
538 0xe9ab2c22, 0x38363f74, 0xfd7c58aa, 0xc49a1433, 0xd9a6c4c3, 0x75d3395e, 0x0d79f6e3,
539 ]);
540 let got = Scalar::from_bytes(scalar.to_bytes());
541 assert_eq!(scalar, got)
542 }
543 #[test]
544 fn test_debug() {
545 let k = Scalar([
546 200, 210, 250, 145, 130, 180, 147, 122, 222, 230, 214, 247, 203, 32,
547 ]);
548 let s = k;
549 dbg!(&s.to_radix_16()[..]);
550 }
551 #[test]
552 fn test_from_canonical_bytes() {
553 let mut bytes: [u8; 57] = hex::decode("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap().try_into().unwrap();
555 bytes.reverse();
556 let s = Scalar::from_canonical_bytes(bytes);
557 assert_eq!(s, None);
558
559 let mut bytes: [u8; 57] = hex::decode("003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f3").unwrap().try_into().unwrap();
561 bytes.reverse();
562 let s = Scalar::from_canonical_bytes(bytes);
563 assert_eq!(s, None);
564
565 let mut bytes: [u8; 57] = hex::decode("003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f2").unwrap().try_into().unwrap();
567 bytes.reverse();
568 let s = Scalar::from_canonical_bytes(bytes);
569 match s {
570 Some(s) => assert_eq!(s, Scalar::zero() - Scalar::one()),
571 None => panic!("should not return None"),
572 };
573 }
574
575 #[test]
576 fn test_from_bytes_mod_order_wide() {
577 let mut bytes: [u8; 114] = hex::decode("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f3").unwrap().try_into().unwrap();
579 bytes.reverse();
580 let s = Scalar::from_bytes_mod_order_wide(&bytes);
581 assert_eq!(s, Scalar::zero());
582
583 let mut bytes: [u8; 114] = hex::decode("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f2").unwrap().try_into().unwrap();
585 bytes.reverse();
586 let s = Scalar::from_bytes_mod_order_wide(&bytes);
587 assert_eq!(s, Scalar::zero() - Scalar::one());
588
589 let mut bytes: [u8; 114] = hex::decode("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f4").unwrap().try_into().unwrap();
591 bytes.reverse();
592 let s = Scalar::from_bytes_mod_order_wide(&bytes);
593 assert_eq!(s, Scalar::one());
594
595 let mut bytes: [u8; 114] = hex::decode("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap().try_into().unwrap();
597 bytes.reverse();
598 let s = Scalar::from_bytes_mod_order_wide(&bytes);
599 let mut bytes: [u8; 57] = hex::decode("002939f823b7292052bcb7e4d070af1a9cc14ba3c47c44ae17cf72c985bb24b6c520e319fb37a63e29800f160787ad1d2e11883fa931e7de81").unwrap().try_into().unwrap();
600 bytes.reverse();
601 let reduced = Scalar::from_canonical_bytes(bytes).unwrap();
602 assert_eq!(s, reduced);
603 }
604
605 #[test]
606 fn test_to_bytes_rfc8032() {
607 let mut bytes: [u8; 57] = hex::decode("003fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f2").unwrap().try_into().unwrap();
609 bytes.reverse();
610 let x = Scalar::zero() - Scalar::one();
611 let candidate = x.to_bytes_rfc_8032();
612 assert_eq!(bytes, candidate);
613 }
614}