use crate::{
evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension},
Polynomial,
};
use ark_ff::{Field, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{
fmt,
fmt::Formatter,
iter::IntoIterator,
log2,
ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign},
rand::Rng,
slice::{Iter, IterMut},
vec::*,
};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
pub struct DenseMultilinearExtension<F: Field> {
pub evaluations: Vec<F>,
pub num_vars: usize,
}
impl<F: Field> DenseMultilinearExtension<F> {
pub fn from_evaluations_slice(num_vars: usize, evaluations: &[F]) -> Self {
Self::from_evaluations_vec(num_vars, evaluations.to_vec())
}
pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<F>) -> Self {
assert_eq!(
evaluations.len(),
1 << num_vars,
"The size of evaluations should be 2^num_vars."
);
Self {
num_vars,
evaluations,
}
}
pub fn relabel_in_place(&mut self, mut a: usize, mut b: usize, k: usize) {
if a > b {
ark_std::mem::swap(&mut a, &mut b);
}
if a == b || k == 0 {
return;
}
assert!(b + k <= self.num_vars, "invalid relabel argument");
assert!(a + k <= b, "overlapped swap window is not allowed");
for i in 0..self.evaluations.len() {
let j = swap_bits(i, a, b, k);
if i < j {
self.evaluations.swap(i, j);
}
}
}
pub fn iter(&self) -> Iter<'_, F> {
self.evaluations.iter()
}
pub fn iter_mut(&mut self) -> IterMut<'_, F> {
self.evaluations.iter_mut()
}
pub fn concat(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> Self {
let polys_iter_cloned = polys.clone().into_iter();
let total_len: usize = polys
.into_iter()
.map(|poly| poly.as_ref().evaluations.len())
.sum();
let next_pow_of_two = total_len.next_power_of_two();
let num_vars = log2(next_pow_of_two);
let mut evaluations: Vec<F> = Vec::with_capacity(next_pow_of_two);
for poly in polys_iter_cloned {
evaluations.extend_from_slice(&poly.as_ref().evaluations.as_slice());
}
evaluations.resize(next_pow_of_two, F::zero());
Self::from_evaluations_slice(num_vars as usize, &evaluations)
}
}
impl<F: Field> AsRef<DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
fn as_ref(&self) -> &DenseMultilinearExtension<F> {
self
}
}
impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
fn num_vars(&self) -> usize {
self.num_vars
}
fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
Self::from_evaluations_vec(
num_vars,
(0..(1 << num_vars)).map(|_| F::rand(rng)).collect(),
)
}
fn relabel(&self, a: usize, b: usize, k: usize) -> Self {
let mut copied = self.clone();
copied.relabel_in_place(a, b, k);
copied
}
fn fix_variables(&self, partial_point: &[F]) -> Self {
assert!(
partial_point.len() <= self.num_vars,
"invalid size of partial point"
);
let mut poly = self.evaluations.to_vec();
let nv = self.num_vars;
let dim = partial_point.len();
for i in 1..dim + 1 {
let r = partial_point[i - 1];
for b in 0..(1 << (nv - i)) {
let left = poly[b << 1];
let right = poly[(b << 1) + 1];
poly[b] = left + r * (right - left);
}
}
Self::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
}
fn to_evaluations(&self) -> Vec<F> {
self.evaluations.to_vec()
}
}
impl<F: Field> Index<usize> for DenseMultilinearExtension<F> {
type Output = F;
fn index(&self, index: usize) -> &Self::Output {
&self.evaluations[index]
}
}
impl<F: Field> Add for DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn add(self, other: DenseMultilinearExtension<F>) -> Self {
&self + &other
}
}
impl<'a, 'b, F: Field> Add<&'a DenseMultilinearExtension<F>> for &'b DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn add(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
if rhs.is_zero() {
return self.clone();
}
if self.is_zero() {
return rhs.clone();
}
assert_eq!(self.num_vars, rhs.num_vars);
let result: Vec<F> = cfg_iter!(self.evaluations)
.zip(cfg_iter!(rhs.evaluations))
.map(|(a, b)| *a + *b)
.collect();
Self::Output::from_evaluations_vec(self.num_vars, result)
}
}
impl<F: Field> AddAssign for DenseMultilinearExtension<F> {
fn add_assign(&mut self, other: Self) {
*self = &*self + &other;
}
}
impl<'a, F: Field> AddAssign<&'a DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
fn add_assign(&mut self, other: &'a DenseMultilinearExtension<F>) {
*self = &*self + other;
}
}
impl<'a, F: Field> AddAssign<(F, &'a DenseMultilinearExtension<F>)>
for DenseMultilinearExtension<F>
{
fn add_assign(&mut self, (f, other): (F, &'a DenseMultilinearExtension<F>)) {
let other = Self {
num_vars: other.num_vars,
evaluations: cfg_iter!(other.evaluations).map(|x| f * x).collect(),
};
*self = &*self + &other;
}
}
impl<F: Field> Neg for DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn neg(self) -> Self::Output {
Self::Output {
num_vars: self.num_vars,
evaluations: cfg_iter!(self.evaluations).map(|x| -*x).collect(),
}
}
}
impl<F: Field> Sub for DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn sub(self, other: DenseMultilinearExtension<F>) -> Self {
&self - &other
}
}
impl<'a, 'b, F: Field> Sub<&'a DenseMultilinearExtension<F>> for &'b DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn sub(self, rhs: &'a DenseMultilinearExtension<F>) -> Self::Output {
self + &rhs.clone().neg()
}
}
impl<F: Field> SubAssign for DenseMultilinearExtension<F> {
fn sub_assign(&mut self, other: Self) {
*self = &*self - &other;
}
}
impl<'a, F: Field> SubAssign<&'a DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
fn sub_assign(&mut self, other: &'a DenseMultilinearExtension<F>) {
*self = &*self - other;
}
}
impl<F: Field> Mul<F> for DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn mul(self, scalar: F) -> Self::Output {
&self * &scalar
}
}
impl<'a, 'b, F: Field> Mul<&'a F> for &'b DenseMultilinearExtension<F> {
type Output = DenseMultilinearExtension<F>;
fn mul(self, scalar: &'a F) -> Self::Output {
if scalar.is_zero() {
return DenseMultilinearExtension::zero();
} else if scalar.is_one() {
return self.clone();
}
let result: Vec<F> = self.evaluations.iter().map(|&x| x * scalar).collect();
DenseMultilinearExtension {
num_vars: self.num_vars,
evaluations: result,
}
}
}
impl<F: Field> MulAssign<F> for DenseMultilinearExtension<F> {
fn mul_assign(&mut self, scalar: F) {
*self = &*self * &scalar
}
}
impl<'a, F: Field> MulAssign<&'a F> for DenseMultilinearExtension<F> {
fn mul_assign(&mut self, scalar: &'a F) {
*self = &*self * scalar
}
}
impl<F: Field> fmt::Debug for DenseMultilinearExtension<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "DenseML(nv = {}, evaluations = [", self.num_vars)?;
for i in 0..ark_std::cmp::min(4, self.evaluations.len()) {
write!(f, "{:?} ", self.evaluations[i])?;
}
if self.evaluations.len() < 4 {
write!(f, "])")?;
} else {
write!(f, "...])")?;
}
Ok(())
}
}
impl<F: Field> Zero for DenseMultilinearExtension<F> {
fn zero() -> Self {
Self {
num_vars: 0,
evaluations: vec![F::zero()],
}
}
fn is_zero(&self) -> bool {
self.num_vars == 0 && self.evaluations[0].is_zero()
}
}
impl<F: Field> Polynomial<F> for DenseMultilinearExtension<F> {
type Point = Vec<F>;
fn degree(&self) -> usize {
self.num_vars
}
fn evaluate(&self, point: &Self::Point) -> F {
assert!(point.len() == self.num_vars);
self.fix_variables(&point)[0]
}
}
#[cfg(test)]
mod tests {
use crate::{DenseMultilinearExtension, MultilinearExtension, Polynomial};
use ark_ff::{Field, One, Zero};
use ark_std::{ops::Neg, test_rng, vec::*, UniformRand};
use ark_test_curves::bls12_381::Fr;
fn evaluate_data_array<F: Field>(data: &[F], point: &[F]) -> F {
if data.len() != (1 << point.len()) {
panic!("Data size mismatch with number of variables. ")
}
let nv = point.len();
let mut a = data.to_vec();
for i in 1..nv + 1 {
let r = point[i - 1];
for b in 0..(1 << (nv - i)) {
a[b] = a[b << 1] * (F::one() - r) + a[(b << 1) + 1] * r;
}
}
a[0]
}
#[test]
fn evaluate_at_a_point() {
let mut rng = test_rng();
let poly = DenseMultilinearExtension::rand(10, &mut rng);
for _ in 0..10 {
let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
assert_eq!(
evaluate_data_array(&poly.evaluations, &point),
poly.evaluate(&point)
)
}
}
#[test]
fn relabel_polynomial() {
let mut rng = test_rng();
for _ in 0..20 {
let mut poly = DenseMultilinearExtension::rand(10, &mut rng);
let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
let expected = poly.evaluate(&point);
poly.relabel_in_place(2, 2, 1); assert_eq!(expected, poly.evaluate(&point));
poly.relabel_in_place(3, 4, 1); point.swap(3, 4);
assert_eq!(expected, poly.evaluate(&point));
poly.relabel_in_place(7, 5, 1);
point.swap(7, 5);
assert_eq!(expected, poly.evaluate(&point));
poly.relabel_in_place(2, 5, 3);
point.swap(2, 5);
point.swap(3, 6);
point.swap(4, 7);
assert_eq!(expected, poly.evaluate(&point));
poly.relabel_in_place(7, 0, 2);
point.swap(0, 7);
point.swap(1, 8);
assert_eq!(expected, poly.evaluate(&point));
poly.relabel_in_place(0, 9, 1);
point.swap(0, 9);
assert_eq!(expected, poly.evaluate(&point));
}
}
#[test]
fn arithmetic() {
const NV: usize = 10;
let mut rng = test_rng();
for _ in 0..20 {
let scalar = Fr::rand(&mut rng);
let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
let poly1 = DenseMultilinearExtension::rand(NV, &mut rng);
let poly2 = DenseMultilinearExtension::rand(NV, &mut rng);
let v1 = poly1.evaluate(&point);
let v2 = poly2.evaluate(&point);
assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
assert_eq!((&poly1 * &scalar).evaluate(&point), v1 * scalar);
{
let mut poly1 = poly1.clone();
poly1 += &poly2;
assert_eq!(poly1.evaluate(&point), v1 + v2)
}
{
let mut poly1 = poly1.clone();
poly1 -= &poly2;
assert_eq!(poly1.evaluate(&point), v1 - v2)
}
{
let mut poly1 = poly1.clone();
let scalar = Fr::rand(&mut rng);
poly1 += (scalar, &poly2);
assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
}
{
assert_eq!(&poly1 + &DenseMultilinearExtension::zero(), poly1);
assert_eq!(&DenseMultilinearExtension::zero() + &poly1, poly1);
{
let mut poly1_cloned = poly1.clone();
poly1_cloned += &DenseMultilinearExtension::zero();
assert_eq!(&poly1_cloned, &poly1);
let mut zero = DenseMultilinearExtension::zero();
let scalar = Fr::rand(&mut rng);
zero += (scalar, &poly1);
assert_eq!(zero.evaluate(&point), scalar * v1);
}
}
{
let mut poly1_cloned = poly1.clone();
poly1_cloned *= Fr::one();
assert_eq!(poly1_cloned.evaluate(&point), v1);
poly1_cloned *= scalar;
assert_eq!(poly1_cloned.evaluate(&point), v1 * scalar);
poly1_cloned *= Fr::zero();
assert_eq!(poly1_cloned, DenseMultilinearExtension::zero());
}
}
}
#[test]
fn concat_two_equal_polys() {
let mut rng = test_rng();
let degree = 10;
let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
let poly_r = DenseMultilinearExtension::rand(degree, &mut rng);
let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
for _ in 0..10 {
let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
+ point[10] * poly_r.evaluate(&point[..10].to_vec());
assert_eq!(expected, merged.evaluate(&point));
}
}
#[test]
fn concat_unequal_polys() {
let mut rng = test_rng();
let degree = 10;
let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
let poly_r = DenseMultilinearExtension::rand(degree - 1, &mut rng);
let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
for _ in 0..10 {
let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
+ point[10] * ((Fr::ONE - point[9]) * poly_r.evaluate(&point[..9].to_vec()));
assert_eq!(expected, merged.evaluate(&point));
}
}
#[test]
fn concat_two_iterators() {
let mut rng = test_rng();
let degree = 10;
let polys_l: Vec<_> = (0..2)
.map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
.collect();
let polys_r: Vec<_> = (0..2)
.map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
.collect();
let merged = DenseMultilinearExtension::<Fr>::concat(polys_l.iter().chain(polys_r.iter()));
for _ in 0..10 {
let point: Vec<_> = (0..(degree)).map(|_| Fr::rand(&mut rng)).collect();
let expected = (Fr::ONE - point[9])
* ((Fr::ONE - point[8]) * polys_l[0].evaluate(&point[..8].to_vec())
+ point[8] * polys_l[1].evaluate(&point[..8].to_vec()))
+ point[9]
* ((Fr::ONE - point[8]) * polys_r[0].evaluate(&point[..8].to_vec())
+ point[8] * polys_r[1].evaluate(&point[..8].to_vec()));
assert_eq!(expected, merged.evaluate(&point));
}
}
}