cairo_lang_sierra_to_casm/invocations/
enm.rs

1use cairo_lang_casm::builder::CasmBuilder;
2use cairo_lang_casm::cell_expression::CellExpression;
3use cairo_lang_casm::operand::CellRef;
4use cairo_lang_casm::{casm, casm_build_extend, casm_extend};
5use cairo_lang_sierra::extensions::ConcreteLibfunc;
6use cairo_lang_sierra::extensions::enm::{EnumConcreteLibfunc, EnumInitConcreteLibfunc};
7use cairo_lang_sierra::ids::ConcreteTypeId;
8use cairo_lang_sierra::program::{BranchInfo, BranchTarget};
9use cairo_lang_utils::try_extract_matches;
10use itertools::{chain, repeat_n};
11use num_bigint::BigInt;
12use starknet_types_core::felt::{Felt as Felt252, NonZeroFelt};
13
14use super::{
15    CompiledInvocation, CompiledInvocationBuilder, InvocationError, ReferenceExpressionView,
16};
17use crate::invocations::{CostValidationInfo, ProgramInfo, add_input_variables, misc};
18use crate::references::{ReferenceExpression, ReferencesError};
19use crate::relocations::{Relocation, RelocationEntry};
20
21/// Builds instructions for Sierra enum operations.
22pub fn build(
23    libfunc: &EnumConcreteLibfunc,
24    builder: CompiledInvocationBuilder<'_>,
25) -> Result<CompiledInvocation, InvocationError> {
26    match libfunc {
27        EnumConcreteLibfunc::Init(EnumInitConcreteLibfunc { index, n_variants, .. }) => {
28            build_enum_init(builder, *index, *n_variants)
29        }
30        EnumConcreteLibfunc::FromBoundedInt(libfunc) => {
31            build_enum_from_bounded_int(builder, libfunc.n_variants)
32        }
33        EnumConcreteLibfunc::Match(_) | EnumConcreteLibfunc::SnapshotMatch(_) => {
34            build_enum_match(builder)
35        }
36    }
37}
38
39/// Handles statement for initializing an enum.
40/// For example, with this setup
41/// ```ignore
42/// type felt252_ty = felt252;
43/// type unit_ty = Tuple;
44/// type Option = Enum<felt252_ty, unit_ty>;
45/// libfunc init_option_some = enum_init<Option, 0>;
46/// felt252_const<8>() -> (felt8);
47/// ````
48/// this "Sierra statement"
49/// ```ignore
50/// init_option_some(felt8=[ap-5]) -> (some_id);
51/// ```
52/// translates to these casm instructions:
53/// ```ignore
54/// [ap] = 0; ap++
55/// [ap] = 8; ap++
56/// ```
57fn build_enum_init(
58    builder: CompiledInvocationBuilder<'_>,
59    index: usize,
60    n_variants: usize,
61) -> Result<CompiledInvocation, InvocationError> {
62    let [expression] = builder.try_get_refs()?;
63    let init_arg_cells = &expression.cells;
64    let variant_selector = get_variant_selector(n_variants, index)?;
65
66    let variant_size = builder
67        .program_info
68        .type_sizes
69        .get(&builder.libfunc.param_signatures()[0].ty)
70        .ok_or(InvocationError::UnknownTypeData)?
71        .to_owned();
72    if init_arg_cells.len() != variant_size as usize {
73        return Err(InvocationError::InvalidReferenceExpressionForArgument);
74    }
75    // Pad the variant to match the size of the largest variant
76    let concrete_enum_type = &builder.libfunc.output_types()[0][0];
77    let enum_size = get_enum_size(&builder.program_info, concrete_enum_type)
78        .ok_or(InvocationError::UnknownTypeData)?;
79    let num_padding = enum_size - 1 - variant_size;
80    let inner_value = chain!(
81        repeat_n(CellExpression::Immediate(BigInt::from(0)), num_padding as usize),
82        init_arg_cells.clone(),
83    )
84    .collect();
85
86    let enum_val = EnumView {
87        variant_selector: CellExpression::Immediate(BigInt::from(variant_selector)),
88        inner_value,
89    };
90    let output_expressions = [enum_val.to_reference_expression()].into_iter();
91    Ok(builder.build_only_reference_changes(output_expressions))
92}
93
94/// Returns the variant selector for variant `index` out of `n_variants`.
95pub fn get_variant_selector(n_variants: usize, index: usize) -> Result<usize, InvocationError> {
96    Ok(if n_variants <= 2 {
97        // For num_branches <= 2, we use the index as the variant_selector as the `match`
98        // implementation jumps to the index 0 statement on 0, and to the index 1 statement on
99        // 1.
100        index
101    } else {
102        // For num_branches > 2, the `enum_match` libfunc is implemented using a jump table. In
103        // order to optimize `enum_match`, we define the variant_selector as the relevant
104        // relative jump in case we match the actual variant.
105        //
106        // - To jump to the variant in index 0, we skip the jump table and directly jump to it. Its
107        //   location is (2 * n - 1) CASM steps ahead, where n is the number of variants in this
108        //   enum (2 per variant but the first variant, and 1 for the first jump with a deref
109        //   operand).
110        // - To jump to the variant in index k, we add "jump rel (2 * (n - k) - 1)" as the first
111        //   jump is of size 1 and the rest of the jump instructions are with an immediate operand,
112        //   which makes them of size 2.
113        (n_variants - index).checked_mul(2).ok_or(InvocationError::IntegerOverflow)? - 1
114    })
115}
116
117fn build_enum_from_bounded_int(
118    builder: CompiledInvocationBuilder<'_>,
119    n_variants: usize,
120) -> Result<CompiledInvocation, InvocationError> {
121    if n_variants <= 2 {
122        return misc::build_identity(builder);
123    }
124
125    let [value] = builder.try_get_single_cells()?;
126    let mut casm_builder = CasmBuilder::default();
127    add_input_variables! {casm_builder,
128        deref value;
129    };
130
131    // Given the number of variants, `n`, and the index of the variant `0 <= k < n`:
132    // The variant selector for enums with 3 or more variants is the relative jump to the variant
133    // handle which is `2 * (n - k) - 1`.
134    // `2 * (n - k) - 1 = 2*n - 2*k - 1 = 2 * (2*n - 1) / 2 - 2*k = 2 * ((2*n - 1) / 2 - k)`
135    // Define `(2*n - 1) / 2` as `m` - which is known in compilation time.
136    // Hence the variant selector is `2 * (m - k)` or  alternatively `-2 * (k - m)`
137
138    let m = (Felt252::from(n_variants * 2 - 1).field_div(&NonZeroFelt::TWO)).to_bigint();
139    casm_build_extend! {casm_builder,
140        const m = m;
141        const negative_two = -2;
142        tempvar value_minus_m = value - m;
143        let variant_selector = value_minus_m * negative_two;
144    };
145
146    Ok(builder.build_from_casm_builder(
147        casm_builder,
148        [("Fallthrough", &[&[variant_selector]], None)],
149        CostValidationInfo::default(),
150    ))
151}
152
153/// Handles statement for matching an enum.
154fn build_enum_match(
155    builder: CompiledInvocationBuilder<'_>,
156) -> Result<CompiledInvocation, InvocationError> {
157    let concrete_enum_type = &builder.libfunc.param_signatures()[0].ty;
158    let [expression] = builder.try_get_refs()?;
159    let matched_var = EnumView::try_get_view(expression, &builder.program_info, concrete_enum_type)
160        .map_err(|_| InvocationError::InvalidReferenceExpressionForArgument)?;
161    // Verify variant_selector is of type deref. This is the case with an enum_value
162    // that was validly created and then stored.
163    let variant_selector =
164        try_extract_matches!(matched_var.variant_selector, CellExpression::Deref)
165            .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?;
166
167    let mut branch_output_sizes: Vec<usize> = Vec::new();
168    for branch_outputs in &builder.libfunc.output_types() {
169        // Each branch has a single output.
170        let branch_output = &branch_outputs[0];
171        let branch_output_size = builder
172            .program_info
173            .type_sizes
174            .get(branch_output)
175            .ok_or(InvocationError::UnknownTypeData)?;
176        branch_output_sizes.push(*branch_output_size as usize);
177    }
178    let output_expressions = branch_output_sizes.into_iter().map(|size| {
179        // The size of an output must be smaller than the size of `matched_var.inner_value` as the
180        // size of inner_value is fixed and is calculated as the max of the sizes of all the
181        // variants (which are the outputs in all the branches). Thus it is guaranteed that the
182        // iter we generate here is of size `size` (and not less).
183        let padding_size = matched_var.inner_value.len() - size;
184        vec![ReferenceExpression {
185            cells: matched_var.inner_value.iter().skip(padding_size).cloned().collect(),
186        }]
187        .into_iter()
188    });
189
190    let num_branches = builder.invocation.branches.len();
191    if num_branches <= 2 {
192        build_enum_match_short(builder, variant_selector, output_expressions)
193    } else {
194        build_enum_match_long(builder, variant_selector, output_expressions)
195    }
196}
197
198/// Handles statement for matching an enum with 1 or 2 variants.
199/// For example, with this setup
200/// ```ignore
201/// type felt252_ty = felt252;
202/// type unit_ty = Tuple;
203/// type Option = Enum<felt252_ty, unit_ty>;
204/// libfunc init_option_some = enum_init<Option, 0>;
205/// libfunc match_option = enum_match<Option>;
206/// felt252_const<8>() -> (felt8);
207/// init_option_some(felt8=[ap-5]) -> (enum_var);
208/// ````
209/// this "Sierra statement" (2-variants-enum)
210/// ```ignore
211/// match_option(enum_var=[ap-10]) {fallthrough(some=[ap-9]), 2000(none=[ap-9])};
212/// ```
213/// translates to these casm instructions:
214/// ```ignore
215/// jmp rel <jump_offset_2000> if [ap-10] != 0
216/// jmp rel <jump_offset_fallthrough>
217/// ```
218/// Or this "Sierra statement" (single-variant-enum)
219/// ```ignore
220/// match_option(enum_var=[ap-10]) {fallthrough(var=[ap-9])};
221/// ```
222/// translates to 0 casm instructions.
223///
224/// Assumes that builder.invocation.branches.len() == output_expressions.len() and that
225/// builder.invocation.branches.len() <= 2.
226fn build_enum_match_short(
227    builder: CompiledInvocationBuilder<'_>,
228    variant_selector: CellRef,
229    output_expressions: impl ExactSizeIterator<
230        Item = impl ExactSizeIterator<Item = ReferenceExpression>,
231    >,
232) -> Result<CompiledInvocation, InvocationError> {
233    let mut instructions = Vec::new();
234    let mut relocations = Vec::new();
235
236    // First branch is fallthrough. If there is only one branch, this `match` statement is
237    // translated to nothing in Casm.
238
239    // If we have 2 branches, add the jump_nz instruction to branch 1 if variant_selector != 0.
240    if let Some(branch) = builder.invocation.branches.get(1) {
241        let statement_id = match branch {
242            BranchInfo { target: BranchTarget::Statement(statement_id), .. } => *statement_id,
243            _ => panic!("malformed invocation"),
244        };
245
246        instructions.extend(casm! { jmp rel 0 if variant_selector != 0; }.instructions);
247        relocations.push(RelocationEntry {
248            instruction_idx: 0,
249            relocation: Relocation::RelativeStatementId(statement_id),
250        });
251    }
252
253    Ok(builder.build(instructions, relocations, output_expressions))
254}
255
256/// Handles statement for matching an enum with 3+ variants.
257/// For example, with this setup
258/// ```ignore
259/// type felt252_ty = felt252;
260/// type Positivity = Enum<felt252_ty, felt252_ty, felt252_ty>;
261/// libfunc init_positive = enum_init<Positivity, 0>;
262/// libfunc match_positivity = enum_match<Positivity>;
263/// felt252_const<8>() -> (felt8);
264/// init_positive(felt8=[ap-5]) -> (enum_var);
265/// ````
266/// this "Sierra statement" (3-variants-enum)
267/// ```ignore
268/// match_positivity(enum_var=[ap-10]) {fallthrough(pos=[ap-9]), 2000(neg=[ap-9]), 3000(zero=[ap-9])};
269/// ```
270/// translates to these casm instructions:
271/// ```ignore
272/// jmp rel [ap-10]
273/// jmp rel <jump_offset_2000>
274/// jmp rel <jump_offset_3000>
275/// ```
276/// Where in the first location of the enum_var there will be the jmp_table_idx (2*n-1 for
277/// branch index 0 (where n is the number of variants of this enum), 1 for branch index 1, 3 for
278/// branch index 2 and so on: (2 * k - 1) for branch index k).
279///
280/// Assumes that self.invocation.branches.len() == output_expressions.len() > 2.
281fn build_enum_match_long(
282    builder: CompiledInvocationBuilder<'_>,
283    variant_selector: CellRef,
284    output_expressions: impl ExactSizeIterator<
285        Item = impl ExactSizeIterator<Item = ReferenceExpression>,
286    >,
287) -> Result<CompiledInvocation, InvocationError> {
288    let target_statement_ids = builder.invocation.branches[1..].iter().map(|b| match b {
289        BranchInfo { target: BranchTarget::Statement(stmnt_id), .. } => *stmnt_id,
290        _ => panic!("malformed invocation"),
291    });
292
293    // The first instruction is the jmp to the relevant index in the jmp table.
294    let mut ctx = casm! { jmp rel variant_selector; };
295    let mut relocations = Vec::new();
296
297    // Add a jump-table entry for all the branches but the first one (we directly jump to it from
298    // the first jump above).
299    for (i, stmnt_id) in target_statement_ids.rev().enumerate() {
300        // Add the jump instruction to the relevant target.
301        casm_extend!(ctx, jmp rel 0;);
302        relocations.push(RelocationEntry {
303            instruction_idx: i + 1,
304            relocation: Relocation::RelativeStatementId(stmnt_id),
305        });
306    }
307
308    Ok(builder.build(ctx.instructions, relocations, output_expressions))
309}
310
311/// A struct representing an actual enum value in the Sierra program.
312#[derive(Clone, Debug, Eq, PartialEq)]
313pub struct EnumView {
314    /// This would be ReferenceExpression::Immediate after enum_init, and would be
315    /// ReferenceExpression::Deref after store_*.
316    pub variant_selector: CellExpression,
317    /// The inner value of the enum (a flat vector of cell expressions), padded with
318    /// CellExpression::Padding to match the size of the largest variant.
319    pub inner_value: Vec<CellExpression>,
320}
321
322impl ReferenceExpressionView for EnumView {
323    type Error = ReferencesError;
324
325    fn try_get_view(
326        expr: &ReferenceExpression,
327        program_info: &ProgramInfo<'_>,
328        enum_concrete_type: &ConcreteTypeId,
329    ) -> Result<Self, Self::Error> {
330        let enum_size = get_enum_size(program_info, enum_concrete_type)
331            .ok_or(ReferencesError::InvalidReferenceTypeForArgument)?
332            as usize;
333        // Verify the size.
334        if expr.cells.len() != enum_size {
335            return Err(ReferencesError::InvalidReferenceTypeForArgument);
336        }
337
338        let mut expr_cells_iter = expr.cells.iter();
339        let variant_selector =
340            expr_cells_iter.next().ok_or(ReferencesError::InvalidReferenceTypeForArgument)?.clone();
341
342        Ok(EnumView { variant_selector, inner_value: expr_cells_iter.cloned().collect() })
343    }
344
345    fn to_reference_expression(self) -> ReferenceExpression {
346        ReferenceExpression {
347            cells: chain!(
348                // Variant selector
349                vec![self.variant_selector].into_iter(),
350                // actual value's cells
351                self.inner_value.into_iter(),
352            )
353            .collect(),
354        }
355    }
356}
357
358/// Gets the size of the given concrete enum type.
359fn get_enum_size(
360    program_info: &ProgramInfo<'_>,
361    concrete_enum_type: &ConcreteTypeId,
362) -> Option<i16> {
363    Some(program_info.type_sizes.get(concrete_enum_type)?.to_owned())
364}