1#![allow(clippy::needless_range_loop)]
14
15#[macro_use]
16mod macros;
17
18#[cfg(feature = "alloc")]
19pub(crate) mod boxed;
20
21use crate::{ConstChoice, ConstCtOption, Inverter, Limb, Odd, Uint, Word};
22use subtle::CtOption;
23
24#[derive(Clone, Debug)]
49pub struct SafeGcdInverter<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize> {
50 pub(super) modulus: UnsatInt<UNSAT_LIMBS>,
52
53 adjuster: UnsatInt<UNSAT_LIMBS>,
55
56 inverse: i64,
58}
59
60type Matrix = [[i64; 2]; 2];
62
63impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize> SafeGcdInverter<SAT_LIMBS, UNSAT_LIMBS> {
64 pub const fn new(modulus: &Odd<Uint<SAT_LIMBS>>, adjuster: &Uint<SAT_LIMBS>) -> Self {
68 Self {
69 modulus: UnsatInt::from_uint(&modulus.0),
70 adjuster: UnsatInt::from_uint(adjuster),
71 inverse: inv_mod2_62(modulus.0.as_words()),
72 }
73 }
74
75 pub const fn inv(&self, value: &Uint<SAT_LIMBS>) -> ConstCtOption<Uint<SAT_LIMBS>> {
78 let (d, f) = divsteps(
79 self.adjuster,
80 self.modulus,
81 UnsatInt::from_uint(value),
82 self.inverse,
83 );
84
85 let antiunit = f.eq(&UnsatInt::MINUS_ONE);
89 let ret = self.norm(d, antiunit);
90 let is_some = f.eq(&UnsatInt::ONE).or(antiunit);
91 ConstCtOption::new(ret.to_uint(), is_some)
92 }
93
94 pub const fn inv_vartime(&self, value: &Uint<SAT_LIMBS>) -> ConstCtOption<Uint<SAT_LIMBS>> {
99 let (d, f) = divsteps_vartime(
100 self.adjuster,
101 self.modulus,
102 UnsatInt::from_uint(value),
103 self.inverse,
104 );
105
106 let antiunit = f.eq(&UnsatInt::MINUS_ONE);
110 let ret = self.norm(d, antiunit);
111 let is_some = f.eq(&UnsatInt::ONE).or(antiunit);
112 ConstCtOption::new(ret.to_uint(), is_some)
113 }
114
115 pub(crate) const fn gcd(f: &Uint<SAT_LIMBS>, g: &Uint<SAT_LIMBS>) -> Uint<SAT_LIMBS> {
121 let inverse = inv_mod2_62(f.as_words());
122 let e = UnsatInt::<UNSAT_LIMBS>::ONE;
123 let f = UnsatInt::from_uint(f);
124 let g = UnsatInt::from_uint(g);
125 let (_, mut f) = divsteps(e, f, g, inverse);
126 f = UnsatInt::select(&f, &f.neg(), f.is_negative());
127 f.to_uint()
128 }
129
130 pub(crate) const fn gcd_vartime(f: &Uint<SAT_LIMBS>, g: &Uint<SAT_LIMBS>) -> Uint<SAT_LIMBS> {
134 let inverse = inv_mod2_62(f.as_words());
135 let e = UnsatInt::<UNSAT_LIMBS>::ONE;
136 let f = UnsatInt::from_uint(f);
137 let g = UnsatInt::from_uint(g);
138 let (_, mut f) = divsteps_vartime(e, f, g, inverse);
139 f = UnsatInt::select(&f, &f.neg(), f.is_negative());
140 f.to_uint()
141 }
142
143 const fn norm(
147 &self,
148 mut value: UnsatInt<UNSAT_LIMBS>,
149 negate: ConstChoice,
150 ) -> UnsatInt<UNSAT_LIMBS> {
151 value = UnsatInt::select(&value, &value.add(&self.modulus), value.is_negative());
152 value = UnsatInt::select(&value, &value.neg(), negate);
153 value = UnsatInt::select(&value, &value.add(&self.modulus), value.is_negative());
154 value
155 }
156}
157
158impl<const SAT_LIMBS: usize, const UNSAT_LIMBS: usize> Inverter
159 for SafeGcdInverter<SAT_LIMBS, UNSAT_LIMBS>
160{
161 type Output = Uint<SAT_LIMBS>;
162
163 fn invert(&self, value: &Uint<SAT_LIMBS>) -> CtOption<Self::Output> {
164 self.inv(value).into()
165 }
166
167 fn invert_vartime(&self, value: &Uint<SAT_LIMBS>) -> CtOption<Self::Output> {
168 self.inv_vartime(value).into()
169 }
170}
171
172const fn inv_mod2_62(value: &[Word]) -> i64 {
182 let value = {
183 #[cfg(target_pointer_width = "32")]
184 {
185 debug_assert!(value.len() >= 1);
186 let mut ret = value[0] as u64;
187
188 if value.len() >= 2 {
189 ret |= (value[1] as u64) << 32;
190 }
191
192 ret
193 }
194
195 #[cfg(target_pointer_width = "64")]
196 {
197 value[0]
198 }
199 };
200
201 let x = value.wrapping_mul(3) ^ 2;
202 let y = 1u64.wrapping_sub(x.wrapping_mul(value));
203 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
204 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
205 let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
206 (x.wrapping_mul(y.wrapping_add(1)) & (u64::MAX >> 2)) as i64
207}
208
209const fn divsteps<const LIMBS: usize>(
215 mut e: UnsatInt<LIMBS>,
216 f_0: UnsatInt<LIMBS>,
217 mut g: UnsatInt<LIMBS>,
218 inverse: i64,
219) -> (UnsatInt<LIMBS>, UnsatInt<LIMBS>) {
220 let mut d = UnsatInt::ZERO;
221 let mut f = f_0;
222 let mut delta = 1;
223 let mut matrix;
224 let mut i = 0;
225 let m = iterations(f_0.bits(), g.bits());
226
227 while i < m {
228 (delta, matrix) = jump(&f.0, &g.0, delta);
229 (f, g) = fg(f, g, matrix);
230 (d, e) = de(&f_0, inverse, matrix, d, e);
231 i += 1;
232 }
233
234 debug_assert!(g.eq(&UnsatInt::ZERO).to_bool_vartime());
235 (d, f)
236}
237
238const fn divsteps_vartime<const LIMBS: usize>(
243 mut e: UnsatInt<LIMBS>,
244 f_0: UnsatInt<LIMBS>,
245 mut g: UnsatInt<LIMBS>,
246 inverse: i64,
247) -> (UnsatInt<LIMBS>, UnsatInt<LIMBS>) {
248 let mut d = UnsatInt::ZERO;
249 let mut f = f_0;
250 let mut delta = 1;
251 let mut matrix;
252
253 while !g.eq(&UnsatInt::ZERO).to_bool_vartime() {
254 (delta, matrix) = jump(&f.0, &g.0, delta);
255 (f, g) = fg(f, g, matrix);
256 (d, e) = de(&f_0, inverse, matrix, d, e);
257 }
258
259 (d, f)
260}
261
262const fn jump(f: &[u64], g: &[u64], mut delta: i64) -> (i64, Matrix) {
266 const fn min(a: i64, b: i64) -> i64 {
268 if a > b {
269 b
270 } else {
271 a
272 }
273 }
274
275 let (mut steps, mut f, mut g) = (62, f[0] as i64, g[0] as i128);
276 let mut t: Matrix = [[1, 0], [0, 1]];
277
278 loop {
279 let zeros = min(steps, g.trailing_zeros() as i64);
280 (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
281 t[0] = [t[0][0] << zeros, t[0][1] << zeros];
282
283 if steps == 0 {
284 break;
285 }
286 if delta > 0 {
287 (delta, f, g) = (-delta, g as i64, -f as i128);
288 (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
289 }
290
291 let mask = (1 << min(min(steps, 1 - delta), 5)) - 1;
295 let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
296
297 t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
298 g += w as i128 * f as i128;
299 }
300
301 (delta, t)
302}
303
304const fn fg<const LIMBS: usize>(
309 f: UnsatInt<LIMBS>,
310 g: UnsatInt<LIMBS>,
311 t: Matrix,
312) -> (UnsatInt<LIMBS>, UnsatInt<LIMBS>) {
313 (
314 f.mul(t[0][0]).add(&g.mul(t[0][1])).shr(),
315 f.mul(t[1][0]).add(&g.mul(t[1][1])).shr(),
316 )
317}
318
319const fn de<const LIMBS: usize>(
327 modulus: &UnsatInt<LIMBS>,
328 inverse: i64,
329 t: Matrix,
330 d: UnsatInt<LIMBS>,
331 e: UnsatInt<LIMBS>,
332) -> (UnsatInt<LIMBS>, UnsatInt<LIMBS>) {
333 let mask = UnsatInt::<LIMBS>::MASK as i64;
334 let mut md =
335 t[0][0] * d.is_negative().to_u8() as i64 + t[0][1] * e.is_negative().to_u8() as i64;
336 let mut me =
337 t[1][0] * d.is_negative().to_u8() as i64 + t[1][1] * e.is_negative().to_u8() as i64;
338
339 let cd = t[0][0]
340 .wrapping_mul(d.lowest() as i64)
341 .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
342 & mask;
343
344 let ce = t[1][0]
345 .wrapping_mul(d.lowest() as i64)
346 .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
347 & mask;
348
349 md -= (inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
350 me -= (inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
351
352 let cd = d.mul(t[0][0]).add(&e.mul(t[0][1])).add(&modulus.mul(md));
353 let ce = d.mul(t[1][0]).add(&e.mul(t[1][1])).add(&modulus.mul(me));
354
355 (cd.shr(), ce.shr())
356}
357
358pub(crate) const fn iterations(f_bits: u32, g_bits: u32) -> usize {
368 let d = ConstChoice::from_u32_lt(f_bits, g_bits).select_u32(f_bits, g_bits);
370 let addend = ConstChoice::from_u32_lt(d, 46).select_u32(57, 80);
371 ((49 * d + addend) / 17) as usize
372}
373
374#[derive(Clone, Copy, Debug)]
379pub(super) struct UnsatInt<const LIMBS: usize>(pub [u64; LIMBS]);
380
381impl<const LIMBS: usize> UnsatInt<LIMBS> {
382 pub const LIMB_BITS: usize = 62;
384
385 pub const MASK: u64 = u64::MAX >> (64 - Self::LIMB_BITS);
387
388 pub const MINUS_ONE: Self = Self([Self::MASK; LIMBS]);
390
391 pub const ZERO: Self = Self([0; LIMBS]);
393
394 pub const ONE: Self = {
396 let mut ret = Self::ZERO;
397 ret.0[0] = 1;
398 ret
399 };
400
401 #[allow(trivial_numeric_casts)]
409 pub const fn from_uint<const SAT_LIMBS: usize>(input: &Uint<SAT_LIMBS>) -> Self {
410 if LIMBS != safegcd_nlimbs!(SAT_LIMBS * Limb::BITS as usize) {
411 panic!("incorrect number of limbs");
412 }
413
414 let mut output = [0; LIMBS];
415 impl_limb_convert!(Word, Word::BITS as usize, input.as_words(), u64, 62, output);
416
417 Self(output)
418 }
419
420 #[allow(trivial_numeric_casts, clippy::wrong_self_convention)]
428 pub const fn to_uint<const SAT_LIMBS: usize>(&self) -> Uint<SAT_LIMBS> {
429 debug_assert!(
430 !self.is_negative().to_bool_vartime(),
431 "can't convert negative number to Uint"
432 );
433
434 if LIMBS != safegcd_nlimbs!(SAT_LIMBS * Limb::BITS as usize) {
435 panic!("incorrect number of limbs");
436 }
437
438 let mut ret = [0 as Word; SAT_LIMBS];
439 impl_limb_convert!(u64, 62, &self.0, Word, Word::BITS as usize, ret);
440 Uint::from_words(ret)
441 }
442
443 pub const fn add(&self, other: &Self) -> Self {
445 let (mut ret, mut carry) = (Self::ZERO, 0);
446 let mut i = 0;
447
448 while i < LIMBS {
449 let sum = self.0[i] + other.0[i] + carry;
450 ret.0[i] = sum & Self::MASK;
451 carry = sum >> Self::LIMB_BITS;
452 i += 1;
453 }
454
455 ret
456 }
457
458 pub const fn mul(&self, other: i64) -> Self {
460 let mut ret = Self::ZERO;
461 let (other, mut carry, mask) = if other < 0 {
475 (-other, -other as u64, Self::MASK)
476 } else {
477 (other, 0, 0)
478 };
479
480 let mut i = 0;
481 while i < LIMBS {
482 let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
483 ret.0[i] = sum as u64 & Self::MASK;
484 carry = (sum >> Self::LIMB_BITS) as u64;
485 i += 1;
486 }
487
488 ret
489 }
490
491 pub const fn neg(&self) -> Self {
493 let (mut ret, mut carry) = (Self::ZERO, 1);
496 let mut i = 0;
497
498 while i < LIMBS {
499 let sum = (self.0[i] ^ Self::MASK) + carry;
500 ret.0[i] = sum & Self::MASK;
501 carry = sum >> Self::LIMB_BITS;
502 i += 1;
503 }
504
505 ret
506 }
507
508 pub const fn shr(&self) -> Self {
510 let mut ret = Self::ZERO;
511 ret.0[LIMBS - 1] = self.is_negative().select_u64(ret.0[LIMBS - 1], Self::MASK);
512
513 let mut i = 0;
514 while i < LIMBS - 1 {
515 ret.0[i] = self.0[i + 1];
516 i += 1;
517 }
518
519 ret
520 }
521
522 pub const fn eq(&self, other: &Self) -> ConstChoice {
524 let mut ret = ConstChoice::TRUE;
525 let mut i = 0;
526
527 while i < LIMBS {
528 ret = ret.and(ConstChoice::from_u64_eq(self.0[i], other.0[i]));
529 i += 1;
530 }
531
532 ret
533 }
534
535 pub const fn is_negative(&self) -> ConstChoice {
537 ConstChoice::from_u64_gt(self.0[LIMBS - 1], Self::MASK >> 1)
538 }
539
540 pub const fn lowest(&self) -> u64 {
542 self.0[0]
543 }
544
545 pub const fn select(a: &Self, b: &Self, choice: ConstChoice) -> Self {
547 let mut ret = Self::ZERO;
548 let mut i = 0;
549
550 while i < LIMBS {
551 ret.0[i] = choice.select_u64(a.0[i], b.0[i]);
552 i += 1;
553 }
554
555 ret
556 }
557
558 pub const fn leading_zeros(&self) -> u32 {
560 let mut count = 0;
561 let mut i = LIMBS;
562 let mut nonzero_limb_not_encountered = ConstChoice::TRUE;
563
564 while i > 0 {
565 i -= 1;
566 let l = self.0[i];
567 let z = l.leading_zeros() - 2;
568 count += nonzero_limb_not_encountered.if_true_u32(z);
569 nonzero_limb_not_encountered =
570 nonzero_limb_not_encountered.and(ConstChoice::from_u64_nonzero(l).not());
571 }
572
573 count
574 }
575
576 pub const fn bits(&self) -> u32 {
578 (LIMBS as u32 * 62) - self.leading_zeros()
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::iterations;
585 use crate::{Inverter, PrecomputeInverter, U256};
586
587 type UnsatInt = super::UnsatInt<4>;
588
589 impl<const LIMBS: usize> PartialEq for crate::modular::safegcd::UnsatInt<LIMBS> {
590 fn eq(&self, other: &Self) -> bool {
591 self.eq(other).to_bool_vartime()
592 }
593 }
594
595 #[test]
596 fn invert() {
597 let g =
598 U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
599 let modulus =
600 U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
601 .to_odd()
602 .unwrap();
603 let inverter = modulus.precompute_inverter();
604 let result = inverter.invert(&g).unwrap();
605 assert_eq!(
606 U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
607 result
608 );
609 }
610
611 #[test]
612 fn iterations_boundary_conditions() {
613 assert_eq!(iterations(0, 0), 4);
614 assert_eq!(iterations(0, 45), 134);
615 assert_eq!(iterations(0, 46), 135);
616 }
617
618 #[test]
619 fn unsatint_add() {
620 assert_eq!(UnsatInt::ZERO, UnsatInt::ZERO.add(&UnsatInt::ZERO));
621 assert_eq!(UnsatInt::ONE, UnsatInt::ONE.add(&UnsatInt::ZERO));
622 assert_eq!(UnsatInt::ZERO, UnsatInt::MINUS_ONE.add(&UnsatInt::ONE));
623 }
624
625 #[test]
626 fn unsatint_mul() {
627 assert_eq!(UnsatInt::ZERO, UnsatInt::ZERO.mul(0));
628 assert_eq!(UnsatInt::ZERO, UnsatInt::ZERO.mul(1));
629 assert_eq!(UnsatInt::ZERO, UnsatInt::ONE.mul(0));
630 assert_eq!(UnsatInt::ZERO, UnsatInt::MINUS_ONE.mul(0));
631 assert_eq!(UnsatInt::ONE, UnsatInt::ONE.mul(1));
632 assert_eq!(UnsatInt::MINUS_ONE, UnsatInt::MINUS_ONE.mul(1));
633 }
634
635 #[test]
636 fn unsatint_neg() {
637 assert_eq!(UnsatInt::ZERO, UnsatInt::ZERO.neg());
638 assert_eq!(UnsatInt::MINUS_ONE, UnsatInt::ONE.neg());
639 assert_eq!(UnsatInt::ONE, UnsatInt::MINUS_ONE.neg());
640 }
641
642 #[test]
643 fn unsatint_is_negative() {
644 assert!(!UnsatInt::ZERO.is_negative().to_bool_vartime());
645 assert!(!UnsatInt::ONE.is_negative().to_bool_vartime());
646 assert!(UnsatInt::MINUS_ONE.is_negative().to_bool_vartime());
647 }
648
649 #[test]
650 fn unsatint_shr() {
651 let n = super::UnsatInt([
652 0,
653 1211048314408256470,
654 1344008336933394898,
655 3913497193346473913,
656 2764114971089162538,
657 4,
658 ]);
659
660 assert_eq!(
661 &n.shr().0,
662 &[
663 1211048314408256470,
664 1344008336933394898,
665 3913497193346473913,
666 2764114971089162538,
667 4,
668 0
669 ]
670 );
671 }
672}