1use crate::reduced::{impl_reduced_binary_pow, Vanilla};
31use crate::{DivExact, ModularUnaryOps, Reducer};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct PreMulInv1by1<T> {
39 m: T,
43
44 shift: u32,
46}
47
48macro_rules! impl_premulinv_1by1_for {
49 ($T:ty) => {
50 impl PreMulInv1by1<$T> {
51 pub const fn new(divisor: $T) -> Self {
52 debug_assert!(divisor > 1);
53
54 let n = <$T>::BITS - (divisor - 1).leading_zeros();
56
57 let (lo, _hi) = split(merge(0, ones(n) - (divisor - 1)) / extend(divisor));
74 debug_assert!(_hi == 0);
75 Self {
76 shift: n - 1,
77 m: lo + 1,
78 }
79 }
80
81 #[inline]
83 pub const fn div_rem(&self, a: $T, d: $T) -> ($T, $T) {
84 let (_, t) = split(wmul(self.m, a));
101 let q = (t + ((a - t) >> 1)) >> self.shift;
103 let r = a - q * d;
104 (q, r)
105 }
106 }
107
108 impl DivExact<$T, PreMulInv1by1<$T>> for $T {
109 type Output = $T;
110
111 #[inline]
112 fn div_exact(self, d: $T, pre: &PreMulInv1by1<$T>) -> Option<Self::Output> {
113 let (q, r) = pre.div_rem(self, d);
114 if r == 0 {
115 Some(q)
116 } else {
117 None
118 }
119 }
120 }
121 };
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub struct Normalized2by1Divisor<T> {
131 divisor: T,
133
134 m: T,
136}
137
138macro_rules! impl_normdiv_2by1_for {
139 ($T:ty, $D:ty) => {
140 impl Normalized2by1Divisor<$T> {
141 #[inline]
146 pub const fn invert_word(divisor: $T) -> $T {
147 let (m, _hi) = split(<$D>::MAX / extend(divisor));
148 debug_assert!(_hi == 1);
149 m
150 }
151
152 #[inline]
156 pub const fn new(divisor: $T) -> Self {
157 assert!(divisor.leading_zeros() == 0);
158 Self {
159 divisor,
160 m: Self::invert_word(divisor),
161 }
162 }
163
164 #[inline]
166 pub const fn div_rem_1by1(&self, a: $T) -> ($T, $T) {
167 if a < self.divisor {
168 (0, a)
169 } else {
170 (1, a - self.divisor) }
172 }
173
174 #[inline]
177 pub const fn div_rem_2by1(&self, a: $D) -> ($T, $T) {
178 let (a_lo, a_hi) = split(a);
179 debug_assert!(a_hi < self.divisor);
180
181 let (q0, q1) = split(wmul(self.m, a_hi) + a);
185
186 let q = q1.wrapping_add(1);
189 let r = a_lo.wrapping_sub(q.wrapping_mul(self.divisor));
190
191 let (_, decrease) = split(extend(q0).wrapping_sub(extend(r)));
225 let mut q = q.wrapping_add(decrease);
226 let mut r = r.wrapping_add(decrease & self.divisor);
227
228 if r >= self.divisor {
231 q += 1;
232 r -= self.divisor;
233 }
234
235 (q, r)
236 }
237 }
238 };
239}
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub struct PreMulInv2by1<T> {
244 div: Normalized2by1Divisor<T>,
245 shift: u32,
246}
247
248impl<T> PreMulInv2by1<T> {
249 #[inline]
250 pub const fn divider(&self) -> &Normalized2by1Divisor<T> {
251 &self.div
252 }
253 #[inline]
254 pub const fn shift(&self) -> u32 {
255 self.shift
256 }
257}
258
259macro_rules! impl_premulinv_2by1_reducer_for {
260 ($T:ty) => {
261 impl PreMulInv2by1<$T> {
262 #[inline]
263 pub const fn new(divisor: $T) -> Self {
264 let shift = divisor.leading_zeros();
265 let div = Normalized2by1Divisor::<$T>::new(divisor << shift);
266 Self { div, shift }
267 }
268
269 #[inline]
271 pub const fn divisor(&self) -> $T {
272 self.div.divisor
273 }
274 }
275
276 impl Reducer<$T> for PreMulInv2by1<$T> {
277 #[inline]
278 fn new(m: &$T) -> Self {
279 PreMulInv2by1::<$T>::new(*m)
280 }
281 #[inline]
282 fn transform(&self, target: $T) -> $T {
283 if self.shift == 0 {
284 self.div.div_rem_1by1(target).1
285 } else {
286 self.div.div_rem_2by1(extend(target) << self.shift).1
287 }
288 }
289 #[inline]
290 fn check(&self, target: &$T) -> bool {
291 *target < self.div.divisor && target & ones(self.shift) == 0
292 }
293 #[inline]
294 fn residue(&self, target: $T) -> $T {
295 target >> self.shift
296 }
297 #[inline]
298 fn modulus(&self) -> $T {
299 self.div.divisor >> self.shift
300 }
301 #[inline]
302 fn is_zero(&self, target: &$T) -> bool {
303 *target == 0
304 }
305
306 #[inline(always)]
307 fn add(&self, lhs: &$T, rhs: &$T) -> $T {
308 Vanilla::<$T>::add(&self.div.divisor, *lhs, *rhs)
309 }
310 #[inline(always)]
311 fn dbl(&self, target: $T) -> $T {
312 Vanilla::<$T>::dbl(&self.div.divisor, target)
313 }
314 #[inline(always)]
315 fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
316 Vanilla::<$T>::sub(&self.div.divisor, *lhs, *rhs)
317 }
318 #[inline(always)]
319 fn neg(&self, target: $T) -> $T {
320 Vanilla::<$T>::neg(&self.div.divisor, target)
321 }
322
323 #[inline(always)]
324 fn inv(&self, target: $T) -> Option<$T> {
325 self.residue(target)
326 .invm(&self.modulus())
327 .map(|v| v << self.shift)
328 }
329 #[inline]
330 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
331 self.div.div_rem_2by1(wmul(lhs >> self.shift, *rhs)).1
332 }
333 #[inline]
334 fn sqr(&self, target: $T) -> $T {
335 self.div.div_rem_2by1(wsqr(target) >> self.shift).1
336 }
337
338 impl_reduced_binary_pow!($T);
339 }
340 };
341}
342
343#[derive(Debug, Clone, Copy, PartialEq, Eq)]
350pub struct Normalized3by2Divisor<T, D> {
351 divisor: D,
353
354 m: T,
356}
357
358macro_rules! impl_normdiv_3by2_for {
359 ($T:ty, $D:ty) => {
360 impl Normalized3by2Divisor<$T, $D> {
361 #[inline]
367 pub const fn invert_double_word(divisor: $D) -> $T {
368 let (d0, d1) = split(divisor);
369 let mut v = Normalized2by1Divisor::<$T>::invert_word(d1);
370 let (mut p, c) = d1.wrapping_mul(v).overflowing_add(d0);
373 if c {
374 v -= 1;
375 if p >= d1 {
376 v -= 1;
377 p -= d1;
378 }
379 p = p.wrapping_sub(d1);
380 }
381 let (t0, t1) = split(extend(v) * extend(d0));
384 let (p, c) = p.overflowing_add(t1);
385 if c {
386 v -= 1;
387 if merge(t0, p) >= divisor {
388 v -= 1;
389 }
390 }
391
392 v
393 }
394
395 #[inline]
399 pub const fn new(divisor: $D) -> Self {
400 assert!(divisor.leading_zeros() == 0);
401 Self {
402 divisor,
403 m: Self::invert_double_word(divisor),
404 }
405 }
406
407 #[inline]
408 pub const fn div_rem_2by2(&self, a: $D) -> ($D, $D) {
409 if a < self.divisor {
410 (0, a)
411 } else {
412 (1, a - self.divisor) }
414 }
415
416 pub const fn div_rem_3by2(&self, a_lo: $T, a_hi: $D) -> ($T, $D) {
419 debug_assert!(a_hi < self.divisor);
420 let (a1, a2) = split(a_hi);
421 let (d0, d1) = split(self.divisor);
422
423 let (q0, q1) = split(wmul(self.m, a2) + a_hi);
425 let r1 = a1.wrapping_sub(q1.wrapping_mul(d1));
426 let t = wmul(d0, q1);
427 let r = merge(a_lo, r1).wrapping_sub(t).wrapping_sub(self.divisor);
428
429 let (_, r1) = split(r);
434 let (_, decrease) = split(extend(r1).wrapping_sub(extend(q0)));
435 let mut q1 = q1.wrapping_sub(decrease);
436 let mut r = r.wrapping_add(merge(!decrease, !decrease) & self.divisor);
437
438 if r >= self.divisor {
440 q1 += 1;
441 r -= self.divisor;
442 }
443
444 (q1, r)
445 }
446
447 pub const fn div_rem_4by2(&self, a_lo: $D, a_hi: $D) -> ($D, $D) {
451 let (a0, a1) = split(a_lo);
452 let (q1, r1) = self.div_rem_3by2(a1, a_hi);
453 let (q0, r0) = self.div_rem_3by2(a0, r1);
454 (merge(q0, q1), r0)
455 }
456 }
457 };
458}
459
460#[derive(Debug, Clone, Copy, PartialEq, Eq)]
462pub struct PreMulInv3by2<T, D> {
463 div: Normalized3by2Divisor<T, D>,
464 shift: u32,
465}
466
467impl<T, D> PreMulInv3by2<T, D> {
468 #[inline]
469 pub const fn divider(&self) -> &Normalized3by2Divisor<T, D> {
470 &self.div
471 }
472 #[inline]
473 pub const fn shift(&self) -> u32 {
474 self.shift
475 }
476}
477
478macro_rules! impl_premulinv_3by2_reducer_for {
479 ($T:ty, $D:ty) => {
480 impl PreMulInv3by2<$T, $D> {
481 #[inline]
482 pub const fn new(divisor: $D) -> Self {
483 let shift = divisor.leading_zeros();
484 let div = Normalized3by2Divisor::<$T, $D>::new(divisor << shift);
485 Self { div, shift }
486 }
487
488 #[inline]
490 pub const fn divisor(&self) -> $D {
491 self.div.divisor
492 }
493 }
494
495 impl Reducer<$D> for PreMulInv3by2<$T, $D> {
496 #[inline]
497 fn new(m: &$D) -> Self {
498 assert!(*m > <$T>::MAX as $D);
499 let shift = m.leading_zeros();
500 let div = Normalized3by2Divisor::<$T, $D>::new(m << shift);
501 Self { div, shift }
502 }
503 #[inline]
504 fn transform(&self, target: $D) -> $D {
505 if self.shift == 0 {
506 self.div.div_rem_2by2(target).1
507 } else {
508 let (lo, hi) = split(target);
509 let (n0, carry) = split(extend(lo) << self.shift);
510 let n12 = (extend(hi) << self.shift) | extend(carry);
511 self.div.div_rem_3by2(n0, n12).1
512 }
513 }
514 #[inline]
515 fn check(&self, target: &$D) -> bool {
516 *target < self.div.divisor && split(*target).0 & ones(self.shift) == 0
517 }
518 #[inline]
519 fn residue(&self, target: $D) -> $D {
520 target >> self.shift
521 }
522 #[inline]
523 fn modulus(&self) -> $D {
524 self.div.divisor >> self.shift
525 }
526 #[inline]
527 fn is_zero(&self, target: &$D) -> bool {
528 *target == 0
529 }
530
531 #[inline(always)]
532 fn add(&self, lhs: &$D, rhs: &$D) -> $D {
533 Vanilla::<$D>::add(&self.div.divisor, *lhs, *rhs)
534 }
535 #[inline(always)]
536 fn dbl(&self, target: $D) -> $D {
537 Vanilla::<$D>::dbl(&self.div.divisor, target)
538 }
539 #[inline(always)]
540 fn sub(&self, lhs: &$D, rhs: &$D) -> $D {
541 Vanilla::<$D>::sub(&self.div.divisor, *lhs, *rhs)
542 }
543 #[inline(always)]
544 fn neg(&self, target: $D) -> $D {
545 Vanilla::<$D>::neg(&self.div.divisor, target)
546 }
547
548 #[inline(always)]
549 fn inv(&self, target: $D) -> Option<$D> {
550 self.residue(target)
551 .invm(&self.modulus())
552 .map(|v| v << self.shift)
553 }
554 #[inline]
555 fn mul(&self, lhs: &$D, rhs: &$D) -> $D {
556 let prod = DoubleWordModule::wmul(lhs >> self.shift, *rhs);
557 let (lo, hi) = DoubleWordModule::split(prod);
558 self.div.div_rem_4by2(lo, hi).1
559 }
560 #[inline]
561 fn sqr(&self, target: $D) -> $D {
562 let prod = DoubleWordModule::wsqr(target) >> self.shift;
563 let (lo, hi) = DoubleWordModule::split(prod);
564 self.div.div_rem_4by2(lo, hi).1
565 }
566
567 impl_reduced_binary_pow!($D);
568 }
569 };
570}
571
572macro_rules! collect_impls {
573 ($T:ident, $ns:ident) => {
574 mod $ns {
575 use super::*;
576 use crate::word::$T::*;
577
578 impl_premulinv_1by1_for!(Word);
579 impl_normdiv_2by1_for!(Word, DoubleWord);
580 impl_premulinv_2by1_reducer_for!(Word);
581 impl_normdiv_3by2_for!(Word, DoubleWord);
582 impl_premulinv_3by2_reducer_for!(Word, DoubleWord);
583 }
584 };
585}
586collect_impls!(u8, u8_impl);
587collect_impls!(u16, u16_impl);
588collect_impls!(u32, u32_impl);
589collect_impls!(u64, u64_impl);
590collect_impls!(usize, usize_impl);
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use crate::reduced::tests::ReducedTester;
596 use rand::prelude::*;
597
598 #[test]
599 fn test_mul_inv_1by1() {
600 type Word = u64;
601 let mut rng = StdRng::seed_from_u64(1);
602 for _ in 0..400000 {
603 let d_bits = rng.gen_range(2..=Word::BITS);
604 let max_d = Word::MAX >> (Word::BITS - d_bits);
605 let d = rng.gen_range(max_d / 2 + 1..=max_d);
606 let fast_div = PreMulInv1by1::<Word>::new(d);
607 let n = rng.gen();
608 let (q, r) = fast_div.div_rem(n, d);
609 assert_eq!(q, n / d);
610 assert_eq!(r, n % d);
611
612 if r == 0 {
613 assert_eq!(n.div_exact(d, &fast_div), Some(q));
614 } else {
615 assert_eq!(n.div_exact(d, &fast_div), None);
616 }
617 }
618 }
619
620 #[test]
621 fn test_mul_inv_2by1() {
622 type Word = u64;
623 type Divider = Normalized2by1Divisor<Word>;
624 use crate::word::u64::*;
625
626 let fast_div = Divider::new(Word::MAX);
627 assert_eq!(fast_div.div_rem_2by1(0), (0, 0));
628
629 let mut rng = StdRng::seed_from_u64(1);
630 for _ in 0..200000 {
631 let d = rng.gen_range(Word::MAX / 2 + 1..=Word::MAX);
632 let q = rng.gen();
633 let r = rng.gen_range(0..d);
634 let (a0, a1) = split(wmul(q, d) + extend(r));
635 let fast_div = Divider::new(d);
636 assert_eq!(fast_div.div_rem_2by1(merge(a0, a1)), (q, r));
637 }
638 }
639
640 #[test]
641 fn test_mul_inv_3by2() {
642 type Word = u64;
643 type DoubleWord = u128;
644 type Divider = Normalized3by2Divisor<Word, DoubleWord>;
645 use crate::word::u64::*;
646
647 let d = DoubleWord::MAX;
648 let fast_div = Divider::new(d);
649 assert_eq!(fast_div.div_rem_3by2(0, 0), (0, 0));
650
651 let mut rng = StdRng::seed_from_u64(1);
652 for _ in 0..100000 {
653 let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX);
654 let r = rng.gen_range(0..d);
655 let q = rng.gen();
656
657 let (d0, d1) = split(d);
658 let (r0, r1) = split(r);
659 let (a0, c) = split(wmul(q, d0) + extend(r0));
660 let (a1, a2) = split(wmul(q, d1) + extend(r1) + extend(c));
661 let a12 = merge(a1, a2);
662
663 let fast_div = Divider::new(d);
664 assert_eq!(
665 fast_div.div_rem_3by2(a0, a12),
666 (q, r),
667 "failed at {:?} / {}",
668 (a0, a12),
669 d
670 );
671 }
672 }
673
674 #[test]
675 fn test_mul_inv_4by2() {
676 type Word = u64;
677 type DoubleWord = u128;
678 type Divider = Normalized3by2Divisor<Word, DoubleWord>;
679 use crate::word::u128::*;
680
681 let mut rng = StdRng::seed_from_u64(1);
682 for _ in 0..20000 {
683 let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX);
684 let q = rng.gen();
685 let r = rng.gen_range(0..d);
686 let (a_lo, a_hi) = split(wmul(q, d) + r as DoubleWord);
687 let fast_div = Divider::new(d);
688 assert_eq!(fast_div.div_rem_4by2(a_lo, a_hi), (q, r));
689 }
690 }
691
692 #[test]
693 fn test_2by1_against_modops() {
694 for _ in 0..10 {
695 ReducedTester::<u8>::test_against_modops::<PreMulInv2by1<u8>>(false);
696 ReducedTester::<u16>::test_against_modops::<PreMulInv2by1<u16>>(false);
697 ReducedTester::<u32>::test_against_modops::<PreMulInv2by1<u32>>(false);
698 ReducedTester::<u64>::test_against_modops::<PreMulInv2by1<u64>>(false);
699 ReducedTester::<usize>::test_against_modops::<PreMulInv2by1<usize>>(false);
701 }
702 }
703
704 #[test]
705 fn test_3by2_against_modops() {
706 for _ in 0..10 {
707 ReducedTester::<u16>::test_against_modops::<PreMulInv3by2<u8, u16>>(false);
708 ReducedTester::<u32>::test_against_modops::<PreMulInv3by2<u16, u32>>(false);
709 ReducedTester::<u64>::test_against_modops::<PreMulInv3by2<u32, u64>>(false);
710 ReducedTester::<u128>::test_against_modops::<PreMulInv3by2<u64, u128>>(false);
711 }
712 }
713}