cairo_vm/hint_processor/builtin_hint_processor/
keccak_utils.rs

1use crate::stdlib::{boxed::Box, cmp, collections::HashMap, prelude::*};
2
3use crate::types::errors::math_errors::MathError;
4use crate::Felt252;
5use crate::{
6    hint_processor::{
7        builtin_hint_processor::hint_utils::{
8            get_integer_from_var_name, get_ptr_from_var_name, get_relocatable_from_var_name,
9        },
10        hint_processor_definition::HintReference,
11    },
12    math_utils::pow2_const_nz,
13    serde::deserialize_program::ApTracking,
14    types::{exec_scope::ExecutionScopes, relocatable::Relocatable},
15    vm::{errors::hint_errors::HintError, vm_core::VirtualMachine},
16};
17use num_integer::Integer;
18use num_traits::ToPrimitive;
19use sha3::{Digest, Keccak256};
20
21use super::hint_utils::insert_value_from_var_name;
22
23const BYTES_IN_WORD: &str = "starkware.cairo.common.builtin_keccak.keccak.BYTES_IN_WORD";
24
25/* Implements hint:
26   %{
27       from eth_hash.auto import keccak
28
29       data, length = ids.data, ids.length
30
31       if '__keccak_max_size' in globals():
32           assert length <= __keccak_max_size, \
33               f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \
34               f'Got: length={length}.'
35
36       keccak_input = bytearray()
37       for word_i, byte_i in enumerate(range(0, length, 16)):
38           word = memory[data + word_i]
39           n_bytes = min(16, length - byte_i)
40           assert 0 <= word < 2 ** (8 * n_bytes)
41           keccak_input += word.to_bytes(n_bytes, 'big')
42
43       hashed = keccak(keccak_input)
44       ids.high = int.from_bytes(hashed[:16], 'big')
45       ids.low = int.from_bytes(hashed[16:32], 'big')
46   %}
47*/
48pub fn unsafe_keccak(
49    vm: &mut VirtualMachine,
50    exec_scopes: &mut ExecutionScopes,
51    ids_data: &HashMap<String, HintReference>,
52    ap_tracking: &ApTracking,
53) -> Result<(), HintError> {
54    let length = get_integer_from_var_name("length", vm, ids_data, ap_tracking)?;
55
56    if let Ok(keccak_max_size) = exec_scopes.get::<Felt252>("__keccak_max_size") {
57        if length.as_ref() > &keccak_max_size {
58            return Err(HintError::KeccakMaxSize(Box::new((
59                length,
60                keccak_max_size,
61            ))));
62        }
63    }
64
65    // `data` is an array, represented by a pointer to the first element.
66    let data = get_ptr_from_var_name("data", vm, ids_data, ap_tracking)?;
67
68    let high_addr = get_relocatable_from_var_name("high", vm, ids_data, ap_tracking)?;
69    let low_addr = get_relocatable_from_var_name("low", vm, ids_data, ap_tracking)?;
70
71    // transform to u64 to make ranges cleaner in the for loop below
72    let u64_length = length
73        .to_u64()
74        .ok_or_else(|| HintError::InvalidKeccakInputLength(Box::new(length)))?;
75
76    const ZEROES: [u8; 32] = [0u8; 32];
77    let mut keccak_input = Vec::new();
78    for (word_i, byte_i) in (0..u64_length).step_by(16).enumerate() {
79        let word_addr = Relocatable {
80            segment_index: data.segment_index,
81            offset: data.offset + word_i,
82        };
83
84        let word = vm.get_integer(word_addr)?;
85        let bytes = word.to_bytes_be();
86        let n_bytes = cmp::min(16, u64_length - byte_i);
87        let start = 32 - n_bytes as usize;
88
89        // word <= 2^(8 * n_bytes) <=> `start` leading zeroes
90        if !ZEROES.starts_with(&bytes[..start]) {
91            return Err(HintError::InvalidWordSize(Box::new(word.into_owned())));
92        }
93
94        keccak_input.extend_from_slice(&bytes[start..]);
95    }
96
97    let mut hasher = Keccak256::new();
98    hasher.update(keccak_input);
99
100    let hashed = hasher.finalize();
101
102    let mut high_bytes = [0; 16].to_vec();
103    let mut low_bytes = [0; 16].to_vec();
104    high_bytes.extend_from_slice(&hashed[0..16]);
105    low_bytes.extend_from_slice(&hashed[16..32]);
106
107    let high = Felt252::from_bytes_be_slice(&high_bytes);
108    let low = Felt252::from_bytes_be_slice(&low_bytes);
109
110    vm.insert_value(high_addr, high)?;
111    vm.insert_value(low_addr, low)?;
112    Ok(())
113}
114
115/*
116Implements hint:
117
118    %{
119        from eth_hash.auto import keccak
120        keccak_input = bytearray()
121        n_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr
122        for word in memory.get_range(ids.keccak_state.start_ptr, n_elms):
123            keccak_input += word.to_bytes(16, 'big')
124        hashed = keccak(keccak_input)
125        ids.high = int.from_bytes(hashed[:16], 'big')
126        ids.low = int.from_bytes(hashed[16:32], 'big')
127    %}
128
129 */
130pub fn unsafe_keccak_finalize(
131    vm: &mut VirtualMachine,
132    ids_data: &HashMap<String, HintReference>,
133    ap_tracking: &ApTracking,
134) -> Result<(), HintError> {
135    /* -----------------------------
136    Just for reference (cairo code):
137    struct KeccakState:
138        member start_ptr : felt*
139        member end_ptr : felt*
140    end
141    ----------------------------- */
142
143    let keccak_state_ptr =
144        get_relocatable_from_var_name("keccak_state", vm, ids_data, ap_tracking)?;
145
146    // as `keccak_state` is a struct, the pointer to the struct is the same as the pointer to the first element.
147    // this is why to get the pointer stored in the field `start_ptr` it is enough to pass the variable name as
148    // `keccak_state`, which is the one that appears in the reference manager of the compiled JSON.
149    let start_ptr = get_ptr_from_var_name("keccak_state", vm, ids_data, ap_tracking)?;
150
151    // in the KeccakState struct, the field `end_ptr` is the second one, so this variable should be get from
152    // the memory cell contiguous to the one where KeccakState is pointing to.
153    let end_ptr = vm.get_relocatable(Relocatable {
154        segment_index: keccak_state_ptr.segment_index,
155        offset: keccak_state_ptr.offset + 1,
156    })?;
157
158    let n_elems = (end_ptr - start_ptr)?;
159
160    let mut keccak_input = Vec::new();
161    let range = vm.get_integer_range(start_ptr, n_elems)?;
162
163    for word in range.into_iter() {
164        keccak_input.extend_from_slice(&word.to_bytes_be()[16..]);
165    }
166
167    let mut hasher = Keccak256::new();
168    hasher.update(keccak_input);
169
170    let hashed = hasher.finalize();
171
172    let mut high_bytes = [0; 16].to_vec();
173    let mut low_bytes = [0; 16].to_vec();
174    high_bytes.extend_from_slice(&hashed[0..16]);
175    low_bytes.extend_from_slice(&hashed[16..32]);
176
177    let high_addr = get_relocatable_from_var_name("high", vm, ids_data, ap_tracking)?;
178    let low_addr = get_relocatable_from_var_name("low", vm, ids_data, ap_tracking)?;
179
180    let high = Felt252::from_bytes_be_slice(&high_bytes);
181    let low = Felt252::from_bytes_be_slice(&low_bytes);
182
183    vm.insert_value(high_addr, high)?;
184    vm.insert_value(low_addr, low)?;
185    Ok(())
186}
187
188// Implements hints of type : ids.output{num}_low = ids.output{num} & ((1 << 128) - 1)
189// ids.output{num}_high = ids.output{num} >> 128
190pub fn split_output(
191    vm: &mut VirtualMachine,
192    ids_data: &HashMap<String, HintReference>,
193    ap_tracking: &ApTracking,
194    num: u32,
195) -> Result<(), HintError> {
196    let output_name = format!("output{}", num);
197    let output = get_integer_from_var_name(&output_name, vm, ids_data, ap_tracking)?;
198    let (high, low) = output.div_rem(pow2_const_nz(128));
199    insert_value_from_var_name(
200        &format!("output{}_high", num),
201        high,
202        vm,
203        ids_data,
204        ap_tracking,
205    )?;
206    insert_value_from_var_name(
207        &format!("output{}_low", num),
208        low,
209        vm,
210        ids_data,
211        ap_tracking,
212    )
213}
214
215// Implements hints of type: ids.high{input_key}, ids.low{input_key} = divmod(memory[ids.inputs + {input_key}], 256 ** {exponent})
216pub fn split_input(
217    vm: &mut VirtualMachine,
218    ids_data: &HashMap<String, HintReference>,
219    ap_tracking: &ApTracking,
220    input_key: usize,
221    exponent: u32,
222) -> Result<(), HintError> {
223    let inputs_ptr = get_ptr_from_var_name("inputs", vm, ids_data, ap_tracking)?;
224    let binding = vm.get_integer((inputs_ptr + input_key)?)?;
225    let split = pow2_const_nz(8 * exponent);
226    let (high, low) = binding.div_rem(split);
227    insert_value_from_var_name(
228        &format!("high{}", input_key),
229        high,
230        vm,
231        ids_data,
232        ap_tracking,
233    )?;
234    insert_value_from_var_name(&format!("low{}", input_key), low, vm, ids_data, ap_tracking)
235}
236
237// Implements hint: ids.n_words_to_copy, ids.n_bytes_left = divmod(ids.n_bytes, ids.BYTES_IN_WORD)
238pub fn split_n_bytes(
239    vm: &mut VirtualMachine,
240    ids_data: &HashMap<String, HintReference>,
241    ap_tracking: &ApTracking,
242    constants: &HashMap<String, Felt252>,
243) -> Result<(), HintError> {
244    let n_bytes =
245        get_integer_from_var_name("n_bytes", vm, ids_data, ap_tracking).and_then(|x| {
246            x.to_u64()
247                .ok_or_else(|| HintError::Math(MathError::Felt252ToU64Conversion(Box::new(x))))
248        })?;
249    let bytes_in_word = constants
250        .get(BYTES_IN_WORD)
251        .and_then(|x| x.to_u64())
252        .ok_or_else(|| HintError::MissingConstant(Box::new(BYTES_IN_WORD)))?;
253    let (high, low) = n_bytes.div_mod_floor(&bytes_in_word);
254    insert_value_from_var_name(
255        "n_words_to_copy",
256        Felt252::from(high),
257        vm,
258        ids_data,
259        ap_tracking,
260    )?;
261    insert_value_from_var_name(
262        "n_bytes_left",
263        Felt252::from(low),
264        vm,
265        ids_data,
266        ap_tracking,
267    )
268}
269
270// Implements hint:
271// tmp, ids.output1_low = divmod(ids.output1, 256 ** 7)
272// ids.output1_high, ids.output1_mid = divmod(tmp, 2 ** 128)
273pub fn split_output_mid_low_high(
274    vm: &mut VirtualMachine,
275    ids_data: &HashMap<String, HintReference>,
276    ap_tracking: &ApTracking,
277) -> Result<(), HintError> {
278    let binding = get_integer_from_var_name("output1", vm, ids_data, ap_tracking)?;
279    let output1 = binding.as_ref();
280    let (tmp, output1_low) = output1.div_rem(pow2_const_nz(8 * 7));
281    let (output1_high, output1_mid) = tmp.div_rem(pow2_const_nz(128));
282    insert_value_from_var_name("output1_high", output1_high, vm, ids_data, ap_tracking)?;
283    insert_value_from_var_name("output1_mid", output1_mid, vm, ids_data, ap_tracking)?;
284    insert_value_from_var_name("output1_low", output1_low, vm, ids_data, ap_tracking)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::any_box;
291    use crate::{
292        hint_processor::{
293            builtin_hint_processor::{
294                builtin_hint_processor_definition::{BuiltinHintProcessor, HintProcessorData},
295                hint_code,
296                keccak_utils::HashMap,
297            },
298            hint_processor_definition::{HintProcessorLogic, HintReference},
299        },
300        utils::test_utils::*,
301    };
302    use assert_matches::assert_matches;
303
304    #[test]
305    fn split_output_0() {
306        let mut vm = vm!();
307        vm.segments = segments![((1, 0), 24)];
308        vm.set_fp(3);
309        let ids_data = ids_data!["output0", "output0_high", "output0_low"];
310        assert_matches!(run_hint!(vm, ids_data, hint_code::SPLIT_OUTPUT_0), Ok(()));
311        check_memory!(vm.segments.memory, ((1, 1), 0), ((1, 2), 24));
312    }
313
314    #[test]
315    fn split_output_1() {
316        let mut vm = vm!();
317        vm.segments = segments![((1, 0), 24)];
318        vm.set_fp(3);
319        let ids_data = ids_data!["output1", "output1_high", "output1_low"];
320        assert_matches!(run_hint!(vm, ids_data, hint_code::SPLIT_OUTPUT_1), Ok(()));
321        check_memory!(vm.segments.memory, ((1, 1), 0), ((1, 2), 24));
322    }
323
324    #[test]
325    fn split_input_3() {
326        let mut vm = vm!();
327        vm.segments = segments![((1, 2), (2, 0)), ((2, 3), 300)];
328        vm.set_fp(3);
329        let ids_data = ids_data!["high3", "low3", "inputs"];
330        assert_matches!(run_hint!(vm, ids_data, hint_code::SPLIT_INPUT_3), Ok(()));
331        check_memory!(vm.segments.memory, ((1, 0), 1), ((1, 1), 44));
332    }
333
334    #[test]
335    fn split_input_6() {
336        let mut vm = vm!();
337        vm.segments = segments![((1, 2), (2, 0)), ((2, 6), 66036)];
338        vm.set_fp(3);
339        let ids_data = ids_data!["high6", "low6", "inputs"];
340        assert_matches!(run_hint!(vm, ids_data, hint_code::SPLIT_INPUT_6), Ok(()));
341        check_memory!(vm.segments.memory, ((1, 0), 1), ((1, 1), 500));
342    }
343
344    #[test]
345    fn split_input_15() {
346        let mut vm = vm!();
347        vm.segments = segments![((1, 2), (2, 0)), ((2, 15), 15150315)];
348        vm.set_fp(3);
349        let ids_data = ids_data!["high15", "low15", "inputs"];
350        assert_matches!(run_hint!(vm, ids_data, hint_code::SPLIT_INPUT_15), Ok(()));
351        check_memory!(vm.segments.memory, ((1, 0), 0), ((1, 1), 15150315));
352    }
353
354    #[test]
355    fn split_n_bytes() {
356        let mut vm = vm!();
357        vm.segments = segments![((1, 2), 17)];
358        vm.set_fp(3);
359        let ids_data = ids_data!["n_words_to_copy", "n_bytes_left", "n_bytes"];
360        assert_matches!(
361            run_hint!(
362                vm,
363                ids_data,
364                hint_code::SPLIT_N_BYTES,
365                exec_scopes_ref!(),
366                &HashMap::from([(String::from(BYTES_IN_WORD), Felt252::from(8))])
367            ),
368            Ok(())
369        );
370        check_memory!(vm.segments.memory, ((1, 0), 2), ((1, 1), 1));
371    }
372
373    #[test]
374    fn split_output_mid_low_high() {
375        let mut vm = vm!();
376        vm.segments = segments![((1, 0), 72057594037927938)];
377        vm.set_fp(4);
378        let ids_data = ids_data!["output1", "output1_low", "output1_mid", "output1_high"];
379        assert_matches!(
380            run_hint!(
381                vm,
382                ids_data,
383                hint_code::SPLIT_OUTPUT_MID_LOW_HIGH,
384                exec_scopes_ref!(),
385                &HashMap::from([(String::from(BYTES_IN_WORD), Felt252::from(8))])
386            ),
387            Ok(())
388        );
389        check_memory!(vm.segments.memory, ((1, 1), 2), ((1, 2), 1), ((1, 3), 0));
390    }
391}