1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#[cfg(feature = "parallel")]
use rayon::iter::{IntoParallelIterator, ParallelIterator};

use crate::field::{element::FieldElement, traits::IsField};
use crate::polynomial::error::MultilinearError;
use alloc::vec::Vec;

pub struct SparseMultilinearPolynomial<F: IsField>
where
    <F as IsField>::BaseType: Send + Sync,
{
    num_vars: usize,
    evals: Vec<(usize, FieldElement<F>)>,
}

impl<F: IsField> SparseMultilinearPolynomial<F>
where
    <F as IsField>::BaseType: Send + Sync,
{
    pub fn new(num_vars: usize, evals: Vec<(usize, FieldElement<F>)>) -> Self {
        SparseMultilinearPolynomial { num_vars, evals }
    }

    pub fn num_vars(&self) -> usize {
        self.num_vars
    }

    /// Computes the eq extension polynomial of the polynomial.
    /// return 1 when a == r, otherwise return 0.
    fn compute_chi(a: &[bool], r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
        assert_eq!(a.len(), r.len());
        if a.len() != r.len() {
            return Err(MultilinearError::ChisAndEvalsLengthMismatch(
                a.len(),
                r.len(),
            ));
        }
        let mut chi_i = FieldElement::one();
        for j in 0..r.len() {
            if a[j] {
                chi_i *= &r[j];
            } else {
                chi_i *= FieldElement::<F>::one() - &r[j];
            }
        }
        Ok(chi_i)
    }

    // Takes O(n log n)
    pub fn evaluate(&self, r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
        if r.len() != self.num_vars() {
            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
                r.len(),
                self.num_vars(),
            ));
        }

        #[cfg(feature = "parallel")]
        let iter = (0..self.evals.len()).into_par_iter();

        #[cfg(not(feature = "parallel"))]
        let iter = 0..self.evals.len();

        Ok(iter
            .map(|i| {
                let bits = get_bits(self.evals[i].0, r.len());
                let mut chi_i = FieldElement::<F>::one();
                for j in 0..r.len() {
                    if bits[j] {
                        chi_i *= &r[j];
                    } else {
                        chi_i *= FieldElement::<F>::one() - &r[j];
                    }
                }
                chi_i * &self.evals[i].1
            })
            .sum())
    }

    // Takes O(n log n)
    pub fn evaluate_with(
        num_vars: usize,
        evals: &[(usize, FieldElement<F>)],
        r: &[FieldElement<F>],
    ) -> Result<FieldElement<F>, MultilinearError> {
        assert_eq!(num_vars, r.len());
        if r.len() != num_vars {
            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
                r.len(),
                num_vars,
            ));
        }

        #[cfg(feature = "parallel")]
        let iter = (0..evals.len()).into_par_iter();

        #[cfg(not(feature = "parallel"))]
        let iter = 0..evals.len();
        Ok(iter
            .map(|i| {
                let bits = get_bits(evals[i].0, r.len());
                SparseMultilinearPolynomial::compute_chi(&bits, r).unwrap() * &evals[i].1
            })
            .sum())
    }
}

/// Returns the bit decomposition (Vec<bool>) of the `index` of an evaluation within the sparse multilinear polynomial.
fn get_bits(n: usize, num_bits: usize) -> Vec<bool> {
    (0..num_bits)
        .map(|shift_amount| ((n & (1 << (num_bits - shift_amount - 1))) > 0))
        .collect::<Vec<bool>>()
}

#[cfg(test)]
mod test {

    #[test]
    fn evaluate() {
        use crate::field::fields::u64_prime_field::U64PrimeField;
        use alloc::vec;

        use super::*;

        const ORDER: u64 = 101;
        type F = U64PrimeField<ORDER>;
        type FE = FieldElement<F>;

        // Let the polynomial have 3 variables, p(x_1, x_2, x_3) = (x_1 + x_2) * x_3
        // Evaluations of the polynomial at boolean cube are [0, 0, 0, 1, 0, 1, 0, 2].

        let two = FE::from(2);
        let z = vec![(3, FE::one()), (5, FE::one()), (7, two)];
        let m_poly = SparseMultilinearPolynomial::<F>::new(3, z.clone());

        let x = vec![FE::one(), FE::one(), FE::one()];
        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), two);
        assert_eq!(
            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
            two
        );

        let x = vec![FE::one(), FE::zero(), FE::one()];
        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), FE::one());
        assert_eq!(
            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
            FE::one()
        );
    }
}