use rand::prelude::*;
use triton_vm::prelude::*;
use crate::data_type::DataType;
use crate::empty_stack;
use crate::traits::basic_snippet::BasicSnippet;
use crate::traits::closure::Closure;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct OverflowingSub;
impl BasicSnippet for OverflowingSub {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![
(DataType::U64, "lhs".to_string()),
(DataType::U64, "rhs".to_string()),
]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![
(DataType::U64, "wrapped_diff".to_owned()),
(DataType::Bool, "overflow".to_owned()),
]
}
fn entrypoint(&self) -> String {
"tasmlib_arithmetic_u64_overflowing_sub".to_string()
}
fn code(&self, _library: &mut crate::library::Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
const TWO_POW_32: &str = "4294967296";
triton_asm!(
{entrypoint}:
push -1
mul
swap 1 swap 2
add
push {TWO_POW_32}
add
split
swap 2 swap 1
push 0
eq
add
push -1
mul
swap 1 swap 2
add
push {TWO_POW_32}
add
split
swap 1
push 0 eq
swap 2 swap 1
swap 2
return
)
}
}
impl Closure for OverflowingSub {
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
let rhs_lo: u32 = stack.pop().unwrap().try_into().unwrap();
let rhs_hi: u32 = stack.pop().unwrap().try_into().unwrap();
let lhs_lo: u32 = stack.pop().unwrap().try_into().unwrap();
let lhs_hi: u32 = stack.pop().unwrap().try_into().unwrap();
let rhs: u64 = rhs_lo as u64 + ((rhs_hi as u64) << 32);
let lhs: u64 = lhs_lo as u64 + ((lhs_hi as u64) << 32);
let (wrapped_diff, overflow) = lhs.overflowing_sub(rhs);
stack.push(BFieldElement::new(wrapped_diff >> 32));
stack.push(BFieldElement::new(wrapped_diff & (u32::MAX as u64)));
stack.push(BFieldElement::new(overflow as u64));
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<crate::snippet_bencher::BenchmarkCase>,
) -> Vec<BFieldElement> {
let (lhs, rhs) = match bench_case {
Some(crate::snippet_bencher::BenchmarkCase::CommonCase) => {
(1u64 << 63, (1u64 << 63) - 1)
}
Some(crate::snippet_bencher::BenchmarkCase::WorstCase) => (1u64 << 63, 1u64 << 50),
None => {
let mut rng = StdRng::from_seed(seed);
(rng.next_u64(), rng.next_u64())
}
};
[empty_stack(), lhs.encode(), rhs.encode()].concat()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::test_rust_equivalence_given_complete_state;
use crate::traits::closure::ShadowedClosure;
use crate::traits::rust_shadow::RustShadow;
#[test]
fn u64_wrapping_sub_pbt() {
ShadowedClosure::new(OverflowingSub).test()
}
#[test]
fn u64_overflowing_sub_unit_test() {
for (lhs, rhs) in [
(0u64, 0u64),
(0, 1),
(1, 0),
(1, 1),
(1 << 32, 1 << 32),
(1 << 63, 1 << 63),
(u64::MAX, u64::MAX),
(u64::MAX, 0),
(0, u64::MAX),
(100, 101),
(101, 100),
(1 << 40, (1 << 40) + 1),
((1 << 40) + 1, 1 << 40),
(0, 1 << 40),
(1 << 40, 0),
(BFieldElement::MAX, BFieldElement::MAX),
(BFieldElement::MAX, 0),
(0, BFieldElement::MAX),
(0, BFieldElement::MAX + 1),
(BFieldElement::MAX + 1, 0),
(BFieldElement::MAX + 1, BFieldElement::MAX),
(BFieldElement::MAX + 1, BFieldElement::MAX + 1),
(BFieldElement::MAX, BFieldElement::MAX + 1),
] {
let init_stack = [
empty_stack(),
vec![
BFieldElement::new(lhs >> 32),
BFieldElement::new(lhs & u32::MAX as u64),
BFieldElement::new(rhs >> 32),
BFieldElement::new(rhs & u32::MAX as u64),
],
]
.concat();
let expected = lhs.overflowing_sub(rhs);
let expected_final_stack = [
empty_stack(),
vec![
(expected.0 >> 32).into(),
(expected.0 & u32::MAX as u64).into(),
(expected.1 as u64).into(),
],
]
.concat();
let _vm_output_state = test_rust_equivalence_given_complete_state(
&ShadowedClosure::new(OverflowingSub),
&init_stack,
&[],
&NonDeterminism::default(),
&None,
Some(&expected_final_stack),
);
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::traits::closure::ShadowedClosure;
use crate::traits::rust_shadow::RustShadow;
#[test]
fn u64_wrapping_sub_bench() {
ShadowedClosure::new(OverflowingSub).bench()
}
}