snarkvm_circuit_types_scalar/
add.rsuse super::*;
impl<E: Environment> Add<Scalar<E>> for Scalar<E> {
type Output = Scalar<E>;
fn add(self, other: Scalar<E>) -> Self::Output {
self + &other
}
}
impl<E: Environment> Add<Scalar<E>> for &Scalar<E> {
type Output = Scalar<E>;
fn add(self, other: Scalar<E>) -> Self::Output {
self + &other
}
}
impl<E: Environment> Add<&Scalar<E>> for Scalar<E> {
type Output = Scalar<E>;
fn add(self, other: &Scalar<E>) -> Self::Output {
&self + other
}
}
impl<E: Environment> Add<&Scalar<E>> for &Scalar<E> {
type Output = Scalar<E>;
fn add(self, other: &Scalar<E>) -> Self::Output {
let mut result = self.clone();
result += other;
result
}
}
impl<E: Environment> AddAssign<Scalar<E>> for Scalar<E> {
fn add_assign(&mut self, other: Scalar<E>) {
*self += &other;
}
}
impl<E: Environment> AddAssign<&Scalar<E>> for Scalar<E> {
fn add_assign(&mut self, other: &Scalar<E>) {
if self.is_constant() && other.is_constant() {
*self = witness!(|self, other| self + other);
} else {
let sum = self.to_field() + other.to_field();
let bits_le = sum.to_lower_bits_le(E::ScalarField::size_in_bits() + 1);
let sum = Field::from_bits_le(&bits_le);
let modulus =
Field::constant(match console::FromBits::from_bits_le(&E::ScalarField::modulus().to_bits_le()) {
Ok(modulus) => modulus,
Err(error) => E::halt(format!("Failed to retrieve the scalar modulus as bytes: {error}")),
});
let wrapping_sum = Ternary::ternary(&sum.is_less_than(&modulus), &sum, &(&sum - &modulus));
let bits_le = wrapping_sum.to_lower_bits_le(console::Scalar::<E::Network>::size_in_bits());
*self = Scalar { field: wrapping_sum, bits_le: OnceCell::with_value(bits_le) };
}
}
}
impl<E: Environment> Metrics<dyn Add<Scalar<E>, Output = Scalar<E>>> for Scalar<E> {
type Case = (Mode, Mode);
fn count(case: &Self::Case) -> Count {
match (case.0, case.1) {
(Mode::Constant, Mode::Constant) => Count::is(1, 0, 0, 0),
(_, _) => Count::is(1, 0, 755, 757),
}
}
}
impl<E: Environment> OutputMode<dyn Add<Scalar<E>, Output = Scalar<E>>> for Scalar<E> {
type Case = (Mode, Mode);
fn output_mode(case: &Self::Case) -> Mode {
match (case.0, case.1) {
(Mode::Constant, Mode::Constant) => Mode::Constant,
(_, _) => Mode::Private,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use snarkvm_circuit_environment::Circuit;
const ITERATIONS: u64 = 128;
#[rustfmt::skip]
fn check_add(
name: &str,
first: console::Scalar<<Circuit as Environment>::Network>,
second: console::Scalar<<Circuit as Environment>::Network>,
mode_a: Mode,
mode_b: Mode,
) {
let a = Scalar::<Circuit>::new(mode_a, first);
let b = Scalar::<Circuit>::new(mode_b, second);
let case = format!("({} + {})", a.eject_value(), b.eject_value());
let expected = first + second;
Circuit::scope(name, || {
let candidate = a + b;
assert_eq!(expected, candidate.eject_value(), "{case}");
assert_count!(Add(Scalar, Scalar) => Scalar, &(mode_a, mode_b));
assert_output_mode!(Add(Scalar, Scalar) => Scalar, &(mode_a, mode_b), candidate);
});
}
#[rustfmt::skip]
fn run_test(
mode_a: Mode,
mode_b: Mode,
) {
let mut rng = TestRng::default();
for i in 0..ITERATIONS {
let first = Uniform::rand(&mut rng);
let second = Uniform::rand(&mut rng);
let name = format!("Add: {mode_a} + {mode_b} {i}");
check_add(&name, first, second, mode_a, mode_b);
let name = format!("Add: {mode_a} + {mode_b} {i} (commutative)");
check_add(&name, second, first, mode_a, mode_b);
}
}
#[test]
fn test_scalar_constant_plus_constant() {
run_test(Mode::Constant, Mode::Constant);
}
#[test]
fn test_scalar_constant_plus_public() {
run_test(Mode::Constant, Mode::Public);
}
#[test]
fn test_scalar_constant_plus_private() {
run_test(Mode::Constant, Mode::Private);
}
#[test]
fn test_scalar_public_plus_constant() {
run_test(Mode::Public, Mode::Constant);
}
#[test]
fn test_scalar_private_plus_constant() {
run_test(Mode::Private, Mode::Constant);
}
#[test]
fn test_scalar_public_plus_public() {
run_test(Mode::Public, Mode::Public);
}
#[test]
fn test_scalar_public_plus_private() {
run_test(Mode::Public, Mode::Private);
}
#[test]
fn test_scalar_private_plus_public() {
run_test(Mode::Private, Mode::Public);
}
#[test]
fn test_scalar_private_plus_private() {
run_test(Mode::Private, Mode::Private);
}
}