cairo_vm/hint_processor/builtin_hint_processor/
field_arithmetic.rs

1use crate::Felt252;
2use num_bigint::{BigUint, ToBigInt};
3use num_integer::Integer;
4use num_traits::Zero;
5
6use super::hint_utils::insert_value_from_var_name;
7use super::secp::bigint_utils::Uint384;
8use super::uint256_utils::Uint256;
9use crate::math_utils::{is_quad_residue, mul_inv, sqrt_prime_power};
10use crate::serde::deserialize_program::ApTracking;
11use crate::stdlib::{collections::HashMap, prelude::*};
12use crate::types::errors::math_errors::MathError;
13use crate::vm::errors::hint_errors::HintError;
14use crate::{
15    hint_processor::hint_processor_definition::HintReference, vm::vm_core::VirtualMachine,
16};
17
18/* Implements Hint:
19%{
20    from starkware.python.math_utils import is_quad_residue, sqrt
21
22    def split(num: int, num_bits_shift: int = 128, length: int = 3):
23        a = []
24        for _ in range(length):
25            a.append( num & ((1 << num_bits_shift) - 1) )
26            num = num >> num_bits_shift
27        return tuple(a)
28
29    def pack(z, num_bits_shift: int = 128) -> int:
30        limbs = (z.d0, z.d1, z.d2)
31        return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
32
33
34    generator = pack(ids.generator)
35    x = pack(ids.x)
36    p = pack(ids.p)
37
38    success_x = is_quad_residue(x, p)
39    root_x = sqrt(x, p) if success_x else None
40
41    success_gx = is_quad_residue(generator*x, p)
42    root_gx = sqrt(generator*x, p) if success_gx else None
43
44    # Check that one is 0 and the other is 1
45    if x != 0:
46        assert success_x + success_gx ==1
47
48    # `None` means that no root was found, but we need to transform these into a felt no matter what
49    if root_x == None:
50        root_x = 0
51    if root_gx == None:
52        root_gx = 0
53    ids.success_x = int(success_x)
54    ids.success_gx = int(success_gx)
55    split_root_x = split(root_x)
56    split_root_gx = split(root_gx)
57    ids.sqrt_x.d0 = split_root_x[0]
58    ids.sqrt_x.d1 = split_root_x[1]
59    ids.sqrt_x.d2 = split_root_x[2]
60    ids.sqrt_gx.d0 = split_root_gx[0]
61    ids.sqrt_gx.d1 = split_root_gx[1]
62    ids.sqrt_gx.d2 = split_root_gx[2]
63%}
64*/
65pub fn u384_get_square_root(
66    vm: &mut VirtualMachine,
67    ids_data: &HashMap<String, HintReference>,
68    ap_tracking: &ApTracking,
69) -> Result<(), HintError> {
70    let generator = Uint384::from_var_name("generator", vm, ids_data, ap_tracking)?.pack();
71    let x = Uint384::from_var_name("x", vm, ids_data, ap_tracking)?.pack();
72    let p = Uint384::from_var_name("p", vm, ids_data, ap_tracking)?.pack();
73    let success_x = is_quad_residue(&x, &p)?;
74
75    let root_x = if success_x {
76        sqrt_prime_power(&x, &p).unwrap_or_default()
77    } else {
78        BigUint::zero()
79    };
80
81    let gx = generator * &x;
82    let success_gx = is_quad_residue(&gx, &p)?;
83
84    let root_gx = if success_gx {
85        sqrt_prime_power(&gx, &p).unwrap_or_default()
86    } else {
87        BigUint::zero()
88    };
89
90    if !&x.is_zero() && !(success_x ^ success_gx) {
91        return Err(HintError::AssertionFailed(
92            "assert success_x + success_gx ==1"
93                .to_string()
94                .into_boxed_str(),
95        ));
96    }
97    insert_value_from_var_name(
98        "success_x",
99        Felt252::from(success_x as u8),
100        vm,
101        ids_data,
102        ap_tracking,
103    )?;
104    insert_value_from_var_name(
105        "success_gx",
106        Felt252::from(success_gx as u8),
107        vm,
108        ids_data,
109        ap_tracking,
110    )?;
111    Uint384::split(&root_x).insert_from_var_name("sqrt_x", vm, ids_data, ap_tracking)?;
112    Uint384::split(&root_gx).insert_from_var_name("sqrt_gx", vm, ids_data, ap_tracking)?;
113    Ok(())
114}
115
116/* Implements Hint:
117%{
118    from starkware.python.math_utils import is_quad_residue, sqrt
119
120    def split(a: int):
121        return (a & ((1 << 128) - 1), a >> 128)
122
123    def pack(z) -> int:
124        return z.low + (z.high << 128)
125
126    generator = pack(ids.generator)
127    x = pack(ids.x)
128    p = pack(ids.p)
129
130    success_x = is_quad_residue(x, p)
131    root_x = sqrt(x, p) if success_x else None
132    success_gx = is_quad_residue(generator*x, p)
133    root_gx = sqrt(generator*x, p) if success_gx else None
134
135    # Check that one is 0 and the other is 1
136    if x != 0:
137        assert success_x + success_gx == 1
138
139    # `None` means that no root was found, but we need to transform these into a felt no matter what
140    if root_x == None:
141        root_x = 0
142    if root_gx == None:
143        root_gx = 0
144    ids.success_x = int(success_x)
145    ids.success_gx = int(success_gx)
146    split_root_x = split(root_x)
147    # print('split root x', split_root_x)
148    split_root_gx = split(root_gx)
149    ids.sqrt_x.low = split_root_x[0]
150    ids.sqrt_x.high = split_root_x[1]
151    ids.sqrt_gx.low = split_root_gx[0]
152    ids.sqrt_gx.high = split_root_gx[1]
153%}
154*/
155// TODO: extract UintNNN methods to a trait, and use generics
156//  to merge this with u384_get_square_root
157pub fn u256_get_square_root(
158    vm: &mut VirtualMachine,
159    ids_data: &HashMap<String, HintReference>,
160    ap_tracking: &ApTracking,
161) -> Result<(), HintError> {
162    let generator = Uint256::from_var_name("generator", vm, ids_data, ap_tracking)?.pack();
163    let x = Uint256::from_var_name("x", vm, ids_data, ap_tracking)?.pack();
164    let p = Uint256::from_var_name("p", vm, ids_data, ap_tracking)?.pack();
165    let success_x = is_quad_residue(&x, &p)?;
166
167    let root_x = if success_x {
168        sqrt_prime_power(&x, &p).unwrap_or_default()
169    } else {
170        BigUint::zero()
171    };
172
173    let gx = generator * &x;
174    let success_gx = is_quad_residue(&gx, &p)?;
175
176    let root_gx = if success_gx {
177        sqrt_prime_power(&gx, &p).unwrap_or_default()
178    } else {
179        BigUint::zero()
180    };
181
182    if !&x.is_zero() && !(success_x ^ success_gx) {
183        return Err(HintError::AssertionFailed(
184            "assert success_x + success_gx ==1"
185                .to_string()
186                .into_boxed_str(),
187        ));
188    }
189    insert_value_from_var_name(
190        "success_x",
191        Felt252::from(success_x as u8),
192        vm,
193        ids_data,
194        ap_tracking,
195    )?;
196    insert_value_from_var_name(
197        "success_gx",
198        Felt252::from(success_gx as u8),
199        vm,
200        ids_data,
201        ap_tracking,
202    )?;
203    Uint256::split(&root_x).insert_from_var_name("sqrt_x", vm, ids_data, ap_tracking)?;
204    Uint256::split(&root_gx).insert_from_var_name("sqrt_gx", vm, ids_data, ap_tracking)?;
205    Ok(())
206}
207
208/* Implements Hint:
209 %{
210    from starkware.python.math_utils import div_mod
211
212    def split(num: int, num_bits_shift: int, length: int):
213        a = []
214        for _ in range(length):
215            a.append( num & ((1 << num_bits_shift) - 1) )
216            num = num >> num_bits_shift
217        return tuple(a)
218
219    def pack(z, num_bits_shift: int) -> int:
220        limbs = (z.d0, z.d1, z.d2)
221        return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
222
223    a = pack(ids.a, num_bits_shift = 128)
224    b = pack(ids.b, num_bits_shift = 128)
225    p = pack(ids.p, num_bits_shift = 128)
226    # For python3.8 and above the modular inverse can be computed as follows:
227    # b_inverse_mod_p = pow(b, -1, p)
228    # Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils
229    b_inverse_mod_p = div_mod(1, b, p)
230
231
232    b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3)
233
234    ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0]
235    ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
236    ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
237%}
238 */
239pub fn uint384_div(
240    vm: &mut VirtualMachine,
241    ids_data: &HashMap<String, HintReference>,
242    ap_tracking: &ApTracking,
243) -> Result<(), HintError> {
244    // Note: ids.a is not used here, nor is it used by following hints, so we dont need to extract it.
245    let b = Uint384::from_var_name("b", vm, ids_data, ap_tracking)?
246        .pack()
247        .to_bigint()
248        .unwrap_or_default();
249    let p = Uint384::from_var_name("p", vm, ids_data, ap_tracking)?
250        .pack()
251        .to_bigint()
252        .unwrap_or_default();
253
254    if b.is_zero() {
255        return Err(MathError::DividedByZero.into());
256    }
257    let b_inverse_mod_p = mul_inv(&b, &p)
258        .mod_floor(&p)
259        .to_biguint()
260        .unwrap_or_default();
261    let b_inverse_mod_p_split = Uint384::split(&b_inverse_mod_p);
262    b_inverse_mod_p_split.insert_from_var_name("b_inverse_mod_p", vm, ids_data, ap_tracking)
263}
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::hint_processor::builtin_hint_processor::hint_code;
268    use crate::vm::errors::memory_errors::MemoryError;
269    use crate::{
270        any_box,
271        hint_processor::{
272            builtin_hint_processor::builtin_hint_processor_definition::{
273                BuiltinHintProcessor, HintProcessorData,
274            },
275            hint_processor_definition::HintProcessorLogic,
276        },
277        utils::test_utils::*,
278        vm::vm_core::VirtualMachine,
279    };
280    use assert_matches::assert_matches;
281
282    #[cfg(target_arch = "wasm32")]
283    use wasm_bindgen_test::*;
284
285    #[test]
286    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
287    fn run_u384_get_square_ok_goldilocks_prime() {
288        let mut vm = vm_with_range_check!();
289        //Initialize fp
290        vm.run_context.fp = 14;
291        //Create hint_data
292        let ids_data = non_continuous_ids_data![
293            ("p", -14),
294            ("x", -11),
295            ("generator", -8),
296            ("sqrt_x", -5),
297            ("sqrt_gx", -2),
298            ("success_x", 1),
299            ("success_gx", 2)
300        ];
301        //Insert ids into memory
302        vm.segments = segments![
303            //p
304            ((1, 0), 18446744069414584321),
305            ((1, 1), 0),
306            ((1, 2), 0),
307            //x
308            ((1, 3), 25),
309            ((1, 4), 0),
310            ((1, 5), 0),
311            //generator
312            ((1, 6), 7),
313            ((1, 7), 0),
314            ((1, 8), 0)
315        ];
316        //Execute the hint
317        assert_matches!(
318            run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT),
319            Ok(())
320        );
321        //Check hint memory inserts
322        check_memory![
323            vm.segments.memory,
324            // sqrt_x
325            ((1, 9), 5),
326            ((1, 10), 0),
327            ((1, 11), 0),
328            // sqrt_gx
329            ((1, 12), 0),
330            ((1, 13), 0),
331            ((1, 14), 0),
332            // success_x
333            ((1, 15), 1),
334            // success_gx
335            ((1, 16), 0)
336        ];
337    }
338
339    #[test]
340    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
341    fn run_u384_get_square_no_successes() {
342        let mut vm = vm_with_range_check!();
343        //Initialize fp
344        vm.run_context.fp = 14;
345        //Create hint_data
346        let ids_data = non_continuous_ids_data![
347            ("p", -14),
348            ("x", -11),
349            ("generator", -8),
350            ("sqrt_x", -5),
351            ("sqrt_gx", -2),
352            ("success_x", 1),
353            ("success_gx", 2)
354        ];
355        //Insert ids into memory
356        vm.segments = segments![
357            //p
358            ((1, 0), 3),
359            ((1, 1), 0),
360            ((1, 2), 0),
361            //x
362            ((1, 3), 17),
363            ((1, 4), 0),
364            ((1, 5), 0),
365            //generator
366            ((1, 6), 1),
367            ((1, 7), 0),
368            ((1, 8), 0)
369        ];
370        //Execute the hint
371        assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT),
372            Err(HintError::AssertionFailed(bx)) if bx.as_ref() == "assert success_x + success_gx ==1"
373        );
374    }
375
376    #[test]
377    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
378    fn run_u384_get_square_ok_success_gx() {
379        let mut vm = vm_with_range_check!();
380        //Initialize fp
381        vm.run_context.fp = 14;
382        //Create hint_data
383        let ids_data = non_continuous_ids_data![
384            ("p", -14),
385            ("x", -11),
386            ("generator", -8),
387            ("sqrt_x", -5),
388            ("sqrt_gx", -2),
389            ("success_x", 1),
390            ("success_gx", 2),
391        ];
392        //Insert ids into memory
393        vm.segments = segments![
394            //p
395            ((1, 0), 3),
396            ((1, 1), 0),
397            ((1, 2), 0),
398            //x
399            ((1, 3), 17),
400            ((1, 4), 0),
401            ((1, 5), 0),
402            //generator
403            ((1, 6), 71),
404            ((1, 7), 0),
405            ((1, 8), 0),
406        ];
407        //Execute the hint
408        assert_matches!(
409            run_hint!(vm, ids_data, hint_code::UINT384_GET_SQUARE_ROOT),
410            Ok(())
411        );
412        //Check hint memory inserts
413        check_memory![
414            vm.segments.memory,
415            // sqrt_x
416            ((1, 9), 0),
417            ((1, 10), 0),
418            ((1, 11), 0),
419            // sqrt_gx
420            ((1, 12), 1),
421            ((1, 13), 0),
422            ((1, 14), 0),
423            // success_x
424            ((1, 15), 0),
425            // success_gx
426            ((1, 16), 1),
427        ];
428    }
429
430    #[test]
431    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
432    fn run_u256_get_square_ok_goldilocks_prime() {
433        let mut vm = vm_with_range_check!();
434        //Initialize fp
435        vm.run_context.fp = 14;
436        //Create hint_data
437        let ids_data = non_continuous_ids_data![
438            ("p", -14),
439            ("x", -11),
440            ("generator", -8),
441            ("sqrt_x", -5),
442            ("sqrt_gx", -2),
443            ("success_x", 1),
444            ("success_gx", 2),
445        ];
446        //Insert ids into memory
447        vm.segments = segments![
448            //p
449            ((1, 0), 18446744069414584321),
450            ((1, 1), 0),
451            //x
452            ((1, 3), 25),
453            ((1, 4), 0),
454            //generator
455            ((1, 6), 7),
456            ((1, 7), 0),
457        ];
458        //Execute the hint
459        assert!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT).is_ok());
460        //Check hint memory inserts
461        check_memory![
462            vm.segments.memory,
463            // sqrt_x
464            ((1, 9), 5),
465            ((1, 10), 0),
466            // sqrt_gx
467            ((1, 12), 0),
468            ((1, 13), 0),
469            // success_x
470            ((1, 15), 1),
471            // success_gx
472            ((1, 16), 0),
473        ];
474    }
475
476    #[test]
477    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
478    fn run_u256_get_square_no_successes() {
479        let mut vm = vm_with_range_check!();
480        //Initialize fp
481        vm.run_context.fp = 14;
482        //Create hint_data
483        let ids_data = non_continuous_ids_data![
484            ("p", -14),
485            ("x", -11),
486            ("generator", -8),
487            ("sqrt_x", -5),
488            ("sqrt_gx", -2),
489            ("success_x", 1),
490            ("success_gx", 2),
491        ];
492        //Insert ids into memory
493        vm.segments = segments![
494            //p
495            ((1, 0), 3),
496            ((1, 1), 0),
497            //x
498            ((1, 3), 17),
499            ((1, 4), 0),
500            //generator
501            ((1, 6), 1),
502            ((1, 7), 0),
503        ];
504        //Execute the hint
505        assert_matches!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT),
506            Err(HintError::AssertionFailed(bx)) if bx.as_ref() == "assert success_x + success_gx ==1"
507        );
508    }
509
510    #[test]
511    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
512    fn run_u256_get_square_ok_success_gx() {
513        let mut vm = vm_with_range_check!();
514        //Initialize fp
515        vm.run_context.fp = 14;
516        //Create hint_data
517        let ids_data = non_continuous_ids_data![
518            ("p", -14),
519            ("x", -11),
520            ("generator", -8),
521            ("sqrt_x", -5),
522            ("sqrt_gx", -2),
523            ("success_x", 1),
524            ("success_gx", 2),
525        ];
526        //Insert ids into memory
527        vm.segments = segments![
528            //p
529            ((1, 0), 3),
530            ((1, 1), 0),
531            //x
532            ((1, 3), 17),
533            ((1, 4), 0),
534            //generator
535            ((1, 6), 71),
536            ((1, 7), 0),
537        ];
538        //Execute the hint
539        assert!(run_hint!(vm, ids_data, hint_code::UINT256_GET_SQUARE_ROOT).is_ok());
540        //Check hint memory inserts
541        check_memory![
542            vm.segments.memory,
543            // sqrt_x
544            ((1, 9), 0),
545            ((1, 10), 0),
546            // sqrt_gx
547            ((1, 12), 1),
548            ((1, 13), 0),
549            // success_x
550            ((1, 15), 0),
551            // success_gx
552            ((1, 16), 1),
553        ];
554    }
555
556    #[test]
557    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
558    fn run_uint384_div_ok() {
559        let mut vm = vm_with_range_check!();
560        //Initialize fp
561        vm.run_context.fp = 11;
562        //Create hint_data
563        let ids_data =
564            non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
565        //Insert ids into memory
566        vm.segments = segments![
567            //a
568            ((1, 0), 25),
569            ((1, 1), 0),
570            ((1, 2), 0),
571            //b
572            ((1, 3), 5),
573            ((1, 4), 0),
574            ((1, 5), 0),
575            //p
576            ((1, 6), 31),
577            ((1, 7), 0),
578            ((1, 8), 0)
579        ];
580        //Execute the hint
581        assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_DIV), Ok(()));
582        //Check hint memory inserts
583        check_memory![
584            vm.segments.memory,
585            // b_inverse_mod_p
586            ((1, 9), 25),
587            ((1, 10), 0),
588            ((1, 11), 0)
589        ];
590    }
591
592    #[test]
593    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
594    fn run_uint384_div_b_is_zero() {
595        let mut vm = vm_with_range_check!();
596        //Initialize fp
597        vm.run_context.fp = 11;
598        //Create hint_data
599        let ids_data =
600            non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
601        //Insert ids into memory
602        vm.segments = segments![
603            //a
604            ((1, 0), 25),
605            ((1, 1), 0),
606            ((1, 2), 0),
607            //b
608            ((1, 3), 0),
609            ((1, 4), 0),
610            ((1, 5), 0),
611            //p
612            ((1, 6), 31),
613            ((1, 7), 0),
614            ((1, 8), 0)
615        ];
616        //Execute the hint
617        assert_matches!(
618            run_hint!(vm, ids_data, hint_code::UINT384_DIV),
619            Err(HintError::Math(MathError::DividedByZero))
620        );
621    }
622
623    #[test]
624    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
625    fn run_uint384_div_inconsistent_memory() {
626        let mut vm = vm_with_range_check!();
627        //Initialize fp
628        vm.run_context.fp = 11;
629        //Create hint_data
630        let ids_data =
631            non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
632        //Insert ids into memory
633        vm.segments = segments![
634            //a
635            ((1, 0), 25),
636            ((1, 1), 0),
637            ((1, 2), 0),
638            //b
639            ((1, 3), 5),
640            ((1, 4), 0),
641            ((1, 5), 0),
642            //p
643            ((1, 6), 31),
644            ((1, 7), 0),
645            ((1, 8), 0),
646            //b_inverse_mod_p
647            ((1, 9), 0)
648        ];
649        //Execute the hint
650        assert_matches!(
651            run_hint!(vm, ids_data, hint_code::UINT384_DIV),
652            Err(HintError::Memory(MemoryError::InconsistentMemory(_)))
653        );
654    }
655}