crypto_bigint/uint/
mul.rs

1//! [`Uint`] multiplication operations.
2
3use core::ops::{Mul, MulAssign};
4
5use subtle::CtOption;
6
7use crate::{
8    Checked, CheckedMul, Concat, ConcatMixed, ConstCtOption, Limb, Uint, WideningMul, Wrapping,
9    WrappingMul, Zero,
10};
11
12use self::karatsuba::UintKaratsubaMul;
13
14pub(crate) mod karatsuba;
15
16/// Schoolbook multiplication a.k.a. long multiplication, i.e. the traditional method taught in
17/// schools.
18///
19/// The most efficient method for small numbers.
20const fn schoolbook_multiplication(lhs: &[Limb], rhs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
21    if lhs.len() != lo.len() || rhs.len() != hi.len() {
22        panic!("schoolbook multiplication length mismatch");
23    }
24
25    let mut i = 0;
26    while i < lhs.len() {
27        let mut j = 0;
28        let mut carry = Limb::ZERO;
29        let xi = lhs[i];
30
31        while j < rhs.len() {
32            let k = i + j;
33
34            if k >= lhs.len() {
35                (hi[k - lhs.len()], carry) = hi[k - lhs.len()].mac(xi, rhs[j], carry);
36            } else {
37                (lo[k], carry) = lo[k].mac(xi, rhs[j], carry);
38            }
39
40            j += 1;
41        }
42
43        if i + j >= lhs.len() {
44            hi[i + j - lhs.len()] = carry;
45        } else {
46            lo[i + j] = carry;
47        }
48        i += 1;
49    }
50}
51
52/// Schoolbook method of squaring.
53///
54/// Like schoolbook multiplication, but only considering half of the multiplication grid.
55pub(crate) const fn schoolbook_squaring(limbs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) {
56    // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410
57    //
58    // Permission to relicense the resulting translation as Apache 2.0 + MIT was given
59    // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
60
61    if limbs.len() != lo.len() || lo.len() != hi.len() {
62        panic!("schoolbook squaring length mismatch");
63    }
64
65    let mut i = 1;
66    while i < limbs.len() {
67        let mut j = 0;
68        let mut carry = Limb::ZERO;
69        let xi = limbs[i];
70
71        while j < i {
72            let k = i + j;
73
74            if k >= limbs.len() {
75                (hi[k - limbs.len()], carry) = hi[k - limbs.len()].mac(xi, limbs[j], carry);
76            } else {
77                (lo[k], carry) = lo[k].mac(xi, limbs[j], carry);
78            }
79
80            j += 1;
81        }
82
83        if (2 * i) < limbs.len() {
84            lo[2 * i] = carry;
85        } else {
86            hi[2 * i - limbs.len()] = carry;
87        }
88
89        i += 1;
90    }
91
92    // Double the current result, this accounts for the other half of the multiplication grid.
93    // The top word is empty, so we use a special purpose shl.
94    let mut carry = Limb::ZERO;
95    let mut i = 0;
96    while i < limbs.len() {
97        (lo[i].0, carry) = ((lo[i].0 << 1) | carry.0, lo[i].shr(Limb::BITS - 1));
98        i += 1;
99    }
100
101    let mut i = 0;
102    while i < limbs.len() - 1 {
103        (hi[i].0, carry) = ((hi[i].0 << 1) | carry.0, hi[i].shr(Limb::BITS - 1));
104        i += 1;
105    }
106    hi[limbs.len() - 1] = carry;
107
108    // Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
109    let mut carry = Limb::ZERO;
110    let mut i = 0;
111    while i < limbs.len() {
112        let xi = limbs[i];
113        if (i * 2) < limbs.len() {
114            (lo[i * 2], carry) = lo[i * 2].mac(xi, xi, carry);
115        } else {
116            (hi[i * 2 - limbs.len()], carry) = hi[i * 2 - limbs.len()].mac(xi, xi, carry);
117        }
118
119        if (i * 2 + 1) < limbs.len() {
120            (lo[i * 2 + 1], carry) = lo[i * 2 + 1].overflowing_add(carry);
121        } else {
122            (hi[i * 2 + 1 - limbs.len()], carry) =
123                hi[i * 2 + 1 - limbs.len()].overflowing_add(carry);
124        }
125
126        i += 1;
127    }
128}
129
130impl<const LIMBS: usize> Uint<LIMBS> {
131    /// Multiply `self` by `rhs`, returning a concatenated "wide" result.
132    pub const fn widening_mul<const RHS_LIMBS: usize, const WIDE_LIMBS: usize>(
133        &self,
134        rhs: &Uint<RHS_LIMBS>,
135    ) -> Uint<WIDE_LIMBS>
136    where
137        Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
138    {
139        let (lo, hi) = self.split_mul(rhs);
140        Uint::concat_mixed(&lo, &hi)
141    }
142
143    /// Compute "wide" multiplication as a 2-tuple containing the `(lo, hi)` components of the product, whose sizes
144    /// correspond to the sizes of the operands.
145    pub const fn split_mul<const RHS_LIMBS: usize>(
146        &self,
147        rhs: &Uint<RHS_LIMBS>,
148    ) -> (Self, Uint<RHS_LIMBS>) {
149        if LIMBS == RHS_LIMBS {
150            if LIMBS == 128 {
151                let (a, b) = UintKaratsubaMul::<128>::multiply(&self.limbs, &rhs.limbs);
152                // resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
153                return (a.resize(), b.resize());
154            }
155            if LIMBS == 64 {
156                let (a, b) = UintKaratsubaMul::<64>::multiply(&self.limbs, &rhs.limbs);
157                return (a.resize(), b.resize());
158            }
159            if LIMBS == 32 {
160                let (a, b) = UintKaratsubaMul::<32>::multiply(&self.limbs, &rhs.limbs);
161                return (a.resize(), b.resize());
162            }
163            if LIMBS == 16 {
164                let (a, b) = UintKaratsubaMul::<16>::multiply(&self.limbs, &rhs.limbs);
165                return (a.resize(), b.resize());
166            }
167        }
168
169        uint_mul_limbs(&self.limbs, &rhs.limbs)
170    }
171
172    /// Perform wrapping multiplication, discarding overflow.
173    pub const fn wrapping_mul<const H: usize>(&self, rhs: &Uint<H>) -> Self {
174        self.split_mul(rhs).0
175    }
176
177    /// Perform saturating multiplication, returning `MAX` on overflow.
178    pub const fn saturating_mul<const RHS_LIMBS: usize>(&self, rhs: &Uint<RHS_LIMBS>) -> Self {
179        let (res, overflow) = self.split_mul(rhs);
180        Self::select(&res, &Self::MAX, overflow.is_nonzero())
181    }
182}
183
184/// Squaring operations
185impl<const LIMBS: usize> Uint<LIMBS> {
186    /// Square self, returning a "wide" result in two parts as (lo, hi).
187    pub const fn square_wide(&self) -> (Self, Self) {
188        if LIMBS == 128 {
189            let (a, b) = UintKaratsubaMul::<128>::square(&self.limbs);
190            // resize() should be a no-op, but the compiler can't infer that Uint<LIMBS> is Uint<128>
191            return (a.resize(), b.resize());
192        }
193        if LIMBS == 64 {
194            let (a, b) = UintKaratsubaMul::<64>::square(&self.limbs);
195            return (a.resize(), b.resize());
196        }
197
198        uint_square_limbs(&self.limbs)
199    }
200
201    /// Square self, returning a concatenated "wide" result.
202    pub const fn widening_square<const WIDE_LIMBS: usize>(&self) -> Uint<WIDE_LIMBS>
203    where
204        Self: ConcatMixed<Uint<LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
205    {
206        let (lo, hi) = self.square_wide();
207        Uint::concat_mixed(&lo, &hi)
208    }
209
210    /// Square self, checking that the result fits in the original [`Uint`] size.
211    pub const fn checked_square(&self) -> ConstCtOption<Uint<LIMBS>> {
212        let (lo, hi) = self.square_wide();
213        ConstCtOption::new(lo, Self::eq(&hi, &Self::ZERO))
214    }
215
216    /// Perform wrapping square, discarding overflow.
217    pub const fn wrapping_square(&self) -> Uint<LIMBS> {
218        self.square_wide().0
219    }
220
221    /// Perform saturating squaring, returning `MAX` on overflow.
222    pub const fn saturating_square(&self) -> Self {
223        let (res, overflow) = self.square_wide();
224        Self::select(&res, &Self::MAX, overflow.is_nonzero())
225    }
226}
227
228impl<const LIMBS: usize, const WIDE_LIMBS: usize> Uint<LIMBS>
229where
230    Self: Concat<Output = Uint<WIDE_LIMBS>>,
231{
232    /// Square self, returning a concatenated "wide" result.
233    pub const fn square(&self) -> Uint<WIDE_LIMBS> {
234        let (lo, hi) = self.square_wide();
235        lo.concat(&hi)
236    }
237}
238
239impl<const LIMBS: usize, const RHS_LIMBS: usize> CheckedMul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
240    #[inline]
241    fn checked_mul(&self, rhs: &Uint<RHS_LIMBS>) -> CtOption<Self> {
242        let (lo, hi) = self.split_mul(rhs);
243        CtOption::new(lo, hi.is_zero())
244    }
245}
246
247impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
248    type Output = Uint<LIMBS>;
249
250    fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self {
251        self.mul(&rhs)
252    }
253}
254
255impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
256    type Output = Uint<LIMBS>;
257
258    fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self {
259        (&self).mul(rhs)
260    }
261}
262
263impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for &Uint<LIMBS> {
264    type Output = Uint<LIMBS>;
265
266    fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
267        self.mul(&rhs)
268    }
269}
270
271impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for &Uint<LIMBS> {
272    type Output = Uint<LIMBS>;
273
274    fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
275        self.checked_mul(rhs)
276            .expect("attempted to multiply with overflow")
277    }
278}
279
280impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<Uint<RHS_LIMBS>> for Uint<LIMBS> {
281    fn mul_assign(&mut self, rhs: Uint<RHS_LIMBS>) {
282        *self = self.mul(&rhs)
283    }
284}
285
286impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
287    fn mul_assign(&mut self, rhs: &Uint<RHS_LIMBS>) {
288        *self = self.mul(rhs)
289    }
290}
291
292impl<const LIMBS: usize> MulAssign<Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
293    fn mul_assign(&mut self, other: Wrapping<Uint<LIMBS>>) {
294        *self = *self * other;
295    }
296}
297
298impl<const LIMBS: usize> MulAssign<&Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
299    fn mul_assign(&mut self, other: &Wrapping<Uint<LIMBS>>) {
300        *self = *self * other;
301    }
302}
303
304impl<const LIMBS: usize> MulAssign<Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
305    fn mul_assign(&mut self, other: Checked<Uint<LIMBS>>) {
306        *self = *self * other;
307    }
308}
309
310impl<const LIMBS: usize> MulAssign<&Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
311    fn mul_assign(&mut self, other: &Checked<Uint<LIMBS>>) {
312        *self = *self * other;
313    }
314}
315
316impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
317    WideningMul<Uint<RHS_LIMBS>> for Uint<LIMBS>
318where
319    Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
320{
321    type Output = <Self as ConcatMixed<Uint<RHS_LIMBS>>>::MixedOutput;
322
323    #[inline]
324    fn widening_mul(&self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
325        self.widening_mul(&rhs)
326    }
327}
328
329impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
330    WideningMul<&Uint<RHS_LIMBS>> for Uint<LIMBS>
331where
332    Self: ConcatMixed<Uint<RHS_LIMBS>, MixedOutput = Uint<WIDE_LIMBS>>,
333{
334    type Output = <Self as ConcatMixed<Uint<RHS_LIMBS>>>::MixedOutput;
335
336    #[inline]
337    fn widening_mul(&self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
338        self.widening_mul(rhs)
339    }
340}
341
342impl<const LIMBS: usize> WrappingMul for Uint<LIMBS> {
343    fn wrapping_mul(&self, v: &Self) -> Self {
344        self.wrapping_mul(v)
345    }
346}
347
348/// Helper method to perform schoolbook multiplication
349#[inline]
350pub(crate) const fn uint_mul_limbs<const LIMBS: usize, const RHS_LIMBS: usize>(
351    lhs: &[Limb],
352    rhs: &[Limb],
353) -> (Uint<LIMBS>, Uint<RHS_LIMBS>) {
354    debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS);
355    let mut lo: Uint<LIMBS> = Uint::<LIMBS>::ZERO;
356    let mut hi = Uint::<RHS_LIMBS>::ZERO;
357    schoolbook_multiplication(lhs, rhs, &mut lo.limbs, &mut hi.limbs);
358    (lo, hi)
359}
360
361/// Helper method to perform schoolbook multiplication
362#[inline]
363pub(crate) const fn uint_square_limbs<const LIMBS: usize>(
364    limbs: &[Limb],
365) -> (Uint<LIMBS>, Uint<LIMBS>) {
366    let mut lo = Uint::<LIMBS>::ZERO;
367    let mut hi = Uint::<LIMBS>::ZERO;
368    schoolbook_squaring(limbs, &mut lo.limbs, &mut hi.limbs);
369    (lo, hi)
370}
371
372/// Wrapper function used by `BoxedUint`
373#[cfg(feature = "alloc")]
374pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) {
375    debug_assert_eq!(lhs.len() + rhs.len(), out.len());
376    let (lo, hi) = out.split_at_mut(lhs.len());
377    schoolbook_multiplication(lhs, rhs, lo, hi);
378}
379
380/// Wrapper function used by `BoxedUint`
381#[cfg(feature = "alloc")]
382pub(crate) fn square_limbs(limbs: &[Limb], out: &mut [Limb]) {
383    debug_assert_eq!(limbs.len() * 2, out.len());
384    let (lo, hi) = out.split_at_mut(limbs.len());
385    schoolbook_squaring(limbs, lo, hi);
386}
387
388#[cfg(test)]
389mod tests {
390    use crate::{CheckedMul, ConstChoice, Zero, U128, U192, U256, U64};
391
392    #[test]
393    fn mul_wide_zero_and_one() {
394        assert_eq!(U64::ZERO.split_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
395        assert_eq!(U64::ZERO.split_mul(&U64::ONE), (U64::ZERO, U64::ZERO));
396        assert_eq!(U64::ONE.split_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
397        assert_eq!(U64::ONE.split_mul(&U64::ONE), (U64::ONE, U64::ZERO));
398    }
399
400    #[test]
401    fn mul_wide_lo_only() {
402        let primes: &[u32] = &[3, 5, 17, 257, 65537];
403
404        for &a_int in primes {
405            for &b_int in primes {
406                let (lo, hi) = U64::from_u32(a_int).split_mul(&U64::from_u32(b_int));
407                let expected = U64::from_u64(a_int as u64 * b_int as u64);
408                assert_eq!(lo, expected);
409                assert!(bool::from(hi.is_zero()));
410            }
411        }
412    }
413
414    #[test]
415    fn mul_concat_even() {
416        assert_eq!(U64::ZERO.widening_mul(&U64::MAX), U128::ZERO);
417        assert_eq!(U64::MAX.widening_mul(&U64::ZERO), U128::ZERO);
418        assert_eq!(
419            U64::MAX.widening_mul(&U64::MAX),
420            U128::from_u128(0xfffffffffffffffe_0000000000000001)
421        );
422        assert_eq!(
423            U64::ONE.widening_mul(&U64::MAX),
424            U128::from_u128(0x0000000000000000_ffffffffffffffff)
425        );
426    }
427
428    #[test]
429    fn mul_concat_mixed() {
430        let a = U64::from_u64(0x0011223344556677);
431        let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
432        assert_eq!(a.widening_mul(&b), U192::from(&a).saturating_mul(&b));
433        assert_eq!(b.widening_mul(&a), U192::from(&b).saturating_mul(&a));
434    }
435
436    #[test]
437    fn checked_mul_ok() {
438        let n = U64::from_u32(0xffff_ffff);
439        assert_eq!(
440            n.checked_mul(&n).unwrap(),
441            U64::from_u64(0xffff_fffe_0000_0001)
442        );
443    }
444
445    #[test]
446    fn checked_mul_overflow() {
447        let n = U64::from_u64(0xffff_ffff_ffff_ffff);
448        assert!(bool::from(n.checked_mul(&n).is_none()));
449    }
450
451    #[test]
452    fn saturating_mul_no_overflow() {
453        let n = U64::from_u8(8);
454        assert_eq!(n.saturating_mul(&n), U64::from_u8(64));
455    }
456
457    #[test]
458    fn saturating_mul_overflow() {
459        let a = U64::from(0xffff_ffff_ffff_ffffu64);
460        let b = U64::from(2u8);
461        assert_eq!(a.saturating_mul(&b), U64::MAX);
462    }
463
464    #[test]
465    fn square() {
466        let n = U64::from_u64(0xffff_ffff_ffff_ffff);
467        let (lo, hi) = n.square().split();
468        assert_eq!(lo, U64::from_u64(1));
469        assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
470    }
471
472    #[test]
473    fn square_larger() {
474        let n = U256::MAX;
475        let (lo, hi) = n.square().split();
476        assert_eq!(lo, U256::ONE);
477        assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
478    }
479
480    #[test]
481    fn checked_square() {
482        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
483        let n2 = n.checked_square();
484        assert_eq!(n2.is_some(), ConstChoice::TRUE);
485        let n4 = n2.unwrap().checked_square();
486        assert_eq!(n4.is_none(), ConstChoice::TRUE);
487    }
488
489    #[test]
490    fn wrapping_square() {
491        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
492        let n2 = n.wrapping_square();
493        assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
494        let n4 = n2.wrapping_square();
495        assert_eq!(n4, U256::ZERO);
496    }
497
498    #[test]
499    fn saturating_square() {
500        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
501        let n2 = n.saturating_square();
502        assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
503        let n4 = n2.saturating_square();
504        assert_eq!(n4, U256::MAX);
505    }
506
507    #[cfg(feature = "rand_core")]
508    #[test]
509    fn mul_cmp() {
510        use crate::{Random, U4096};
511        use rand_core::SeedableRng;
512        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
513
514        for _ in 0..50 {
515            let a = U4096::random(&mut rng);
516            assert_eq!(a.split_mul(&a), a.square_wide(), "a = {a}");
517        }
518    }
519}