snarkvm_circuit_environment/helpers/
linear_combination.rsuse crate::{Mode, *};
use snarkvm_fields::PrimeField;
use core::{
fmt,
ops::{Add, AddAssign, Mul, Neg, Sub},
};
#[derive(Clone)]
pub struct LinearCombination<F: PrimeField> {
constant: F,
terms: Vec<(Variable<F>, F)>,
value: F,
}
impl<F: PrimeField> LinearCombination<F> {
pub(crate) fn zero() -> Self {
Self { constant: F::zero(), terms: Default::default(), value: Default::default() }
}
pub(crate) fn one() -> Self {
Self { constant: F::one(), terms: Default::default(), value: F::one() }
}
pub fn is_constant(&self) -> bool {
self.terms.is_empty()
}
pub fn is_public(&self) -> bool {
self.constant.is_zero()
&& self.terms.len() == 1
&& match self.terms.first() {
Some((Variable::Public(..), coefficient)) => *coefficient == F::one(),
_ => false,
}
}
pub fn is_private(&self) -> bool {
!self.is_constant() && !self.is_public()
}
pub fn mode(&self) -> Mode {
if self.is_constant() {
Mode::Constant
} else if self.is_public() {
Mode::Public
} else {
Mode::Private
}
}
pub fn value(&self) -> F {
self.value
}
pub fn is_boolean_type(&self) -> bool {
if self.terms.is_empty() {
self.constant.is_zero() || self.constant.is_one()
}
else if self.constant.is_zero() {
if self.terms.iter().any(|(v, _)| !(v.value().is_zero() || v.value().is_one())) {
eprintln!("Property 2 of the `Boolean` type was violated in {self}");
return false;
}
if !(self.value.is_zero() || self.value.is_one()) {
eprintln!("Property 3 of the `Boolean` type was violated");
return false;
}
true
} else {
eprintln!("Both LC::constant and LC::terms contain elements, which is a violation");
false
}
}
pub(super) fn to_constant(&self) -> F {
self.constant
}
pub(super) fn to_terms(&self) -> &[(Variable<F>, F)] {
&self.terms
}
pub(super) fn num_nonzeros(&self) -> u64 {
match self.constant.is_zero() {
true => self.terms.len() as u64,
false => (self.terms.len() as u64).saturating_add(1),
}
}
#[cfg(test)]
pub(super) fn num_additions(&self) -> u64 {
match !self.constant.is_zero() && !self.terms.is_empty() {
true => self.terms.len() as u64,
false => (self.terms.len() as u64).saturating_sub(1),
}
}
}
impl<F: PrimeField> From<Variable<F>> for LinearCombination<F> {
fn from(variable: Variable<F>) -> Self {
Self::from(&variable)
}
}
impl<F: PrimeField> From<&Variable<F>> for LinearCombination<F> {
fn from(variable: &Variable<F>) -> Self {
Self::from(&[variable.clone()])
}
}
impl<F: PrimeField, const N: usize> From<[Variable<F>; N]> for LinearCombination<F> {
fn from(variables: [Variable<F>; N]) -> Self {
Self::from(&variables[..])
}
}
impl<F: PrimeField, const N: usize> From<&[Variable<F>; N]> for LinearCombination<F> {
fn from(variables: &[Variable<F>; N]) -> Self {
Self::from(&variables[..])
}
}
impl<F: PrimeField> From<Vec<Variable<F>>> for LinearCombination<F> {
fn from(variables: Vec<Variable<F>>) -> Self {
Self::from(variables.as_slice())
}
}
impl<F: PrimeField> From<&Vec<Variable<F>>> for LinearCombination<F> {
fn from(variables: &Vec<Variable<F>>) -> Self {
Self::from(variables.as_slice())
}
}
impl<F: PrimeField> From<&[Variable<F>]> for LinearCombination<F> {
fn from(variables: &[Variable<F>]) -> Self {
let mut output = Self::zero();
for variable in variables {
match variable.is_constant() {
true => output.constant += variable.value(),
false => {
match output.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
Ok(idx) => {
output.terms[idx].1 += F::one();
if output.terms[idx].1.is_zero() {
output.terms.remove(idx);
}
}
Err(idx) => {
output.terms.insert(idx, (variable.clone(), F::one()));
}
}
}
}
output.value += variable.value();
}
output
}
}
impl<F: PrimeField> Neg for LinearCombination<F> {
type Output = Self;
#[inline]
fn neg(self) -> Self::Output {
let mut output = self;
output.constant = -output.constant;
output.terms.iter_mut().for_each(|(_, coefficient)| *coefficient = -(*coefficient));
output.value = -output.value;
output
}
}
impl<F: PrimeField> Neg for &LinearCombination<F> {
type Output = LinearCombination<F>;
#[inline]
fn neg(self) -> Self::Output {
-(self.clone())
}
}
impl<F: PrimeField> Add<Variable<F>> for LinearCombination<F> {
type Output = Self;
#[allow(clippy::op_ref)]
fn add(self, other: Variable<F>) -> Self::Output {
self + &other
}
}
impl<F: PrimeField> Add<&Variable<F>> for LinearCombination<F> {
type Output = Self;
fn add(self, other: &Variable<F>) -> Self::Output {
self + Self::from(other)
}
}
impl<F: PrimeField> Add<Variable<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
#[allow(clippy::op_ref)]
fn add(self, other: Variable<F>) -> Self::Output {
self.clone() + &other
}
}
impl<F: PrimeField> Add<LinearCombination<F>> for LinearCombination<F> {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
self + &other
}
}
impl<F: PrimeField> Add<&LinearCombination<F>> for LinearCombination<F> {
type Output = Self;
fn add(self, other: &Self) -> Self::Output {
&self + other
}
}
impl<F: PrimeField> Add<LinearCombination<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
fn add(self, other: LinearCombination<F>) -> Self::Output {
self + &other
}
}
impl<F: PrimeField> Add<&LinearCombination<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
fn add(self, other: &LinearCombination<F>) -> Self::Output {
if self.constant.is_zero() && self.terms.is_empty() {
other.clone()
} else if other.constant.is_zero() && other.terms.is_empty() {
self.clone()
} else if self.terms.len() > other.terms.len() {
let mut output = self.clone();
output += other;
output
} else {
let mut output = other.clone();
output += self;
output
}
}
}
impl<F: PrimeField> AddAssign<LinearCombination<F>> for LinearCombination<F> {
fn add_assign(&mut self, other: Self) {
*self += &other;
}
}
impl<F: PrimeField> AddAssign<&LinearCombination<F>> for LinearCombination<F> {
fn add_assign(&mut self, other: &Self) {
if other.constant.is_zero() && other.terms.is_empty() {
return;
}
if self.constant.is_zero() && self.terms.is_empty() {
*self = other.clone();
} else {
self.constant += other.constant;
for (variable, coefficient) in other.terms.iter() {
match variable.is_constant() {
true => panic!("Malformed linear combination found"),
false => {
match self.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
Ok(idx) => {
self.terms[idx].1 += *coefficient;
if self.terms[idx].1.is_zero() {
self.terms.remove(idx);
}
}
Err(idx) => {
self.terms.insert(idx, (variable.clone(), *coefficient));
}
}
}
}
}
self.value += other.value;
}
}
}
impl<F: PrimeField> Sub<Variable<F>> for LinearCombination<F> {
type Output = Self;
#[allow(clippy::op_ref)]
fn sub(self, other: Variable<F>) -> Self::Output {
self - &other
}
}
impl<F: PrimeField> Sub<&Variable<F>> for LinearCombination<F> {
type Output = Self;
fn sub(self, other: &Variable<F>) -> Self::Output {
self - Self::from(other)
}
}
impl<F: PrimeField> Sub<Variable<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
#[allow(clippy::op_ref)]
fn sub(self, other: Variable<F>) -> Self::Output {
self.clone() - &other
}
}
impl<F: PrimeField> Sub<LinearCombination<F>> for LinearCombination<F> {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
self - &other
}
}
impl<F: PrimeField> Sub<&LinearCombination<F>> for LinearCombination<F> {
type Output = Self;
fn sub(self, other: &Self) -> Self::Output {
&self - other
}
}
impl<F: PrimeField> Sub<LinearCombination<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
fn sub(self, other: LinearCombination<F>) -> Self::Output {
self - &other
}
}
impl<F: PrimeField> Sub<&LinearCombination<F>> for &LinearCombination<F> {
type Output = LinearCombination<F>;
fn sub(self, other: &LinearCombination<F>) -> Self::Output {
self + &(-other)
}
}
impl<F: PrimeField> Mul<F> for LinearCombination<F> {
type Output = Self;
#[allow(clippy::op_ref)]
fn mul(self, coefficient: F) -> Self::Output {
self * &coefficient
}
}
impl<F: PrimeField> Mul<&F> for LinearCombination<F> {
type Output = Self;
fn mul(self, coefficient: &F) -> Self::Output {
let mut output = self;
output.constant *= coefficient;
output.terms = output
.terms
.into_iter()
.filter_map(|(v, current_coefficient)| {
let res = current_coefficient * coefficient;
(!res.is_zero()).then_some((v, res))
})
.collect();
output.value *= coefficient;
output
}
}
impl<F: PrimeField> Mul<F> for &LinearCombination<F> {
type Output = LinearCombination<F>;
#[allow(clippy::op_ref)]
fn mul(self, coefficient: F) -> Self::Output {
self * &coefficient
}
}
impl<F: PrimeField> Mul<&F> for &LinearCombination<F> {
type Output = LinearCombination<F>;
fn mul(self, coefficient: &F) -> Self::Output {
self.clone() * coefficient
}
}
impl<F: PrimeField> fmt::Debug for LinearCombination<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
let mut output = format!("Constant({})", self.constant);
for (variable, coefficient) in &self.terms {
output += &match (variable.mode(), coefficient.is_one()) {
(Mode::Constant, _) => panic!("Malformed linear combination at: ({coefficient} * {variable:?})"),
(_, true) => format!(" + {variable:?}"),
_ => format!(" + {coefficient} * {variable:?}"),
};
}
write!(f, "{output}")
}
}
impl<F: PrimeField> fmt::Display for LinearCombination<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "{}", self.value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use snarkvm_fields::{One as O, Zero as Z};
use std::rc::Rc;
#[test]
fn test_zero() {
let zero = <Circuit as Environment>::BaseField::zero();
let candidate = LinearCombination::zero();
assert_eq!(zero, candidate.constant);
assert!(candidate.terms.is_empty());
assert_eq!(zero, candidate.value());
}
#[test]
fn test_one() {
let one = <Circuit as Environment>::BaseField::one();
let candidate = LinearCombination::one();
assert_eq!(one, candidate.constant);
assert!(candidate.terms.is_empty());
assert_eq!(one, candidate.value());
}
#[test]
fn test_two() {
let one = <Circuit as Environment>::BaseField::one();
let two = one + one;
let candidate = LinearCombination::one() + LinearCombination::one();
assert_eq!(two, candidate.constant);
assert!(candidate.terms.is_empty());
assert_eq!(two, candidate.value());
}
#[test]
fn test_is_constant() {
let zero = <Circuit as Environment>::BaseField::zero();
let one = <Circuit as Environment>::BaseField::one();
let candidate = LinearCombination::zero();
assert!(candidate.is_constant());
assert_eq!(zero, candidate.constant);
assert_eq!(zero, candidate.value());
let candidate = LinearCombination::one();
assert!(candidate.is_constant());
assert_eq!(one, candidate.constant);
assert_eq!(one, candidate.value());
}
#[test]
fn test_mul() {
let zero = <Circuit as Environment>::BaseField::zero();
let one = <Circuit as Environment>::BaseField::one();
let two = one + one;
let four = two + two;
let start = LinearCombination::from(Variable::Public(Rc::new((1, one))));
assert!(!start.is_constant());
assert_eq!(one, start.value());
let candidate = start * four;
assert_eq!(four, candidate.value());
assert_eq!(zero, candidate.constant);
assert_eq!(1, candidate.terms.len());
let (candidate_variable, candidate_coefficient) = candidate.terms.first().unwrap();
assert!(candidate_variable.is_public());
assert_eq!(one, candidate_variable.value());
assert_eq!(four, *candidate_coefficient);
}
#[test]
fn test_debug() {
let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
{
let expected = "Constant(1) + Public(1, 1) + Private(0, 1)";
let candidate = LinearCombination::one() + one_public + one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + one_public + LinearCombination::one();
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + LinearCombination::one() + one_public;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_public + LinearCombination::one() + one_private;
assert_eq!(expected, format!("{candidate:?}"));
}
{
let expected = "Constant(1) + 2 * Public(1, 1) + Private(0, 1)";
let candidate = LinearCombination::one() + one_public + one_public + one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + one_public + LinearCombination::one() + one_public;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_public + one_private + LinearCombination::one() + one_public;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_public + LinearCombination::one() + one_private + one_public;
assert_eq!(expected, format!("{candidate:?}"));
}
{
let expected = "Constant(1) + Public(1, 1) + 2 * Private(0, 1)";
let candidate = LinearCombination::one() + one_public + one_private + one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + one_public + LinearCombination::one() + one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + one_private + LinearCombination::one() + one_public;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_public + LinearCombination::one() + one_private + one_private;
assert_eq!(expected, format!("{candidate:?}"));
}
{
let expected = "Constant(1) + Public(1, 1)";
let candidate = LinearCombination::one() + one_public + one_private - one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private + one_public + LinearCombination::one() - one_private;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_private - one_private + LinearCombination::one() + one_public;
assert_eq!(expected, format!("{candidate:?}"));
let candidate = one_public + LinearCombination::one() + one_private - one_private;
assert_eq!(expected, format!("{candidate:?}"));
}
}
#[rustfmt::skip]
#[test]
fn test_num_additions() {
let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
let two_private = one_private + one_private;
let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::zero();
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::one();
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + one_public;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_public;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_private + one_public;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_private + one_public;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_private + one_public + one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_private + one_public + one_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private + &two_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() - one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() - one_public;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public - one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + one_public - one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public - one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + one_public - one_private;
assert_eq!(2, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_private - one_public;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_private - one_public;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + one_public + one_private + one_public - one_private;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + one_public + one_private + one_public - one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public - one_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
assert_eq!(0, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
assert_eq!(1, candidate.num_additions());
let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private - &two_private;
assert_eq!(1, candidate.num_additions());
}
}