tasm_lib/arithmetic/u128/
shift_left_static.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use triton_vm::prelude::*;

use crate::prelude::*;

/// [Shift left][shl] for unsigned 128-bit integers, with the shift amount
/// specified at compile time.
///
/// The shift amount `N` must be in range `0..=32`.
///
/// ### Behavior
///
/// ```text
/// BEFORE: _ [v: u128]
/// AFTER:  _ [v << SHIFT_AMOUNT: u128]
/// ```
///
/// ### Preconditions
///
/// - input argument `arg` is properly [`BFieldCodec`] encoded
///
/// ### Postconditions
///
/// - the output is properly [`BFieldCodec`] encoded
///
/// [shl]: core::ops::Shl
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub struct ShiftLeftStatic<const N: u8>;

impl<const N: u8> BasicSnippet for ShiftLeftStatic<N> {
    fn inputs(&self) -> Vec<(DataType, String)> {
        vec![(DataType::U128, "value".to_string())]
    }

    fn outputs(&self) -> Vec<(DataType, String)> {
        vec![(DataType::U128, "shifted_value".to_string())]
    }

    fn entrypoint(&self) -> String {
        format!("tasmlib_arithmetic_u128_shift_left_static_{N}")
    }

    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
        assert!(N <= 32, "shift amount must be in range 0..=32");

        triton_asm!(
            // BEFORE: _ [value: u128]
            // AFTER:  _ [value << shift: u128]
            {self.entrypoint()}:
                push {1_u64 << N}
                xb_mul          // _ v3 (v2<<shift) (v1<<shift) (v0<<shift)
                pick 3          // _ (v2<<shift) (v1<<shift) (v0<<shift) v3
                push {1_u64 << N}
                mul             // _ (v2<<shift) (v1<<shift) (v0<<shift) (v3<<shift)
                                // _  v2s         v1s         v0s         v3s

                split           // _ v2s v1s v0s v3s_hi v3s_lo
                pick 4          // _ v1s v0s v3s_hi v3s_lo v2s
                split           // _ v1s v0s v3s_hi v3s_lo v2s_hi v2s_lo
                pick 5          // _ v0s v3s_hi v3s_lo v2s_hi v2s_lo v1s
                split           // _ v0s v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi v1s_lo
                pick 6          // _ v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi v1s_lo v0s
                split           // _ v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi v1s_lo v0s_hi v0'

                place 7         // _ v0' v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi v1s_lo v0s_hi
                add             // _ v0' v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi v1'
                place 6         // _ v1' v0' v3s_hi v3s_lo v2s_hi v2s_lo v1s_hi
                add             // _ v1' v0' v3s_hi v3s_lo v2s_hi v2'
                place 5         // _ v2' v1' v0' v3s_hi v3s_lo v2s_hi
                add             // _ v2' v1' v0' v3s_hi v3'
                place 4         // _ v3' v2' v1' v0' v3s_hi
                pop 1           // _ v3' v2' v1' v0'
                return
        )
    }
}

#[cfg(test)]
pub(crate) mod tests {
    use super::*;
    use crate::test_prelude::*;
    use rand::rngs::StdRng;

    impl<const N: u8> Closure for ShiftLeftStatic<N> {
        type Args = u128;

        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
            let v = pop_encodable::<Self::Args>(stack);
            push_encodable(stack, &(v << N));
        }

        fn pseudorandom_args(
            &self,
            seed: [u8; 32],
            bench_case: Option<BenchmarkCase>,
        ) -> Self::Args {
            match bench_case {
                Some(BenchmarkCase::CommonCase) => 0x1282,
                Some(BenchmarkCase::WorstCase) => 0x123456789abcdef,
                None => StdRng::from_seed(seed).random(),
            }
        }

        fn corner_case_args(&self) -> Vec<Self::Args> {
            vec![0, 1, 8, u32::MAX.into(), u64::MAX.into(), u128::MAX]
        }
    }

    #[test]
    fn rust_shadow() {
        macro_rules! test_shift_left_static {
            ($($i:expr),*$(,)?) => {
                $(ShadowedClosure::new(ShiftLeftStatic::<$i>).test();)*
            };
        }

        test_shift_left_static!(0, 1, 2, 3, 4, 5, 6, 7);
        test_shift_left_static!(8, 9, 10, 11, 12, 13, 14, 15);
        test_shift_left_static!(16, 17, 18, 19, 20, 21, 22, 23);
        test_shift_left_static!(24, 25, 26, 27, 28, 29, 30, 31);
        test_shift_left_static!(32);
    }

    #[test]
    #[should_panic]
    fn shift_beyond_limit() {
        ShadowedClosure::new(ShiftLeftStatic::<33>).test();
    }
}

#[cfg(test)]
mod benches {
    use super::*;
    use crate::test_prelude::*;

    #[test]
    fn benchmark() {
        ShadowedClosure::new(ShiftLeftStatic::<5>).bench();
    }
}