use std::collections::HashMap;
use num::Zero;
use rand::prelude::*;
use triton_vm::prelude::*;
use triton_vm::twenty_first::prelude::U32s;
use crate::data_type::DataType;
use crate::empty_stack;
use crate::library::Library;
use crate::push_encodable;
use crate::traits::deprecated_snippet::DeprecatedSnippet;
use crate::InitVmState;
#[derive(Clone, Debug)]
pub struct AddU64;
impl DeprecatedSnippet for AddU64 {
fn entrypoint_name(&self) -> String {
"tasmlib_arithmetic_u64_add".to_string()
}
fn input_field_names(&self) -> Vec<String> {
vec![
"rhs_hi".to_string(),
"rhs_lo".to_string(),
"lhs_hi".to_string(),
"lhs_lo".to_string(),
]
}
fn input_types(&self) -> Vec<crate::data_type::DataType> {
vec![DataType::U64, DataType::U64]
}
fn output_field_names(&self) -> Vec<String> {
vec!["(lhs + rhs)_hi".to_string(), "(lhs + rhs)_lo".to_string()]
}
fn output_types(&self) -> Vec<crate::data_type::DataType> {
vec![DataType::U64]
}
fn stack_diff(&self) -> isize {
-2
}
fn function_code(&self, _library: &mut Library) -> String {
let entrypoint = self.entrypoint_name();
format!(
"
// BEFORE: _ rhs_hi rhs_lo lhs_hi lhs_lo
// AFTER: _ sum_hi sum_lo
{entrypoint}:
swap 1 swap 2
// _ rhs_hi lhs_hi lhs_lo rhs_lo
add
split
// _ rhs_hi lhs_hi carry sum_lo
swap 3
// _ sum_lo lhs_hi carry rhs_hi
add
add
// _ sum_lo (lhs_hi+rhs_hi+carry)
split
// _ sum_lo overflow sum_hi
swap 1
push 0
eq
assert
// _ sum_lo sum_hi
swap 1
// _ sum_hi sum_lo
return
"
)
}
fn crash_conditions(&self) -> Vec<String> {
vec!["if (lhs + rhs) overflows u64".to_string()]
}
fn gen_input_states(&self) -> Vec<InitVmState> {
let mut rng = rand::thread_rng();
let zero = U32s::<2>::zero();
let small_a: U32s<2> = rng.gen::<u32>().into();
let small_b: U32s<2> = rng.gen::<u32>().into();
let large_a = U32s::<2>::try_from(rng.gen::<u64>()).unwrap();
let mut states = vec![];
states.push({
let mut stack = empty_stack();
push_encodable(&mut stack, &zero);
push_encodable(&mut stack, &large_a);
InitVmState::with_stack(stack)
});
states.push({
let mut stack = empty_stack();
push_encodable(&mut stack, &small_a);
push_encodable(&mut stack, &small_b);
InitVmState::with_stack(stack)
});
states
}
fn common_case_input_state(&self) -> InitVmState {
InitVmState::with_stack(
[
empty_stack(),
vec![BFieldElement::zero(), BFieldElement::new(1 << 31)],
vec![BFieldElement::zero(), BFieldElement::new(1 << 30)],
]
.concat(),
)
}
fn worst_case_input_state(&self) -> InitVmState {
InitVmState::with_stack(
[
empty_stack(),
vec![BFieldElement::new(1 << 31), BFieldElement::new(1 << 31)],
vec![
BFieldElement::new(1 << 30),
BFieldElement::new((1 << 31) + 10),
],
]
.concat(),
)
}
fn rust_shadowing(
&self,
stack: &mut Vec<BFieldElement>,
_std_in: Vec<BFieldElement>,
_secret_in: Vec<BFieldElement>,
_memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
let a0: u32 = stack.pop().unwrap().try_into().unwrap();
let b0: u32 = stack.pop().unwrap().try_into().unwrap();
let ab0 = U32s::<2>::new([a0, b0]);
let a1: u32 = stack.pop().unwrap().try_into().unwrap();
let b1: u32 = stack.pop().unwrap().try_into().unwrap();
let ab1 = U32s::<2>::new([a1, b1]);
let ab0_plus_ab1 = ab0 + ab1;
let mut res = ab0_plus_ab1.encode();
for _ in 0..res.len() {
stack.push(res.pop().unwrap());
}
}
}
#[cfg(test)]
mod tests {
use num::BigUint;
use num::One;
use num::Zero;
use rand::prelude::*;
use super::*;
use crate::empty_stack;
use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated;
use crate::test_helpers::test_rust_equivalence_multiple_deprecated;
#[test]
fn add_u64_test() {
test_rust_equivalence_multiple_deprecated(&AddU64, true);
}
#[test]
fn u32s_2_add_no_overflow() {
let mut expected_end_stack = [
empty_stack(),
vec![BFieldElement::zero(), BFieldElement::new(256)],
]
.concat();
prop_add(
U32s::new([127, 0]),
U32s::new([129, 0]),
Some(&expected_end_stack),
);
expected_end_stack = [
empty_stack(),
vec![BFieldElement::new(1045), BFieldElement::new(256)],
]
.concat();
prop_add(
U32s::new([127, 45]),
U32s::new([129, 1000]),
Some(&expected_end_stack),
);
expected_end_stack = [
empty_stack(),
vec![
BFieldElement::new(2004),
BFieldElement::new(u32::MAX as u64),
],
]
.concat();
prop_add(
U32s::new([u32::MAX, 0]),
U32s::new([0, 2004]),
Some(&expected_end_stack),
);
expected_end_stack = [
empty_stack(),
vec![
BFieldElement::new(1_000_000_014),
BFieldElement::new(u32::MAX as u64),
],
]
.concat();
prop_add(
U32s::new([(1 << 31) - 1, 14]),
U32s::new([1 << 31, 1_000_000_000]),
Some(&expected_end_stack),
);
}
#[test]
fn u32s_2_add_with_overflow_in_least_significant_u32() {
let expected_end_stack = [
empty_stack(),
vec![BFieldElement::one(), BFieldElement::zero()],
]
.concat();
prop_add(
U32s::new([1 << 31, 0]),
U32s::new([1 << 31, 0]),
Some(&expected_end_stack),
);
let expected_end_stack = [
empty_stack(),
vec![BFieldElement::one(), BFieldElement::new((1 << 32) - 3)],
]
.concat();
prop_add(
U32s::new([((1u64 << 32) - 1) as u32, 0]),
U32s::new([((1u64 << 32) - 2) as u32, 0]),
Some(&expected_end_stack),
);
}
#[test]
fn u32s_2_add_pbt() {
let mut rng = rand::thread_rng();
for _ in 0..100 {
prop_add(
U32s::new([rng.next_u32(), rng.next_u32() / 2]),
U32s::new([rng.next_u32(), rng.next_u32() / 2]),
None,
);
}
}
#[should_panic]
#[test]
fn overflow_test() {
let lhs: U32s<2> = U32s::from(BigUint::from(1u64 << 63));
let rhs: U32s<2> = U32s::from(BigUint::from((1u64 << 63) + 1));
let mut init_stack = empty_stack();
for elem in rhs.encode().into_iter().rev() {
init_stack.push(elem);
}
for elem in lhs.encode().into_iter().rev() {
init_stack.push(elem);
}
AddU64.link_and_run_tasm_from_state_for_test(&mut InitVmState::with_stack(init_stack));
}
#[should_panic]
#[test]
fn overflow_test_2() {
let lhs: U32s<2> = U32s::from(BigUint::from(u64::MAX));
let rhs: U32s<2> = U32s::from(BigUint::from(u64::MAX));
let mut init_stack = empty_stack();
for elem in rhs.encode().into_iter().rev() {
init_stack.push(elem);
}
for elem in lhs.encode().into_iter().rev() {
init_stack.push(elem);
}
AddU64.link_and_run_tasm_from_state_for_test(&mut InitVmState::with_stack(init_stack));
}
fn prop_add(lhs: U32s<2>, rhs: U32s<2>, expected: Option<&[BFieldElement]>) {
let mut init_stack = empty_stack();
for elem in rhs.encode().into_iter().rev() {
init_stack.push(elem);
}
for elem in lhs.encode().into_iter().rev() {
init_stack.push(elem);
}
test_rust_equivalence_given_input_values_deprecated(
&AddU64,
&init_stack,
&[],
HashMap::default(),
expected,
);
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::snippet_bencher::bench_and_write;
#[test]
fn add_u64_benchmark() {
bench_and_write(AddU64);
}
}