cairo_lang_sierra_to_casm/invocations/
mod.rs

1use assert_matches::assert_matches;
2use cairo_lang_casm::ap_change::ApChange;
3use cairo_lang_casm::builder::{CasmBuildResult, CasmBuilder, Var};
4use cairo_lang_casm::cell_expression::CellExpression;
5use cairo_lang_casm::instructions::Instruction;
6use cairo_lang_casm::operand::{CellRef, Register};
7use cairo_lang_sierra::extensions::circuit::CircuitInfo;
8use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
9use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
10use cairo_lang_sierra::extensions::gas::CostTokenType;
11use cairo_lang_sierra::extensions::lib_func::{BranchSignature, OutputVarInfo, SierraApChange};
12use cairo_lang_sierra::extensions::{ConcreteLibfunc, OutputVarReferenceInfo};
13use cairo_lang_sierra::ids::ConcreteTypeId;
14use cairo_lang_sierra::program::{BranchInfo, BranchTarget, Invocation, StatementIdx};
15use cairo_lang_sierra_ap_change::core_libfunc_ap_change::{
16    InvocationApChangeInfoProvider, core_libfunc_ap_change,
17};
18use cairo_lang_sierra_gas::core_libfunc_cost::{InvocationCostInfoProvider, core_libfunc_cost};
19use cairo_lang_sierra_gas::objects::ConstCost;
20use cairo_lang_sierra_type_size::TypeSizeMap;
21use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
22use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
23use itertools::{Itertools, chain, zip_eq};
24use num_bigint::BigInt;
25use thiserror::Error;
26
27use crate::circuit::CircuitsInfo;
28use crate::environment::Environment;
29use crate::environment::frame_state::{FrameState, FrameStateError};
30use crate::metadata::Metadata;
31use crate::references::{
32    OutputReferenceValue, OutputReferenceValueIntroductionPoint, ReferenceExpression,
33    ReferenceValue,
34};
35use crate::relocations::{InstructionsWithRelocations, Relocation, RelocationEntry};
36
37mod array;
38mod bitwise;
39mod blake;
40mod boolean;
41mod boxing;
42mod bytes31;
43mod casts;
44mod circuit;
45mod const_type;
46mod debug;
47mod ec;
48pub mod enm;
49mod felt252;
50mod felt252_dict;
51mod function_call;
52mod gas;
53mod int;
54mod mem;
55mod misc;
56mod nullable;
57mod pedersen;
58mod poseidon;
59mod range;
60mod range_reduction;
61mod squashed_felt252_dict;
62mod starknet;
63mod structure;
64mod trace;
65
66#[cfg(test)]
67mod test_utils;
68
69#[derive(Error, Debug, Eq, PartialEq)]
70pub enum InvocationError {
71    #[error("One of the arguments does not satisfy the requirements of the libfunc.")]
72    InvalidReferenceExpressionForArgument,
73    #[error("Unexpected error - an unregistered type id used.")]
74    UnknownTypeId(ConcreteTypeId),
75    #[error("Expected a different number of arguments.")]
76    WrongNumberOfArguments { expected: usize, actual: usize },
77    #[error("The requested functionality is not implemented yet.")]
78    NotImplemented(Invocation),
79    #[error("The requested functionality is not implemented yet: {message}")]
80    NotImplementedStr { invocation: Invocation, message: String },
81    #[error("The functionality is supported only for sized types.")]
82    NotSized(Invocation),
83    #[error("Expected type data not found.")]
84    UnknownTypeData,
85    #[error("Expected variable data for statement not found.")]
86    UnknownVariableData,
87    #[error("An integer overflow occurred.")]
88    InvalidGenericArg,
89    #[error("Invalid generic argument for libfunc.")]
90    IntegerOverflow,
91    #[error(transparent)]
92    FrameStateError(#[from] FrameStateError),
93    // TODO(lior): Remove this error once not used.
94    #[error("This libfunc does not support pre-cost metadata yet.")]
95    PreCostMetadataNotSupported,
96    #[error("{output_ty} is not a contained in the circuit {circuit_ty}.")]
97    InvalidCircuitOutput { output_ty: ConcreteTypeId, circuit_ty: ConcreteTypeId },
98}
99
100/// Describes a simple change in the ap tracking itself.
101#[derive(Clone, Debug, Eq, PartialEq)]
102pub enum ApTrackingChange {
103    /// Enables the tracking if not already enabled.
104    Enable,
105    /// Disables the tracking.
106    Disable,
107    /// No changes.
108    None,
109}
110
111/// Describes the changes to the set of references at a single branch target, as well as changes to
112/// the environment.
113#[derive(Clone, Debug, Eq, PartialEq)]
114pub struct BranchChanges {
115    /// New references defined at a given branch.
116    /// should correspond to BranchInfo.results.
117    pub refs: Vec<OutputReferenceValue>,
118    /// The change to AP caused by the libfunc in the branch.
119    pub ap_change: ApChange,
120    /// A change to the ap tracking status.
121    pub ap_tracking_change: ApTrackingChange,
122    /// The change to the remaining gas value in the wallet.
123    pub gas_change: OrderedHashMap<CostTokenType, i64>,
124    /// Should the stack be cleared due to a gap between stack items.
125    pub clear_old_stack: bool,
126    /// The expected size of the known stack after the change.
127    pub new_stack_size: usize,
128}
129impl BranchChanges {
130    /// Creates a `BranchChanges` object.
131    /// `param_ref` is used to fetch the reference value of a param of the libfunc.
132    fn new<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
133        ap_change: ApChange,
134        ap_tracking_change: ApTrackingChange,
135        gas_change: OrderedHashMap<CostTokenType, i64>,
136        expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
137        branch_signature: &BranchSignature,
138        prev_env: &Environment,
139        param_ref: ParamRef,
140    ) -> Self {
141        assert_eq!(
142            expressions.len(),
143            branch_signature.vars.len(),
144            "The number of expressions does not match the number of expected results in the \
145             branch."
146        );
147        let clear_old_stack =
148            !matches!(&branch_signature.ap_change, SierraApChange::Known { new_vars_only: true });
149        let stack_base = if clear_old_stack { 0 } else { prev_env.stack_size };
150        let mut new_stack_size = stack_base;
151
152        let refs: Vec<_> = zip_eq(expressions, &branch_signature.vars)
153            .enumerate()
154            .map(|(output_idx, (expression, OutputVarInfo { ref_info, ty }))| {
155                validate_output_var_refs(ref_info, &expression);
156                let stack_idx =
157                    calc_output_var_stack_idx(ref_info, stack_base, clear_old_stack, &param_ref);
158                if let Some(stack_idx) = stack_idx {
159                    new_stack_size = new_stack_size.max(stack_idx + 1);
160                }
161                let introduction_point =
162                    if let OutputVarReferenceInfo::SameAsParam { param_idx } = ref_info {
163                        OutputReferenceValueIntroductionPoint::Existing(
164                            param_ref(*param_idx).introduction_point.clone(),
165                        )
166                    } else {
167                        // Marking the statement as unknown to be fixed later.
168                        OutputReferenceValueIntroductionPoint::New(output_idx)
169                    };
170                OutputReferenceValue { expression, ty: ty.clone(), stack_idx, introduction_point }
171            })
172            .collect();
173        validate_stack_top(ap_change, branch_signature, &refs);
174        Self { refs, ap_change, ap_tracking_change, gas_change, clear_old_stack, new_stack_size }
175    }
176}
177
178/// Validates that a new temp or local var have valid references in their matching expression.
179fn validate_output_var_refs(ref_info: &OutputVarReferenceInfo, expression: &ReferenceExpression) {
180    match ref_info {
181        OutputVarReferenceInfo::SameAsParam { .. } => {}
182        _ if expression.cells.is_empty() => {
183            assert_matches!(ref_info, OutputVarReferenceInfo::ZeroSized);
184        }
185        OutputVarReferenceInfo::ZeroSized => {
186            unreachable!("Non empty ReferenceExpression for zero sized variable.")
187        }
188        OutputVarReferenceInfo::NewTempVar { .. } => {
189            expression.cells.iter().for_each(|cell| {
190                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::AP, .. }))
191            });
192        }
193        OutputVarReferenceInfo::NewLocalVar => {
194            expression.cells.iter().for_each(|cell| {
195                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::FP, .. }))
196            });
197        }
198        OutputVarReferenceInfo::SimpleDerefs => {
199            expression
200                .cells
201                .iter()
202                .for_each(|cell| assert_matches!(cell, CellExpression::Deref(_)));
203        }
204        OutputVarReferenceInfo::PartialParam { .. } | OutputVarReferenceInfo::Deferred(_) => {}
205    };
206}
207
208/// Validates that the variables that are now on the top of the stack are contiguous and that if the
209/// stack was not broken the size of all the variables is consistent with the ap change.
210fn validate_stack_top(
211    ap_change: ApChange,
212    branch_signature: &BranchSignature,
213    refs: &[OutputReferenceValue],
214) {
215    // A mapping for the new temp vars allocated on the top of the stack from their index on the
216    // top of the stack to their index in the `refs` vector.
217    let stack_top_vars = UnorderedHashMap::<usize, usize>::from_iter(
218        branch_signature.vars.iter().enumerate().filter_map(|(arg_idx, var)| {
219            if let OutputVarReferenceInfo::NewTempVar { idx: stack_idx } = var.ref_info {
220                Some((stack_idx, arg_idx))
221            } else {
222                None
223            }
224        }),
225    );
226    let mut prev_ap_offset = None;
227    let mut stack_top_size = 0;
228    for i in 0..stack_top_vars.len() {
229        let Some(arg) = stack_top_vars.get(&i) else {
230            panic!("Missing top stack var #{i} out of {}.", stack_top_vars.len());
231        };
232        let cells = &refs[*arg].expression.cells;
233        stack_top_size += cells.len();
234        for cell in cells {
235            let ap_offset = match cell {
236                CellExpression::Deref(CellRef { register: Register::AP, offset }) => *offset,
237                _ => unreachable!("Tested in `validate_output_var_refs`."),
238            };
239            if let Some(prev_ap_offset) = prev_ap_offset {
240                assert_eq!(ap_offset, prev_ap_offset + 1, "Top stack vars are not contiguous.");
241            }
242            prev_ap_offset = Some(ap_offset);
243        }
244    }
245    if matches!(branch_signature.ap_change, SierraApChange::Known { new_vars_only: true }) {
246        assert_eq!(
247            ap_change,
248            ApChange::Known(stack_top_size),
249            "New tempvar variables are not contiguous with the old stack."
250        );
251    }
252    // TODO(orizi): Add assertion for the non-new_vars_only case, that it is optimal.
253}
254
255/// Calculates the continuous stack index for an output var of a branch.
256/// `param_ref` is used to fetch the reference value of a param of the libfunc.
257fn calc_output_var_stack_idx<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
258    ref_info: &OutputVarReferenceInfo,
259    stack_base: usize,
260    clear_old_stack: bool,
261    param_ref: &ParamRef,
262) -> Option<usize> {
263    match ref_info {
264        OutputVarReferenceInfo::NewTempVar { idx } => Some(stack_base + idx),
265        OutputVarReferenceInfo::SameAsParam { param_idx } if !clear_old_stack => {
266            param_ref(*param_idx).stack_idx
267        }
268        OutputVarReferenceInfo::SameAsParam { .. }
269        | OutputVarReferenceInfo::SimpleDerefs
270        | OutputVarReferenceInfo::NewLocalVar
271        | OutputVarReferenceInfo::PartialParam { .. }
272        | OutputVarReferenceInfo::Deferred(_)
273        | OutputVarReferenceInfo::ZeroSized => None,
274    }
275}
276
277/// The result from a compilation of a single invocation statement.
278#[derive(Debug)]
279pub struct CompiledInvocation {
280    /// A vector of instructions that implement the invocation.
281    pub instructions: Vec<Instruction>,
282    /// A vector of static relocations.
283    pub relocations: Vec<RelocationEntry>,
284    /// A vector of BranchRefChanges, should correspond to the branches of the invocation
285    /// statement.
286    pub results: Vec<BranchChanges>,
287    /// The environment after the invocation statement.
288    pub environment: Environment,
289}
290
291/// Checks that the list of references is contiguous on the stack and ends at ap - 1.
292/// This is the requirement for function call and return statements.
293pub fn check_references_on_stack(refs: &[ReferenceValue]) -> Result<(), InvocationError> {
294    let mut expected_offset: i16 = -1;
295    for reference in refs.iter().rev() {
296        for cell_expr in reference.expression.cells.iter().rev() {
297            match cell_expr {
298                CellExpression::Deref(CellRef { register: Register::AP, offset })
299                    if *offset == expected_offset =>
300                {
301                    expected_offset -= 1;
302                }
303                _ => return Err(InvocationError::InvalidReferenceExpressionForArgument),
304            }
305        }
306    }
307    Ok(())
308}
309
310/// The cells per returned Sierra variables, in casm-builder vars.
311type VarCells = [Var];
312/// The configuration for all Sierra variables returned from a libfunc.
313type AllVars<'a> = [&'a VarCells];
314
315impl InvocationApChangeInfoProvider for CompiledInvocationBuilder<'_> {
316    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
317        self.program_info.type_sizes[ty] as usize
318    }
319
320    fn token_usages(&self, token_type: CostTokenType) -> usize {
321        self.program_info
322            .metadata
323            .gas_info
324            .variable_values
325            .get(&(self.idx, token_type))
326            .copied()
327            .unwrap_or(0) as usize
328    }
329}
330
331impl InvocationCostInfoProvider for CompiledInvocationBuilder<'_> {
332    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
333        self.program_info.type_sizes[ty] as usize
334    }
335
336    fn ap_change_var_value(&self) -> usize {
337        self.program_info
338            .metadata
339            .ap_change_info
340            .variable_values
341            .get(&self.idx)
342            .copied()
343            .unwrap_or_default()
344    }
345
346    fn token_usages(&self, token_type: CostTokenType) -> usize {
347        InvocationApChangeInfoProvider::token_usages(self, token_type)
348    }
349
350    fn circuit_info(&self, ty: &ConcreteTypeId) -> &CircuitInfo {
351        self.program_info.circuits_info.circuits.get(ty).unwrap()
352    }
353}
354
355/// Cost validation info for a builtin.
356struct BuiltinInfo {
357    /// The cost token type associated with the builtin.
358    cost_token_ty: CostTokenType,
359    /// The builtin pointer at the start of the libfunc.
360    start: Var,
361    /// The builtin pointer at the end of all the libfunc branches.
362    end: Var,
363}
364
365/// Information required for validating libfunc cost.
366#[derive(Default)]
367struct CostValidationInfo<const BRANCH_COUNT: usize> {
368    /// infos about builtin usage.
369    pub builtin_infos: Vec<BuiltinInfo>,
370    /// Possible extra cost per branch.
371    /// Useful for amortized costs, as well as gas withdrawal libfuncs.
372    pub extra_costs: Option<[i32; BRANCH_COUNT]>,
373}
374
375/// Helper for building compiled invocations.
376pub struct CompiledInvocationBuilder<'a> {
377    pub program_info: ProgramInfo<'a>,
378    pub invocation: &'a Invocation,
379    pub libfunc: &'a CoreConcreteLibfunc,
380    pub idx: StatementIdx,
381    /// The arguments of the libfunc.
382    pub refs: &'a [ReferenceValue],
383    pub environment: Environment,
384}
385impl CompiledInvocationBuilder<'_> {
386    /// Creates a new invocation.
387    fn build(
388        self,
389        instructions: Vec<Instruction>,
390        relocations: Vec<RelocationEntry>,
391        output_expressions: impl ExactSizeIterator<
392            Item = impl ExactSizeIterator<Item = ReferenceExpression>,
393        >,
394    ) -> CompiledInvocation {
395        let gas_changes =
396            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self);
397
398        let branch_signatures = self.libfunc.branch_signatures();
399        assert_eq!(
400            branch_signatures.len(),
401            output_expressions.len(),
402            "The number of output expressions does not match signature."
403        );
404        let ap_changes = core_libfunc_ap_change(self.libfunc, &self);
405        assert_eq!(
406            branch_signatures.len(),
407            ap_changes.len(),
408            "The number of ap changes does not match signature."
409        );
410        assert_eq!(
411            branch_signatures.len(),
412            gas_changes.len(),
413            "The number of gas changes does not match signature."
414        );
415
416        CompiledInvocation {
417            instructions,
418            relocations,
419            results: zip_eq(
420                zip_eq(branch_signatures, gas_changes),
421                zip_eq(output_expressions, ap_changes),
422            )
423            .map(|((branch_signature, gas_change), (expressions, ap_change))| {
424                let ap_tracking_change = match ap_change {
425                    cairo_lang_sierra_ap_change::ApChange::EnableApTracking => {
426                        ApTrackingChange::Enable
427                    }
428                    cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
429                        ApTrackingChange::Disable
430                    }
431                    _ => ApTrackingChange::None,
432                };
433                let ap_change = match ap_change {
434                    cairo_lang_sierra_ap_change::ApChange::Known(x) => ApChange::Known(x),
435                    cairo_lang_sierra_ap_change::ApChange::AtLocalsFinalization(_)
436                    | cairo_lang_sierra_ap_change::ApChange::EnableApTracking
437                    | cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
438                        ApChange::Known(0)
439                    }
440                    cairo_lang_sierra_ap_change::ApChange::FinalizeLocals => {
441                        if let FrameState::Finalized { allocated } = self.environment.frame_state {
442                            ApChange::Known(allocated)
443                        } else {
444                            panic!("Unexpected frame state.")
445                        }
446                    }
447                    cairo_lang_sierra_ap_change::ApChange::FunctionCall(id) => self
448                        .program_info
449                        .metadata
450                        .ap_change_info
451                        .function_ap_change
452                        .get(&id)
453                        .map_or(ApChange::Unknown, |x| ApChange::Known(x + 2)),
454                    cairo_lang_sierra_ap_change::ApChange::FromMetadata => ApChange::Known(
455                        *self
456                            .program_info
457                            .metadata
458                            .ap_change_info
459                            .variable_values
460                            .get(&self.idx)
461                            .unwrap_or(&0),
462                    ),
463                    cairo_lang_sierra_ap_change::ApChange::Unknown => ApChange::Unknown,
464                };
465
466                BranchChanges::new(
467                    ap_change,
468                    ap_tracking_change,
469                    gas_change.iter().map(|(token_type, val)| (*token_type, -val)).collect(),
470                    expressions,
471                    branch_signature,
472                    &self.environment,
473                    |idx| &self.refs[idx],
474                )
475            })
476            .collect(),
477            environment: self.environment,
478        }
479    }
480
481    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
482    /// Per branch requires `(name, result_variables, target_statement_id)`.
483    fn build_from_casm_builder<const BRANCH_COUNT: usize>(
484        self,
485        casm_builder: CasmBuilder,
486        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
487        cost_validation: CostValidationInfo<BRANCH_COUNT>,
488    ) -> CompiledInvocation {
489        self.build_from_casm_builder_ex(
490            casm_builder,
491            branch_extractions,
492            cost_validation,
493            Default::default(),
494        )
495    }
496
497    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
498    /// Per branch requires `(name, result_variables, target_statement_id)`.
499    ///
500    /// `pre_instructions` - Instructions to execute before the ones created by the builder.
501    fn build_from_casm_builder_ex<const BRANCH_COUNT: usize>(
502        self,
503        casm_builder: CasmBuilder,
504        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
505        cost_validation: CostValidationInfo<BRANCH_COUNT>,
506        pre_instructions: InstructionsWithRelocations,
507    ) -> CompiledInvocation {
508        let CasmBuildResult { instructions, branches } =
509            casm_builder.build(branch_extractions.map(|(name, _, _)| name));
510        let expected_ap_changes = core_libfunc_ap_change(self.libfunc, &self);
511        let actual_ap_changes = branches
512            .iter()
513            .map(|(state, _)| cairo_lang_sierra_ap_change::ApChange::Known(state.ap_change));
514        if !itertools::equal(expected_ap_changes.iter().cloned(), actual_ap_changes.clone()) {
515            panic!(
516                "Wrong ap changes for {}. Expected: {expected_ap_changes:?}, actual: {:?}.",
517                self.invocation,
518                actual_ap_changes.collect_vec(),
519            );
520        }
521        let gas_changes =
522            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self)
523                .into_iter()
524                .map(|costs| costs.get(&CostTokenType::Const).copied().unwrap_or_default());
525        let mut final_costs: [ConstCost; BRANCH_COUNT] =
526            std::array::from_fn(|_| Default::default());
527        for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
528            cost.steps += state.steps as i32;
529        }
530
531        for BuiltinInfo { cost_token_ty, start, end } in cost_validation.builtin_infos {
532            for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
533                let (start_base, start_offset) =
534                    state.get_adjusted(start).to_deref_with_offset().unwrap();
535                let (end_base, end_offset) =
536                    state.get_adjusted(end).to_deref_with_offset().unwrap();
537                assert_eq!(start_base, end_base);
538                let diff = end_offset - start_offset;
539                match cost_token_ty {
540                    CostTokenType::RangeCheck => {
541                        cost.range_checks += diff;
542                    }
543                    CostTokenType::RangeCheck96 => {
544                        cost.range_checks96 += diff;
545                    }
546                    _ => panic!("Cost token type not supported."),
547                }
548            }
549        }
550
551        let extra_costs =
552            cost_validation.extra_costs.unwrap_or(std::array::from_fn(|_| Default::default()));
553        let final_costs_with_extra =
554            final_costs.iter().zip(extra_costs).map(|(final_cost, extra)| {
555                (final_cost.cost() + extra + pre_instructions.cost.cost()) as i64
556            });
557        if !itertools::equal(gas_changes.clone(), final_costs_with_extra.clone()) {
558            panic!(
559                "Wrong costs for {}. Expected: {:?}, actual: {:?}, Costs from casm_builder: {:?}.",
560                self.invocation,
561                gas_changes.collect_vec(),
562                final_costs_with_extra.collect_vec(),
563                final_costs,
564            );
565        }
566        let branch_relocations = branches.iter().zip_eq(branch_extractions.iter()).flat_map(
567            |((_, relocations), (_, _, target))| {
568                assert_eq!(
569                    relocations.is_empty(),
570                    target.is_none(),
571                    "No relocations if nowhere to relocate to."
572                );
573                relocations.iter().map(|idx| RelocationEntry {
574                    instruction_idx: pre_instructions.instructions.len() + *idx,
575                    relocation: Relocation::RelativeStatementId(target.unwrap()),
576                })
577            },
578        );
579        let relocations = chain!(pre_instructions.relocations, branch_relocations).collect();
580        let output_expressions =
581            zip_eq(branches, branch_extractions).map(|((state, _), (_, vars, _))| {
582                vars.iter().map(move |var_cells| ReferenceExpression {
583                    cells: var_cells.iter().map(|cell| state.get_adjusted(*cell)).collect(),
584                })
585            });
586        self.build(
587            chain!(pre_instructions.instructions, instructions).collect(),
588            relocations,
589            output_expressions,
590        )
591    }
592
593    /// Creates a new invocation with only reference changes.
594    fn build_only_reference_changes(
595        self,
596        output_expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
597    ) -> CompiledInvocation {
598        self.build(vec![], vec![], [output_expressions].into_iter())
599    }
600
601    /// Returns the reference expressions if the size is correct.
602    pub fn try_get_refs<const COUNT: usize>(
603        &self,
604    ) -> Result<[&ReferenceExpression; COUNT], InvocationError> {
605        if self.refs.len() == COUNT {
606            Ok(core::array::from_fn(|i| &self.refs[i].expression))
607        } else {
608            Err(InvocationError::WrongNumberOfArguments {
609                expected: COUNT,
610                actual: self.refs.len(),
611            })
612        }
613    }
614
615    /// Returns the reference expressions, assuming all contains one cell if the size is correct.
616    pub fn try_get_single_cells<const COUNT: usize>(
617        &self,
618    ) -> Result<[&CellExpression; COUNT], InvocationError> {
619        let refs = self.try_get_refs::<COUNT>()?;
620        let mut last_err = None;
621        const FAKE_CELL: CellExpression =
622            CellExpression::Deref(CellRef { register: Register::AP, offset: 0 });
623        // TODO(orizi): Use `refs.try_map` once it is a stable feature.
624        let result = refs.map(|r| match r.try_unpack_single() {
625            Ok(cell) => cell,
626            Err(err) => {
627                last_err = Some(err);
628                &FAKE_CELL
629            }
630        });
631        if let Some(err) = last_err { Err(err) } else { Ok(result) }
632    }
633}
634
635/// Information in the program level required for compiling an invocation.
636pub struct ProgramInfo<'a> {
637    pub metadata: &'a Metadata,
638    pub type_sizes: &'a TypeSizeMap,
639    /// Information about the circuits in the program.
640    pub circuits_info: &'a CircuitsInfo,
641    /// Returns the given a const type returns a vector of cells value representing it.
642    pub const_data_values: &'a dyn Fn(&ConcreteTypeId) -> Vec<BigInt>,
643}
644
645/// Given a Sierra invocation statement and concrete libfunc, creates a compiled casm representation
646/// of the Sierra statement.
647pub fn compile_invocation(
648    program_info: ProgramInfo<'_>,
649    invocation: &Invocation,
650    libfunc: &CoreConcreteLibfunc,
651    idx: StatementIdx,
652    refs: &[ReferenceValue],
653    environment: Environment,
654) -> Result<CompiledInvocation, InvocationError> {
655    let builder =
656        CompiledInvocationBuilder { program_info, invocation, libfunc, idx, refs, environment };
657    match libfunc {
658        Felt252(libfunc) => felt252::build(libfunc, builder),
659        Felt252SquashedDict(libfunc) => squashed_felt252_dict::build(libfunc, builder),
660        Bool(libfunc) => boolean::build(libfunc, builder),
661        Cast(libfunc) => casts::build(libfunc, builder),
662        Ec(libfunc) => ec::build(libfunc, builder),
663        Uint8(libfunc) => int::unsigned::build_uint::<_, 0x100>(libfunc, builder),
664        Uint16(libfunc) => int::unsigned::build_uint::<_, 0x10000>(libfunc, builder),
665        Uint32(libfunc) => int::unsigned::build_uint::<_, 0x100000000>(libfunc, builder),
666        Uint64(libfunc) => int::unsigned::build_uint::<_, 0x10000000000000000>(libfunc, builder),
667        Uint128(libfunc) => int::unsigned128::build(libfunc, builder),
668        Uint256(libfunc) => int::unsigned256::build(libfunc, builder),
669        Uint512(libfunc) => int::unsigned512::build(libfunc, builder),
670        Sint8(libfunc) => {
671            int::signed::build_sint::<_, { i8::MIN as i128 }, { i8::MAX as i128 }>(libfunc, builder)
672        }
673        Sint16(libfunc) => {
674            int::signed::build_sint::<_, { i16::MIN as i128 }, { i16::MAX as i128 }>(
675                libfunc, builder,
676            )
677        }
678        Sint32(libfunc) => {
679            int::signed::build_sint::<_, { i32::MIN as i128 }, { i32::MAX as i128 }>(
680                libfunc, builder,
681            )
682        }
683        Sint64(libfunc) => {
684            int::signed::build_sint::<_, { i64::MIN as i128 }, { i64::MAX as i128 }>(
685                libfunc, builder,
686            )
687        }
688        Sint128(libfunc) => int::signed128::build(libfunc, builder),
689        Gas(libfunc) => gas::build(libfunc, builder),
690        BranchAlign(_) => misc::build_branch_align(builder),
691        Array(libfunc) => array::build(libfunc, builder),
692        Drop(_) => misc::build_drop(builder),
693        Dup(_) => misc::build_dup(builder),
694        Mem(libfunc) => mem::build(libfunc, builder),
695        UnwrapNonZero(_) => misc::build_identity(builder),
696        FunctionCall(libfunc) | CouponCall(libfunc) => function_call::build(libfunc, builder),
697        UnconditionalJump(_) => misc::build_jump(builder),
698        ApTracking(_) => misc::build_update_ap_tracking(builder),
699        Box(libfunc) => boxing::build(libfunc, builder),
700        Enum(libfunc) => enm::build(libfunc, builder),
701        Struct(libfunc) => structure::build(libfunc, builder),
702        Felt252Dict(libfunc) => felt252_dict::build_dict(libfunc, builder),
703        Pedersen(libfunc) => pedersen::build(libfunc, builder),
704        Poseidon(libfunc) => poseidon::build(libfunc, builder),
705        Starknet(libfunc) => starknet::build(libfunc, builder),
706        Nullable(libfunc) => nullable::build(libfunc, builder),
707        Debug(libfunc) => debug::build(libfunc, builder),
708        SnapshotTake(_) => misc::build_dup(builder),
709        Felt252DictEntry(libfunc) => felt252_dict::build_entry(libfunc, builder),
710        Bytes31(libfunc) => bytes31::build(libfunc, builder),
711        Const(libfunc) => const_type::build(libfunc, builder),
712        Coupon(libfunc) => match libfunc {
713            CouponConcreteLibfunc::Buy(_) => Ok(builder
714                .build_only_reference_changes([ReferenceExpression::zero_sized()].into_iter())),
715            CouponConcreteLibfunc::Refund(_) => {
716                Ok(builder.build_only_reference_changes([].into_iter()))
717            }
718        },
719        BoundedInt(libfunc) => int::bounded::build(libfunc, builder),
720        Circuit(libfunc) => circuit::build(libfunc, builder),
721        IntRange(libfunc) => range::build(libfunc, builder),
722        Blake(libfunc) => blake::build(libfunc, builder),
723        Trace(libfunc) => trace::build(libfunc, builder),
724    }
725}
726
727/// A trait for views of the Complex ReferenceExpressions as specific data structures (e.g.
728/// enum/array).
729trait ReferenceExpressionView: Sized {
730    type Error;
731    /// Extracts the specific view from the reference expressions. Can include validations and thus
732    /// returns a result.
733    /// `concrete_type_id` - the concrete type this view should represent.
734    fn try_get_view(
735        expr: &ReferenceExpression,
736        program_info: &ProgramInfo<'_>,
737        concrete_type_id: &ConcreteTypeId,
738    ) -> Result<Self, Self::Error>;
739    /// Converts the view into a ReferenceExpression.
740    fn to_reference_expression(self) -> ReferenceExpression;
741}
742
743/// Fetches the non-fallthrough jump target of the invocation, assuming this invocation is a
744/// conditional jump.
745pub fn get_non_fallthrough_statement_id(builder: &CompiledInvocationBuilder<'_>) -> StatementIdx {
746    match builder.invocation.branches.as_slice() {
747        [
748            BranchInfo { target: BranchTarget::Fallthrough, results: _ },
749            BranchInfo { target: BranchTarget::Statement(target_statement_id), results: _ },
750        ] => *target_statement_id,
751        _ => panic!("malformed invocation"),
752    }
753}
754
755/// Adds input variables into the builder while validating their type.
756macro_rules! add_input_variables {
757    ($casm_builder:ident,) => {};
758    ($casm_builder:ident, deref $var:ident; $($tok:tt)*) => {
759        let $var = $casm_builder.add_var(cairo_lang_casm::cell_expression::CellExpression::Deref(
760            $var.to_deref().ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
761        ));
762        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
763    };
764    ($casm_builder:ident, deref_or_immediate $var:ident; $($tok:tt)*) => {
765        let $var = $casm_builder.add_var(
766            match $var
767                .to_deref_or_immediate()
768                .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?
769            {
770                cairo_lang_casm::operand::DerefOrImmediate::Deref(cell) => {
771                    cairo_lang_casm::cell_expression::CellExpression::Deref(cell)
772                }
773                cairo_lang_casm::operand::DerefOrImmediate::Immediate(cell) => {
774                    cairo_lang_casm::cell_expression::CellExpression::Immediate(cell.value)
775                }
776            },
777        );
778        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
779    };
780    ($casm_builder:ident, buffer($slack:expr) $var:ident; $($tok:tt)*) => {
781        let $var = $casm_builder.add_var(
782            $var.to_buffer($slack).ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
783        );
784        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
785    };
786}
787use add_input_variables;