1use alloc::{vec, vec::Vec};
21use codec::{Decode, Encode};
22use core::{cell::RefCell, cmp::Ordering, ops};
23use num_traits::{One, Zero};
24
25pub type Single = u32;
31pub type Double = u64;
33const SHIFT: usize = 32;
35const B: Double = Single::max_value() as Double + 1;
37
38static_assertions::const_assert!(
39 core::mem::size_of::<Double>() - core::mem::size_of::<Single>() == SHIFT / 8
40);
41
42pub fn split(a: Double) -> (Single, Single) {
44 let al = a as Single;
45 let ah = (a >> SHIFT) as Single;
46 (ah, al)
47}
48
49pub fn mul_single(a: Single, b: Single) -> Double {
53 let a: Double = a.into();
54 let b: Double = b.into();
55 a * b
56}
57
58pub fn add_single(a: Single, b: Single) -> (Single, Single) {
63 let a: Double = a.into();
64 let b: Double = b.into();
65 let q = a + b;
66 let (carry, r) = split(q);
67 (r, carry)
68}
69
70fn div_single(a: Double, b: Single) -> (Double, Single) {
75 let b: Double = b.into();
76 let q = a / b;
77 let r = a % b;
78 (q, r as Single)
80}
81
82#[derive(Encode, Decode, Clone, Default)]
84pub struct BigUint {
85 pub(crate) digits: Vec<Single>,
87}
88
89impl BigUint {
90 pub fn with_capacity(size: usize) -> Self {
95 Self { digits: vec![0; size.max(1)] }
96 }
97
98 pub fn from_limbs(limbs: &[Single]) -> Self {
101 if !limbs.is_empty() {
102 Self { digits: limbs.to_vec() }
103 } else {
104 Zero::zero()
105 }
106 }
107
108 pub fn len(&self) -> usize {
110 self.digits.len()
111 }
112
113 pub fn get(&self, index: usize) -> Single {
119 self.digits[self.len() - 1 - index]
120 }
121
122 pub fn checked_get(&self, index: usize) -> Option<Single> {
124 let i = self.len().checked_sub(1)?;
125 let j = i.checked_sub(index)?;
126 self.digits.get(j).cloned()
127 }
128
129 pub fn set(&mut self, index: usize, value: Single) {
135 let len = self.digits.len();
136 self.digits[len - 1 - index] = value;
137 }
138
139 pub fn lsb(&self) -> Single {
145 self.digits[self.len() - 1]
146 }
147
148 pub fn msb(&self) -> Single {
154 self.digits[0]
155 }
156
157 pub fn lstrip(&mut self) {
159 if self.len().is_zero() {
163 return
164 }
165 let index = self.digits.iter().position(|&elem| elem != 0).unwrap_or(self.len() - 1);
166
167 if index > 0 {
168 self.digits = self.digits[index..].to_vec()
169 }
170 }
171
172 pub fn lpad(&mut self, size: usize) {
175 let n = self.len();
176 if n >= size {
177 return
178 }
179 let pad = size - n;
180 let mut new_digits = (0..pad).map(|_| 0).collect::<Vec<Single>>();
181 new_digits.extend(self.digits.iter());
182 self.digits = new_digits;
183 }
184
185 pub fn add(self, other: &Self) -> Self {
194 let n = self.len().max(other.len());
195 let mut k: Double = 0;
196 let mut w = Self::with_capacity(n + 1);
197
198 for j in 0..n {
199 let u = Double::from(self.checked_get(j).unwrap_or(0));
200 let v = Double::from(other.checked_get(j).unwrap_or(0));
201 let s = u + v + k;
202 w.set(j, (s % B) as Single);
204 k = s / B;
205 }
206 w.set(n, k as Single);
208 w
209 }
210
211 pub fn sub(self, other: &Self) -> Result<Self, Self> {
218 let n = self.len().max(other.len());
219 let mut k = 0;
220 let mut w = Self::with_capacity(n);
221 for j in 0..n {
222 let s = {
223 let u = Double::from(self.checked_get(j).unwrap_or(0));
224 let v = Double::from(other.checked_get(j).unwrap_or(0));
225
226 if let Some(v2) = u.checked_sub(v).and_then(|v1| v1.checked_sub(k)) {
227 let t = v2;
229 k = 0;
230
231 t
232 } else {
233 let t = u + B - v - k;
238 k = 1;
239
240 t
241 }
242 };
243 w.set(j, s as Single);
244 }
245
246 if k.is_zero() {
247 Ok(w)
248 } else {
249 Err(w)
250 }
251 }
252
253 pub fn mul(self, other: &Self) -> Self {
262 let n = self.len();
263 let m = other.len();
264 let mut w = Self::with_capacity(m + n);
265
266 for j in 0..n {
267 if self.get(j) == 0 {
268 continue
271 }
272
273 let mut k = 0;
274 for i in 0..m {
275 let t = mul_single(self.get(j), other.get(i)) +
277 Double::from(w.get(i + j)) +
278 Double::from(k);
279 w.set(i + j, (t % B) as Single);
280 k = (t / B) as Single;
282 }
283 w.set(j + m, k);
284 }
285 w
286 }
287
288 pub fn div_unit(self, mut other: Single) -> Self {
293 other = other.max(1);
294 let n = self.len();
295 let mut out = Self::with_capacity(n);
296 let mut r: Single = 0;
297 let with_r = |x: Single, r: Single| Double::from(r) * B + Double::from(x);
299 for d in (0..n).rev() {
300 let (q, rr) = div_single(with_r(self.get(d), r), other);
301 out.set(d, q as Single);
302 r = rr;
303 }
304 out
305 }
306
307 pub fn div(self, other: &Self, rem: bool) -> Option<(Self, Self)> {
321 if other.len() <= 1 || other.msb() == 0 || self.msb() == 0 || self.len() <= other.len() {
322 return None
323 }
324 let n = other.len();
325 let m = self.len() - n;
326
327 let mut q = Self::with_capacity(m + 1);
328 let mut r = Self::with_capacity(n);
329
330 let normalizer_bits = other.msb().leading_zeros() as Single;
333 let normalizer = 2_u32.pow(normalizer_bits as u32) as Single;
334
335 let mut self_norm = self.mul(&Self::from(normalizer));
337 let mut other_norm = other.clone().mul(&Self::from(normalizer));
338
339 self_norm.lpad(n + m + 1);
341 other_norm.lstrip();
342
343 for j in (0..=m).rev() {
345 let (qhat, rhat) = {
347 let dividend =
350 Double::from(self_norm.get(j + n)) * B + Double::from(self_norm.get(j + n - 1));
351 let divisor = other_norm.get(n - 1);
352 div_single(dividend, divisor)
353 };
354
355 let qhat = RefCell::new(qhat);
358 let rhat = RefCell::new(Double::from(rhat));
359
360 let test = || {
361 let qhat_local = *qhat.borrow();
363 let rhat_local = *rhat.borrow();
364 let predicate_1 = qhat_local >= B;
365 let predicate_2 = {
366 let lhs = qhat_local * Double::from(other_norm.get(n - 2));
367 let rhs = B * rhat_local + Double::from(self_norm.get(j + n - 2));
368 lhs > rhs
369 };
370 if predicate_1 || predicate_2 {
371 *qhat.borrow_mut() -= 1;
372 *rhat.borrow_mut() += Double::from(other_norm.get(n - 1));
373 true
374 } else {
375 false
376 }
377 };
378
379 test();
380 while (*rhat.borrow() as Double) < B {
381 if !test() {
382 break
383 }
384 }
385
386 let qhat = qhat.into_inner();
387 let lhs = Self { digits: (j..=j + n).rev().map(|d| self_norm.get(d)).collect() };
391 let rhs = other_norm.clone().mul(&Self::from(qhat));
392
393 let maybe_sub = lhs.sub(&rhs);
394 let mut negative = false;
395 let sub = match maybe_sub {
396 Ok(t) => t,
397 Err(t) => {
398 negative = true;
399 t
400 },
401 };
402 (j..=j + n).for_each(|d| {
403 self_norm.set(d, sub.get(d - j));
404 });
405
406 q.set(j, qhat as Single);
410
411 if negative {
413 q.set(j, q.get(j) - 1);
414 let u = Self { digits: (j..=j + n).rev().map(|d| self_norm.get(d)).collect() };
415 let r = other_norm.clone().add(&u);
416 (j..=j + n).rev().for_each(|d| {
417 self_norm.set(d, r.get(d - j));
418 })
419 }
420 }
421
422 if rem {
424 if normalizer_bits > 0 {
426 let s = SHIFT as u32;
427 let nb = normalizer_bits;
428 for d in 0..n - 1 {
429 let v = self_norm.get(d) >> nb | self_norm.get(d + 1).overflowing_shl(s - nb).0;
430 r.set(d, v);
431 }
432 r.set(n - 1, self_norm.get(n - 1) >> normalizer_bits);
433 } else {
434 r = self_norm;
435 }
436 }
437
438 Some((q, r))
439 }
440}
441
442impl core::fmt::Debug for BigUint {
443 #[cfg(feature = "std")]
444 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
445 write!(
446 f,
447 "BigUint {{ {:?} ({:?})}}",
448 self.digits,
449 u128::try_from(self.clone()).unwrap_or(0),
450 )
451 }
452
453 #[cfg(not(feature = "std"))]
454 fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
455 Ok(())
456 }
457}
458
459impl PartialEq for BigUint {
460 fn eq(&self, other: &Self) -> bool {
461 self.cmp(other) == Ordering::Equal
462 }
463}
464
465impl Eq for BigUint {}
466
467impl Ord for BigUint {
468 fn cmp(&self, other: &Self) -> Ordering {
469 let lhs_first = self.digits.iter().position(|&e| e != 0);
470 let rhs_first = other.digits.iter().position(|&e| e != 0);
471
472 match (lhs_first, rhs_first) {
473 (None, None) => Ordering::Equal,
476 (Some(_), None) => Ordering::Greater,
477 (None, Some(_)) => Ordering::Less,
478 (Some(lhs_idx), Some(rhs_idx)) => {
479 let lhs = &self.digits[lhs_idx..];
480 let rhs = &other.digits[rhs_idx..];
481 let len_cmp = lhs.len().cmp(&rhs.len());
482 match len_cmp {
483 Ordering::Equal => lhs.cmp(rhs),
484 _ => len_cmp,
485 }
486 },
487 }
488 }
489}
490
491impl PartialOrd for BigUint {
492 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
493 Some(self.cmp(other))
494 }
495}
496
497impl ops::Add for BigUint {
498 type Output = Self;
499 fn add(self, rhs: Self) -> Self::Output {
500 self.add(&rhs)
501 }
502}
503
504impl ops::Sub for BigUint {
505 type Output = Self;
506 fn sub(self, rhs: Self) -> Self::Output {
507 self.sub(&rhs).unwrap_or_else(|e| e)
508 }
509}
510
511impl ops::Mul for BigUint {
512 type Output = Self;
513 fn mul(self, rhs: Self) -> Self::Output {
514 self.mul(&rhs)
515 }
516}
517
518impl Zero for BigUint {
519 fn zero() -> Self {
520 Self { digits: vec![Zero::zero()] }
521 }
522
523 fn is_zero(&self) -> bool {
524 self.digits.iter().all(|d| d.is_zero())
525 }
526}
527
528impl One for BigUint {
529 fn one() -> Self {
530 Self { digits: vec![Single::one()] }
531 }
532}
533
534macro_rules! impl_try_from_number_for {
535 ($([$type:ty, $len:expr]),+) => {
536 $(
537 impl TryFrom<BigUint> for $type {
538 type Error = &'static str;
539 fn try_from(mut value: BigUint) -> Result<$type, Self::Error> {
540 value.lstrip();
541 let error_message = concat!("cannot fit a number into ", stringify!($type));
542 if value.len() * SHIFT > $len {
543 Err(error_message)
544 } else {
545 let mut acc: $type = Zero::zero();
546 for (i, d) in value.digits.iter().rev().cloned().enumerate() {
547 let d: $type = d.into();
548 acc += d << (SHIFT * i);
549 }
550 Ok(acc)
551 }
552 }
553 }
554 )*
555 };
556}
557impl_try_from_number_for!([u128, 128], [u64, 64]);
559
560macro_rules! impl_from_for_smaller_than_word {
561 ($($type:ty),+) => {
562 $(impl From<$type> for BigUint {
563 fn from(a: $type) -> Self {
564 Self { digits: vec! [a.into()] }
565 }
566 })*
567 }
568}
569impl_from_for_smaller_than_word!(u8, u16, u32);
570
571impl From<u64> for BigUint {
572 fn from(a: Double) -> Self {
573 let (ah, al) = split(a);
574 Self { digits: vec![ah, al] }
575 }
576}
577
578impl From<u128> for BigUint {
579 fn from(a: u128) -> Self {
580 crate::helpers_128bit::to_big_uint(a)
581 }
582}
583
584#[cfg(test)]
585pub mod tests {
586 use super::*;
587
588 fn with_limbs(n: usize) -> BigUint {
589 BigUint { digits: vec![1; n] }
590 }
591
592 #[test]
593 fn split_works() {
594 let a = SHIFT / 2;
595 let b = SHIFT * 3 / 2;
596 let num: Double = 1 << a | 1 << b;
597 assert_eq!(num, 0x_0001_0000_0001_0000);
598 assert_eq!(split(num), (1 << a, 1 << a));
599
600 let a = SHIFT / 2 + 4;
601 let b = SHIFT / 2 - 4;
602 let num: Double = 1 << (SHIFT + a) | 1 << b;
603 assert_eq!(num, 0x_0010_0000_0000_1000);
604 assert_eq!(split(num), (1 << a, 1 << b));
605 }
606
607 #[test]
608 fn strip_works() {
609 let mut a = BigUint::from_limbs(&[0, 1, 0]);
610 a.lstrip();
611 assert_eq!(a.digits, vec![1, 0]);
612
613 let mut a = BigUint::from_limbs(&[0, 0, 1]);
614 a.lstrip();
615 assert_eq!(a.digits, vec![1]);
616
617 let mut a = BigUint::from_limbs(&[0, 0]);
618 a.lstrip();
619 assert_eq!(a.digits, vec![0]);
620
621 let mut a = BigUint::from_limbs(&[0, 0, 0]);
622 a.lstrip();
623 assert_eq!(a.digits, vec![0]);
624 }
625
626 #[test]
627 fn lpad_works() {
628 let mut a = BigUint::from_limbs(&[0, 1, 0]);
629 a.lpad(2);
630 assert_eq!(a.digits, vec![0, 1, 0]);
631
632 let mut a = BigUint::from_limbs(&[0, 1, 0]);
633 a.lpad(3);
634 assert_eq!(a.digits, vec![0, 1, 0]);
635
636 let mut a = BigUint::from_limbs(&[0, 1, 0]);
637 a.lpad(4);
638 assert_eq!(a.digits, vec![0, 0, 1, 0]);
639 }
640
641 #[test]
642 fn equality_works() {
643 assert_eq!(BigUint { digits: vec![1, 2, 3] } == BigUint { digits: vec![1, 2, 3] }, true);
644 assert_eq!(BigUint { digits: vec![3, 2, 3] } == BigUint { digits: vec![1, 2, 3] }, false);
645 assert_eq!(BigUint { digits: vec![0, 1, 2, 3] } == BigUint { digits: vec![1, 2, 3] }, true);
646 }
647
648 #[test]
649 fn ordering_works() {
650 assert!(BigUint { digits: vec![0] } < BigUint { digits: vec![1] });
651 assert!(BigUint { digits: vec![0] } == BigUint { digits: vec![0] });
652 assert!(BigUint { digits: vec![] } == BigUint { digits: vec![0] });
653 assert!(BigUint { digits: vec![] } == BigUint { digits: vec![] });
654 assert!(BigUint { digits: vec![] } < BigUint { digits: vec![1] });
655
656 assert!(BigUint { digits: vec![1, 2, 3] } == BigUint { digits: vec![1, 2, 3] });
657 assert!(BigUint { digits: vec![0, 1, 2, 3] } == BigUint { digits: vec![1, 2, 3] });
658
659 assert!(BigUint { digits: vec![1, 2, 4] } > BigUint { digits: vec![1, 2, 3] });
660 assert!(BigUint { digits: vec![0, 1, 2, 4] } > BigUint { digits: vec![1, 2, 3] });
661 assert!(BigUint { digits: vec![1, 2, 1, 0] } > BigUint { digits: vec![1, 2, 3] });
662
663 assert!(BigUint { digits: vec![0, 1, 2, 1] } < BigUint { digits: vec![1, 2, 3] });
664 }
665
666 #[test]
667 fn can_try_build_numbers_from_types() {
668 assert_eq!(u64::try_from(with_limbs(1)).unwrap(), 1);
669 assert_eq!(u64::try_from(with_limbs(2)).unwrap(), u32::MAX as u64 + 2);
670 assert_eq!(u64::try_from(with_limbs(3)).unwrap_err(), "cannot fit a number into u64");
671 assert_eq!(u128::try_from(with_limbs(3)).unwrap(), u32::MAX as u128 + u64::MAX as u128 + 3);
672 }
673
674 #[test]
675 fn zero_works() {
676 assert_eq!(BigUint::zero(), BigUint { digits: vec![0] });
677 assert_eq!(BigUint { digits: vec![0, 1, 0] }.is_zero(), false);
678 assert_eq!(BigUint { digits: vec![0, 0, 0] }.is_zero(), true);
679
680 let a = BigUint::zero();
681 let b = BigUint::zero();
682 let c = a * b;
683 assert_eq!(c.digits, vec![0, 0]);
684 }
685
686 #[test]
687 fn sub_negative_works() {
688 assert_eq!(
689 BigUint::from(10 as Single).sub(&BigUint::from(5 as Single)).unwrap(),
690 BigUint::from(5 as Single)
691 );
692 assert_eq!(
693 BigUint::from(10 as Single).sub(&BigUint::from(10 as Single)).unwrap(),
694 BigUint::from(0 as Single)
695 );
696 assert_eq!(
697 BigUint::from(10 as Single).sub(&BigUint::from(13 as Single)).unwrap_err(),
698 BigUint::from((B - 3) as Single),
699 );
700 }
701
702 #[test]
703 fn mul_always_appends_one_digit() {
704 let a = BigUint::from(10 as Single);
705 let b = BigUint::from(4 as Single);
706 assert_eq!(a.len(), 1);
707 assert_eq!(b.len(), 1);
708
709 let n = a.mul(&b);
710
711 assert_eq!(n.len(), 2);
712 assert_eq!(n.digits, vec![0, 40]);
713 }
714
715 #[test]
716 fn div_conditions_work() {
717 let a = BigUint { digits: vec![2] };
718 let b = BigUint { digits: vec![1, 2] };
719 let c = BigUint { digits: vec![1, 1, 2] };
720 let d = BigUint { digits: vec![0, 2] };
721 let e = BigUint { digits: vec![0, 1, 1, 2] };
722 let f = BigUint { digits: vec![7, 8] };
723
724 assert!(a.clone().div(&b, true).is_none());
725 assert!(c.clone().div(&a, true).is_none());
726 assert!(c.clone().div(&d, true).is_none());
727 assert!(e.clone().div(&a, true).is_none());
728
729 assert!(f.clone().div(&b, true).is_none());
730 assert!(c.clone().div(&b, true).is_some());
731 }
732
733 #[test]
734 fn div_unit_works() {
735 let a = BigUint { digits: vec![100] };
736 let b = BigUint { digits: vec![1, 100] };
737 let c = BigUint { digits: vec![14, 28, 100] };
738
739 assert_eq!(a.clone().div_unit(1), a);
740 assert_eq!(a.clone().div_unit(0), a);
741 assert_eq!(a.clone().div_unit(2), BigUint::from(50 as Single));
742 assert_eq!(a.clone().div_unit(7), BigUint::from(14 as Single));
743
744 assert_eq!(b.clone().div_unit(1), b);
745 assert_eq!(b.clone().div_unit(0), b);
746 assert_eq!(b.clone().div_unit(2), BigUint::from(((B + 100) / 2) as Single));
747 assert_eq!(b.clone().div_unit(7), BigUint::from(((B + 100) / 7) as Single));
748
749 assert_eq!(c.clone().div_unit(1), c);
750 assert_eq!(c.clone().div_unit(0), c);
751 assert_eq!(c.clone().div_unit(2), BigUint { digits: vec![7, 14, 50] });
752 assert_eq!(c.clone().div_unit(7), BigUint { digits: vec![2, 4, 14] });
753 }
754}