ark_r1cs_std/uint/add/
saturating.rs

1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3
4use crate::uint::*;
5use crate::{boolean::Boolean, R1CSVar};
6
7impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
8    /// Compute `*self = self.wrapping_add(other)`.
9    pub fn saturating_add_in_place(&mut self, other: &Self) {
10        let result = Self::saturating_add_many(&[self.clone(), other.clone()]).unwrap();
11        *self = result;
12    }
13
14    /// Compute `self.wrapping_add(other)`.
15    pub fn saturating_add(&self, other: &Self) -> Self {
16        let mut result = self.clone();
17        result.saturating_add_in_place(other);
18        result
19    }
20
21    /// Perform wrapping addition of `operands`.
22    /// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`.
23    ///
24    /// The user must ensure that overflow does not occur.
25    #[tracing::instrument(target = "r1cs", skip(operands))]
26    pub fn saturating_add_many(operands: &[Self]) -> Result<Self, SynthesisError>
27    where
28        F: PrimeField,
29    {
30        let (sum_bits, value) = Self::add_many_helper(operands, |a, b| a.saturating_add(b))?;
31        if operands.is_constant() {
32            // If all operands are constant, then the result is also constant.
33            // In this case, we can return early.
34            Ok(UInt::constant(value.unwrap()))
35        } else if sum_bits.len() == N {
36            // No overflow occurred.
37            Ok(UInt::from_bits_le(&sum_bits))
38        } else {
39            // Split the sum into the bottom `N` bits and the top bits.
40            let (bottom_bits, top_bits) = sum_bits.split_at(N);
41
42            // Construct a candidate result assuming that no overflow occurred.
43            let bits = TryFrom::try_from(bottom_bits.to_vec()).unwrap();
44            let candidate_result = UInt { bits, value };
45
46            // Check if any of the top bits is set.
47            // If any of them is set, then overflow occurred.
48            let overflow_occurred = Boolean::kary_or(&top_bits)?;
49
50            // If overflow occurred, return the maximum value.
51            overflow_occurred.select(&Self::MAX, &candidate_result)
52        }
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::{
60        alloc::{AllocVar, AllocationMode},
61        prelude::EqGadget,
62        uint::test_utils::{run_binary_exhaustive, run_binary_random},
63        R1CSVar,
64    };
65    use ark_ff::PrimeField;
66    use ark_test_curves::bls12_381::Fr;
67
68    fn uint_saturating_add<T: PrimUInt, const N: usize, F: PrimeField>(
69        a: UInt<N, T, F>,
70        b: UInt<N, T, F>,
71    ) -> Result<(), SynthesisError> {
72        let cs = a.cs().or(b.cs());
73        let both_constant = a.is_constant() && b.is_constant();
74        let computed = a.saturating_add(&b);
75        let expected_mode = if both_constant {
76            AllocationMode::Constant
77        } else {
78            AllocationMode::Witness
79        };
80        let expected = UInt::new_variable(
81            cs.clone(),
82            || Ok(a.value()?.saturating_add(b.value()?)),
83            expected_mode,
84        )?;
85        assert_eq!(expected.value(), computed.value());
86        expected.enforce_equal(&computed)?;
87        if !both_constant {
88            assert!(cs.is_satisfied().unwrap());
89        }
90        Ok(())
91    }
92
93    #[test]
94    fn u8_saturating_add() {
95        run_binary_exhaustive(uint_saturating_add::<u8, 8, Fr>).unwrap()
96    }
97
98    #[test]
99    fn u16_saturating_add() {
100        run_binary_random::<1000, 16, _, _>(uint_saturating_add::<u16, 16, Fr>).unwrap()
101    }
102
103    #[test]
104    fn u32_saturating_add() {
105        run_binary_random::<1000, 32, _, _>(uint_saturating_add::<u32, 32, Fr>).unwrap()
106    }
107
108    #[test]
109    fn u64_saturating_add() {
110        run_binary_random::<1000, 64, _, _>(uint_saturating_add::<u64, 64, Fr>).unwrap()
111    }
112
113    #[test]
114    fn u128_saturating_add() {
115        run_binary_random::<1000, 128, _, _>(uint_saturating_add::<u128, 128, Fr>).unwrap()
116    }
117}