num_bigint/biguint/multiplication.rs
1use super::addition::{__add2, add2};
2use super::subtraction::sub2;
3use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
4
5use crate::big_digit::{self, BigDigit, DoubleBigDigit};
6use crate::Sign::{self, Minus, NoSign, Plus};
7use crate::{BigInt, UsizePromotion};
8
9use core::cmp::Ordering;
10use core::iter::Product;
11use core::ops::{Mul, MulAssign};
12use num_traits::{CheckedMul, FromPrimitive, One, Zero};
13
14#[inline]
15pub(super) fn mac_with_carry(
16 a: BigDigit,
17 b: BigDigit,
18 c: BigDigit,
19 acc: &mut DoubleBigDigit,
20) -> BigDigit {
21 *acc += DoubleBigDigit::from(a);
22 *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
23 let lo = *acc as BigDigit;
24 *acc >>= big_digit::BITS;
25 lo
26}
27
28#[inline]
29fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
30 *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
31 let lo = *acc as BigDigit;
32 *acc >>= big_digit::BITS;
33 lo
34}
35
36/// Three argument multiply accumulate:
37/// acc += b * c
38fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
39 if c == 0 {
40 return;
41 }
42
43 let mut carry = 0;
44 let (a_lo, a_hi) = acc.split_at_mut(b.len());
45
46 for (a, &b) in a_lo.iter_mut().zip(b) {
47 *a = mac_with_carry(*a, b, c, &mut carry);
48 }
49
50 let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
51
52 let final_carry = if carry_hi == 0 {
53 __add2(a_hi, &[carry_lo])
54 } else {
55 __add2(a_hi, &[carry_hi, carry_lo])
56 };
57 assert_eq!(final_carry, 0, "carry overflow during multiplication!");
58}
59
60fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
61 BigInt::from(biguint_from_vec(slice.to_vec()))
62}
63
64/// Three argument multiply accumulate:
65/// acc += b * c
66#[allow(clippy::many_single_char_names)]
67fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
68 // Least-significant zeros have no effect on the output.
69 if let Some(&0) = b.first() {
70 if let Some(nz) = b.iter().position(|&d| d != 0) {
71 b = &b[nz..];
72 acc = &mut acc[nz..];
73 } else {
74 return;
75 }
76 }
77 if let Some(&0) = c.first() {
78 if let Some(nz) = c.iter().position(|&d| d != 0) {
79 c = &c[nz..];
80 acc = &mut acc[nz..];
81 } else {
82 return;
83 }
84 }
85
86 let acc = acc;
87 let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
88
89 // We use four algorithms for different input sizes.
90 //
91 // - For small inputs, long multiplication is fastest.
92 // - If y is at least least twice as long as x, split using Half-Karatsuba.
93 // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
94 // to avoid unnecessary allocations for intermediate values.
95 // - For the largest inputs we use Toom-3, which better optimizes the
96 // number of operations, but uses more temporary allocations.
97 //
98 // The thresholds are somewhat arbitrary, chosen by evaluating the results
99 // of `cargo bench --bench bigint multiply`.
100
101 if x.len() <= 32 {
102 // Long multiplication:
103 for (i, xi) in x.iter().enumerate() {
104 mac_digit(&mut acc[i..], y, *xi);
105 }
106 } else if x.len() * 2 <= y.len() {
107 // Karatsuba Multiplication for factors with significant length disparity.
108 //
109 // The Half-Karatsuba Multiplication Algorithm is a specialized case of
110 // the normal Karatsuba multiplication algorithm, designed for the scenario
111 // where y has at least twice as many base digits as x.
112 //
113 // In this case y (the longer input) is split into high2 and low2,
114 // at m2 (half the length of y) and x (the shorter input),
115 // is used directly without splitting.
116 //
117 // The algorithm then proceeds as follows:
118 //
119 // 1. Compute the product z0 = x * low2.
120 // 2. Compute the product temp = x * high2.
121 // 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
122 // 4. Add temp and z0 to obtain the final result.
123 //
124 // Proof:
125 //
126 // The algorithm can be derived from the original Karatsuba algorithm by
127 // simplifying the formula when the shorter factor x is not split into
128 // high and low parts, as shown below.
129 //
130 // Original Karatsuba formula:
131 //
132 // result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
133 //
134 // Substitutions:
135 //
136 // low1 = x
137 // high1 = 0
138 //
139 // Applying substitutions:
140 //
141 // z0 = (low1 * low2)
142 // = (x * low2)
143 //
144 // z1 = ((low1 + high1) * (low2 + high2))
145 // = ((x + 0) * (low2 + high2))
146 // = (x * low2) + (x * high2)
147 //
148 // z2 = (high1 * high2)
149 // = (0 * high2)
150 // = 0
151 //
152 // Simplified using the above substitutions:
153 //
154 // result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
155 // = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
156 // = ((z1 - z0) * NBASE ^ m2) + z0
157 // = ((z1 - z0) * NBASE ^ m2) + z0
158 // = (x * high2) * NBASE ^ m2 + z0
159 let m2 = y.len() / 2;
160 let (low2, high2) = y.split_at(m2);
161
162 // (x * high2) * NBASE ^ m2 + z0
163 mac3(acc, x, low2);
164 mac3(&mut acc[m2..], x, high2);
165 } else if x.len() <= 256 {
166 // Karatsuba multiplication:
167 //
168 // The idea is that we break x and y up into two smaller numbers that each have about half
169 // as many digits, like so (note that multiplying by b is just a shift):
170 //
171 // x = x0 + x1 * b
172 // y = y0 + y1 * b
173 //
174 // With some algebra, we can compute x * y with three smaller products, where the inputs to
175 // each of the smaller products have only about half as many digits as x and y:
176 //
177 // x * y = (x0 + x1 * b) * (y0 + y1 * b)
178 //
179 // x * y = x0 * y0
180 // + x0 * y1 * b
181 // + x1 * y0 * b
182 // + x1 * y1 * b^2
183 //
184 // Let p0 = x0 * y0 and p2 = x1 * y1:
185 //
186 // x * y = p0
187 // + (x0 * y1 + x1 * y0) * b
188 // + p2 * b^2
189 //
190 // The real trick is that middle term:
191 //
192 // x0 * y1 + x1 * y0
193 //
194 // = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
195 //
196 // = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
197 //
198 // Now we complete the square:
199 //
200 // = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
201 //
202 // = -((x1 - x0) * (y1 - y0)) + p0 + p2
203 //
204 // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
205 //
206 // x * y = p0
207 // + (p0 + p2 - p1) * b
208 // + p2 * b^2
209 //
210 // Where the three intermediate products are:
211 //
212 // p0 = x0 * y0
213 // p1 = (x1 - x0) * (y1 - y0)
214 // p2 = x1 * y1
215 //
216 // In doing the computation, we take great care to avoid unnecessary temporary variables
217 // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
218 // bit so we can use the same temporary variable for all the intermediate products:
219 //
220 // x * y = p2 * b^2 + p2 * b
221 // + p0 * b + p0
222 // - p1 * b
223 //
224 // The other trick we use is instead of doing explicit shifts, we slice acc at the
225 // appropriate offset when doing the add.
226
227 // When x is smaller than y, it's significantly faster to pick b such that x is split in
228 // half, not y:
229 let b = x.len() / 2;
230 let (x0, x1) = x.split_at(b);
231 let (y0, y1) = y.split_at(b);
232
233 // We reuse the same BigUint for all the intermediate multiplies and have to size p
234 // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
235 let len = x1.len() + y1.len() + 1;
236 let mut p = BigUint { data: vec![0; len] };
237
238 // p2 = x1 * y1
239 mac3(&mut p.data, x1, y1);
240
241 // Not required, but the adds go faster if we drop any unneeded 0s from the end:
242 p.normalize();
243
244 add2(&mut acc[b..], &p.data);
245 add2(&mut acc[b * 2..], &p.data);
246
247 // Zero out p before the next multiply:
248 p.data.truncate(0);
249 p.data.resize(len, 0);
250
251 // p0 = x0 * y0
252 mac3(&mut p.data, x0, y0);
253 p.normalize();
254
255 add2(acc, &p.data);
256 add2(&mut acc[b..], &p.data);
257
258 // p1 = (x1 - x0) * (y1 - y0)
259 // We do this one last, since it may be negative and acc can't ever be negative:
260 let (j0_sign, j0) = sub_sign(x1, x0);
261 let (j1_sign, j1) = sub_sign(y1, y0);
262
263 match j0_sign * j1_sign {
264 Plus => {
265 p.data.truncate(0);
266 p.data.resize(len, 0);
267
268 mac3(&mut p.data, &j0.data, &j1.data);
269 p.normalize();
270
271 sub2(&mut acc[b..], &p.data);
272 }
273 Minus => {
274 mac3(&mut acc[b..], &j0.data, &j1.data);
275 }
276 NoSign => (),
277 }
278 } else {
279 // Toom-3 multiplication:
280 //
281 // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
282 // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
283 //
284 // The general idea is to treat the large integers digits as
285 // polynomials of a certain degree and determine the coefficients/digits
286 // of the product of the two via interpolation of the polynomial product.
287 let i = y.len() / 3 + 1;
288
289 let x0_len = Ord::min(x.len(), i);
290 let x1_len = Ord::min(x.len() - x0_len, i);
291
292 let y0_len = i;
293 let y1_len = Ord::min(y.len() - y0_len, i);
294
295 // Break x and y into three parts, representating an order two polynomial.
296 // t is chosen to be the size of a digit so we can use faster shifts
297 // in place of multiplications.
298 //
299 // x(t) = x2*t^2 + x1*t + x0
300 let x0 = bigint_from_slice(&x[..x0_len]);
301 let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
302 let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
303
304 // y(t) = y2*t^2 + y1*t + y0
305 let y0 = bigint_from_slice(&y[..y0_len]);
306 let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
307 let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
308
309 // Let w(t) = x(t) * y(t)
310 //
311 // This gives us the following order-4 polynomial.
312 //
313 // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
314 //
315 // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
316 // of simply multiplying the x and y in total, we can evaluate w
317 // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
318 // points.
319 //
320 // It is arbitrary as to what points we evaluate w at but we use the
321 // following.
322 //
323 // w(t) at t = 0, 1, -1, -2 and inf
324 //
325 // The values for w(t) in terms of x(t)*y(t) at these points are:
326 //
327 // let a = w(0) = x0 * y0
328 // let b = w(1) = (x2 + x1 + x0) * (y2 + y1 + y0)
329 // let c = w(-1) = (x2 - x1 + x0) * (y2 - y1 + y0)
330 // let d = w(-2) = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
331 // let e = w(inf) = x2 * y2 as t -> inf
332
333 // x0 + x2, avoiding temporaries
334 let p = &x0 + &x2;
335
336 // y0 + y2, avoiding temporaries
337 let q = &y0 + &y2;
338
339 // x2 - x1 + x0, avoiding temporaries
340 let p2 = &p - &x1;
341
342 // y2 - y1 + y0, avoiding temporaries
343 let q2 = &q - &y1;
344
345 // w(0)
346 let r0 = &x0 * &y0;
347
348 // w(inf)
349 let r4 = &x2 * &y2;
350
351 // w(1)
352 let r1 = (p + x1) * (q + y1);
353
354 // w(-1)
355 let r2 = &p2 * &q2;
356
357 // w(-2)
358 let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
359
360 // Evaluating these points gives us the following system of linear equations.
361 //
362 // 0 0 0 0 1 | a
363 // 1 1 1 1 1 | b
364 // 1 -1 1 -1 1 | c
365 // 16 -8 4 -2 1 | d
366 // 1 0 0 0 0 | e
367 //
368 // The solved equation (after gaussian elimination or similar)
369 // in terms of its coefficients:
370 //
371 // w0 = w(0)
372 // w1 = w(0)/2 + w(1)/3 - w(-1) + w(-2)/6 - 2*w(inf)
373 // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
374 // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(-2)/6 + 2*w(inf)
375 // w4 = w(inf)
376 //
377 // This particular sequence is given by Bodrato and is an interpolation
378 // of the above equations.
379 let mut comp3: BigInt = (r3 - &r1) / 3u32;
380 let mut comp1: BigInt = (r1 - &r2) >> 1;
381 let mut comp2: BigInt = r2 - &r0;
382 comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
383 comp2 += &comp1 - &r4;
384 comp1 -= &comp3;
385
386 // Recomposition. The coefficients of the polynomial are now known.
387 //
388 // Evaluate at w(t) where t is our given base to get the result.
389 //
390 // let bits = u64::from(big_digit::BITS) * i as u64;
391 // let result = r0
392 // + (comp1 << bits)
393 // + (comp2 << (2 * bits))
394 // + (comp3 << (3 * bits))
395 // + (r4 << (4 * bits));
396 // let result_pos = result.to_biguint().unwrap();
397 // add2(&mut acc[..], &result_pos.data);
398 //
399 // But with less intermediate copying:
400 for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
401 match result.sign() {
402 Plus => add2(&mut acc[i * j..], result.digits()),
403 Minus => sub2(&mut acc[i * j..], result.digits()),
404 NoSign => {}
405 }
406 }
407 }
408}
409
410fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
411 let len = x.len() + y.len() + 1;
412 let mut prod = BigUint { data: vec![0; len] };
413
414 mac3(&mut prod.data, x, y);
415 prod.normalized()
416}
417
418fn scalar_mul(a: &mut BigUint, b: BigDigit) {
419 match b {
420 0 => a.set_zero(),
421 1 => {}
422 _ => {
423 if b.is_power_of_two() {
424 *a <<= b.trailing_zeros();
425 } else {
426 let mut carry = 0;
427 for a in a.data.iter_mut() {
428 *a = mul_with_carry(*a, b, &mut carry);
429 }
430 if carry != 0 {
431 a.data.push(carry as BigDigit);
432 }
433 }
434 }
435 }
436}
437
438fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
439 // Normalize:
440 if let Some(&0) = a.last() {
441 a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
442 }
443 if let Some(&0) = b.last() {
444 b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
445 }
446
447 match cmp_slice(a, b) {
448 Ordering::Greater => {
449 let mut a = a.to_vec();
450 sub2(&mut a, b);
451 (Plus, biguint_from_vec(a))
452 }
453 Ordering::Less => {
454 let mut b = b.to_vec();
455 sub2(&mut b, a);
456 (Minus, biguint_from_vec(b))
457 }
458 Ordering::Equal => (NoSign, BigUint::ZERO),
459 }
460}
461
462macro_rules! impl_mul {
463 ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
464 impl Mul<$Other> for $Self {
465 type Output = BigUint;
466
467 #[inline]
468 fn mul(self, other: $Other) -> BigUint {
469 match (&*self.data, &*other.data) {
470 // multiply by zero
471 (&[], _) | (_, &[]) => BigUint::ZERO,
472 // multiply by a scalar
473 (_, &[digit]) => self * digit,
474 (&[digit], _) => other * digit,
475 // full multiplication
476 (x, y) => mul3(x, y),
477 }
478 }
479 }
480 )*}
481}
482impl_mul! {
483 impl Mul<BigUint> for BigUint;
484 impl Mul<BigUint> for &BigUint;
485 impl Mul<&BigUint> for BigUint;
486 impl Mul<&BigUint> for &BigUint;
487}
488
489macro_rules! impl_mul_assign {
490 ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
491 impl MulAssign<$Other> for BigUint {
492 #[inline]
493 fn mul_assign(&mut self, other: $Other) {
494 match (&*self.data, &*other.data) {
495 // multiply by zero
496 (&[], _) => {},
497 (_, &[]) => self.set_zero(),
498 // multiply by a scalar
499 (_, &[digit]) => *self *= digit,
500 (&[digit], _) => *self = other * digit,
501 // full multiplication
502 (x, y) => *self = mul3(x, y),
503 }
504 }
505 }
506 )*}
507}
508impl_mul_assign! {
509 impl MulAssign<BigUint> for BigUint;
510 impl MulAssign<&BigUint> for BigUint;
511}
512
513promote_unsigned_scalars!(impl Mul for BigUint, mul);
514promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
515forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
516forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
517forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
518
519impl Mul<u32> for BigUint {
520 type Output = BigUint;
521
522 #[inline]
523 fn mul(mut self, other: u32) -> BigUint {
524 self *= other;
525 self
526 }
527}
528impl MulAssign<u32> for BigUint {
529 #[inline]
530 fn mul_assign(&mut self, other: u32) {
531 scalar_mul(self, other as BigDigit);
532 }
533}
534
535impl Mul<u64> for BigUint {
536 type Output = BigUint;
537
538 #[inline]
539 fn mul(mut self, other: u64) -> BigUint {
540 self *= other;
541 self
542 }
543}
544impl MulAssign<u64> for BigUint {
545 cfg_digit!(
546 #[inline]
547 fn mul_assign(&mut self, other: u64) {
548 if let Some(other) = BigDigit::from_u64(other) {
549 scalar_mul(self, other);
550 } else {
551 let (hi, lo) = big_digit::from_doublebigdigit(other);
552 *self = mul3(&self.data, &[lo, hi]);
553 }
554 }
555
556 #[inline]
557 fn mul_assign(&mut self, other: u64) {
558 scalar_mul(self, other);
559 }
560 );
561}
562
563impl Mul<u128> for BigUint {
564 type Output = BigUint;
565
566 #[inline]
567 fn mul(mut self, other: u128) -> BigUint {
568 self *= other;
569 self
570 }
571}
572
573impl MulAssign<u128> for BigUint {
574 cfg_digit!(
575 #[inline]
576 fn mul_assign(&mut self, other: u128) {
577 if let Some(other) = BigDigit::from_u128(other) {
578 scalar_mul(self, other);
579 } else {
580 *self = match super::u32_from_u128(other) {
581 (0, 0, c, d) => mul3(&self.data, &[d, c]),
582 (0, b, c, d) => mul3(&self.data, &[d, c, b]),
583 (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
584 };
585 }
586 }
587
588 #[inline]
589 fn mul_assign(&mut self, other: u128) {
590 if let Some(other) = BigDigit::from_u128(other) {
591 scalar_mul(self, other);
592 } else {
593 let (hi, lo) = big_digit::from_doublebigdigit(other);
594 *self = mul3(&self.data, &[lo, hi]);
595 }
596 }
597 );
598}
599
600impl CheckedMul for BigUint {
601 #[inline]
602 fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
603 Some(self.mul(v))
604 }
605}
606
607impl_product_iter_type!(BigUint);
608
609#[test]
610fn test_sub_sign() {
611 use crate::BigInt;
612 use num_traits::Num;
613
614 fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
615 let (sign, val) = sub_sign(a, b);
616 BigInt::from_biguint(sign, val)
617 }
618
619 let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
620 let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
621 let a_i = BigInt::from(a.clone());
622 let b_i = BigInt::from(b.clone());
623
624 assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
625 assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
626}