num_bigint/biguint/
multiplication.rs

1use super::addition::{__add2, add2};
2use super::subtraction::sub2;
3use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
4
5use crate::big_digit::{self, BigDigit, DoubleBigDigit};
6use crate::Sign::{self, Minus, NoSign, Plus};
7use crate::{BigInt, UsizePromotion};
8
9use core::cmp::Ordering;
10use core::iter::Product;
11use core::ops::{Mul, MulAssign};
12use num_traits::{CheckedMul, FromPrimitive, One, Zero};
13
14#[inline]
15pub(super) fn mac_with_carry(
16    a: BigDigit,
17    b: BigDigit,
18    c: BigDigit,
19    acc: &mut DoubleBigDigit,
20) -> BigDigit {
21    *acc += DoubleBigDigit::from(a);
22    *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
23    let lo = *acc as BigDigit;
24    *acc >>= big_digit::BITS;
25    lo
26}
27
28#[inline]
29fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
30    *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
31    let lo = *acc as BigDigit;
32    *acc >>= big_digit::BITS;
33    lo
34}
35
36/// Three argument multiply accumulate:
37/// acc += b * c
38fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
39    if c == 0 {
40        return;
41    }
42
43    let mut carry = 0;
44    let (a_lo, a_hi) = acc.split_at_mut(b.len());
45
46    for (a, &b) in a_lo.iter_mut().zip(b) {
47        *a = mac_with_carry(*a, b, c, &mut carry);
48    }
49
50    let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
51
52    let final_carry = if carry_hi == 0 {
53        __add2(a_hi, &[carry_lo])
54    } else {
55        __add2(a_hi, &[carry_hi, carry_lo])
56    };
57    assert_eq!(final_carry, 0, "carry overflow during multiplication!");
58}
59
60fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
61    BigInt::from(biguint_from_vec(slice.to_vec()))
62}
63
64/// Three argument multiply accumulate:
65/// acc += b * c
66#[allow(clippy::many_single_char_names)]
67fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
68    // Least-significant zeros have no effect on the output.
69    if let Some(&0) = b.first() {
70        if let Some(nz) = b.iter().position(|&d| d != 0) {
71            b = &b[nz..];
72            acc = &mut acc[nz..];
73        } else {
74            return;
75        }
76    }
77    if let Some(&0) = c.first() {
78        if let Some(nz) = c.iter().position(|&d| d != 0) {
79            c = &c[nz..];
80            acc = &mut acc[nz..];
81        } else {
82            return;
83        }
84    }
85
86    let acc = acc;
87    let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
88
89    // We use four algorithms for different input sizes.
90    //
91    // - For small inputs, long multiplication is fastest.
92    // - If y is at least least twice as long as x, split using Half-Karatsuba.
93    // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
94    //   to avoid unnecessary allocations for intermediate values.
95    // - For the largest inputs we use Toom-3, which better optimizes the
96    //   number of operations, but uses more temporary allocations.
97    //
98    // The thresholds are somewhat arbitrary, chosen by evaluating the results
99    // of `cargo bench --bench bigint multiply`.
100
101    if x.len() <= 32 {
102        // Long multiplication:
103        for (i, xi) in x.iter().enumerate() {
104            mac_digit(&mut acc[i..], y, *xi);
105        }
106    } else if x.len() * 2 <= y.len() {
107        // Karatsuba Multiplication for factors with significant length disparity.
108        //
109        // The Half-Karatsuba Multiplication Algorithm is a specialized case of
110        // the normal Karatsuba multiplication algorithm, designed for the scenario
111        // where y has at least twice as many base digits as x.
112        //
113        // In this case y (the longer input) is split into high2 and low2,
114        // at m2 (half the length of y) and x (the shorter input),
115        // is used directly without splitting.
116        //
117        // The algorithm then proceeds as follows:
118        //
119        // 1. Compute the product z0 = x * low2.
120        // 2. Compute the product temp = x * high2.
121        // 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
122        // 4. Add temp and z0 to obtain the final result.
123        //
124        // Proof:
125        //
126        // The algorithm can be derived from the original Karatsuba algorithm by
127        // simplifying the formula when the shorter factor x is not split into
128        // high and low parts, as shown below.
129        //
130        // Original Karatsuba formula:
131        //
132        //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
133        //
134        // Substitutions:
135        //
136        //     low1 = x
137        //     high1 = 0
138        //
139        // Applying substitutions:
140        //
141        //     z0 = (low1 * low2)
142        //        = (x * low2)
143        //
144        //     z1 = ((low1 + high1) * (low2 + high2))
145        //        = ((x + 0) * (low2 + high2))
146        //        = (x * low2) + (x * high2)
147        //
148        //     z2 = (high1 * high2)
149        //        = (0 * high2)
150        //        = 0
151        //
152        // Simplified using the above substitutions:
153        //
154        //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
155        //            = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
156        //            = ((z1 - z0) * NBASE ^ m2) + z0
157        //            = ((z1 - z0) * NBASE ^ m2) + z0
158        //            = (x * high2) * NBASE ^ m2 + z0
159        let m2 = y.len() / 2;
160        let (low2, high2) = y.split_at(m2);
161
162        // (x * high2) * NBASE ^ m2 + z0
163        mac3(acc, x, low2);
164        mac3(&mut acc[m2..], x, high2);
165    } else if x.len() <= 256 {
166        // Karatsuba multiplication:
167        //
168        // The idea is that we break x and y up into two smaller numbers that each have about half
169        // as many digits, like so (note that multiplying by b is just a shift):
170        //
171        // x = x0 + x1 * b
172        // y = y0 + y1 * b
173        //
174        // With some algebra, we can compute x * y with three smaller products, where the inputs to
175        // each of the smaller products have only about half as many digits as x and y:
176        //
177        // x * y = (x0 + x1 * b) * (y0 + y1 * b)
178        //
179        // x * y = x0 * y0
180        //       + x0 * y1 * b
181        //       + x1 * y0 * b
182        //       + x1 * y1 * b^2
183        //
184        // Let p0 = x0 * y0 and p2 = x1 * y1:
185        //
186        // x * y = p0
187        //       + (x0 * y1 + x1 * y0) * b
188        //       + p2 * b^2
189        //
190        // The real trick is that middle term:
191        //
192        //         x0 * y1 + x1 * y0
193        //
194        //       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
195        //
196        //       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
197        //
198        // Now we complete the square:
199        //
200        //       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
201        //
202        //       = -((x1 - x0) * (y1 - y0)) + p0 + p2
203        //
204        // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
205        //
206        // x * y = p0
207        //       + (p0 + p2 - p1) * b
208        //       + p2 * b^2
209        //
210        // Where the three intermediate products are:
211        //
212        // p0 = x0 * y0
213        // p1 = (x1 - x0) * (y1 - y0)
214        // p2 = x1 * y1
215        //
216        // In doing the computation, we take great care to avoid unnecessary temporary variables
217        // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
218        // bit so we can use the same temporary variable for all the intermediate products:
219        //
220        // x * y = p2 * b^2 + p2 * b
221        //       + p0 * b + p0
222        //       - p1 * b
223        //
224        // The other trick we use is instead of doing explicit shifts, we slice acc at the
225        // appropriate offset when doing the add.
226
227        // When x is smaller than y, it's significantly faster to pick b such that x is split in
228        // half, not y:
229        let b = x.len() / 2;
230        let (x0, x1) = x.split_at(b);
231        let (y0, y1) = y.split_at(b);
232
233        // We reuse the same BigUint for all the intermediate multiplies and have to size p
234        // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
235        let len = x1.len() + y1.len() + 1;
236        let mut p = BigUint { data: vec![0; len] };
237
238        // p2 = x1 * y1
239        mac3(&mut p.data, x1, y1);
240
241        // Not required, but the adds go faster if we drop any unneeded 0s from the end:
242        p.normalize();
243
244        add2(&mut acc[b..], &p.data);
245        add2(&mut acc[b * 2..], &p.data);
246
247        // Zero out p before the next multiply:
248        p.data.truncate(0);
249        p.data.resize(len, 0);
250
251        // p0 = x0 * y0
252        mac3(&mut p.data, x0, y0);
253        p.normalize();
254
255        add2(acc, &p.data);
256        add2(&mut acc[b..], &p.data);
257
258        // p1 = (x1 - x0) * (y1 - y0)
259        // We do this one last, since it may be negative and acc can't ever be negative:
260        let (j0_sign, j0) = sub_sign(x1, x0);
261        let (j1_sign, j1) = sub_sign(y1, y0);
262
263        match j0_sign * j1_sign {
264            Plus => {
265                p.data.truncate(0);
266                p.data.resize(len, 0);
267
268                mac3(&mut p.data, &j0.data, &j1.data);
269                p.normalize();
270
271                sub2(&mut acc[b..], &p.data);
272            }
273            Minus => {
274                mac3(&mut acc[b..], &j0.data, &j1.data);
275            }
276            NoSign => (),
277        }
278    } else {
279        // Toom-3 multiplication:
280        //
281        // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
282        // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
283        //
284        // The general idea is to treat the large integers digits as
285        // polynomials of a certain degree and determine the coefficients/digits
286        // of the product of the two via interpolation of the polynomial product.
287        let i = y.len() / 3 + 1;
288
289        let x0_len = Ord::min(x.len(), i);
290        let x1_len = Ord::min(x.len() - x0_len, i);
291
292        let y0_len = i;
293        let y1_len = Ord::min(y.len() - y0_len, i);
294
295        // Break x and y into three parts, representating an order two polynomial.
296        // t is chosen to be the size of a digit so we can use faster shifts
297        // in place of multiplications.
298        //
299        // x(t) = x2*t^2 + x1*t + x0
300        let x0 = bigint_from_slice(&x[..x0_len]);
301        let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
302        let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
303
304        // y(t) = y2*t^2 + y1*t + y0
305        let y0 = bigint_from_slice(&y[..y0_len]);
306        let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
307        let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
308
309        // Let w(t) = x(t) * y(t)
310        //
311        // This gives us the following order-4 polynomial.
312        //
313        // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
314        //
315        // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
316        // of simply multiplying the x and y in total, we can evaluate w
317        // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
318        // points.
319        //
320        // It is arbitrary as to what points we evaluate w at but we use the
321        // following.
322        //
323        // w(t) at t = 0, 1, -1, -2 and inf
324        //
325        // The values for w(t) in terms of x(t)*y(t) at these points are:
326        //
327        // let a = w(0)   = x0 * y0
328        // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
329        // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
330        // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
331        // let e = w(inf) = x2 * y2 as t -> inf
332
333        // x0 + x2, avoiding temporaries
334        let p = &x0 + &x2;
335
336        // y0 + y2, avoiding temporaries
337        let q = &y0 + &y2;
338
339        // x2 - x1 + x0, avoiding temporaries
340        let p2 = &p - &x1;
341
342        // y2 - y1 + y0, avoiding temporaries
343        let q2 = &q - &y1;
344
345        // w(0)
346        let r0 = &x0 * &y0;
347
348        // w(inf)
349        let r4 = &x2 * &y2;
350
351        // w(1)
352        let r1 = (p + x1) * (q + y1);
353
354        // w(-1)
355        let r2 = &p2 * &q2;
356
357        // w(-2)
358        let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
359
360        // Evaluating these points gives us the following system of linear equations.
361        //
362        //  0  0  0  0  1 | a
363        //  1  1  1  1  1 | b
364        //  1 -1  1 -1  1 | c
365        // 16 -8  4 -2  1 | d
366        //  1  0  0  0  0 | e
367        //
368        // The solved equation (after gaussian elimination or similar)
369        // in terms of its coefficients:
370        //
371        // w0 = w(0)
372        // w1 = w(0)/2 + w(1)/3 - w(-1) + w(-2)/6 - 2*w(inf)
373        // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
374        // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(-2)/6 + 2*w(inf)
375        // w4 = w(inf)
376        //
377        // This particular sequence is given by Bodrato and is an interpolation
378        // of the above equations.
379        let mut comp3: BigInt = (r3 - &r1) / 3u32;
380        let mut comp1: BigInt = (r1 - &r2) >> 1;
381        let mut comp2: BigInt = r2 - &r0;
382        comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
383        comp2 += &comp1 - &r4;
384        comp1 -= &comp3;
385
386        // Recomposition. The coefficients of the polynomial are now known.
387        //
388        // Evaluate at w(t) where t is our given base to get the result.
389        //
390        //     let bits = u64::from(big_digit::BITS) * i as u64;
391        //     let result = r0
392        //         + (comp1 << bits)
393        //         + (comp2 << (2 * bits))
394        //         + (comp3 << (3 * bits))
395        //         + (r4 << (4 * bits));
396        //     let result_pos = result.to_biguint().unwrap();
397        //     add2(&mut acc[..], &result_pos.data);
398        //
399        // But with less intermediate copying:
400        for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
401            match result.sign() {
402                Plus => add2(&mut acc[i * j..], result.digits()),
403                Minus => sub2(&mut acc[i * j..], result.digits()),
404                NoSign => {}
405            }
406        }
407    }
408}
409
410fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
411    let len = x.len() + y.len() + 1;
412    let mut prod = BigUint { data: vec![0; len] };
413
414    mac3(&mut prod.data, x, y);
415    prod.normalized()
416}
417
418fn scalar_mul(a: &mut BigUint, b: BigDigit) {
419    match b {
420        0 => a.set_zero(),
421        1 => {}
422        _ => {
423            if b.is_power_of_two() {
424                *a <<= b.trailing_zeros();
425            } else {
426                let mut carry = 0;
427                for a in a.data.iter_mut() {
428                    *a = mul_with_carry(*a, b, &mut carry);
429                }
430                if carry != 0 {
431                    a.data.push(carry as BigDigit);
432                }
433            }
434        }
435    }
436}
437
438fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
439    // Normalize:
440    if let Some(&0) = a.last() {
441        a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
442    }
443    if let Some(&0) = b.last() {
444        b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
445    }
446
447    match cmp_slice(a, b) {
448        Ordering::Greater => {
449            let mut a = a.to_vec();
450            sub2(&mut a, b);
451            (Plus, biguint_from_vec(a))
452        }
453        Ordering::Less => {
454            let mut b = b.to_vec();
455            sub2(&mut b, a);
456            (Minus, biguint_from_vec(b))
457        }
458        Ordering::Equal => (NoSign, BigUint::ZERO),
459    }
460}
461
462macro_rules! impl_mul {
463    ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
464        impl Mul<$Other> for $Self {
465            type Output = BigUint;
466
467            #[inline]
468            fn mul(self, other: $Other) -> BigUint {
469                match (&*self.data, &*other.data) {
470                    // multiply by zero
471                    (&[], _) | (_, &[]) => BigUint::ZERO,
472                    // multiply by a scalar
473                    (_, &[digit]) => self * digit,
474                    (&[digit], _) => other * digit,
475                    // full multiplication
476                    (x, y) => mul3(x, y),
477                }
478            }
479        }
480    )*}
481}
482impl_mul! {
483    impl Mul<BigUint> for BigUint;
484    impl Mul<BigUint> for &BigUint;
485    impl Mul<&BigUint> for BigUint;
486    impl Mul<&BigUint> for &BigUint;
487}
488
489macro_rules! impl_mul_assign {
490    ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
491        impl MulAssign<$Other> for BigUint {
492            #[inline]
493            fn mul_assign(&mut self, other: $Other) {
494                match (&*self.data, &*other.data) {
495                    // multiply by zero
496                    (&[], _) => {},
497                    (_, &[]) => self.set_zero(),
498                    // multiply by a scalar
499                    (_, &[digit]) => *self *= digit,
500                    (&[digit], _) => *self = other * digit,
501                    // full multiplication
502                    (x, y) => *self = mul3(x, y),
503                }
504            }
505        }
506    )*}
507}
508impl_mul_assign! {
509    impl MulAssign<BigUint> for BigUint;
510    impl MulAssign<&BigUint> for BigUint;
511}
512
513promote_unsigned_scalars!(impl Mul for BigUint, mul);
514promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
515forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
516forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
517forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
518
519impl Mul<u32> for BigUint {
520    type Output = BigUint;
521
522    #[inline]
523    fn mul(mut self, other: u32) -> BigUint {
524        self *= other;
525        self
526    }
527}
528impl MulAssign<u32> for BigUint {
529    #[inline]
530    fn mul_assign(&mut self, other: u32) {
531        scalar_mul(self, other as BigDigit);
532    }
533}
534
535impl Mul<u64> for BigUint {
536    type Output = BigUint;
537
538    #[inline]
539    fn mul(mut self, other: u64) -> BigUint {
540        self *= other;
541        self
542    }
543}
544impl MulAssign<u64> for BigUint {
545    cfg_digit!(
546        #[inline]
547        fn mul_assign(&mut self, other: u64) {
548            if let Some(other) = BigDigit::from_u64(other) {
549                scalar_mul(self, other);
550            } else {
551                let (hi, lo) = big_digit::from_doublebigdigit(other);
552                *self = mul3(&self.data, &[lo, hi]);
553            }
554        }
555
556        #[inline]
557        fn mul_assign(&mut self, other: u64) {
558            scalar_mul(self, other);
559        }
560    );
561}
562
563impl Mul<u128> for BigUint {
564    type Output = BigUint;
565
566    #[inline]
567    fn mul(mut self, other: u128) -> BigUint {
568        self *= other;
569        self
570    }
571}
572
573impl MulAssign<u128> for BigUint {
574    cfg_digit!(
575        #[inline]
576        fn mul_assign(&mut self, other: u128) {
577            if let Some(other) = BigDigit::from_u128(other) {
578                scalar_mul(self, other);
579            } else {
580                *self = match super::u32_from_u128(other) {
581                    (0, 0, c, d) => mul3(&self.data, &[d, c]),
582                    (0, b, c, d) => mul3(&self.data, &[d, c, b]),
583                    (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
584                };
585            }
586        }
587
588        #[inline]
589        fn mul_assign(&mut self, other: u128) {
590            if let Some(other) = BigDigit::from_u128(other) {
591                scalar_mul(self, other);
592            } else {
593                let (hi, lo) = big_digit::from_doublebigdigit(other);
594                *self = mul3(&self.data, &[lo, hi]);
595            }
596        }
597    );
598}
599
600impl CheckedMul for BigUint {
601    #[inline]
602    fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
603        Some(self.mul(v))
604    }
605}
606
607impl_product_iter_type!(BigUint);
608
609#[test]
610fn test_sub_sign() {
611    use crate::BigInt;
612    use num_traits::Num;
613
614    fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
615        let (sign, val) = sub_sign(a, b);
616        BigInt::from_biguint(sign, val)
617    }
618
619    let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
620    let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
621    let a_i = BigInt::from(a.clone());
622    let b_i = BigInt::from(b.clone());
623
624    assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
625    assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
626}