snarkvm_algorithms/fft/polynomial/
multiplier.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{borrow::Borrow, collections::BTreeMap};
17
18use crate::fft::domain::{FFTPrecomputation, IFFTPrecomputation};
19
20/// A struct that helps multiply a batch of polynomials
21use super::*;
22use snarkvm_utilities::{ExecutionPool, cfg_into_iter, cfg_iter, cfg_iter_mut, cfg_reduce_with};
23
24#[derive(Default)]
25pub struct PolyMultiplier<'a, F: PrimeField> {
26    polynomials: Vec<(String, Cow<'a, DensePolynomial<F>>)>,
27    evaluations: Vec<(String, Cow<'a, crate::fft::Evaluations<F>>)>,
28    fft_precomputation: Option<Cow<'a, FFTPrecomputation<F>>>,
29    ifft_precomputation: Option<Cow<'a, IFFTPrecomputation<F>>>,
30}
31
32impl<'a, F: PrimeField> PolyMultiplier<'a, F> {
33    #[inline]
34    pub fn new() -> Self {
35        Self { polynomials: Vec::new(), evaluations: Vec::new(), fft_precomputation: None, ifft_precomputation: None }
36    }
37
38    #[inline]
39    pub fn add_precomputation(&mut self, fft_pc: &'a FFTPrecomputation<F>, ifft_pc: &'a IFFTPrecomputation<F>) {
40        self.fft_precomputation = Some(Cow::Borrowed(fft_pc));
41        self.ifft_precomputation = Some(Cow::Borrowed(ifft_pc));
42    }
43
44    #[inline]
45    pub fn add_polynomial(&mut self, poly: DensePolynomial<F>, label: impl ToString) {
46        self.polynomials.push((label.to_string(), Cow::Owned(poly)))
47    }
48
49    #[inline]
50    pub fn add_evaluation(&mut self, evals: Evaluations<F>, label: impl ToString) {
51        self.evaluations.push((label.to_string(), Cow::Owned(evals)))
52    }
53
54    #[inline]
55    pub fn add_polynomial_ref(&mut self, poly: &'a DensePolynomial<F>, label: impl ToString) {
56        self.polynomials.push((label.to_string(), Cow::Borrowed(poly)))
57    }
58
59    #[inline]
60    pub fn add_evaluation_ref(&mut self, evals: &'a Evaluations<F>, label: impl ToString) {
61        self.evaluations.push((label.to_string(), Cow::Borrowed(evals)))
62    }
63
64    /// Multiplies all polynomials stored in `self`.
65    ///
66    /// Returns `None` if any of the stored evaluations are over a domain that's
67    /// insufficiently large to interpolate the product, or if `F` does not contain
68    /// a sufficiently large subgroup for interpolation.
69    #[allow(unused_mut)]
70    pub fn multiply(mut self) -> Option<DensePolynomial<F>> {
71        if self.polynomials.is_empty() && self.evaluations.is_empty() {
72            Some(DensePolynomial::zero())
73        } else {
74            let degree = self.polynomials.iter().map(|(_, p)| p.degree() + 1).sum::<usize>();
75            let domain = EvaluationDomain::new(degree)?;
76            if self.evaluations.iter().any(|(_, e)| e.domain() != domain) {
77                None
78            } else {
79                #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
80                {
81                    let mut poly_slices = Vec::new();
82                    for (_, p) in &self.polynomials {
83                        poly_slices.push(p.coeffs().to_vec());
84                    }
85                    let mut eval_slices = Vec::new();
86                    for (_, e) in &self.evaluations {
87                        eval_slices.push(e.evaluations().to_vec());
88                    }
89
90                    let gpu_result_vec =
91                        snarkvm_algorithms_cuda::polymul(domain.size(), &poly_slices, &eval_slices, &F::zero());
92                    if let Ok(result) = gpu_result_vec {
93                        return Some(DensePolynomial::from_coefficients_vec(result));
94                    }
95                }
96
97                if self.fft_precomputation.is_none() {
98                    self.fft_precomputation = Some(Cow::Owned(domain.precompute_fft()));
99                }
100                if self.ifft_precomputation.is_none() {
101                    self.ifft_precomputation =
102                        Some(Cow::Owned(self.fft_precomputation.as_ref().unwrap().to_ifft_precomputation()));
103                }
104                let fft_pc = &self.fft_precomputation.unwrap();
105                let ifft_pc = &self.ifft_precomputation.unwrap();
106                let mut pool = ExecutionPool::with_capacity(self.polynomials.len() + self.evaluations.len());
107                for (_, p) in self.polynomials {
108                    pool.add_job(move || {
109                        let mut p = p.into_owned().coeffs;
110                        p.resize(domain.size(), F::zero());
111                        domain.out_order_fft_in_place_with_pc(&mut p, fft_pc);
112                        p
113                    })
114                }
115                for (_, e) in self.evaluations {
116                    pool.add_job(move || {
117                        let mut e = e.into_owned().evaluations;
118                        e.resize(domain.size(), F::zero());
119                        crate::fft::domain::derange(&mut e);
120                        e
121                    })
122                }
123                let results = pool.execute_all();
124                let iter = cfg_into_iter!(results);
125                let mut result = cfg_reduce_with!(iter, |mut a, b| {
126                    cfg_iter_mut!(a).zip(b).for_each(|(a, b)| *a *= b);
127                    a
128                })
129                .unwrap();
130                domain.out_order_ifft_in_place_with_pc(&mut result, ifft_pc);
131                Some(DensePolynomial::from_coefficients_vec(result))
132            }
133        }
134    }
135
136    pub fn element_wise_arithmetic_4_over_domain<T: Borrow<str>>(
137        mut self,
138        domain: EvaluationDomain<F>,
139        labels: [T; 4],
140        f: impl Fn(F, F, F, F) -> F + Sync,
141    ) -> Option<DensePolynomial<F>> {
142        if self.fft_precomputation.is_none() {
143            self.fft_precomputation = Some(Cow::Owned(domain.precompute_fft()));
144        }
145        if self.ifft_precomputation.is_none() {
146            self.ifft_precomputation =
147                Some(Cow::Owned(self.fft_precomputation.as_ref().unwrap().to_ifft_precomputation()));
148        }
149        let fft_pc = self.fft_precomputation.as_ref().unwrap();
150        let mut pool = ExecutionPool::with_capacity(self.polynomials.len() + self.evaluations.len());
151        for (l, p) in self.polynomials {
152            pool.add_job(move || {
153                let mut p = p.clone().into_owned().coeffs;
154                p.resize(domain.size(), F::zero());
155                domain.out_order_fft_in_place_with_pc(&mut p, fft_pc);
156                (l, p)
157            })
158        }
159        for (l, e) in self.evaluations {
160            pool.add_job(move || {
161                let mut e = e.clone().into_owned().evaluations;
162                e.resize(domain.size(), F::zero());
163                crate::fft::domain::derange(&mut e);
164                (l, e)
165            })
166        }
167        let p = pool.execute_all().into_iter().collect::<BTreeMap<_, _>>();
168        assert_eq!(p.len(), 4);
169        let mut result = cfg_iter!(p[labels[0].borrow()])
170            .zip(&p[labels[1].borrow()])
171            .zip(&p[labels[2].borrow()])
172            .zip(&p[labels[3].borrow()])
173            .map(|(((a, b), c), d)| f(*a, *b, *c, *d))
174            .collect::<Vec<_>>();
175        drop(p);
176        domain.out_order_ifft_in_place_with_pc(&mut result, &self.ifft_precomputation.unwrap());
177        Some(DensePolynomial::from_coefficients_vec(result))
178    }
179}