cairo_lang_sierra_to_casm/invocations/
mod.rs

1use assert_matches::assert_matches;
2use cairo_lang_casm::ap_change::ApChange;
3use cairo_lang_casm::builder::{CasmBuildResult, CasmBuilder, Var};
4use cairo_lang_casm::cell_expression::CellExpression;
5use cairo_lang_casm::instructions::Instruction;
6use cairo_lang_casm::operand::{CellRef, Register};
7use cairo_lang_sierra::extensions::circuit::CircuitInfo;
8use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
9use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
10use cairo_lang_sierra::extensions::gas::CostTokenType;
11use cairo_lang_sierra::extensions::lib_func::{BranchSignature, OutputVarInfo, SierraApChange};
12use cairo_lang_sierra::extensions::{ConcreteLibfunc, OutputVarReferenceInfo};
13use cairo_lang_sierra::ids::ConcreteTypeId;
14use cairo_lang_sierra::program::{BranchInfo, BranchTarget, Invocation, StatementIdx};
15use cairo_lang_sierra_ap_change::core_libfunc_ap_change::{
16    InvocationApChangeInfoProvider, core_libfunc_ap_change,
17};
18use cairo_lang_sierra_gas::core_libfunc_cost::{InvocationCostInfoProvider, core_libfunc_cost};
19use cairo_lang_sierra_gas::objects::ConstCost;
20use cairo_lang_sierra_type_size::TypeSizeMap;
21use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
22use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
23use itertools::{Itertools, chain, zip_eq};
24use num_bigint::BigInt;
25use thiserror::Error;
26
27use crate::circuit::CircuitsInfo;
28use crate::environment::Environment;
29use crate::environment::frame_state::{FrameState, FrameStateError};
30use crate::metadata::Metadata;
31use crate::references::{
32    OutputReferenceValue, OutputReferenceValueIntroductionPoint, ReferenceExpression,
33    ReferenceValue,
34};
35use crate::relocations::{InstructionsWithRelocations, Relocation, RelocationEntry};
36
37mod array;
38mod bitwise;
39mod boolean;
40mod boxing;
41mod bytes31;
42mod casts;
43mod circuit;
44mod const_type;
45mod debug;
46mod ec;
47pub mod enm;
48mod felt252;
49mod felt252_dict;
50mod function_call;
51mod gas;
52mod int;
53mod mem;
54mod misc;
55mod nullable;
56mod pedersen;
57mod poseidon;
58mod range;
59mod range_reduction;
60mod starknet;
61mod structure;
62
63#[cfg(test)]
64mod test_utils;
65
66#[derive(Error, Debug, Eq, PartialEq)]
67pub enum InvocationError {
68    #[error("One of the arguments does not satisfy the requirements of the libfunc.")]
69    InvalidReferenceExpressionForArgument,
70    #[error("Unexpected error - an unregistered type id used.")]
71    UnknownTypeId(ConcreteTypeId),
72    #[error("Expected a different number of arguments.")]
73    WrongNumberOfArguments { expected: usize, actual: usize },
74    #[error("The requested functionality is not implemented yet.")]
75    NotImplemented(Invocation),
76    #[error("The requested functionality is not implemented yet: {message}")]
77    NotImplementedStr { invocation: Invocation, message: String },
78    #[error("The functionality is supported only for sized types.")]
79    NotSized(Invocation),
80    #[error("Expected type data not found.")]
81    UnknownTypeData,
82    #[error("Expected variable data for statement not found.")]
83    UnknownVariableData,
84    #[error("An integer overflow occurred.")]
85    InvalidGenericArg,
86    #[error("Invalid generic argument for libfunc.")]
87    IntegerOverflow,
88    #[error(transparent)]
89    FrameStateError(#[from] FrameStateError),
90    // TODO(lior): Remove this error once not used.
91    #[error("This libfunc does not support pre-cost metadata yet.")]
92    PreCostMetadataNotSupported,
93    #[error("{output_ty} is not a contained in the circuit {circuit_ty}.")]
94    InvalidCircuitOutput { output_ty: ConcreteTypeId, circuit_ty: ConcreteTypeId },
95}
96
97/// Describes a simple change in the ap tracking itself.
98#[derive(Clone, Debug, Eq, PartialEq)]
99pub enum ApTrackingChange {
100    /// Enables the tracking if not already enabled.
101    Enable,
102    /// Disables the tracking.
103    Disable,
104    /// No changes.
105    None,
106}
107
108/// Describes the changes to the set of references at a single branch target, as well as changes to
109/// the environment.
110#[derive(Clone, Debug, Eq, PartialEq)]
111pub struct BranchChanges {
112    /// New references defined at a given branch.
113    /// should correspond to BranchInfo.results.
114    pub refs: Vec<OutputReferenceValue>,
115    /// The change to AP caused by the libfunc in the branch.
116    pub ap_change: ApChange,
117    /// A change to the ap tracking status.
118    pub ap_tracking_change: ApTrackingChange,
119    /// The change to the remaining gas value in the wallet.
120    pub gas_change: OrderedHashMap<CostTokenType, i64>,
121    /// Should the stack be cleared due to a gap between stack items.
122    pub clear_old_stack: bool,
123    /// The expected size of the known stack after the change.
124    pub new_stack_size: usize,
125}
126impl BranchChanges {
127    /// Creates a `BranchChanges` object.
128    /// `param_ref` is used to fetch the reference value of a param of the libfunc.
129    fn new<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
130        ap_change: ApChange,
131        ap_tracking_change: ApTrackingChange,
132        gas_change: OrderedHashMap<CostTokenType, i64>,
133        expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
134        branch_signature: &BranchSignature,
135        prev_env: &Environment,
136        param_ref: ParamRef,
137    ) -> Self {
138        assert_eq!(
139            expressions.len(),
140            branch_signature.vars.len(),
141            "The number of expressions does not match the number of expected results in the \
142             branch."
143        );
144        let clear_old_stack =
145            !matches!(&branch_signature.ap_change, SierraApChange::Known { new_vars_only: true });
146        let stack_base = if clear_old_stack { 0 } else { prev_env.stack_size };
147        let mut new_stack_size = stack_base;
148
149        let refs: Vec<_> = zip_eq(expressions, &branch_signature.vars)
150            .enumerate()
151            .map(|(output_idx, (expression, OutputVarInfo { ref_info, ty }))| {
152                validate_output_var_refs(ref_info, &expression);
153                let stack_idx =
154                    calc_output_var_stack_idx(ref_info, stack_base, clear_old_stack, &param_ref);
155                if let Some(stack_idx) = stack_idx {
156                    new_stack_size = new_stack_size.max(stack_idx + 1);
157                }
158                let introduction_point =
159                    if let OutputVarReferenceInfo::SameAsParam { param_idx } = ref_info {
160                        OutputReferenceValueIntroductionPoint::Existing(
161                            param_ref(*param_idx).introduction_point.clone(),
162                        )
163                    } else {
164                        // Marking the statement as unknown to be fixed later.
165                        OutputReferenceValueIntroductionPoint::New(output_idx)
166                    };
167                OutputReferenceValue { expression, ty: ty.clone(), stack_idx, introduction_point }
168            })
169            .collect();
170        validate_stack_top(ap_change, branch_signature, &refs);
171        Self { refs, ap_change, ap_tracking_change, gas_change, clear_old_stack, new_stack_size }
172    }
173}
174
175/// Validates that a new temp or local var have valid references in their matching expression.
176fn validate_output_var_refs(ref_info: &OutputVarReferenceInfo, expression: &ReferenceExpression) {
177    match ref_info {
178        OutputVarReferenceInfo::SameAsParam { .. } => {}
179        _ if expression.cells.is_empty() => {
180            assert_matches!(ref_info, OutputVarReferenceInfo::ZeroSized);
181        }
182        OutputVarReferenceInfo::ZeroSized => {
183            unreachable!("Non empty ReferenceExpression for zero sized variable.")
184        }
185        OutputVarReferenceInfo::NewTempVar { .. } => {
186            expression.cells.iter().for_each(|cell| {
187                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::AP, .. }))
188            });
189        }
190        OutputVarReferenceInfo::NewLocalVar => {
191            expression.cells.iter().for_each(|cell| {
192                assert_matches!(cell, CellExpression::Deref(CellRef { register: Register::FP, .. }))
193            });
194        }
195        OutputVarReferenceInfo::SimpleDerefs => {
196            expression
197                .cells
198                .iter()
199                .for_each(|cell| assert_matches!(cell, CellExpression::Deref(_)));
200        }
201        OutputVarReferenceInfo::PartialParam { .. } | OutputVarReferenceInfo::Deferred(_) => {}
202    };
203}
204
205/// Validates that the variables that are now on the top of the stack are contiguous and that if the
206/// stack was not broken the size of all the variables is consistent with the ap change.
207fn validate_stack_top(
208    ap_change: ApChange,
209    branch_signature: &BranchSignature,
210    refs: &[OutputReferenceValue],
211) {
212    // A mapping for the new temp vars allocated on the top of the stack from their index on the
213    // top of the stack to their index in the `refs` vector.
214    let stack_top_vars = UnorderedHashMap::<usize, usize>::from_iter(
215        branch_signature.vars.iter().enumerate().filter_map(|(arg_idx, var)| {
216            if let OutputVarReferenceInfo::NewTempVar { idx: stack_idx } = var.ref_info {
217                Some((stack_idx, arg_idx))
218            } else {
219                None
220            }
221        }),
222    );
223    let mut prev_ap_offset = None;
224    let mut stack_top_size = 0;
225    for i in 0..stack_top_vars.len() {
226        let Some(arg) = stack_top_vars.get(&i) else {
227            panic!("Missing top stack var #{i} out of {}.", stack_top_vars.len());
228        };
229        let cells = &refs[*arg].expression.cells;
230        stack_top_size += cells.len();
231        for cell in cells {
232            let ap_offset = match cell {
233                CellExpression::Deref(CellRef { register: Register::AP, offset }) => *offset,
234                _ => unreachable!("Tested in `validate_output_var_refs`."),
235            };
236            if let Some(prev_ap_offset) = prev_ap_offset {
237                assert_eq!(ap_offset, prev_ap_offset + 1, "Top stack vars are not contiguous.");
238            }
239            prev_ap_offset = Some(ap_offset);
240        }
241    }
242    if matches!(branch_signature.ap_change, SierraApChange::Known { new_vars_only: true }) {
243        assert_eq!(
244            ap_change,
245            ApChange::Known(stack_top_size),
246            "New tempvar variables are not contiguous with the old stack."
247        );
248    }
249    // TODO(orizi): Add assertion for the non-new_vars_only case, that it is optimal.
250}
251
252/// Calculates the continuous stack index for an output var of a branch.
253/// `param_ref` is used to fetch the reference value of a param of the libfunc.
254fn calc_output_var_stack_idx<'a, ParamRef: Fn(usize) -> &'a ReferenceValue>(
255    ref_info: &OutputVarReferenceInfo,
256    stack_base: usize,
257    clear_old_stack: bool,
258    param_ref: &ParamRef,
259) -> Option<usize> {
260    match ref_info {
261        OutputVarReferenceInfo::NewTempVar { idx } => Some(stack_base + idx),
262        OutputVarReferenceInfo::SameAsParam { param_idx } if !clear_old_stack => {
263            param_ref(*param_idx).stack_idx
264        }
265        OutputVarReferenceInfo::SameAsParam { .. }
266        | OutputVarReferenceInfo::SimpleDerefs
267        | OutputVarReferenceInfo::NewLocalVar
268        | OutputVarReferenceInfo::PartialParam { .. }
269        | OutputVarReferenceInfo::Deferred(_)
270        | OutputVarReferenceInfo::ZeroSized => None,
271    }
272}
273
274/// The result from a compilation of a single invocation statement.
275#[derive(Debug)]
276pub struct CompiledInvocation {
277    /// A vector of instructions that implement the invocation.
278    pub instructions: Vec<Instruction>,
279    /// A vector of static relocations.
280    pub relocations: Vec<RelocationEntry>,
281    /// A vector of BranchRefChanges, should correspond to the branches of the invocation
282    /// statement.
283    pub results: Vec<BranchChanges>,
284    /// The environment after the invocation statement.
285    pub environment: Environment,
286}
287
288/// Checks that the list of references is contiguous on the stack and ends at ap - 1.
289/// This is the requirement for function call and return statements.
290pub fn check_references_on_stack(refs: &[ReferenceValue]) -> Result<(), InvocationError> {
291    let mut expected_offset: i16 = -1;
292    for reference in refs.iter().rev() {
293        for cell_expr in reference.expression.cells.iter().rev() {
294            match cell_expr {
295                CellExpression::Deref(CellRef { register: Register::AP, offset })
296                    if *offset == expected_offset =>
297                {
298                    expected_offset -= 1;
299                }
300                _ => return Err(InvocationError::InvalidReferenceExpressionForArgument),
301            }
302        }
303    }
304    Ok(())
305}
306
307/// The cells per returned Sierra variables, in casm-builder vars.
308type VarCells = [Var];
309/// The configuration for all Sierra variables returned from a libfunc.
310type AllVars<'a> = [&'a VarCells];
311
312impl InvocationApChangeInfoProvider for CompiledInvocationBuilder<'_> {
313    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
314        self.program_info.type_sizes[ty] as usize
315    }
316
317    fn token_usages(&self, token_type: CostTokenType) -> usize {
318        self.program_info
319            .metadata
320            .gas_info
321            .variable_values
322            .get(&(self.idx, token_type))
323            .copied()
324            .unwrap_or(0) as usize
325    }
326}
327
328impl InvocationCostInfoProvider for CompiledInvocationBuilder<'_> {
329    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
330        self.program_info.type_sizes[ty] as usize
331    }
332
333    fn ap_change_var_value(&self) -> usize {
334        self.program_info
335            .metadata
336            .ap_change_info
337            .variable_values
338            .get(&self.idx)
339            .copied()
340            .unwrap_or_default()
341    }
342
343    fn token_usages(&self, token_type: CostTokenType) -> usize {
344        InvocationApChangeInfoProvider::token_usages(self, token_type)
345    }
346
347    fn circuit_info(&self, ty: &ConcreteTypeId) -> &CircuitInfo {
348        self.program_info.circuits_info.circuits.get(ty).unwrap()
349    }
350}
351
352/// Cost validation info for a builtin.
353struct BuiltinInfo {
354    /// The cost token type associated with the builtin.
355    cost_token_ty: CostTokenType,
356    /// The builtin pointer at the start of the libfunc.
357    start: Var,
358    /// The builtin pointer at the end of all the libfunc branches.
359    end: Var,
360}
361
362/// Information required for validating libfunc cost.
363#[derive(Default)]
364struct CostValidationInfo<const BRANCH_COUNT: usize> {
365    /// infos about builtin usage.
366    pub builtin_infos: Vec<BuiltinInfo>,
367    /// Possible extra cost per branch.
368    /// Useful for amortized costs, as well as gas withdrawal libfuncs.
369    pub extra_costs: Option<[i32; BRANCH_COUNT]>,
370}
371
372/// Helper for building compiled invocations.
373pub struct CompiledInvocationBuilder<'a> {
374    pub program_info: ProgramInfo<'a>,
375    pub invocation: &'a Invocation,
376    pub libfunc: &'a CoreConcreteLibfunc,
377    pub idx: StatementIdx,
378    /// The arguments of the libfunc.
379    pub refs: &'a [ReferenceValue],
380    pub environment: Environment,
381}
382impl CompiledInvocationBuilder<'_> {
383    /// Creates a new invocation.
384    fn build(
385        self,
386        instructions: Vec<Instruction>,
387        relocations: Vec<RelocationEntry>,
388        output_expressions: impl ExactSizeIterator<
389            Item = impl ExactSizeIterator<Item = ReferenceExpression>,
390        >,
391    ) -> CompiledInvocation {
392        let gas_changes =
393            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self);
394
395        let branch_signatures = self.libfunc.branch_signatures();
396        assert_eq!(
397            branch_signatures.len(),
398            output_expressions.len(),
399            "The number of output expressions does not match signature."
400        );
401        let ap_changes = core_libfunc_ap_change(self.libfunc, &self);
402        assert_eq!(
403            branch_signatures.len(),
404            ap_changes.len(),
405            "The number of ap changes does not match signature."
406        );
407        assert_eq!(
408            branch_signatures.len(),
409            gas_changes.len(),
410            "The number of gas changes does not match signature."
411        );
412
413        CompiledInvocation {
414            instructions,
415            relocations,
416            results: zip_eq(
417                zip_eq(branch_signatures, gas_changes),
418                zip_eq(output_expressions, ap_changes),
419            )
420            .map(|((branch_signature, gas_change), (expressions, ap_change))| {
421                let ap_tracking_change = match ap_change {
422                    cairo_lang_sierra_ap_change::ApChange::EnableApTracking => {
423                        ApTrackingChange::Enable
424                    }
425                    cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
426                        ApTrackingChange::Disable
427                    }
428                    _ => ApTrackingChange::None,
429                };
430                let ap_change = match ap_change {
431                    cairo_lang_sierra_ap_change::ApChange::Known(x) => ApChange::Known(x),
432                    cairo_lang_sierra_ap_change::ApChange::AtLocalsFinalization(_)
433                    | cairo_lang_sierra_ap_change::ApChange::EnableApTracking
434                    | cairo_lang_sierra_ap_change::ApChange::DisableApTracking => {
435                        ApChange::Known(0)
436                    }
437                    cairo_lang_sierra_ap_change::ApChange::FinalizeLocals => {
438                        if let FrameState::Finalized { allocated } = self.environment.frame_state {
439                            ApChange::Known(allocated)
440                        } else {
441                            panic!("Unexpected frame state.")
442                        }
443                    }
444                    cairo_lang_sierra_ap_change::ApChange::FunctionCall(id) => self
445                        .program_info
446                        .metadata
447                        .ap_change_info
448                        .function_ap_change
449                        .get(&id)
450                        .map_or(ApChange::Unknown, |x| ApChange::Known(x + 2)),
451                    cairo_lang_sierra_ap_change::ApChange::FromMetadata => ApChange::Known(
452                        *self
453                            .program_info
454                            .metadata
455                            .ap_change_info
456                            .variable_values
457                            .get(&self.idx)
458                            .unwrap_or(&0),
459                    ),
460                    cairo_lang_sierra_ap_change::ApChange::Unknown => ApChange::Unknown,
461                };
462
463                BranchChanges::new(
464                    ap_change,
465                    ap_tracking_change,
466                    gas_change.iter().map(|(token_type, val)| (*token_type, -val)).collect(),
467                    expressions,
468                    branch_signature,
469                    &self.environment,
470                    |idx| &self.refs[idx],
471                )
472            })
473            .collect(),
474            environment: self.environment,
475        }
476    }
477
478    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
479    /// Per branch requires `(name, result_variables, target_statement_id)`.
480    fn build_from_casm_builder<const BRANCH_COUNT: usize>(
481        self,
482        casm_builder: CasmBuilder,
483        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
484        cost_validation: CostValidationInfo<BRANCH_COUNT>,
485    ) -> CompiledInvocation {
486        self.build_from_casm_builder_ex(
487            casm_builder,
488            branch_extractions,
489            cost_validation,
490            Default::default(),
491        )
492    }
493
494    /// Builds a `CompiledInvocation` from a casm builder and branch extractions.
495    /// Per branch requires `(name, result_variables, target_statement_id)`.
496    ///
497    /// `pre_instructions` - Instructions to execute before the ones created by the builder.
498    fn build_from_casm_builder_ex<const BRANCH_COUNT: usize>(
499        self,
500        casm_builder: CasmBuilder,
501        branch_extractions: [(&str, &AllVars<'_>, Option<StatementIdx>); BRANCH_COUNT],
502        cost_validation: CostValidationInfo<BRANCH_COUNT>,
503        pre_instructions: InstructionsWithRelocations,
504    ) -> CompiledInvocation {
505        let CasmBuildResult { instructions, branches } =
506            casm_builder.build(branch_extractions.map(|(name, _, _)| name));
507        let expected_ap_changes = core_libfunc_ap_change(self.libfunc, &self);
508        let actual_ap_changes = branches
509            .iter()
510            .map(|(state, _)| cairo_lang_sierra_ap_change::ApChange::Known(state.ap_change));
511        if !itertools::equal(expected_ap_changes.iter().cloned(), actual_ap_changes.clone()) {
512            panic!(
513                "Wrong ap changes for {}. Expected: {expected_ap_changes:?}, actual: {:?}.",
514                self.invocation,
515                actual_ap_changes.collect_vec(),
516            );
517        }
518        let gas_changes =
519            core_libfunc_cost(&self.program_info.metadata.gas_info, &self.idx, self.libfunc, &self)
520                .into_iter()
521                .map(|costs| costs.get(&CostTokenType::Const).copied().unwrap_or_default());
522        let mut final_costs: [ConstCost; BRANCH_COUNT] =
523            std::array::from_fn(|_| Default::default());
524        for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
525            cost.steps += state.steps as i32;
526        }
527
528        for BuiltinInfo { cost_token_ty, start, end } in cost_validation.builtin_infos {
529            for (cost, (state, _)) in final_costs.iter_mut().zip(branches.iter()) {
530                let (start_base, start_offset) =
531                    state.get_adjusted(start).to_deref_with_offset().unwrap();
532                let (end_base, end_offset) =
533                    state.get_adjusted(end).to_deref_with_offset().unwrap();
534                assert_eq!(start_base, end_base);
535                let diff = end_offset - start_offset;
536                match cost_token_ty {
537                    CostTokenType::RangeCheck => {
538                        cost.range_checks += diff;
539                    }
540                    CostTokenType::RangeCheck96 => {
541                        cost.range_checks96 += diff;
542                    }
543                    _ => panic!("Cost token type not supported."),
544                }
545            }
546        }
547
548        let extra_costs =
549            cost_validation.extra_costs.unwrap_or(std::array::from_fn(|_| Default::default()));
550        let final_costs_with_extra =
551            final_costs.iter().zip(extra_costs).map(|(final_cost, extra)| {
552                (final_cost.cost() + extra + pre_instructions.cost.cost()) as i64
553            });
554        if !itertools::equal(gas_changes.clone(), final_costs_with_extra.clone()) {
555            panic!(
556                "Wrong costs for {}. Expected: {:?}, actual: {:?}, Costs from casm_builder: {:?}.",
557                self.invocation,
558                gas_changes.collect_vec(),
559                final_costs_with_extra.collect_vec(),
560                final_costs,
561            );
562        }
563        let branch_relocations = branches.iter().zip_eq(branch_extractions.iter()).flat_map(
564            |((_, relocations), (_, _, target))| {
565                assert_eq!(
566                    relocations.is_empty(),
567                    target.is_none(),
568                    "No relocations if nowhere to relocate to."
569                );
570                relocations.iter().map(|idx| RelocationEntry {
571                    instruction_idx: pre_instructions.instructions.len() + *idx,
572                    relocation: Relocation::RelativeStatementId(target.unwrap()),
573                })
574            },
575        );
576        let relocations = chain!(pre_instructions.relocations, branch_relocations).collect();
577        let output_expressions =
578            zip_eq(branches, branch_extractions).map(|((state, _), (_, vars, _))| {
579                vars.iter().map(move |var_cells| ReferenceExpression {
580                    cells: var_cells.iter().map(|cell| state.get_adjusted(*cell)).collect(),
581                })
582            });
583        self.build(
584            chain!(pre_instructions.instructions, instructions).collect(),
585            relocations,
586            output_expressions,
587        )
588    }
589
590    /// Creates a new invocation with only reference changes.
591    fn build_only_reference_changes(
592        self,
593        output_expressions: impl ExactSizeIterator<Item = ReferenceExpression>,
594    ) -> CompiledInvocation {
595        self.build(vec![], vec![], [output_expressions].into_iter())
596    }
597
598    /// Returns the reference expressions if the size is correct.
599    pub fn try_get_refs<const COUNT: usize>(
600        &self,
601    ) -> Result<[&ReferenceExpression; COUNT], InvocationError> {
602        if self.refs.len() == COUNT {
603            Ok(core::array::from_fn(|i| &self.refs[i].expression))
604        } else {
605            Err(InvocationError::WrongNumberOfArguments {
606                expected: COUNT,
607                actual: self.refs.len(),
608            })
609        }
610    }
611
612    /// Returns the reference expressions, assuming all contains one cell if the size is correct.
613    pub fn try_get_single_cells<const COUNT: usize>(
614        &self,
615    ) -> Result<[&CellExpression; COUNT], InvocationError> {
616        let refs = self.try_get_refs::<COUNT>()?;
617        let mut last_err = None;
618        const FAKE_CELL: CellExpression =
619            CellExpression::Deref(CellRef { register: Register::AP, offset: 0 });
620        // TODO(orizi): Use `refs.try_map` once it is a stable feature.
621        let result = refs.map(|r| match r.try_unpack_single() {
622            Ok(cell) => cell,
623            Err(err) => {
624                last_err = Some(err);
625                &FAKE_CELL
626            }
627        });
628        if let Some(err) = last_err { Err(err) } else { Ok(result) }
629    }
630}
631
632/// Information in the program level required for compiling an invocation.
633pub struct ProgramInfo<'a> {
634    pub metadata: &'a Metadata,
635    pub type_sizes: &'a TypeSizeMap,
636    /// Information about the circuits in the program.
637    pub circuits_info: &'a CircuitsInfo,
638    /// Returns the given a const type returns a vector of cells value representing it.
639    pub const_data_values: &'a dyn Fn(&ConcreteTypeId) -> Vec<BigInt>,
640}
641
642/// Given a Sierra invocation statement and concrete libfunc, creates a compiled casm representation
643/// of the Sierra statement.
644pub fn compile_invocation(
645    program_info: ProgramInfo<'_>,
646    invocation: &Invocation,
647    libfunc: &CoreConcreteLibfunc,
648    idx: StatementIdx,
649    refs: &[ReferenceValue],
650    environment: Environment,
651) -> Result<CompiledInvocation, InvocationError> {
652    let builder =
653        CompiledInvocationBuilder { program_info, invocation, libfunc, idx, refs, environment };
654    match libfunc {
655        Felt252(libfunc) => felt252::build(libfunc, builder),
656        Bool(libfunc) => boolean::build(libfunc, builder),
657        Cast(libfunc) => casts::build(libfunc, builder),
658        Ec(libfunc) => ec::build(libfunc, builder),
659        Uint8(libfunc) => int::unsigned::build_uint::<_, 0x100>(libfunc, builder),
660        Uint16(libfunc) => int::unsigned::build_uint::<_, 0x10000>(libfunc, builder),
661        Uint32(libfunc) => int::unsigned::build_uint::<_, 0x100000000>(libfunc, builder),
662        Uint64(libfunc) => int::unsigned::build_uint::<_, 0x10000000000000000>(libfunc, builder),
663        Uint128(libfunc) => int::unsigned128::build(libfunc, builder),
664        Uint256(libfunc) => int::unsigned256::build(libfunc, builder),
665        Uint512(libfunc) => int::unsigned512::build(libfunc, builder),
666        Sint8(libfunc) => {
667            int::signed::build_sint::<_, { i8::MIN as i128 }, { i8::MAX as i128 }>(libfunc, builder)
668        }
669        Sint16(libfunc) => {
670            int::signed::build_sint::<_, { i16::MIN as i128 }, { i16::MAX as i128 }>(
671                libfunc, builder,
672            )
673        }
674        Sint32(libfunc) => {
675            int::signed::build_sint::<_, { i32::MIN as i128 }, { i32::MAX as i128 }>(
676                libfunc, builder,
677            )
678        }
679        Sint64(libfunc) => {
680            int::signed::build_sint::<_, { i64::MIN as i128 }, { i64::MAX as i128 }>(
681                libfunc, builder,
682            )
683        }
684        Sint128(libfunc) => int::signed128::build(libfunc, builder),
685        Gas(libfunc) => gas::build(libfunc, builder),
686        BranchAlign(_) => misc::build_branch_align(builder),
687        Array(libfunc) => array::build(libfunc, builder),
688        Drop(_) => misc::build_drop(builder),
689        Dup(_) => misc::build_dup(builder),
690        Mem(libfunc) => mem::build(libfunc, builder),
691        UnwrapNonZero(_) => misc::build_identity(builder),
692        FunctionCall(libfunc) | CouponCall(libfunc) => function_call::build(libfunc, builder),
693        UnconditionalJump(_) => misc::build_jump(builder),
694        ApTracking(_) => misc::build_update_ap_tracking(builder),
695        Box(libfunc) => boxing::build(libfunc, builder),
696        Enum(libfunc) => enm::build(libfunc, builder),
697        Struct(libfunc) => structure::build(libfunc, builder),
698        Felt252Dict(libfunc) => felt252_dict::build_dict(libfunc, builder),
699        Pedersen(libfunc) => pedersen::build(libfunc, builder),
700        Poseidon(libfunc) => poseidon::build(libfunc, builder),
701        StarkNet(libfunc) => starknet::build(libfunc, builder),
702        Nullable(libfunc) => nullable::build(libfunc, builder),
703        Debug(libfunc) => debug::build(libfunc, builder),
704        SnapshotTake(_) => misc::build_dup(builder),
705        Felt252DictEntry(libfunc) => felt252_dict::build_entry(libfunc, builder),
706        Bytes31(libfunc) => bytes31::build(libfunc, builder),
707        Const(libfunc) => const_type::build(libfunc, builder),
708        Coupon(libfunc) => match libfunc {
709            CouponConcreteLibfunc::Buy(_) => Ok(builder
710                .build_only_reference_changes([ReferenceExpression::zero_sized()].into_iter())),
711            CouponConcreteLibfunc::Refund(_) => {
712                Ok(builder.build_only_reference_changes([].into_iter()))
713            }
714        },
715        BoundedInt(libfunc) => int::bounded::build(libfunc, builder),
716        Circuit(libfunc) => circuit::build(libfunc, builder),
717        IntRange(libfunc) => range::build(libfunc, builder),
718    }
719}
720
721/// A trait for views of the Complex ReferenceExpressions as specific data structures (e.g.
722/// enum/array).
723trait ReferenceExpressionView: Sized {
724    type Error;
725    /// Extracts the specific view from the reference expressions. Can include validations and thus
726    /// returns a result.
727    /// `concrete_type_id` - the concrete type this view should represent.
728    fn try_get_view(
729        expr: &ReferenceExpression,
730        program_info: &ProgramInfo<'_>,
731        concrete_type_id: &ConcreteTypeId,
732    ) -> Result<Self, Self::Error>;
733    /// Converts the view into a ReferenceExpression.
734    fn to_reference_expression(self) -> ReferenceExpression;
735}
736
737/// Fetches the non-fallthrough jump target of the invocation, assuming this invocation is a
738/// conditional jump.
739pub fn get_non_fallthrough_statement_id(builder: &CompiledInvocationBuilder<'_>) -> StatementIdx {
740    match builder.invocation.branches.as_slice() {
741        [
742            BranchInfo { target: BranchTarget::Fallthrough, results: _ },
743            BranchInfo { target: BranchTarget::Statement(target_statement_id), results: _ },
744        ] => *target_statement_id,
745        _ => panic!("malformed invocation"),
746    }
747}
748
749/// Adds input variables into the builder while validating their type.
750macro_rules! add_input_variables {
751    ($casm_builder:ident,) => {};
752    ($casm_builder:ident, deref $var:ident; $($tok:tt)*) => {
753        let $var = $casm_builder.add_var(cairo_lang_casm::cell_expression::CellExpression::Deref(
754            $var.to_deref().ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
755        ));
756        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
757    };
758    ($casm_builder:ident, deref_or_immediate $var:ident; $($tok:tt)*) => {
759        let $var = $casm_builder.add_var(
760            match $var
761                .to_deref_or_immediate()
762                .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?
763            {
764                cairo_lang_casm::operand::DerefOrImmediate::Deref(cell) => {
765                    cairo_lang_casm::cell_expression::CellExpression::Deref(cell)
766                }
767                cairo_lang_casm::operand::DerefOrImmediate::Immediate(cell) => {
768                    cairo_lang_casm::cell_expression::CellExpression::Immediate(cell.value)
769                }
770            },
771        );
772        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
773    };
774    ($casm_builder:ident, buffer($slack:expr) $var:ident; $($tok:tt)*) => {
775        let $var = $casm_builder.add_var(
776            $var.to_buffer($slack).ok_or(InvocationError::InvalidReferenceExpressionForArgument)?,
777        );
778        $crate::invocations::add_input_variables!($casm_builder, $($tok)*)
779    };
780}
781use add_input_variables;