1use core::ops::{Mul, MulAssign};
4
5use subtle::CtOption;
6
7use crate::{
8 Checked, CheckedMul, Concat, ConcatMixed, ConstCtOption, Limb, Uint, WideningMul, Wrapping,
9 WrappingMul, Zero,
10};
11
12use self::karatsuba::UintKaratsubaMul;
13
14pub(crate) mod karatsuba;
15
16const fn schoolbook_multiplication(lhs: &[Limb], rhs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
21 if lhs.len() != lo.len() || rhs.len() != hi.len() {
22 panic!("schoolbook multiplication length mismatch");
23 }
24
25 let mut i = 0;
26 while i < lhs.len() {
27 let mut j = 0;
28 let mut carry = Limb::ZERO;
29 let xi = lhs[i];
30
31 while j < rhs.len() {
32 let k = i + j;
33
34 if k >= lhs.len() {
35 (hi[k - lhs.len()], carry) = hi[k - lhs.len()].mac(xi, rhs[j], carry);
36 } else {
37 (lo[k], carry) = lo[k].mac(xi, rhs[j], carry);
38 }
39
40 j += 1;
41 }
42
43 if i + j >= lhs.len() {
44 hi[i + j - lhs.len()] = carry;
45 } else {
46 lo[i + j] = carry;
47 }
48 i += 1;
49 }
50}
51
52pub(crate) const fn schoolbook_squaring(limbs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
56 if limbs.len() != lo.len() || lo.len() != hi.len() {
62 panic!("schoolbook squaring length mismatch");
63 }
64
65 let mut i = 1;
66 while i < limbs.len() {
67 let mut j = 0;
68 let mut carry = Limb::ZERO;
69 let xi = limbs[i];
70
71 while j < i {
72 let k = i + j;
73
74 if k >= limbs.len() {
75 (hi[k - limbs.len()], carry) = hi[k - limbs.len()].mac(xi, limbs[j], carry);
76 } else {
77 (lo[k], carry) = lo[k].mac(xi, limbs[j], carry);
78 }
79
80 j += 1;
81 }
82
83 if (2 * i) < limbs.len() {
84 lo[2 * i] = carry;
85 } else {
86 hi[2 * i - limbs.len()] = carry;
87 }
88
89 i += 1;
90 }
91
92 let mut carry = Limb::ZERO;
95 let mut i = 0;
96 while i < limbs.len() {
97 (lo[i].0, carry) = ((lo[i].0 << 1) | carry.0, lo[i].shr(Limb::BITS - 1));
98 i += 1;
99 }
100
101 let mut i = 0;
102 while i < limbs.len() - 1 {
103 (hi[i].0, carry) = ((hi[i].0 << 1) | carry.0, hi[i].shr(Limb::BITS - 1));
104 i += 1;
105 }
106 hi[limbs.len() - 1] = carry;
107
108 let mut carry = Limb::ZERO;
110 let mut i = 0;
111 while i < limbs.len() {
112 let xi = limbs[i];
113 if (i * 2) < limbs.len() {
114 (lo[i * 2], carry) = lo[i * 2].mac(xi, xi, carry);
115 } else {
116 (hi[i * 2 - limbs.len()], carry) = hi[i * 2 - limbs.len()].mac(xi, xi, carry);
117 }
118
119 if (i * 2 + 1) < limbs.len() {
120 (lo[i * 2 + 1], carry) = lo[i * 2 + 1].overflowing_add(carry);
121 } else {
122 (hi[i * 2 + 1 - limbs.len()], carry) =
123 hi[i * 2 + 1 - limbs.len()].overflowing_add(carry);
124 }
125
126 i += 1;
127 }
128}
129
130impl<const LIMBS: usize> Uint<LIMBS> {
131 pub const fn widening_mul<const RHS_LIMBS: usize, const WIDE_LIMBS: usize>(
133 &self,
134 rhs: &Uint<RHS_LIMBS>,
135 ) -> Uint<WIDE_LIMBS>
136 where
137 Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
138 {
139 let (lo, hi) = self.split_mul(rhs);
140 Uint::concat_mixed(&lo, &hi)
141 }
142
143 pub const fn split_mul<const RHS_LIMBS: usize>(
146 &self,
147 rhs: &Uint<RHS_LIMBS>,
148 ) -> (Self, Uint<RHS_LIMBS>) {
149 if LIMBS == RHS_LIMBS {
150 if LIMBS == 128 {
151 let (a, b) = UintKaratsubaMul::<128>::multiply(&self.limbs, &rhs.limbs);
152 return (a.resize(), b.resize());
154 }
155 if LIMBS == 64 {
156 let (a, b) = UintKaratsubaMul::<64>::multiply(&self.limbs, &rhs.limbs);
157 return (a.resize(), b.resize());
158 }
159 if LIMBS == 32 {
160 let (a, b) = UintKaratsubaMul::<32>::multiply(&self.limbs, &rhs.limbs);
161 return (a.resize(), b.resize());
162 }
163 if LIMBS == 16 {
164 let (a, b) = UintKaratsubaMul::<16>::multiply(&self.limbs, &rhs.limbs);
165 return (a.resize(), b.resize());
166 }
167 }
168
169 uint_mul_limbs(&self.limbs, &rhs.limbs)
170 }
171
172 pub const fn wrapping_mul<const H: usize>(&self, rhs: &Uint<H>) -> Self {
174 self.split_mul(rhs).0
175 }
176
177 pub const fn saturating_mul<const RHS_LIMBS: usize>(&self, rhs: &Uint<RHS_LIMBS>) -> Self {
179 let (res, overflow) = self.split_mul(rhs);
180 Self::select(&res, &Self::MAX, overflow.is_nonzero())
181 }
182}
183
184impl<const LIMBS: usize> Uint<LIMBS> {
186 pub const fn square_wide(&self) -> (Self, Self) {
188 if LIMBS == 128 {
189 let (a, b) = UintKaratsubaMul::<128>::square(&self.limbs);
190 return (a.resize(), b.resize());
192 }
193 if LIMBS == 64 {
194 let (a, b) = UintKaratsubaMul::<64>::square(&self.limbs);
195 return (a.resize(), b.resize());
196 }
197
198 uint_square_limbs(&self.limbs)
199 }
200
201 pub const fn widening_square<const WIDE_LIMBS: usize>(&self) -> Uint<WIDE_LIMBS>
203 where
204 Self: ConcatMixed<Uint<LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
205 {
206 let (lo, hi) = self.square_wide();
207 Uint::concat_mixed(&lo, &hi)
208 }
209
210 pub const fn checked_square(&self) -> ConstCtOption<Uint<LIMBS>> {
212 let (lo, hi) = self.square_wide();
213 ConstCtOption::new(lo, Self::eq(&hi, &Self::ZERO))
214 }
215
216 pub const fn wrapping_square(&self) -> Uint<LIMBS> {
218 self.square_wide().0
219 }
220
221 pub const fn saturating_square(&self) -> Self {
223 let (res, overflow) = self.square_wide();
224 Self::select(&res, &Self::MAX, overflow.is_nonzero())
225 }
226}
227
228impl<const LIMBS: usize, const WIDE_LIMBS: usize> Uint<LIMBS>
229where
230 Self: Concat<Output = Uint<WIDE_LIMBS>>,
231{
232 pub const fn square(&self) -> Uint<WIDE_LIMBS> {
234 let (lo, hi) = self.square_wide();
235 lo.concat(&hi)
236 }
237}
238
239impl<const LIMBS: usize, const RHS_LIMBS: usize> CheckedMul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
240 #[inline]
241 fn checked_mul(&self, rhs: &Uint<RHS_LIMBS>) -> CtOption<Self> {
242 let (lo, hi) = self.split_mul(rhs);
243 CtOption::new(lo, hi.is_zero())
244 }
245}
246
247impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
248 type Output = Uint<LIMBS>;
249
250 fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self {
251 self.mul(&rhs)
252 }
253}
254
255impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
256 type Output = Uint<LIMBS>;
257
258 fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self {
259 (&self).mul(rhs)
260 }
261}
262
263impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for &Uint<LIMBS> {
264 type Output = Uint<LIMBS>;
265
266 fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
267 self.mul(&rhs)
268 }
269}
270
271impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for &Uint<LIMBS> {
272 type Output = Uint<LIMBS>;
273
274 fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
275 self.checked_mul(rhs)
276 .expect("attempted to multiply with overflow")
277 }
278}
279
280impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<Uint<RHS_LIMBS>> for Uint<LIMBS> {
281 fn mul_assign(&mut self, rhs: Uint<RHS_LIMBS>) {
282 *self = self.mul(&rhs)
283 }
284}
285
286impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
287 fn mul_assign(&mut self, rhs: &Uint<RHS_LIMBS>) {
288 *self = self.mul(rhs)
289 }
290}
291
292impl<const LIMBS: usize> MulAssign<Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
293 fn mul_assign(&mut self, other: Wrapping<Uint<LIMBS>>) {
294 *self = *self * other;
295 }
296}
297
298impl<const LIMBS: usize> MulAssign<&Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
299 fn mul_assign(&mut self, other: &Wrapping<Uint<LIMBS>>) {
300 *self = *self * other;
301 }
302}
303
304impl<const LIMBS: usize> MulAssign<Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
305 fn mul_assign(&mut self, other: Checked<Uint<LIMBS>>) {
306 *self = *self * other;
307 }
308}
309
310impl<const LIMBS: usize> MulAssign<&Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
311 fn mul_assign(&mut self, other: &Checked<Uint<LIMBS>>) {
312 *self = *self * other;
313 }
314}
315
316impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
317 WideningMul<Uint<RHS_LIMBS>> for Uint<LIMBS>
318where
319 Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
320{
321 type Output = <Self as ConcatMixed<Uint<RHS_LIMBS>>>::MixedOutput;
322
323 #[inline]
324 fn widening_mul(&self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
325 self.widening_mul(&rhs)
326 }
327}
328
329impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
330 WideningMul<&Uint<RHS_LIMBS>> for Uint<LIMBS>
331where
332 Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
333{
334 type Output = <Self as ConcatMixed<Uint<RHS_LIMBS>>>::MixedOutput;
335
336 #[inline]
337 fn widening_mul(&self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
338 self.widening_mul(rhs)
339 }
340}
341
342impl<const LIMBS: usize> WrappingMul for Uint<LIMBS> {
343 fn wrapping_mul(&self, v: &Self) -> Self {
344 self.wrapping_mul(v)
345 }
346}
347
348#[inline]
350pub(crate) const fn uint_mul_limbs<const LIMBS: usize, const RHS_LIMBS: usize>(
351 lhs: &[Limb],
352 rhs: &[Limb],
353) -> (Uint<LIMBS>, Uint<RHS_LIMBS>) {
354 debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS);
355 let mut lo: Uint<LIMBS> = Uint::<LIMBS>::ZERO;
356 let mut hi = Uint::<RHS_LIMBS>::ZERO;
357 schoolbook_multiplication(lhs, rhs, &mut lo.limbs, &mut hi.limbs);
358 (lo, hi)
359}
360
361#[inline]
363pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
364 limbs: &[Limb],
365) -> (Uint<LIMBS>, Uint<LIMBS>) {
366 let mut lo = Uint::<LIMBS>::ZERO;
367 let mut hi = Uint::<LIMBS>::ZERO;
368 schoolbook_squaring(limbs, &mut lo.limbs, &mut hi.limbs);
369 (lo, hi)
370}
371
372#[cfg(feature = "alloc")]
374pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) {
375 debug_assert_eq!(lhs.len() + rhs.len(), out.len());
376 let (lo, hi) = out.split_at_mut(lhs.len());
377 schoolbook_multiplication(lhs, rhs, lo, hi);
378}
379
380#[cfg(feature = "alloc")]
382pub(crate) fn square_limbs(limbs: &[Limb], out: &mut [Limb]) {
383 debug_assert_eq!(limbs.len() * 2, out.len());
384 let (lo, hi) = out.split_at_mut(limbs.len());
385 schoolbook_squaring(limbs, lo, hi);
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::{CheckedMul, ConstChoice, Zero, U128, U192, U256, U64};
391
392 #[test]
393 fn mul_wide_zero_and_one() {
394 assert_eq!(U64::ZERO.split_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
395 assert_eq!(U64::ZERO.split_mul(&U64::ONE), (U64::ZERO, U64::ZERO));
396 assert_eq!(U64::ONE.split_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
397 assert_eq!(U64::ONE.split_mul(&U64::ONE), (U64::ONE, U64::ZERO));
398 }
399
400 #[test]
401 fn mul_wide_lo_only() {
402 let primes: &[u32] = &[3, 5, 17, 257, 65537];
403
404 for &a_int in primes {
405 for &b_int in primes {
406 let (lo, hi) = U64::from_u32(a_int).split_mul(&U64::from_u32(b_int));
407 let expected = U64::from_u64(a_int as u64 * b_int as u64);
408 assert_eq!(lo, expected);
409 assert!(bool::from(hi.is_zero()));
410 }
411 }
412 }
413
414 #[test]
415 fn mul_concat_even() {
416 assert_eq!(U64::ZERO.widening_mul(&U64::MAX), U128::ZERO);
417 assert_eq!(U64::MAX.widening_mul(&U64::ZERO), U128::ZERO);
418 assert_eq!(
419 U64::MAX.widening_mul(&U64::MAX),
420 U128::from_u128(0xfffffffffffffffe_0000000000000001)
421 );
422 assert_eq!(
423 U64::ONE.widening_mul(&U64::MAX),
424 U128::from_u128(0x0000000000000000_ffffffffffffffff)
425 );
426 }
427
428 #[test]
429 fn mul_concat_mixed() {
430 let a = U64::from_u64(0x0011223344556677);
431 let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
432 assert_eq!(a.widening_mul(&b), U192::from(&a).saturating_mul(&b));
433 assert_eq!(b.widening_mul(&a), U192::from(&b).saturating_mul(&a));
434 }
435
436 #[test]
437 fn checked_mul_ok() {
438 let n = U64::from_u32(0xffff_ffff);
439 assert_eq!(
440 n.checked_mul(&n).unwrap(),
441 U64::from_u64(0xffff_fffe_0000_0001)
442 );
443 }
444
445 #[test]
446 fn checked_mul_overflow() {
447 let n = U64::from_u64(0xffff_ffff_ffff_ffff);
448 assert!(bool::from(n.checked_mul(&n).is_none()));
449 }
450
451 #[test]
452 fn saturating_mul_no_overflow() {
453 let n = U64::from_u8(8);
454 assert_eq!(n.saturating_mul(&n), U64::from_u8(64));
455 }
456
457 #[test]
458 fn saturating_mul_overflow() {
459 let a = U64::from(0xffff_ffff_ffff_ffffu64);
460 let b = U64::from(2u8);
461 assert_eq!(a.saturating_mul(&b), U64::MAX);
462 }
463
464 #[test]
465 fn square() {
466 let n = U64::from_u64(0xffff_ffff_ffff_ffff);
467 let (lo, hi) = n.square().split();
468 assert_eq!(lo, U64::from_u64(1));
469 assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
470 }
471
472 #[test]
473 fn square_larger() {
474 let n = U256::MAX;
475 let (lo, hi) = n.square().split();
476 assert_eq!(lo, U256::ONE);
477 assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
478 }
479
480 #[test]
481 fn checked_square() {
482 let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
483 let n2 = n.checked_square();
484 assert_eq!(n2.is_some(), ConstChoice::TRUE);
485 let n4 = n2.unwrap().checked_square();
486 assert_eq!(n4.is_none(), ConstChoice::TRUE);
487 }
488
489 #[test]
490 fn wrapping_square() {
491 let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
492 let n2 = n.wrapping_square();
493 assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
494 let n4 = n2.wrapping_square();
495 assert_eq!(n4, U256::ZERO);
496 }
497
498 #[test]
499 fn saturating_square() {
500 let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
501 let n2 = n.saturating_square();
502 assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
503 let n4 = n2.saturating_square();
504 assert_eq!(n4, U256::MAX);
505 }
506
507 #[cfg(feature = "rand_core")]
508 #[test]
509 fn mul_cmp() {
510 use crate::{Random, U4096};
511 use rand_core::SeedableRng;
512 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
513
514 for _ in 0..50 {
515 let a = U4096::random(&mut rng);
516 assert_eq!(a.split_mul(&a), a.square_wide(), "a = {a}");
517 }
518 }
519}