solana_program/
big_mod_exp.rs

1#[repr(C)]
2pub struct BigModExpParams {
3    pub base: *const u8,
4    pub base_len: u64,
5    pub exponent: *const u8,
6    pub exponent_len: u64,
7    pub modulus: *const u8,
8    pub modulus_len: u64,
9}
10
11/// Big integer modular exponentiation
12pub fn big_mod_exp(base: &[u8], exponent: &[u8], modulus: &[u8]) -> Vec<u8> {
13    #[cfg(not(target_os = "solana"))]
14    {
15        use {
16            num_bigint::BigUint,
17            num_traits::{One, Zero},
18        };
19
20        let modulus_len = modulus.len();
21        let base = BigUint::from_bytes_be(base);
22        let exponent = BigUint::from_bytes_be(exponent);
23        let modulus = BigUint::from_bytes_be(modulus);
24
25        if modulus.is_zero() || modulus.is_one() {
26            return vec![0_u8; modulus_len];
27        }
28
29        let ret_int = base.modpow(&exponent, &modulus);
30        let ret_int = ret_int.to_bytes_be();
31        let mut return_value = vec![0_u8; modulus_len.saturating_sub(ret_int.len())];
32        return_value.extend(ret_int);
33        return_value
34    }
35
36    #[cfg(target_os = "solana")]
37    {
38        let mut return_value = vec![0_u8; modulus.len()];
39
40        let param = BigModExpParams {
41            base: base as *const _ as *const u8,
42            base_len: base.len() as u64,
43            exponent: exponent as *const _ as *const u8,
44            exponent_len: exponent.len() as u64,
45            modulus: modulus as *const _ as *const u8,
46            modulus_len: modulus.len() as u64,
47        };
48        unsafe {
49            crate::syscalls::sol_big_mod_exp(
50                &param as *const _ as *const u8,
51                return_value.as_mut_slice() as *mut _ as *mut u8,
52            )
53        };
54
55        return_value
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[test]
64    fn big_mod_exp_test() {
65        #[derive(serde_derive::Deserialize)]
66        #[serde(rename_all = "PascalCase")]
67        struct TestCase {
68            base: String,
69            exponent: String,
70            modulus: String,
71            expected: String,
72        }
73
74        let test_data = r#"[
75        {
76            "Base":     "1111111111111111111111111111111111111111111111111111111111111111",
77            "Exponent": "1111111111111111111111111111111111111111111111111111111111111111",
78            "Modulus":  "111111111111111111111111111111111111111111111111111111111111110A",
79            "Expected": "0A7074864588D6847F33A168209E516F60005A0CEC3F33AAF70E8002FE964BCD"
80        },
81        {
82            "Base":     "2222222222222222222222222222222222222222222222222222222222222222",
83            "Exponent": "2222222222222222222222222222222222222222222222222222222222222222",
84            "Modulus":  "1111111111111111111111111111111111111111111111111111111111111111",
85            "Expected": "0000000000000000000000000000000000000000000000000000000000000000"
86        },
87        {
88            "Base":     "3333333333333333333333333333333333333333333333333333333333333333",
89            "Exponent": "3333333333333333333333333333333333333333333333333333333333333333",
90            "Modulus":  "2222222222222222222222222222222222222222222222222222222222222222",
91            "Expected": "1111111111111111111111111111111111111111111111111111111111111111"
92        },
93        {
94            "Base":     "9874231472317432847923174392874918237439287492374932871937289719",
95            "Exponent": "0948403985401232889438579475812347232099080051356165126166266222",
96            "Modulus":  "25532321a214321423124212222224222b242222222222222222222222222444",
97            "Expected": "220ECE1C42624E98AEE7EB86578B2FE5C4855DFFACCB43CCBB708A3AB37F184D"
98        },
99        {
100            "Base":     "3494396663463663636363662632666565656456646566786786676786768766",
101            "Exponent": "2324324333246536456354655645656616169896565698987033121934984955",
102            "Modulus":  "0218305479243590485092843590249879879842313131156656565565656566",
103            "Expected": "012F2865E8B9E79B645FCE3A9E04156483AE1F9833F6BFCF86FCA38FC2D5BEF0"
104        },
105        {
106            "Base":     "0000000000000000000000000000000000000000000000000000000000000005",
107            "Exponent": "0000000000000000000000000000000000000000000000000000000000000002",
108            "Modulus":  "0000000000000000000000000000000000000000000000000000000000000007",
109            "Expected": "0000000000000000000000000000000000000000000000000000000000000004"
110        },
111        {
112            "Base":     "0000000000000000000000000000000000000000000000000000000000000019",
113            "Exponent": "0000000000000000000000000000000000000000000000000000000000000019",
114            "Modulus":  "0000000000000000000000000000000000000000000000000000000000000064",
115            "Expected": "0000000000000000000000000000000000000000000000000000000000000019"
116        },
117        {
118            "Base":     "0000000000000000000000000000000000000000000000000000000000000019",
119            "Exponent": "0000000000000000000000000000000000000000000000000000000000000019",
120            "Modulus":  "0000000000000000000000000000000000000000000000000000000000000000",
121            "Expected": "0000000000000000000000000000000000000000000000000000000000000000"
122        },
123        {
124            "Base":     "0000000000000000000000000000000000000000000000000000000000000019",
125            "Exponent": "0000000000000000000000000000000000000000000000000000000000000019",
126            "Modulus":  "0000000000000000000000000000000000000000000000000000000000000001",
127            "Expected": "0000000000000000000000000000000000000000000000000000000000000000"
128        }
129        ]"#;
130
131        let test_cases: Vec<TestCase> = serde_json::from_str(test_data).unwrap();
132        test_cases.iter().for_each(|test| {
133            let base = array_bytes::hex2bytes_unchecked(&test.base);
134            let exponent = array_bytes::hex2bytes_unchecked(&test.exponent);
135            let modulus = array_bytes::hex2bytes_unchecked(&test.modulus);
136            let expected = array_bytes::hex2bytes_unchecked(&test.expected);
137            let result = big_mod_exp(base.as_slice(), exponent.as_slice(), modulus.as_slice());
138            assert_eq!(result, expected);
139        });
140    }
141}