snarkvm_algorithms_cuda/
lib.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[allow(unused_imports)]
17use blst::*;
18
19use core::ffi::c_void;
20sppark::cuda_error!();
21
22#[repr(C)]
23pub enum NTTInputOutputOrder {
24    NN = 0,
25    NR = 1,
26    RN = 2,
27    RR = 3,
28}
29
30#[repr(C)]
31pub enum NTTDirection {
32    Forward = 0,
33    Inverse = 1,
34}
35
36#[repr(C)]
37pub enum NTTType {
38    Standard = 0,
39    Coset = 1,
40}
41
42extern "C" {
43    fn snarkvm_ntt(
44        inout: *mut core::ffi::c_void,
45        lg_domain_size: u32,
46        ntt_order: NTTInputOutputOrder,
47        ntt_direction: NTTDirection,
48        ntt_type: NTTType,
49    ) -> cuda::Error;
50
51    fn snarkvm_polymul(
52        out: *mut core::ffi::c_void,
53        pcount: usize,
54        polynomials: *const core::ffi::c_void,
55        plens: *const core::ffi::c_void,
56        ecount: usize,
57        evaluations: *const core::ffi::c_void,
58        elens: *const core::ffi::c_void,
59        lg_domain_size: u32,
60    ) -> cuda::Error;
61
62    fn snarkvm_msm(
63        out: *mut c_void,
64        points_with_infinity: *const c_void,
65        npoints: usize,
66        scalars: *const c_void,
67        ffi_affine_sz: usize,
68    ) -> cuda::Error;
69}
70
71///////////////////////////////////////////////////////////////////////////////
72// Rust functions
73///////////////////////////////////////////////////////////////////////////////
74
75/// Compute an in-place NTT on the input data.
76#[allow(non_snake_case)]
77pub fn NTT<T>(
78    domain_size: usize,
79    inout: &mut [T],
80    ntt_order: NTTInputOutputOrder,
81    ntt_direction: NTTDirection,
82    ntt_type: NTTType,
83) -> Result<(), cuda::Error> {
84    if (domain_size & (domain_size - 1)) != 0 {
85        panic!("domain_size is not power of 2");
86    }
87    let lg_domain_size = domain_size.trailing_zeros();
88
89    let err = unsafe {
90        snarkvm_ntt(inout.as_mut_ptr() as *mut core::ffi::c_void, lg_domain_size, ntt_order, ntt_direction, ntt_type)
91    };
92
93    if err.code != 0 {
94        return Err(err);
95    }
96    Ok(())
97}
98
99/// Compute a polynomial multiply
100pub fn polymul<T: std::clone::Clone>(
101    domain: usize,
102    polynomials: &Vec<Vec<T>>,
103    evaluations: &Vec<Vec<T>>,
104    zero: &T,
105) -> Result<Vec<T>, cuda::Error> {
106    let initial_domain_size = domain;
107    if (initial_domain_size & (initial_domain_size - 1)) != 0 {
108        panic!("domain_size is not power of 2");
109    }
110
111    let lg_domain_size = initial_domain_size.trailing_zeros();
112
113    let mut pptrs = Vec::new();
114    let mut plens = Vec::new();
115    for polynomial in polynomials {
116        pptrs.push(polynomial.as_ptr() as *const core::ffi::c_void);
117        plens.push(polynomial.len());
118    }
119    let mut eptrs = Vec::new();
120    let mut elens = Vec::new();
121    for evaluation in evaluations {
122        eptrs.push(evaluation.as_ptr() as *const core::ffi::c_void);
123        elens.push(evaluation.len());
124    }
125
126    let mut out = Vec::new();
127    out.resize(initial_domain_size, zero.clone());
128    let err = unsafe {
129        snarkvm_polymul(
130            out.as_mut_ptr() as *mut core::ffi::c_void,
131            pptrs.len(),
132            pptrs.as_ptr() as *const core::ffi::c_void,
133            plens.as_ptr() as *const core::ffi::c_void,
134            eptrs.len(),
135            eptrs.as_ptr() as *const core::ffi::c_void,
136            elens.as_ptr() as *const core::ffi::c_void,
137            lg_domain_size,
138        )
139    };
140
141    if err.code != 0 {
142        return Err(err);
143    }
144    Ok(out)
145}
146
147/// Compute a multi-scalar multiplication
148pub fn msm<Affine, Projective, Scalar>(points: &[Affine], scalars: &[Scalar]) -> Result<Projective, cuda::Error> {
149    let npoints = scalars.len();
150    if npoints > points.len() {
151        panic!("length mismatch {} points < {} scalars", npoints, scalars.len())
152    }
153    #[allow(clippy::uninit_assumed_init)]
154    let mut ret: Projective = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
155    let err = unsafe {
156        snarkvm_msm(
157            &mut ret as *mut _ as *mut c_void,
158            points as *const _ as *const c_void,
159            npoints,
160            scalars as *const _ as *const c_void,
161            std::mem::size_of::<Affine>(),
162        )
163    };
164    if err.code != 0 {
165        return Err(err);
166    }
167    Ok(ret)
168}