tasm_lib/arithmetic/u128/
safe_mul.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
use std::collections::HashMap;

use triton_vm::prelude::*;

use crate::prelude::*;
use crate::traits::basic_snippet::Reviewer;
use crate::traits::basic_snippet::SignOffFingerprint;

/// Multiply two `u128`s and crash on overflow.
///
/// ### Behavior
///
/// ```text
/// BEFORE: _ [right: u128] [left: u128]
/// AFTER:  _ [left · right: u128]
/// ```
///
/// ### Preconditions
///
/// - all input arguments are properly [`BFieldCodec`] encoded
/// - the product of `left` and `right` is less than or equal to [`u128::MAX`]
///
/// ### Postconditions
///
/// - the output is the product of the input
/// - the output is properly [`BFieldCodec`] encoded
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub struct SafeMul;

impl BasicSnippet for SafeMul {
    fn inputs(&self) -> Vec<(DataType, String)> {
        ["right", "left"]
            .map(|side| (DataType::U128, side.to_string()))
            .to_vec()
    }

    fn outputs(&self) -> Vec<(DataType, String)> {
        vec![(DataType::U128, "product".to_string())]
    }

    fn entrypoint(&self) -> String {
        "tasmlib_arithmetic_u128_safe_mul".to_string()
    }

    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
        triton_asm!(
            // BEFORE: _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0
            // AFTER:  _ p_3 p_2 p_1 p_0
            {self.entrypoint()}:
                /*
                 * p_0 is low limb, c_0 high limb of
                 *        l_0·r_0
                 *
                 * p_1 is low limb, c_1 high limb of
                 *        (l_1·r_0)_lo + (l_0·r_1)_lo
                 *      + c_0
                 *
                 * p_2 is low limb, c_2 high limb of
                 *        (l_1·r_0)_hi + (l_0·r_1)_hi
                 *      + (l_2·r_0)_lo + (l_1·r_1)_lo + (l_0·r_2)_lo
                 *      + c_1
                 *
                 * p_3 is low limb, c_3 high limb of
                 *        (l_2·r_0)_hi + (l_1·r_1)_hi + (l_0·r_2)_hi
                 *      + (l_3·r_0)_lo + (l_2·r_1)_lo + (l_1·r_2)_lo + (l_0·r_3)_lo
                 *      + c_2
                 *
                 * All remaining limb combinations (l_3·r_1, l_3·r_2, l_3·r_3 l_2·r_2,
                 * l_2·r_3, and l_1·r_3) as well as c_3 must be 0.
                 */

                /* p_0 */
                dup 0 dup 5 mul split
                // _ r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 p_0

                place 9
                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0

                /* p_1 */
                dup 2 dup 6 mul split
                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo

                dup 3 dup 9 mul split
                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 c_0 (l_1·r_0)_hi (l_1·r_0)_lo (l_0·r_1)_hi (l_0·r_1)_lo
                //                                       ^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^

                pick 2 pick 4
                add add
                split
                // _ p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1 p_1

                place 12
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_1·r_0)_hi (l_0·r_1)_hi c_1

                /* p_2 */
                add add
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip

                dup 3 dup 6 mul split
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo

                dup 4 dup 9 mul split
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo

                dup 5 dup 12 mul split
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_2_wip (l_2·r_0)_hi (l_2·r_0)_lo (l_1·r_1)_hi (l_1·r_1)_lo (l_0·r_2)_hi (l_0·r_2)_lo
                //                                           ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^

                pick 2 pick 4 pick 6
                add add add
                split
                // _ p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2 p_2

                place 14
                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 (l_2·r_0)_hi (l_1·r_1)_hi (l_0·r_2)_hi c_2

                /* p_3 */
                add add add
                // _ p_2 p_1 p_0 r_3 r_2 r_1 r_0 l_3 l_2 l_1 l_0 p_3_wip

                dup 4 pick 6 mul split
                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo

                dup 5 dup 8 mul split
                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo

                dup 6 dup 11 mul split
                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 l_0 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo

                pick 7 dup 13 mul split
                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 p_3_wip (l_3·r_0)_hi (l_3·r_0)_lo (l_2·r_1)_hi (l_2·r_1)_lo (l_1·r_2)_hi (l_1·r_2)_lo (l_0·l_3)_hi (l_0·l_3)_lo
                //                                       ^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^              ^^^^^^^^^^^^

                pick 2 pick 4 pick 6 pick 8
                add add add add
                split
                // _ p_2 p_1 p_0 r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3 p_3

                place 14
                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1 (l_3·r_0)_hi (l_2·r_1)_hi (l_1·r_2)_hi (l_0·l_3)_hi c_3

                /* overflow checks
                 *
                 * Carry c_3 and the high limbs still on stack are guaranteed to be smaller than
                 * 2^32 since they resulted from instruction `split`. The sum of those 5 elements
                 * cannot “wrap around” `BFieldElement::P`.
                 */
                add add add add
                push 0 eq assert error_id 500
                // _ [p; 4] r_3 r_2 r_1 l_3 l_2 l_1

                /* l_3·r_1 */
                dup 2 pick 4 mul
                push 0 eq assert error_id 501
                // _ [p; 4] r_3 r_2 l_3 l_2 l_1

                /* l_2·r_2 */
                dup 1 dup 4 mul
                push 0 eq assert error_id 502
                // _ [p; 4] r_3 r_2 l_3 l_2 l_1

                /* l_1·r_3 */
                dup 4 mul
                push 0 eq assert error_id 503
                // _ [p; 4] r_3 r_2 l_3 l_2

                /* l_3·r_2 */
                dup 1 pick 3 mul
                push 0 eq assert error_id 504
                // _ [p; 4] r_3 l_3 l_2

                /* l_2·r_3 */
                dup 2 mul
                push 0 eq assert error_id 505
                // _ [p; 4] r_3 l_3

                /* l_3·r_3 */
                mul
                push 0 eq assert error_id 506
                // _ [p; 4]

                return
        )
    }

    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
        let mut sign_offs = HashMap::new();
        sign_offs.insert(Reviewer("ferdinand"), 0x6a6ab0928dd2f0e4.into());
        sign_offs
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_prelude::*;
    use rand::rngs::StdRng;

    impl SafeMul {
        fn test_assertion_failure(&self, left: u128, right: u128, error_ids: &[i128]) {
            test_assertion_failure(
                &ShadowedClosure::new(Self),
                InitVmState::with_stack(self.set_up_test_stack((right, left))),
                error_ids,
            );
        }
    }

    impl Closure for SafeMul {
        type Args = (u128, u128);

        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
            let (right, left) = pop_encodable::<Self::Args>(stack);
            let product = left.checked_mul(right).unwrap();
            push_encodable(stack, &product);
        }

        fn pseudorandom_args(
            &self,
            seed: [u8; 32],
            bench_case: Option<BenchmarkCase>,
        ) -> Self::Args {
            let Some(bench_case) = bench_case else {
                let mut rng = StdRng::from_seed(seed);
                let left = rng.random_range(1..=u128::MAX);
                let right = rng.random_range(0..=u128::MAX / left);

                return (right, left);
            };

            match bench_case {
                BenchmarkCase::CommonCase => (1 << 63, (1 << 45) - 1),
                BenchmarkCase::WorstCase => (1 << 63, (1 << 63) - 1),
            }
        }

        fn corner_case_args(&self) -> Vec<Self::Args> {
            const LEFT_NOISE: u128 = 0xfd4e_3f84_8677_df6b_da64_b83c_8267_c72d;
            const RIGHT_NOISE: u128 = 0x538e_e051_c430_3e7a_0a29_a45a_5efb_67fa;

            (0..u128::BITS)
                .cartesian_product(0..u128::BITS)
                .map(|(l, r)| {
                    let left = (1 << l) | ((1 << l) - 1) & LEFT_NOISE;
                    let right = (1 << r) | ((1 << r) - 1) & RIGHT_NOISE;
                    (right, left)
                })
                .filter(|&(right, left)| left.checked_mul(right).is_some())
                .step_by(5) // test performance is atrocious otherwise
                .chain([(0, 0)])
                .collect()
        }
    }

    #[test]
    fn rust_shadow() {
        ShadowedClosure::new(SafeMul).test();
    }

    #[test]
    fn overflow_crashes_vm() {
        SafeMul.test_assertion_failure(1 << 127, 1 << 1, &[500]);
        SafeMul.test_assertion_failure(1 << 96, 1 << 32, &[501]);
        SafeMul.test_assertion_failure(1 << 64, 1 << 64, &[502]);
        SafeMul.test_assertion_failure(1 << 32, 1 << 96, &[503]);
        SafeMul.test_assertion_failure(1 << 96, 1 << 64, &[504]);
        SafeMul.test_assertion_failure(1 << 64, 1 << 96, &[505]);
        SafeMul.test_assertion_failure(1 << 96, 1 << 96, &[506]);

        for i in 1..64 {
            let left = u128::MAX >> i;
            let right = (1 << i) + 1;
            SafeMul.test_assertion_failure(left, right, &[500]);
            SafeMul.test_assertion_failure(right, left, &[500]);
        }

        for i in 1..128 {
            let left = 1 << i;
            let right = 1 << (128 - i);
            SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503]);
        }
    }

    #[proptest(cases = 1_000)]
    fn arbitrary_overflow_crashes_vm(
        #[strategy(2_u128..)] left: u128,
        #[strategy(u128::MAX / #left + 1..)] right: u128,
    ) {
        SafeMul.test_assertion_failure(left, right, &[500, 501, 502, 503, 504, 505, 506]);
    }
}

#[cfg(test)]
mod benches {
    use super::*;
    use crate::test_prelude::*;

    #[test]
    fn benchmark() {
        ShadowedClosure::new(SafeMul).bench();
    }
}