ark_r1cs_std/uint/
select.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use super::*;
use crate::select::CondSelectGadget;

impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
    for UInt<N, T, ConstraintF>
{
    #[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))]
    fn conditionally_select(
        cond: &Boolean<ConstraintF>,
        true_value: &Self,
        false_value: &Self,
    ) -> Result<Self, SynthesisError> {
        let selected_bits = true_value
            .bits
            .iter()
            .zip(&false_value.bits)
            .map(|(t, f)| cond.select(t, f));
        let mut bits = [Boolean::FALSE; N];
        for (result, new) in bits.iter_mut().zip(selected_bits) {
            *result = new?;
        }

        let value = cond.value().ok().and_then(|cond| {
            if cond {
                true_value.value().ok()
            } else {
                false_value.value().ok()
            }
        });
        Ok(Self { bits, value })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        alloc::{AllocVar, AllocationMode},
        prelude::EqGadget,
        uint::test_utils::{run_binary_exhaustive, run_binary_random},
    };
    use ark_ff::PrimeField;
    use ark_test_curves::bls12_381::Fr;

    fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
        a: UInt<N, T, F>,
        b: UInt<N, T, F>,
    ) -> Result<(), SynthesisError> {
        let cs = a.cs().or(b.cs());
        let both_constant = a.is_constant() && b.is_constant();
        let expected_mode = if both_constant {
            AllocationMode::Constant
        } else {
            AllocationMode::Witness
        };
        for cond in [true, false] {
            let expected = UInt::new_variable(
                cs.clone(),
                || Ok(if cond { a.value()? } else { b.value()? }),
                expected_mode,
            )?;
            let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
            let computed = cond.select(&a, &b)?;

            assert_eq!(expected.value(), computed.value());
            expected.enforce_equal(&computed)?;
            if !both_constant {
                assert!(cs.is_satisfied().unwrap());
            }
        }
        Ok(())
    }

    #[test]
    fn u8_select() {
        run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
    }

    #[test]
    fn u16_select() {
        run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
    }

    #[test]
    fn u32_select() {
        run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
    }

    #[test]
    fn u64_select() {
        run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
    }

    #[test]
    fn u128_select() {
        run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
    }
}