snarkvm_algorithms/msm/variable_base/
mod.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16pub mod batched;
17pub mod standard;
18
19#[cfg(target_arch = "x86_64")]
20pub mod prefetch;
21
22use snarkvm_curves::{bls12_377::G1Affine, traits::AffineCurve};
23use snarkvm_fields::PrimeField;
24
25use core::any::TypeId;
26
27pub struct VariableBase;
28
29impl VariableBase {
30    pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
31        // For BLS12-377, we perform variable base MSM using a batched addition technique.
32        if TypeId::of::<G>() == TypeId::of::<G1Affine>() {
33            #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
34            // TODO SNP: where to set the threshold
35            if scalars.len() > 1024 {
36                let result = snarkvm_algorithms_cuda::msm::<G, G::Projective, <G::ScalarField as PrimeField>::BigInteger>(
37                    bases, scalars,
38                );
39                if let Ok(result) = result {
40                    return result;
41                }
42            }
43            batched::msm(bases, scalars)
44        }
45        // For all other curves, we perform variable base MSM using Pippenger's algorithm.
46        else {
47            standard::msm(bases, scalars)
48        }
49    }
50
51    #[cfg(test)]
52    fn msm_naive<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
53        use itertools::Itertools;
54        use snarkvm_utilities::BitIteratorBE;
55
56        bases.iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
57    }
58
59    #[cfg(test)]
60    fn msm_naive_parallel<G: AffineCurve>(
61        bases: &[G],
62        scalars: &[<G::ScalarField as PrimeField>::BigInteger],
63    ) -> G::Projective {
64        use rayon::prelude::*;
65        use snarkvm_utilities::BitIteratorBE;
66
67        bases.par_iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use snarkvm_curves::bls12_377::{Fr, G1Affine};
75    use snarkvm_fields::PrimeField;
76    use snarkvm_utilities::rand::TestRng;
77
78    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
79    use snarkvm_curves::ProjectiveCurve;
80
81    fn create_scalar_bases<G: AffineCurve<ScalarField = F>, F: PrimeField>(
82        rng: &mut TestRng,
83        size: usize,
84    ) -> (Vec<G>, Vec<F::BigInteger>) {
85        let bases = (0..size).map(|_| G::rand(rng)).collect::<Vec<_>>();
86        let scalars = (0..size).map(|_| F::rand(rng).to_bigint()).collect::<Vec<_>>();
87        (bases, scalars)
88    }
89
90    #[test]
91    fn test_msm() {
92        use snarkvm_curves::ProjectiveCurve;
93        for msm_size in [1, 5, 10, 50, 100, 500, 1000] {
94            let mut rng = TestRng::default();
95            let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, msm_size);
96
97            let naive_a = VariableBase::msm_naive(bases.as_slice(), scalars.as_slice()).to_affine();
98            let naive_b = VariableBase::msm_naive_parallel(bases.as_slice(), scalars.as_slice()).to_affine();
99            assert_eq!(naive_a, naive_b, "MSM size: {msm_size}");
100
101            let candidate = standard::msm(bases.as_slice(), scalars.as_slice()).to_affine();
102            assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
103
104            let candidate = batched::msm(bases.as_slice(), scalars.as_slice()).to_affine();
105            assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
106        }
107    }
108
109    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
110    #[test]
111    fn test_msm_cuda() {
112        let mut rng = TestRng::default();
113        for i in 2..17 {
114            let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, 1 << i);
115            let rust = standard::msm(bases.as_slice(), scalars.as_slice());
116            let cuda = VariableBase::msm::<G1Affine>(bases.as_slice(), scalars.as_slice());
117            assert_eq!(rust.to_affine(), cuda.to_affine());
118        }
119    }
120}