num_modular/
mersenne.rs

1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4// FIXME: use unchecked operators to speed up calculation (after https://github.com/rust-lang/rust/issues/85122)
5/// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus. It supports `P` up to 127 and `K < 2^(P-1)`
6///
7/// The `P` is limited to 127 so that it's not necessary to check overflow. This limit won't be a problem for any
8/// Mersenne primes within the range of [umax] (i.e. [u128]).
9#[derive(Debug, Clone, Copy)]
10pub struct FixedMersenne<const P: u8, const K: umax>();
11
12// XXX: support other primes as modulo, such as solinas prime, proth prime and support multi precision
13// REF: Handbook of Cryptography 14.3.4
14
15impl<const P: u8, const K: umax> FixedMersenne<P, K> {
16    const BITMASK: umax = (1 << P) - 1;
17    pub const MODULUS: umax = (1 << P) - K;
18
19    // Calculate v % Self::MODULUS, where v is a umax integer
20    const fn reduce_single(v: umax) -> umax {
21        let mut lo = v & Self::BITMASK;
22        let mut hi = v >> P;
23        while hi > 0 {
24            let sum = if K == 1 { hi + lo } else { hi * K + lo };
25            lo = sum & Self::BITMASK;
26            hi = sum >> P;
27        }
28
29        if lo >= Self::MODULUS {
30            lo - Self::MODULUS
31        } else {
32            lo
33        }
34    }
35
36    // Calculate v % Self::MODULUS, where v is a udouble integer
37    fn reduce_double(v: udouble) -> umax {
38        // reduce modulo
39        let mut lo = v.lo & Self::BITMASK;
40        let mut hi = v >> P;
41        while hi.hi > 0 {
42            // first reduce until high bits fit in umax
43            let sum = if K == 1 { hi + lo } else { hi * K + lo };
44            lo = sum.lo & Self::BITMASK;
45            hi = sum >> P;
46        }
47
48        let mut hi = hi.lo;
49        while hi > 0 {
50            // then reduce the smaller high bits
51            let sum = if K == 1 { hi + lo } else { hi * K + lo };
52            lo = sum & Self::BITMASK;
53            hi = sum >> P;
54        }
55
56        if lo >= Self::MODULUS {
57            lo - Self::MODULUS
58        } else {
59            lo
60        }
61    }
62}
63
64impl<const P: u8, const K: umax> Reducer<umax> for FixedMersenne<P, K> {
65    #[inline]
66    fn new(m: &umax) -> Self {
67        assert!(
68            *m == Self::MODULUS,
69            "the given modulus doesn't match with the generic params"
70        );
71        debug_assert!(P <= 127);
72        debug_assert!(K > 0 && K < (2 as umax).pow(P as u32 - 1) && K % 2 == 1);
73        debug_assert!(
74            Self::MODULUS % 3 != 0
75                && Self::MODULUS % 5 != 0
76                && Self::MODULUS % 7 != 0
77                && Self::MODULUS % 11 != 0
78                && Self::MODULUS % 13 != 0
79        ); // error on easy composites
80        Self {}
81    }
82    #[inline]
83    fn transform(&self, target: umax) -> umax {
84        Self::reduce_single(target)
85    }
86    fn check(&self, target: &umax) -> bool {
87        *target < Self::MODULUS
88    }
89    #[inline]
90    fn residue(&self, target: umax) -> umax {
91        target
92    }
93    #[inline]
94    fn modulus(&self) -> umax {
95        Self::MODULUS
96    }
97    #[inline]
98    fn is_zero(&self, target: &umax) -> bool {
99        target == &0
100    }
101
102    #[inline]
103    fn add(&self, lhs: &umax, rhs: &umax) -> umax {
104        let mut sum = lhs + rhs;
105        if sum >= Self::MODULUS {
106            sum -= Self::MODULUS
107        }
108        sum
109    }
110    #[inline]
111    fn sub(&self, lhs: &umax, rhs: &umax) -> umax {
112        if lhs >= rhs {
113            lhs - rhs
114        } else {
115            Self::MODULUS - (rhs - lhs)
116        }
117    }
118    #[inline]
119    fn dbl(&self, target: umax) -> umax {
120        self.add(&target, &target)
121    }
122    #[inline]
123    fn neg(&self, target: umax) -> umax {
124        if target == 0 {
125            0
126        } else {
127            Self::MODULUS - target
128        }
129    }
130    #[inline]
131    fn mul(&self, lhs: &umax, rhs: &umax) -> umax {
132        if (P as u32) < (umax::BITS / 2) {
133            Self::reduce_single(lhs * rhs)
134        } else {
135            Self::reduce_double(udouble::widening_mul(*lhs, *rhs))
136        }
137    }
138    #[inline]
139    fn inv(&self, target: umax) -> Option<umax> {
140        if (P as u32) < usize::BITS {
141            (target as usize)
142                .invm(&(Self::MODULUS as usize))
143                .map(|v| v as umax)
144        } else {
145            target.invm(&Self::MODULUS)
146        }
147    }
148    #[inline]
149    fn sqr(&self, target: umax) -> umax {
150        if (P as u32) < (umax::BITS / 2) {
151            Self::reduce_single(target * target)
152        } else {
153            Self::reduce_double(udouble::widening_square(target))
154        }
155    }
156
157    impl_reduced_binary_pow!(umax);
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::{ModularCoreOps, ModularPow};
164    use rand::random;
165
166    type M1 = FixedMersenne<31, 1>;
167    type M2 = FixedMersenne<61, 1>;
168    type M3 = FixedMersenne<127, 1>;
169    type M4 = FixedMersenne<32, 5>;
170    type M5 = FixedMersenne<56, 5>;
171    type M6 = FixedMersenne<122, 3>;
172
173    const NRANDOM: u32 = 10;
174
175    #[test]
176    fn creation_test() {
177        // random creation test
178        for _ in 0..NRANDOM {
179            let a = random::<umax>();
180
181            const P1: umax = (1 << 31) - 1;
182            let m1 = M1::new(&P1);
183            assert_eq!(m1.residue(m1.transform(a)), a % P1);
184            const P2: umax = (1 << 61) - 1;
185            let m2 = M2::new(&P2);
186            assert_eq!(m2.residue(m2.transform(a)), a % P2);
187            const P3: umax = (1 << 127) - 1;
188            let m3 = M3::new(&P3);
189            assert_eq!(m3.residue(m3.transform(a)), a % P3);
190            const P4: umax = (1 << 32) - 5;
191            let m4 = M4::new(&P4);
192            assert_eq!(m4.residue(m4.transform(a)), a % P4);
193            const P5: umax = (1 << 56) - 5;
194            let m5 = M5::new(&P5);
195            assert_eq!(m5.residue(m5.transform(a)), a % P5);
196            const P6: umax = (1 << 122) - 3;
197            let m6 = M6::new(&P6);
198            assert_eq!(m6.residue(m6.transform(a)), a % P6);
199        }
200    }
201
202    #[test]
203    fn test_against_modops() {
204        macro_rules! tests_for {
205            ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
206                const P: umax = <$M>::MODULUS;
207                let r = <$M>::new(&P);
208                let am = r.transform($a);
209                let bm = r.transform($b);
210                assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
211                assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
212                assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
213                assert_eq!(r.neg(am), $a.negm(&P));
214                assert_eq!(r.inv(am), $a.invm(&P));
215                assert_eq!(r.dbl(am), $a.dblm(&P));
216                assert_eq!(r.sqr(am), $a.sqm(&P));
217                assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
218            })*);
219        }
220
221        for _ in 0..NRANDOM {
222            let (a, b) = (random::<u128>(), random::<u128>());
223            let e = random::<u8>() as umax;
224            tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
225        }
226    }
227}