1use crate::reduced::impl_reduced_binary_pow;
2use crate::{udouble, umax, ModularUnaryOps, Reducer};
3
4#[derive(Debug, Clone, Copy)]
10pub struct FixedMersenne<const P: u8, const K: umax>();
11
12impl<const P: u8, const K: umax> FixedMersenne<P, K> {
16 const BITMASK: umax = (1 << P) - 1;
17 pub const MODULUS: umax = (1 << P) - K;
18
19 const fn reduce_single(v: umax) -> umax {
21 let mut lo = v & Self::BITMASK;
22 let mut hi = v >> P;
23 while hi > 0 {
24 let sum = if K == 1 { hi + lo } else { hi * K + lo };
25 lo = sum & Self::BITMASK;
26 hi = sum >> P;
27 }
28
29 if lo >= Self::MODULUS {
30 lo - Self::MODULUS
31 } else {
32 lo
33 }
34 }
35
36 fn reduce_double(v: udouble) -> umax {
38 let mut lo = v.lo & Self::BITMASK;
40 let mut hi = v >> P;
41 while hi.hi > 0 {
42 let sum = if K == 1 { hi + lo } else { hi * K + lo };
44 lo = sum.lo & Self::BITMASK;
45 hi = sum >> P;
46 }
47
48 let mut hi = hi.lo;
49 while hi > 0 {
50 let sum = if K == 1 { hi + lo } else { hi * K + lo };
52 lo = sum & Self::BITMASK;
53 hi = sum >> P;
54 }
55
56 if lo >= Self::MODULUS {
57 lo - Self::MODULUS
58 } else {
59 lo
60 }
61 }
62}
63
64impl<const P: u8, const K: umax> Reducer<umax> for FixedMersenne<P, K> {
65 #[inline]
66 fn new(m: &umax) -> Self {
67 assert!(
68 *m == Self::MODULUS,
69 "the given modulus doesn't match with the generic params"
70 );
71 debug_assert!(P <= 127);
72 debug_assert!(K > 0 && K < (2 as umax).pow(P as u32 - 1) && K % 2 == 1);
73 debug_assert!(
74 Self::MODULUS % 3 != 0
75 && Self::MODULUS % 5 != 0
76 && Self::MODULUS % 7 != 0
77 && Self::MODULUS % 11 != 0
78 && Self::MODULUS % 13 != 0
79 ); Self {}
81 }
82 #[inline]
83 fn transform(&self, target: umax) -> umax {
84 Self::reduce_single(target)
85 }
86 fn check(&self, target: &umax) -> bool {
87 *target < Self::MODULUS
88 }
89 #[inline]
90 fn residue(&self, target: umax) -> umax {
91 target
92 }
93 #[inline]
94 fn modulus(&self) -> umax {
95 Self::MODULUS
96 }
97 #[inline]
98 fn is_zero(&self, target: &umax) -> bool {
99 target == &0
100 }
101
102 #[inline]
103 fn add(&self, lhs: &umax, rhs: &umax) -> umax {
104 let mut sum = lhs + rhs;
105 if sum >= Self::MODULUS {
106 sum -= Self::MODULUS
107 }
108 sum
109 }
110 #[inline]
111 fn sub(&self, lhs: &umax, rhs: &umax) -> umax {
112 if lhs >= rhs {
113 lhs - rhs
114 } else {
115 Self::MODULUS - (rhs - lhs)
116 }
117 }
118 #[inline]
119 fn dbl(&self, target: umax) -> umax {
120 self.add(&target, &target)
121 }
122 #[inline]
123 fn neg(&self, target: umax) -> umax {
124 if target == 0 {
125 0
126 } else {
127 Self::MODULUS - target
128 }
129 }
130 #[inline]
131 fn mul(&self, lhs: &umax, rhs: &umax) -> umax {
132 if (P as u32) < (umax::BITS / 2) {
133 Self::reduce_single(lhs * rhs)
134 } else {
135 Self::reduce_double(udouble::widening_mul(*lhs, *rhs))
136 }
137 }
138 #[inline]
139 fn inv(&self, target: umax) -> Option<umax> {
140 if (P as u32) < usize::BITS {
141 (target as usize)
142 .invm(&(Self::MODULUS as usize))
143 .map(|v| v as umax)
144 } else {
145 target.invm(&Self::MODULUS)
146 }
147 }
148 #[inline]
149 fn sqr(&self, target: umax) -> umax {
150 if (P as u32) < (umax::BITS / 2) {
151 Self::reduce_single(target * target)
152 } else {
153 Self::reduce_double(udouble::widening_square(target))
154 }
155 }
156
157 impl_reduced_binary_pow!(umax);
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::{ModularCoreOps, ModularPow};
164 use rand::random;
165
166 type M1 = FixedMersenne<31, 1>;
167 type M2 = FixedMersenne<61, 1>;
168 type M3 = FixedMersenne<127, 1>;
169 type M4 = FixedMersenne<32, 5>;
170 type M5 = FixedMersenne<56, 5>;
171 type M6 = FixedMersenne<122, 3>;
172
173 const NRANDOM: u32 = 10;
174
175 #[test]
176 fn creation_test() {
177 for _ in 0..NRANDOM {
179 let a = random::<umax>();
180
181 const P1: umax = (1 << 31) - 1;
182 let m1 = M1::new(&P1);
183 assert_eq!(m1.residue(m1.transform(a)), a % P1);
184 const P2: umax = (1 << 61) - 1;
185 let m2 = M2::new(&P2);
186 assert_eq!(m2.residue(m2.transform(a)), a % P2);
187 const P3: umax = (1 << 127) - 1;
188 let m3 = M3::new(&P3);
189 assert_eq!(m3.residue(m3.transform(a)), a % P3);
190 const P4: umax = (1 << 32) - 5;
191 let m4 = M4::new(&P4);
192 assert_eq!(m4.residue(m4.transform(a)), a % P4);
193 const P5: umax = (1 << 56) - 5;
194 let m5 = M5::new(&P5);
195 assert_eq!(m5.residue(m5.transform(a)), a % P5);
196 const P6: umax = (1 << 122) - 3;
197 let m6 = M6::new(&P6);
198 assert_eq!(m6.residue(m6.transform(a)), a % P6);
199 }
200 }
201
202 #[test]
203 fn test_against_modops() {
204 macro_rules! tests_for {
205 ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({
206 const P: umax = <$M>::MODULUS;
207 let r = <$M>::new(&P);
208 let am = r.transform($a);
209 let bm = r.transform($b);
210 assert_eq!(r.add(&am, &bm), $a.addm($b, &P));
211 assert_eq!(r.sub(&am, &bm), $a.subm($b, &P));
212 assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P));
213 assert_eq!(r.neg(am), $a.negm(&P));
214 assert_eq!(r.inv(am), $a.invm(&P));
215 assert_eq!(r.dbl(am), $a.dblm(&P));
216 assert_eq!(r.sqr(am), $a.sqm(&P));
217 assert_eq!(r.pow(am, &$e), $a.powm($e, &P));
218 })*);
219 }
220
221 for _ in 0..NRANDOM {
222 let (a, b) = (random::<u128>(), random::<u128>());
223 let e = random::<u8>() as umax;
224 tests_for!(a, b, e; M1 M2 M3 M4 M5 M6);
225 }
226 }
227}