cairo_lang_sierra_to_casm/invocations/
enm.rs1use cairo_lang_casm::builder::CasmBuilder;
2use cairo_lang_casm::cell_expression::CellExpression;
3use cairo_lang_casm::operand::CellRef;
4use cairo_lang_casm::{casm, casm_build_extend, casm_extend};
5use cairo_lang_sierra::extensions::ConcreteLibfunc;
6use cairo_lang_sierra::extensions::enm::{EnumConcreteLibfunc, EnumInitConcreteLibfunc};
7use cairo_lang_sierra::ids::ConcreteTypeId;
8use cairo_lang_sierra::program::{BranchInfo, BranchTarget};
9use cairo_lang_utils::try_extract_matches;
10use itertools::{chain, repeat_n};
11use num_bigint::BigInt;
12use starknet_types_core::felt::{Felt as Felt252, NonZeroFelt};
13
14use super::{
15 CompiledInvocation, CompiledInvocationBuilder, InvocationError, ReferenceExpressionView,
16};
17use crate::invocations::{CostValidationInfo, ProgramInfo, add_input_variables, misc};
18use crate::references::{ReferenceExpression, ReferencesError};
19use crate::relocations::{Relocation, RelocationEntry};
20
21pub fn build(
23 libfunc: &EnumConcreteLibfunc,
24 builder: CompiledInvocationBuilder<'_>,
25) -> Result<CompiledInvocation, InvocationError> {
26 match libfunc {
27 EnumConcreteLibfunc::Init(EnumInitConcreteLibfunc { index, n_variants, .. }) => {
28 build_enum_init(builder, *index, *n_variants)
29 }
30 EnumConcreteLibfunc::FromBoundedInt(libfunc) => {
31 build_enum_from_bounded_int(builder, libfunc.n_variants)
32 }
33 EnumConcreteLibfunc::Match(_) | EnumConcreteLibfunc::SnapshotMatch(_) => {
34 build_enum_match(builder)
35 }
36 }
37}
38
39fn build_enum_init(
58 builder: CompiledInvocationBuilder<'_>,
59 index: usize,
60 n_variants: usize,
61) -> Result<CompiledInvocation, InvocationError> {
62 let [expression] = builder.try_get_refs()?;
63 let init_arg_cells = &expression.cells;
64 let variant_selector = get_variant_selector(n_variants, index)?;
65
66 let variant_size = builder
67 .program_info
68 .type_sizes
69 .get(&builder.libfunc.param_signatures()[0].ty)
70 .ok_or(InvocationError::UnknownTypeData)?
71 .to_owned();
72 if init_arg_cells.len() != variant_size as usize {
73 return Err(InvocationError::InvalidReferenceExpressionForArgument);
74 }
75 let concrete_enum_type = &builder.libfunc.output_types()[0][0];
77 let enum_size = get_enum_size(&builder.program_info, concrete_enum_type)
78 .ok_or(InvocationError::UnknownTypeData)?;
79 let num_padding = enum_size - 1 - variant_size;
80 let inner_value = chain!(
81 repeat_n(CellExpression::Immediate(BigInt::from(0)), num_padding as usize),
82 init_arg_cells.clone(),
83 )
84 .collect();
85
86 let enum_val = EnumView {
87 variant_selector: CellExpression::Immediate(BigInt::from(variant_selector)),
88 inner_value,
89 };
90 let output_expressions = [enum_val.to_reference_expression()].into_iter();
91 Ok(builder.build_only_reference_changes(output_expressions))
92}
93
94pub fn get_variant_selector(n_variants: usize, index: usize) -> Result<usize, InvocationError> {
96 Ok(if n_variants <= 2 {
97 index
101 } else {
102 (n_variants - index).checked_mul(2).ok_or(InvocationError::IntegerOverflow)? - 1
114 })
115}
116
117fn build_enum_from_bounded_int(
118 builder: CompiledInvocationBuilder<'_>,
119 n_variants: usize,
120) -> Result<CompiledInvocation, InvocationError> {
121 if n_variants <= 2 {
122 return misc::build_identity(builder);
123 }
124
125 let [value] = builder.try_get_single_cells()?;
126 let mut casm_builder = CasmBuilder::default();
127 add_input_variables! {casm_builder,
128 deref value;
129 };
130
131 let m = (Felt252::from(n_variants * 2 - 1).field_div(&NonZeroFelt::TWO)).to_bigint();
139 casm_build_extend! {casm_builder,
140 const m = m;
141 const negative_two = -2;
142 tempvar value_minus_m = value - m;
143 let variant_selector = value_minus_m * negative_two;
144 };
145
146 Ok(builder.build_from_casm_builder(
147 casm_builder,
148 [("Fallthrough", &[&[variant_selector]], None)],
149 CostValidationInfo::default(),
150 ))
151}
152
153fn build_enum_match(
155 builder: CompiledInvocationBuilder<'_>,
156) -> Result<CompiledInvocation, InvocationError> {
157 let concrete_enum_type = &builder.libfunc.param_signatures()[0].ty;
158 let [expression] = builder.try_get_refs()?;
159 let matched_var = EnumView::try_get_view(expression, &builder.program_info, concrete_enum_type)
160 .map_err(|_| InvocationError::InvalidReferenceExpressionForArgument)?;
161 let variant_selector =
164 try_extract_matches!(matched_var.variant_selector, CellExpression::Deref)
165 .ok_or(InvocationError::InvalidReferenceExpressionForArgument)?;
166
167 let mut branch_output_sizes: Vec<usize> = Vec::new();
168 for branch_outputs in &builder.libfunc.output_types() {
169 let branch_output = &branch_outputs[0];
171 let branch_output_size = builder
172 .program_info
173 .type_sizes
174 .get(branch_output)
175 .ok_or(InvocationError::UnknownTypeData)?;
176 branch_output_sizes.push(*branch_output_size as usize);
177 }
178 let output_expressions = branch_output_sizes.into_iter().map(|size| {
179 let padding_size = matched_var.inner_value.len() - size;
184 vec![ReferenceExpression {
185 cells: matched_var.inner_value.iter().skip(padding_size).cloned().collect(),
186 }]
187 .into_iter()
188 });
189
190 let num_branches = builder.invocation.branches.len();
191 if num_branches <= 2 {
192 build_enum_match_short(builder, variant_selector, output_expressions)
193 } else {
194 build_enum_match_long(builder, variant_selector, output_expressions)
195 }
196}
197
198fn build_enum_match_short(
227 builder: CompiledInvocationBuilder<'_>,
228 variant_selector: CellRef,
229 output_expressions: impl ExactSizeIterator<
230 Item = impl ExactSizeIterator<Item = ReferenceExpression>,
231 >,
232) -> Result<CompiledInvocation, InvocationError> {
233 let mut instructions = Vec::new();
234 let mut relocations = Vec::new();
235
236 if let Some(branch) = builder.invocation.branches.get(1) {
241 let statement_id = match branch {
242 BranchInfo { target: BranchTarget::Statement(statement_id), .. } => *statement_id,
243 _ => panic!("malformed invocation"),
244 };
245
246 instructions.extend(casm! { jmp rel 0 if variant_selector != 0; }.instructions);
247 relocations.push(RelocationEntry {
248 instruction_idx: 0,
249 relocation: Relocation::RelativeStatementId(statement_id),
250 });
251 }
252
253 Ok(builder.build(instructions, relocations, output_expressions))
254}
255
256fn build_enum_match_long(
282 builder: CompiledInvocationBuilder<'_>,
283 variant_selector: CellRef,
284 output_expressions: impl ExactSizeIterator<
285 Item = impl ExactSizeIterator<Item = ReferenceExpression>,
286 >,
287) -> Result<CompiledInvocation, InvocationError> {
288 let target_statement_ids = builder.invocation.branches[1..].iter().map(|b| match b {
289 BranchInfo { target: BranchTarget::Statement(stmnt_id), .. } => *stmnt_id,
290 _ => panic!("malformed invocation"),
291 });
292
293 let mut ctx = casm! { jmp rel variant_selector; };
295 let mut relocations = Vec::new();
296
297 for (i, stmnt_id) in target_statement_ids.rev().enumerate() {
300 casm_extend!(ctx, jmp rel 0;);
302 relocations.push(RelocationEntry {
303 instruction_idx: i + 1,
304 relocation: Relocation::RelativeStatementId(stmnt_id),
305 });
306 }
307
308 Ok(builder.build(ctx.instructions, relocations, output_expressions))
309}
310
311#[derive(Clone, Debug, Eq, PartialEq)]
313pub struct EnumView {
314 pub variant_selector: CellExpression,
317 pub inner_value: Vec<CellExpression>,
320}
321
322impl ReferenceExpressionView for EnumView {
323 type Error = ReferencesError;
324
325 fn try_get_view(
326 expr: &ReferenceExpression,
327 program_info: &ProgramInfo<'_>,
328 enum_concrete_type: &ConcreteTypeId,
329 ) -> Result<Self, Self::Error> {
330 let enum_size = get_enum_size(program_info, enum_concrete_type)
331 .ok_or(ReferencesError::InvalidReferenceTypeForArgument)?
332 as usize;
333 if expr.cells.len() != enum_size {
335 return Err(ReferencesError::InvalidReferenceTypeForArgument);
336 }
337
338 let mut expr_cells_iter = expr.cells.iter();
339 let variant_selector =
340 expr_cells_iter.next().ok_or(ReferencesError::InvalidReferenceTypeForArgument)?.clone();
341
342 Ok(EnumView { variant_selector, inner_value: expr_cells_iter.cloned().collect() })
343 }
344
345 fn to_reference_expression(self) -> ReferenceExpression {
346 ReferenceExpression {
347 cells: chain!(
348 vec![self.variant_selector].into_iter(),
350 self.inner_value.into_iter(),
352 )
353 .collect(),
354 }
355 }
356}
357
358fn get_enum_size(
360 program_info: &ProgramInfo<'_>,
361 concrete_enum_type: &ConcreteTypeId,
362) -> Option<i16> {
363 Some(program_info.type_sizes.get(concrete_enum_type)?.to_owned())
364}