malachite_base/num/arithmetic/
mod_inverse.rs

1// Copyright © 2025 Mikhail Hogrefe
2//
3// Uses code adopted from the FLINT Library.
4//
5//      Copyright © 2009, 2016 William Hart
6//
7// This file is part of Malachite.
8//
9// Malachite is free software: you can redistribute it and/or modify it under the terms of the GNU
10// Lesser General Public License (LGPL) as published by the Free Software Foundation; either version
11// 3 of the License, or (at your option) any later version. See <https://www.gnu.org/licenses/>.
12
13use crate::num::arithmetic::traits::ModInverse;
14use crate::num::basic::signeds::PrimitiveSigned;
15use crate::num::basic::unsigneds::PrimitiveUnsigned;
16use crate::num::conversion::traits::WrappingFrom;
17use crate::rounding_modes::RoundingMode::*;
18
19// This is a variation of `n_xgcd` from `ulong_extras/xgcd.c`, FLINT 2.7.1.
20pub_test! {mod_inverse_binary<
21    U: WrappingFrom<S> + PrimitiveUnsigned,
22    S: PrimitiveSigned + WrappingFrom<U>,
23>(
24    x: U,
25    m: U,
26) -> Option<U> {
27    assert_ne!(x, U::ZERO);
28    assert!(x < m, "x must be reduced mod m, but {x} >= {m}");
29    let mut u1 = S::ONE;
30    let mut v2 = S::ONE;
31    let mut u2 = S::ZERO;
32    let mut v1 = S::ZERO;
33    let mut u3 = m;
34    let mut v3 = x;
35    let mut d;
36    let mut t2;
37    let mut t1;
38    if (m & x).get_highest_bit() {
39        d = u3 - v3;
40        t2 = v2;
41        t1 = u2;
42        u2 = u1 - u2;
43        u1 = t1;
44        u3 = v3;
45        v2 = v1 - v2;
46        v1 = t2;
47        v3 = d;
48    }
49    while v3.get_bit(U::WIDTH - 2) {
50        d = u3 - v3;
51        if d < v3 {
52            // quot = 1
53            t2 = v2;
54            t1 = u2;
55            u2 = u1 - u2;
56            u1 = t1;
57            u3 = v3;
58            v2 = v1 - v2;
59            v1 = t2;
60            v3 = d;
61        } else if d < (v3 << 1) {
62            // quot = 2
63            t1 = u2;
64            u2 = u1 - (u2 << 1);
65            u1 = t1;
66            u3 = v3;
67            t2 = v2;
68            v2 = v1 - (v2 << 1);
69            v1 = t2;
70            v3 = d - u3;
71        } else {
72            // quot = 3
73            t1 = u2;
74            u2 = u1 - S::wrapping_from(3) * u2;
75            u1 = t1;
76            u3 = v3;
77            t2 = v2;
78            v2 = v1 - S::wrapping_from(3) * v2;
79            v1 = t2;
80            v3 = d - (u3 << 1);
81        }
82    }
83    while v3 != U::ZERO {
84        d = u3 - v3;
85        // overflow not possible, top 2 bits of v3 not set
86        if u3 < (v3 << 2) {
87            if d < v3 {
88                // quot = 1
89                t2 = v2;
90                t1 = u2;
91                u2 = u1 - u2;
92                u1 = t1;
93                u3 = v3;
94                v2 = v1 - v2;
95                v1 = t2;
96                v3 = d;
97            } else if d < (v3 << 1) {
98                // quot = 2
99                t1 = u2;
100                u2 = u1.wrapping_sub(u2 << 1);
101                u1 = t1;
102                u3 = v3;
103                t2 = v2;
104                v2 = v1.wrapping_sub(v2 << 1);
105                v1 = t2;
106                v3 = d - u3;
107            } else {
108                // quot = 3
109                t1 = u2;
110                u2 = u1.wrapping_sub(S::wrapping_from(3).wrapping_mul(u2));
111                u1 = t1;
112                u3 = v3;
113                t2 = v2;
114                v2 = v1.wrapping_sub(S::wrapping_from(3).wrapping_mul(v2));
115                v1 = t2;
116                v3 = d.wrapping_sub(u3 << 1);
117            }
118        } else {
119            let (quot, rem) = u3.div_rem(v3);
120            t1 = u2;
121            u2 = u1.wrapping_sub(S::wrapping_from(quot).wrapping_mul(u2));
122            u1 = t1;
123            u3 = v3;
124            t2 = v2;
125            v2 = v1.wrapping_sub(S::wrapping_from(quot).wrapping_mul(v2));
126            v1 = t2;
127            v3 = rem;
128        }
129    }
130    if u3 != U::ONE {
131        return None;
132    }
133    let mut inverse = U::wrapping_from(v1);
134    if u1 <= S::ZERO {
135        inverse.wrapping_sub_assign(m);
136    }
137    let limit = (m >> 1u32).wrapping_neg();
138    if inverse < limit {
139        let k = (limit - inverse).div_round(m, Ceiling).0;
140        inverse.wrapping_add_assign(m.wrapping_mul(k));
141    }
142    Some(if inverse.get_highest_bit() {
143        inverse.wrapping_add(m)
144    } else {
145        inverse
146    })
147}}
148
149macro_rules! impl_mod_inverse {
150    ($u:ident, $s:ident) => {
151        impl ModInverse<$u> for $u {
152            type Output = $u;
153
154            /// Computes the multiplicative inverse of a number modulo another number $m$. The input
155            /// must be already reduced modulo $m$.
156            ///
157            /// Returns `None` if $x$ and $m$ are not coprime.
158            ///
159            /// $f(x, m) = y$, where $x, y < m$, $\gcd(x, y) = 1$, and $xy \equiv 1 \mod m$.
160            ///
161            /// # Worst-case complexity
162            /// $T(n) = O(n^2)$
163            ///
164            /// $M(n) = O(n)$
165            ///
166            /// where $T$ is time, $M$ is additional memory, and $n$ is
167            /// `max(self.significant_bits(), m.significant_bits())`.
168            ///
169            /// # Panics
170            /// Panics if `self` is greater than or equal to `m`.
171            ///
172            /// # Examples
173            /// See [here](super::mod_inverse#mod_inverse).
174            #[inline]
175            fn mod_inverse(self, m: $u) -> Option<$u> {
176                mod_inverse_binary::<$u, $s>(self, m)
177            }
178        }
179    };
180}
181apply_to_unsigned_signed_pairs!(impl_mod_inverse);