1use crate::big_digit::{BigDigit, DoubleBigDigit, BITS};
2use crate::bigint::Sign::*;
3use crate::bigint::{BigInt, ToBigInt};
4use crate::biguint::{BigUint, IntDigits};
5use crate::integer::Integer;
6use alloc::borrow::Cow;
7use core::ops::Neg;
8use num_traits::{One, Signed, Zero};
9
10pub fn xgcd(
22 a_in: &BigInt,
23 b_in: &BigInt,
24 extended: bool,
25) -> (BigInt, Option<BigInt>, Option<BigInt>) {
26 if a_in.is_zero() && b_in.is_zero() {
28 if extended {
29 return (0.into(), Some(0.into()), Some(0.into()));
30 } else {
31 return (0.into(), None, None);
32 }
33 }
34
35 if a_in.is_zero() {
38 if extended {
39 let mut y = BigInt::one();
40 if b_in.sign == Minus {
41 y.sign = Minus;
42 }
43
44 return (b_in.abs(), Some(0.into()), Some(y));
45 } else {
46 return (b_in.abs(), None, None);
47 }
48 }
49
50 if b_in.is_zero() {
53 if extended {
54 let mut x = BigInt::one();
55 if a_in.sign == Minus {
56 x.sign = Minus;
57 }
58
59 return (a_in.abs(), Some(x), Some(0.into()));
60 } else {
61 return (a_in.abs(), None, None);
62 }
63 }
64 lehmer_gcd(a_in, b_in, extended)
65}
66
67fn lehmer_gcd(
78 a_in: &BigInt,
79 b_in: &BigInt,
80 extended: bool,
81) -> (BigInt, Option<BigInt>, Option<BigInt>) {
82 let mut a = a_in.clone();
83 let mut b = b_in.clone();
84
85 a.sign = Plus;
87 b.sign = Plus;
88
89 let mut ua = if extended { Some(1.into()) } else { None };
91 let mut ub = if extended { Some(0.into()) } else { None };
92
93 let mut q: BigInt = 0.into();
95 let mut r: BigInt = 0.into();
96 let mut s: BigInt = 0.into();
97 let mut t: BigInt = 0.into();
98
99 if a < b {
101 core::mem::swap(&mut a, &mut b);
102 core::mem::swap(&mut ua, &mut ub);
103 }
104
105 while b.len() > 1 {
107 let (u0, u1, v0, v1, even) = lehmer_simulate(&a, &b);
109
110 if v0 != 0 {
112 lehmer_update(
116 &mut a, &mut b, &mut q, &mut r, &mut s, &mut t, u0, u1, v0, v1, even,
117 );
118
119 if extended {
120 lehmer_update(
123 ua.as_mut().unwrap(),
124 ub.as_mut().unwrap(),
125 &mut q,
126 &mut r,
127 &mut s,
128 &mut t,
129 u0,
130 u1,
131 v0,
132 v1,
133 even,
134 );
135 }
136 } else {
137 euclid_udpate(
139 &mut a, &mut b, &mut ua, &mut ub, &mut q, &mut r, &mut s, &mut t, extended,
140 );
141 }
142 }
143
144 if b.len() > 0 {
145 if a.len() > 1 {
147 euclid_udpate(
149 &mut a, &mut b, &mut ua, &mut ub, &mut q, &mut r, &mut s, &mut t, extended,
150 );
151 }
152
153 if b.len() > 0 {
154 let mut a_word = a.digits()[0];
156 let mut b_word = b.digits()[0];
157
158 if extended {
159 let mut ua_word: BigDigit = 1;
160 let mut ub_word: BigDigit = 0;
161 let mut va: BigDigit = 0;
162 let mut vb: BigDigit = 1;
163 let mut even = true;
164
165 while b_word != 0 {
166 let q = a_word / b_word;
167 let r = a_word % b_word;
168 a_word = b_word;
169 b_word = r;
170
171 let k = ua_word.wrapping_add(q.wrapping_mul(ub_word));
172 ua_word = ub_word;
173 ub_word = k;
174
175 let k = va.wrapping_add(q.wrapping_mul(vb));
176 va = vb;
177 vb = k;
178 even = !even;
179 }
180
181 t.data.set_digit(ua_word);
182 s.data.set_digit(va);
183 t.sign = if even { Plus } else { Minus };
184 s.sign = if even { Minus } else { Plus };
185
186 if let Some(ua) = ua.as_mut() {
187 t *= &*ua;
188 s *= ub.unwrap();
189
190 *ua = &t + &s;
191 }
192 } else {
193 while b_word != 0 {
194 let quotient = a_word % b_word;
195 a_word = b_word;
196 b_word = quotient;
197 }
198 }
199 a.digits_mut()[0] = a_word;
200 }
201 }
202
203 a.normalize();
204
205 let mut neg_a: bool = false;
207 if a_in.sign == Minus {
208 neg_a = true;
209 }
210
211 let y = if let Some(ref mut ua) = ua {
212 let mut tmp = a_in * &*ua;
216
217 if neg_a {
218 tmp.sign = tmp.sign.neg();
219 ua.sign = ua.sign.neg();
220 }
221
222 tmp = &a - &tmp;
224 tmp = &tmp / b_in;
225
226 Some(tmp)
227 } else {
228 None
229 };
230
231 a.sign = Plus;
232
233 (a, ua, y)
234}
235
236pub fn extended_gcd(
240 a_in: Cow<BigUint>,
241 b_in: Cow<BigUint>,
242 extended: bool,
243) -> (BigInt, Option<BigInt>, Option<BigInt>) {
244 if a_in.is_zero() && b_in.is_zero() {
245 if extended {
246 return (b_in.to_bigint().unwrap(), Some(0.into()), Some(0.into()));
247 } else {
248 return (b_in.to_bigint().unwrap(), None, None);
249 }
250 }
251
252 if a_in.is_zero() {
253 if extended {
254 return (b_in.to_bigint().unwrap(), Some(0.into()), Some(1.into()));
255 } else {
256 return (b_in.to_bigint().unwrap(), None, None);
257 }
258 }
259
260 if b_in.is_zero() {
261 if extended {
262 return (a_in.to_bigint().unwrap(), Some(1.into()), Some(0.into()));
263 } else {
264 return (a_in.to_bigint().unwrap(), None, None);
265 }
266 }
267
268 let a_in = a_in.to_bigint().unwrap();
269 let b_in = b_in.to_bigint().unwrap();
270
271 let mut a = a_in.clone();
272 let mut b = b_in.clone();
273
274 let mut ua = if extended { Some(1.into()) } else { None };
276 let mut ub = if extended { Some(0.into()) } else { None };
277
278 if a < b {
280 core::mem::swap(&mut a, &mut b);
281 core::mem::swap(&mut ua, &mut ub);
282 }
283
284 let mut q: BigInt = 0.into();
285 let mut r: BigInt = 0.into();
286 let mut s: BigInt = 0.into();
287 let mut t: BigInt = 0.into();
288
289 while b.len() > 1 {
290 let (u0, u1, v0, v1, even) = lehmer_simulate(&a, &b);
292
293 if v0 != 0 {
295 lehmer_update(
299 &mut a, &mut b, &mut q, &mut r, &mut s, &mut t, u0, u1, v0, v1, even,
300 );
301
302 if extended {
303 lehmer_update(
306 ua.as_mut().unwrap(),
307 ub.as_mut().unwrap(),
308 &mut q,
309 &mut r,
310 &mut s,
311 &mut t,
312 u0,
313 u1,
314 v0,
315 v1,
316 even,
317 );
318 }
319 } else {
320 euclid_udpate(
322 &mut a, &mut b, &mut ua, &mut ub, &mut q, &mut r, &mut s, &mut t, extended,
323 );
324 }
325 }
326
327 if b.len() > 0 {
328 if a.len() > 1 {
330 euclid_udpate(
332 &mut a, &mut b, &mut ua, &mut ub, &mut q, &mut r, &mut s, &mut t, extended,
333 );
334 }
335
336 if b.len() > 0 {
337 let mut a_word = a.digits()[0];
339 let mut b_word = b.digits()[0];
340
341 if extended {
342 let mut ua_word: BigDigit = 1;
343 let mut ub_word: BigDigit = 0;
344 let mut va: BigDigit = 0;
345 let mut vb: BigDigit = 1;
346 let mut even = true;
347
348 while b_word != 0 {
349 let q = a_word / b_word;
350 let r = a_word % b_word;
351 a_word = b_word;
352 b_word = r;
353
354 let k = ua_word.wrapping_add(q.wrapping_mul(ub_word));
355 ua_word = ub_word;
356 ub_word = k;
357
358 let k = va.wrapping_add(q.wrapping_mul(vb));
359 va = vb;
360 vb = k;
361 even = !even;
362 }
363
364 t.data.set_digit(ua_word);
365 s.data.set_digit(va);
366 t.sign = if even { Plus } else { Minus };
367 s.sign = if even { Minus } else { Plus };
368
369 if let Some(ua) = ua.as_mut() {
370 t *= &*ua;
371 s *= ub.unwrap();
372
373 *ua = &t + &s;
374 }
375 } else {
376 while b_word != 0 {
377 let quotient = a_word % b_word;
378 a_word = b_word;
379 b_word = quotient;
380 }
381 }
382 a.digits_mut()[0] = a_word;
383 }
384 }
385
386 a.normalize();
387
388 let y = if let Some(ref ua) = ua {
389 Some((&a - (&a_in * ua)) / &b_in)
391 } else {
392 None
393 };
394
395 (a, ua, y)
396}
397
398#[inline]
409fn lehmer_simulate(a: &BigInt, b: &BigInt) -> (BigDigit, BigDigit, BigDigit, BigDigit, bool) {
410 let m = b.len();
412 let n = a.len();
414
415 let h = a.digits()[n - 1].leading_zeros();
423
424 let mut a1: BigDigit = a.digits()[n - 1] << h
425 | ((a.digits()[n - 2] as DoubleBigDigit) >> (BITS as u32 - h)) as BigDigit;
426
427 let mut a2: BigDigit = if n == m {
429 b.digits()[n - 1] << h
430 | ((b.digits()[n - 2] as DoubleBigDigit) >> (BITS as u32 - h)) as BigDigit
431 } else if n == m + 1 {
432 ((b.digits()[n - 2] as DoubleBigDigit) >> (BITS as u32 - h)) as BigDigit
433 } else {
434 0
435 };
436
437 let mut even = false;
439
440 let mut u0 = 0;
441 let mut u1 = 1;
442 let mut u2 = 0;
443
444 let mut v0 = 0;
445 let mut v1 = 0;
446 let mut v2 = 1;
447
448 while a2 >= v2 && a1.wrapping_sub(a2) >= v1 + v2 {
450 let q = a1 / a2;
451 let r = a1 % a2;
452
453 a1 = a2;
454 a2 = r;
455
456 let k = u1 + q * u2;
457 u0 = u1;
458 u1 = u2;
459 u2 = k;
460
461 let k = v1 + q * v2;
462 v0 = v1;
463 v1 = v2;
464 v2 = k;
465
466 even = !even;
467 }
468
469 (u0, u1, v0, v1, even)
470}
471
472fn lehmer_update(
473 a: &mut BigInt,
474 b: &mut BigInt,
475 q: &mut BigInt,
476 r: &mut BigInt,
477 s: &mut BigInt,
478 t: &mut BigInt,
479 u0: BigDigit,
480 u1: BigDigit,
481 v0: BigDigit,
482 v1: BigDigit,
483 even: bool,
484) {
485 t.data.set_digit(u0);
486 s.data.set_digit(v0);
487 if even {
488 t.sign = Plus;
489 s.sign = Minus
490 } else {
491 t.sign = Minus;
492 s.sign = Plus;
493 }
494
495 *t *= &*a;
496 *s *= &*b;
497
498 r.data.set_digit(u1);
499 q.data.set_digit(v1);
500 if even {
501 q.sign = Plus;
502 r.sign = Minus
503 } else {
504 q.sign = Minus;
505 r.sign = Plus;
506 }
507
508 *r *= &*a;
509 *q *= &*b;
510
511 *a = t + s;
512 *b = r + q;
513}
514
515fn euclid_udpate(
516 a: &mut BigInt,
517 b: &mut BigInt,
518 ua: &mut Option<BigInt>,
519 ub: &mut Option<BigInt>,
520 q: &mut BigInt,
521 r: &mut BigInt,
522 s: &mut BigInt,
523 t: &mut BigInt,
524 extended: bool,
525) {
526 let (q_new, r_new) = a.div_rem(b);
527 *q = q_new;
528 *r = r_new;
529
530 core::mem::swap(a, b);
531 core::mem::swap(b, r);
532
533 if extended {
534 if let Some(ub) = ub.as_mut() {
536 if let Some(ua) = ua.as_mut() {
537 *t = ub.clone();
538 *s = &*ub * &*q;
539 *ub = &*ua - &*s;
540 *ua = t.clone();
541 }
542 }
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use core::str::FromStr;
550
551 use num_traits::FromPrimitive;
552
553 #[cfg(feature = "rand")]
554 use crate::bigrand::RandBigInt;
555 #[cfg(feature = "rand")]
556 use num_traits::{One, Zero};
557 #[cfg(feature = "rand")]
558 use rand::SeedableRng;
559 #[cfg(feature = "rand")]
560 use rand_xorshift::XorShiftRng;
561
562 #[cfg(feature = "rand")]
563 fn extended_gcd_euclid(a: Cow<BigUint>, b: Cow<BigUint>) -> (BigInt, BigInt, BigInt) {
564 if a.is_zero() && b.is_zero() {
567 return (0.into(), 0.into(), 0.into());
568 }
569
570 let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
571 let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
572 let (mut r, mut old_r) = (b.to_bigint().unwrap(), a.to_bigint().unwrap());
573
574 while !r.is_zero() {
575 let quotient = &old_r / &r;
576 old_r = old_r - "ient * &r;
577 core::mem::swap(&mut old_r, &mut r);
578 old_s = old_s - "ient * &s;
579 core::mem::swap(&mut old_s, &mut s);
580 old_t = old_t - quotient * &t;
581 core::mem::swap(&mut old_t, &mut t);
582 }
583
584 (old_r, old_s, old_t)
585 }
586
587 #[test]
588 #[cfg(feature = "rand")]
589 fn test_extended_gcd_assumptions() {
590 let mut rng = XorShiftRng::from_seed([1u8; 16]);
591
592 for i in 1usize..100 {
593 for j in &[1usize, 64, 128] {
594 let a = rng.gen_biguint(i * j);
596 let b = rng.gen_biguint(i * j);
597 let (q, s_k, t_k) = extended_gcd(Cow::Borrowed(&a), Cow::Borrowed(&b), true);
598
599 let lhs = BigInt::from_biguint(Plus, a) * &s_k.unwrap();
600 let rhs = BigInt::from_biguint(Plus, b) * &t_k.unwrap();
601
602 assert_eq!(q.clone(), &lhs + &rhs, "{} = {} + {}", q, lhs, rhs);
603 }
604 }
605 }
606
607 #[test]
608 fn test_extended_gcd_example() {
609 let a = BigUint::from_u32(240).unwrap();
611 let b = BigUint::from_u32(46).unwrap();
612 let (q, s_k, t_k) = extended_gcd(Cow::Owned(a), Cow::Owned(b), true);
613
614 assert_eq!(q, BigInt::from_i32(2).unwrap());
615 assert_eq!(s_k.unwrap(), BigInt::from_i32(-9).unwrap());
616 assert_eq!(t_k.unwrap(), BigInt::from_i32(47).unwrap());
617 }
618
619 #[test]
620 fn test_extended_gcd_example_not_extended() {
621 let a = BigUint::from_u32(240).unwrap();
623 let b = BigUint::from_u32(46).unwrap();
624 let (q, s_k, t_k) = extended_gcd(Cow::Owned(a), Cow::Owned(b), false);
625
626 assert_eq!(q, BigInt::from_i32(2).unwrap());
627 assert_eq!(s_k, None);
628 assert_eq!(t_k, None);
629 }
630
631 #[test]
632 fn test_extended_gcd_example_wolfram() {
633 let a = BigInt::from_str("-565721958").unwrap();
637 let b = BigInt::from_str("4486780496").unwrap();
638
639 let (q, _s_k, _t_k) = xgcd(&a, &b, true);
640
641 assert_eq!(q, BigInt::from(2));
642 assert_eq!(_s_k, Some(BigInt::from(-1090996795)));
643 assert_eq!(_t_k, Some(BigInt::from(-137559848)));
644 }
645
646 #[test]
647 fn test_golang_bignum_negative() {
648 let gcd_test_cases = [
651 ["0", "0", "0", "0", "0"],
652 ["7", "0", "1", "0", "7"],
653 ["7", "0", "-1", "0", "-7"],
654 ["11", "1", "0", "11", "0"],
655 ["7", "-1", "-2", "-77", "35"],
656 ["935", "-3", "8", "64515", "24310"],
657 ["935", "-3", "-8", "64515", "-24310"],
658 ["935", "3", "-8", "-64515", "-24310"],
659 ["1", "-9", "47", "120", "23"],
660 ["7", "1", "-2", "77", "35"],
661 ["935", "-3", "8", "64515", "24310"],
662 [
663 "935000000000000000",
664 "-3",
665 "8",
666 "64515000000000000000",
667 "24310000000000000000",
668 ],
669 [
670 "1",
671 "-221",
672 "22059940471369027483332068679400581064239780177629666810348940098015901108344",
673 "98920366548084643601728869055592650835572950932266967461790948584315647051443",
674 "991",
675 ],
676 ];
677
678 for t in 0..gcd_test_cases.len() {
679 let d_case = BigInt::from_str(gcd_test_cases[t][0]).unwrap();
681 let x_case = BigInt::from_str(gcd_test_cases[t][1]).unwrap();
682 let y_case = BigInt::from_str(gcd_test_cases[t][2]).unwrap();
683 let a_case = BigInt::from_str(gcd_test_cases[t][3]).unwrap();
684 let b_case = BigInt::from_str(gcd_test_cases[t][4]).unwrap();
685
686 let (_d, _x, _y) = xgcd(&a_case, &b_case, false);
695
696 assert_eq!(_d, d_case);
697 assert_eq!(_x, None);
698 assert_eq!(_y, None);
699
700 let (_d, _x, _y) = xgcd(&a_case, &b_case, true);
701
702 assert_eq!(_d, d_case);
703 assert_eq!(_x.unwrap(), x_case);
704 assert_eq!(_y.unwrap(), y_case);
705 }
706 }
707
708 #[test]
709 #[cfg(feature = "rand")]
710 fn test_gcd_lehmer_euclid_extended() {
711 let mut rng = XorShiftRng::from_seed([1u8; 16]);
712
713 for i in 1usize..80 {
714 for j in &[1usize, 16, 24, 64, 128] {
715 let a = rng.gen_biguint(i * j);
717 let b = rng.gen_biguint(i * j);
718 let (q, s_k, t_k) = extended_gcd(Cow::Borrowed(&a), Cow::Borrowed(&b), true);
719
720 let expected = extended_gcd_euclid(Cow::Borrowed(&a), Cow::Borrowed(&b));
721 assert_eq!(q, expected.0);
722 assert_eq!(s_k.unwrap(), expected.1);
723 assert_eq!(t_k.unwrap(), expected.2);
724 }
725 }
726 }
727
728 #[test]
729 #[cfg(feature = "rand")]
730 fn test_gcd_lehmer_euclid_not_extended() {
731 let mut rng = XorShiftRng::from_seed([1u8; 16]);
732
733 for i in 1usize..80 {
734 for j in &[1usize, 16, 24, 64, 128] {
735 let a = rng.gen_biguint(i * j);
737 let b = rng.gen_biguint(i * j);
738 let (q, s_k, t_k) = extended_gcd(Cow::Borrowed(&a), Cow::Borrowed(&b), false);
739
740 let expected = extended_gcd_euclid(Cow::Borrowed(&a), Cow::Borrowed(&b));
741 assert_eq!(
742 q, expected.0,
743 "gcd({}, {}) = {} != {}",
744 &a, &b, &q, expected.0
745 );
746 assert_eq!(s_k, None);
747 assert_eq!(t_k, None);
748 }
749 }
750 }
751}