tasm_lib/arithmetic/u64/
overflowing_sub.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
use std::collections::HashMap;

use triton_vm::prelude::*;

use crate::prelude::*;
use crate::traits::basic_snippet::Reviewer;
use crate::traits::basic_snippet::SignOffFingerprint;

/// [Overflowing subtraction][sub] for unsigned 64-bit integers.
///
/// # Behavior
///
/// ```text
/// BEFORE: _ [subtrahend: u64] [minuend: u64]
/// AFTER:  _ [difference: u64] [is_overflow: bool]
/// ```
///
/// # Preconditions
///
/// - all input arguments are properly [`BFieldCodec`] encoded
///
/// # Postconditions
///
/// - the output `difference` is the `minuend` minus the `subtrahend`
/// - the output `is_overflow` is `true` if and only if the minuend is greater
///   than the subtrahend
/// - the output is properly [`BFieldCodec`] encoded
///
/// [sub]: u64::overflowing_sub
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct OverflowingSub;

impl OverflowingSub {
    /// The code shared between [`crate::arithmetic::u64::sub::Sub`],
    /// [`crate::arithmetic::u64::wrapping_sub::WrappingSub`], and
    /// [`OverflowingSub`]. Take care to treat the `difference_hi` correctly,
    /// depending on how you want to handle overflow.
    ///
    /// ```text
    /// BEFORE: _ subtrahend_hi subtrahend_lo minuend_hi minuend_lo
    /// AFTER:  _ difference_lo (minuend_hi - subtrahend_hi - carry)
    /// ```
    pub(crate) fn common_subtraction_code() -> Vec<LabelledInstruction> {
        triton_asm! {
            // BEFORE: _ subtrahend_hi subtrahend_lo minuend_hi minuend_lo
            // AFTER:  _ difference_hi difference_lo is_overflow
            pick 2
            // _ subtrahend_hi minuend_hi minuend_lo subtrahend_lo

            push -1
            mul
            add
            // _ subtrahend_hi minuend_hi (minuend_lo - subtrahend_lo)

            /* Any overflow manifests in the high limb. By adding 2^32, this high limb
             * is “pushed back” to be either 0 or 1; 1 in the case where _no_ overflow
             * has occurred, and 0 if overflow has occurred.
             *
             * To be honest, I don't fully understand all the subtlety going on here.
             * However, all the edge cases that I have identified pass all the tests,
             * indicating that things are fine. 👍
             */
            addi {1_u64 << 32}
            split
            // _ subtrahend_hi minuend_hi !carry difference_lo

            place 3
            // _ difference_lo subtrahend_hi minuend_hi !carry

            push 0
            eq
            // _ difference_lo subtrahend_hi minuend_hi carry

            pick 2
            add
            // _ difference_lo minuend_hi (subtrahend_hi + carry)

            push -1
            mul
            add
            // _ difference_lo (minuend_hi - subtrahend_hi - carry)
        }
    }
}

impl BasicSnippet for OverflowingSub {
    fn inputs(&self) -> Vec<(DataType, String)> {
        ["subtrahend", "minuend"]
            .map(|s| (DataType::U64, s.to_string()))
            .to_vec()
    }

    fn outputs(&self) -> Vec<(DataType, String)> {
        vec![
            (DataType::U64, "wrapped_diff".to_string()),
            (DataType::Bool, "is_overflow".to_string()),
        ]
    }

    fn entrypoint(&self) -> String {
        "tasmlib_arithmetic_u64_overflowing_sub".to_string()
    }

    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
        triton_asm!(
            {self.entrypoint()}:
                {&Self::common_subtraction_code()}
                // _ difference_lo (minuend_hi - subtrahend_hi - carry)

                addi {1_u64 << 32}
                split
                // _ difference_lo !is_overflow difference_hi

                place 2
                // _ difference_hi difference_lo !is_overflow

                push 0
                eq
                // _ difference_hi difference_lo is_overflow

                return
        )
    }

    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
        let mut sign_offs = HashMap::new();
        sign_offs.insert(Reviewer("ferdinand"), 0x4e4c796ae06e4400.into());
        sign_offs
    }
}

#[cfg(test)]
pub(crate) mod tests {
    use super::*;
    use crate::test_prelude::*;

    impl OverflowingSub {
        pub fn edge_case_values() -> Vec<u64> {
            let wiggle_edge_case_point = |p: u64| {
                [
                    p.checked_sub(3),
                    p.checked_sub(2),
                    p.checked_sub(1),
                    Some(p),
                    p.checked_add(1),
                    p.checked_add(2),
                    p.checked_add(3),
                ]
            };

            [1, 1 << 32, 1 << 33, 1 << 34, 1 << 40, 1 << 63, u64::MAX]
                .into_iter()
                .flat_map(wiggle_edge_case_point)
                .flatten()
                .collect()
        }
    }

    impl Closure for OverflowingSub {
        type Args = (u64, u64);

        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
            let (subtrahend, minuend) = pop_encodable::<Self::Args>(stack);
            push_encodable(stack, &minuend.overflowing_sub(subtrahend));
        }

        fn pseudorandom_args(
            &self,
            seed: [u8; 32],
            bench_case: Option<BenchmarkCase>,
        ) -> Self::Args {
            match bench_case {
                Some(BenchmarkCase::CommonCase) => ((1 << 63) - 1, 1 << 63),
                Some(BenchmarkCase::WorstCase) => (1 << 50, 1 << 63),
                None => StdRng::from_seed(seed).random(),
            }
        }

        fn corner_case_args(&self) -> Vec<Self::Args> {
            let edge_case_values = Self::edge_case_values();

            edge_case_values
                .iter()
                .cartesian_product(&edge_case_values)
                .map(|(&subtrahend, &minuend)| (subtrahend, minuend))
                .collect()
        }
    }

    #[test]
    fn rust_shadow() {
        ShadowedClosure::new(OverflowingSub).test()
    }
}

#[cfg(test)]
mod benches {
    use super::*;
    use crate::test_prelude::*;

    #[test]
    fn benchmark() {
        ShadowedClosure::new(OverflowingSub).bench()
    }
}