cairo_lang_sierra_to_casm/invocations/enm.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
use cairo_lang_casm::builder::CasmBuilder;
use cairo_lang_casm::cell_expression::CellExpression;
use cairo_lang_casm::operand::CellRef;
use cairo_lang_casm::{casm, casm_build_extend, casm_extend};
use cairo_lang_sierra::extensions::enm::{EnumConcreteLibfunc, EnumInitConcreteLibfunc};
use cairo_lang_sierra::extensions::ConcreteLibfunc;
use cairo_lang_sierra::ids::ConcreteTypeId;
use cairo_lang_sierra::program::{BranchInfo, BranchTarget};
use cairo_lang_utils::try_extract_matches;
use itertools::{chain, repeat_n};
use num_bigint::BigInt;
use starknet_types_core::felt::{Felt as Felt252, NonZeroFelt};
use super::{
CompiledInvocation, CompiledInvocationBuilder, InvocationError, ReferenceExpressionView,
};
use crate::invocations::{add_input_variables, misc, CostValidationInfo, ProgramInfo};
use crate::references::{ReferenceExpression, ReferencesError};
use crate::relocations::{Relocation, RelocationEntry};
/// Builds instructions for Sierra enum operations.
pub fn build(
libfunc: &EnumConcreteLibfunc,
builder: CompiledInvocationBuilder<'_>,
) -> Result<CompiledInvocation, InvocationError> {
match libfunc {
EnumConcreteLibfunc::Init(EnumInitConcreteLibfunc { index, n_variants, .. }) => {
build_enum_init(builder, *index, *n_variants)
}
EnumConcreteLibfunc::FromBoundedInt(libfunc) => {
build_enum_from_bounded_int(builder, libfunc.n_variants)
}
EnumConcreteLibfunc::Match(_) | EnumConcreteLibfunc::SnapshotMatch(_) => {
build_enum_match(builder)
}
}
}
/// Handles statement for initializing an enum.
/// For example, with this setup
/// ```ignore
/// type felt252_ty = felt252;
/// type unit_ty = Tuple;
/// type Option = Enum<felt252_ty, unit_ty>;
/// libfunc init_option_some = enum_init<Option, 0>;
/// felt252_const<8>() -> (felt8);
/// ````
/// this "Sierra statement"
/// ```ignore
/// init_option_some(felt8=[ap-5]) -> (some_id);
/// ```
/// translates to these casm instructions:
/// ```ignore
/// [ap] = 0; ap++
/// [ap] = 8; ap++
/// ```
fn build_enum_init(
builder: CompiledInvocationBuilder<'_>,
index: usize,
n_variants: usize,
) -> Result<CompiledInvocation, InvocationError> {
let [expression] = builder.try_get_refs()?;
let init_arg_cells = &expression.cells;
let variant_selector = get_variant_selector(n_variants, index)?;
let variant_size = builder
.program_info
.type_sizes
.get(&builder.libfunc.param_signatures()[0].ty)
.ok_or(InvocationError::UnknownTypeData)?
.to_owned();
if init_arg_cells.len() != variant_size as usize {
return Err(InvocationError::InvalidReferenceExpressionForArgument);
}
// Pad the variant to match the size of the largest variant
let concrete_enum_type = &builder.libfunc.output_types()[0][0];
let enum_size = get_enum_size(&builder.program_info, concrete_enum_type)
.ok_or(InvocationError::UnknownTypeData)?;
let num_padding = enum_size - 1 - variant_size;
let inner_value = chain!(
repeat_n(CellExpression::Immediate(BigInt::from(0)), num_padding as usize),
init_arg_cells.clone(),
)
.collect();
let enum_val = EnumView {
variant_selector: CellExpression::Immediate(BigInt::from(variant_selector)),
inner_value,
};
let output_expressions = [enum_val.to_reference_expression()].into_iter();
Ok(builder.build_only_reference_changes(output_expressions))
}
/// Returns the variant selector for variant `index` out of `n_variants`.
pub fn get_variant_selector(n_variants: usize, index: usize) -> Result<usize, InvocationError> {
Ok(if n_variants <= 2 {
// For num_branches <= 2, we use the index as the variant_selector as the `match`
// implementation jumps to the index 0 statement on 0, and to the index 1 statement on
// 1.
index
} else {
// For num_branches > 2, the `enum_match` libfunc is implemented using a jump table. In
// order to optimize `enum_match`, we define the variant_selector as the relevant
// relative jump in case we match the actual variant.
//
// - To jump to the variant in index 0, we skip the jump table and directly jump to it. Its
// location is (2 * n - 1) CASM steps ahead, where n is the number of variants in this
// enum (2 per variant but the first variant, and 1 for the first jump with a deref
// operand).
// - To jump to the variant in index k, we add "jump rel (2 * (n - k) - 1)" as the first
// jump is of size 1 and the rest of the jump instructions are with an immediate operand,
// which makes them of size 2.
(n_variants - index).checked_mul(2).ok_or(InvocationError::IntegerOverflow)? - 1
})
}
fn build_enum_from_bounded_int(
builder: CompiledInvocationBuilder<'_>,
n_variants: usize,
) -> Result<CompiledInvocation, InvocationError> {
if n_variants <= 2 {
return misc::build_identity(builder);
}
let [value] = builder.try_get_single_cells()?;
let mut casm_builder = CasmBuilder::default();
add_input_variables! {casm_builder,
deref value;
};
// Given the number of variants, `n`, and the index of the variant `0 <= k < n`:
// The variant selector for enums with 3 or more variants is the relative jump to the variant
// handle which is `2 * (n - k) - 1`.
// `2 * (n - k) - 1 = 2*n - 2*k - 1 = 2 * (2*n - 1) / 2 - 2*k = 2 * ((2*n - 1) / 2 - k)`
// Define `(2*n - 1) / 2` as `m` - which is known in compilation time.
// Hence the variant selector is `2 * (m - k)` or alternatively `-2 * (k - m)`
let m = (Felt252::from(n_variants * 2 - 1).field_div(&NonZeroFelt::TWO)).to_bigint();
casm_build_extend! {casm_builder,
const m = m;
const negative_two = -2;
tempvar value_minus_m = value - m;
let variant_selector = value_minus_m * negative_two;
};
Ok(builder.build_from_casm_builder(
casm_builder,
[("Fallthrough", &[&[variant_selector]], None)],
CostValidationInfo::default(),
))
}
/// Handles statement for matching an enum.
fn build_enum_match(
builder: CompiledInvocationBuilder<'_>,
) -> Result<CompiledInvocation, InvocationError> {
let concrete_enum_type = &builder.libfunc.param_signatures()[0].ty;
let [expression] = builder.try_get_refs()?;
let matched_var = EnumView::try_get_view(expression, &builder.program_info, concrete_enum_type)
.map_err(|_| InvocationError::InvalidReferenceExpressionForArgument)?;
// Verify variant_selector is of type deref. This is the case with an enum_value
// that was validly created and then stored.
let variant_selector =
try_extract_matches!(matched_var.variant_selector, CellExpression::Deref)
.ok_or(InvocationError::InvalidReferenceExpressionForArgument)?;
let mut branch_output_sizes: Vec<usize> = Vec::new();
for branch_outputs in &builder.libfunc.output_types() {
// Each branch has a single output.
let branch_output = &branch_outputs[0];
let branch_output_size = builder
.program_info
.type_sizes
.get(branch_output)
.ok_or(InvocationError::UnknownTypeData)?;
branch_output_sizes.push(*branch_output_size as usize);
}
let output_expressions = branch_output_sizes.into_iter().map(|size| {
// The size of an output must be smaller than the size of `matched_var.inner_value` as the
// size of inner_value is fixed and is calculated as the max of the sizes of all the
// variants (which are the outputs in all the branches). Thus it is guaranteed that the
// iter we generate here is of size `size` (and not less).
let padding_size = matched_var.inner_value.len() - size;
vec![ReferenceExpression {
cells: matched_var.inner_value.iter().skip(padding_size).cloned().collect(),
}]
.into_iter()
});
let num_branches = builder.invocation.branches.len();
if num_branches <= 2 {
build_enum_match_short(builder, variant_selector, output_expressions)
} else {
build_enum_match_long(builder, variant_selector, output_expressions)
}
}
/// Handles statement for matching an enum with 1 or 2 variants.
/// For example, with this setup
/// ```ignore
/// type felt252_ty = felt252;
/// type unit_ty = Tuple;
/// type Option = Enum<felt252_ty, unit_ty>;
/// libfunc init_option_some = enum_init<Option, 0>;
/// libfunc match_option = enum_match<Option>;
/// felt252_const<8>() -> (felt8);
/// init_option_some(felt8=[ap-5]) -> (enum_var);
/// ````
/// this "Sierra statement" (2-variants-enum)
/// ```ignore
/// match_option(enum_var=[ap-10]) {fallthrough(some=[ap-9]), 2000(none=[ap-9])};
/// ```
/// translates to these casm instructions:
/// ```ignore
/// jmp rel <jump_offset_2000> if [ap-10] != 0
/// jmp rel <jump_offset_fallthrough>
/// ```
/// Or this "Sierra statement" (single-variant-enum)
/// ```ignore
/// match_option(enum_var=[ap-10]) {fallthrough(var=[ap-9])};
/// ```
/// translates to 0 casm instructions.
///
/// Assumes that builder.invocation.branches.len() == output_expressions.len() and that
/// builder.invocation.branches.len() <= 2.
fn build_enum_match_short(
builder: CompiledInvocationBuilder<'_>,
variant_selector: CellRef,
output_expressions: impl ExactSizeIterator<
Item = impl ExactSizeIterator<Item = ReferenceExpression>,
>,
) -> Result<CompiledInvocation, InvocationError> {
let mut instructions = Vec::new();
let mut relocations = Vec::new();
// First branch is fallthrough. If there is only one branch, this `match` statement is
// translated to nothing in Casm.
// If we have 2 branches, add the jump_nz instruction to branch 1 if variant_selector != 0.
if let Some(branch) = builder.invocation.branches.get(1) {
let statement_id = match branch {
BranchInfo { target: BranchTarget::Statement(statement_id), .. } => *statement_id,
_ => panic!("malformed invocation"),
};
instructions.extend(casm! { jmp rel 0 if variant_selector != 0; }.instructions);
relocations.push(RelocationEntry {
instruction_idx: 0,
relocation: Relocation::RelativeStatementId(statement_id),
});
}
Ok(builder.build(instructions, relocations, output_expressions))
}
/// Handles statement for matching an enum with 3+ variants.
/// For example, with this setup
/// ```ignore
/// type felt252_ty = felt252;
/// type Positivity = Enum<felt252_ty, felt252_ty, felt252_ty>;
/// libfunc init_positive = enum_init<Positivity, 0>;
/// libfunc match_positivity = enum_match<Positivity>;
/// felt252_const<8>() -> (felt8);
/// init_positive(felt8=[ap-5]) -> (enum_var);
/// ````
/// this "Sierra statement" (3-variants-enum)
/// ```ignore
/// match_positivity(enum_var=[ap-10]) {fallthrough(pos=[ap-9]), 2000(neg=[ap-9]), 3000(zero=[ap-9])};
/// ```
/// translates to these casm instructions:
/// ```ignore
/// jmp rel [ap-10]
/// jmp rel <jump_offset_2000>
/// jmp rel <jump_offset_3000>
/// ```
/// Where in the first location of the enum_var there will be the jmp_table_idx (2*n-1 for
/// branch index 0 (where n is the number of variants of this enum), 1 for branch index 1, 3 for
/// branch index 2 and so on: (2 * k - 1) for branch index k).
///
/// Assumes that self.invocation.branches.len() == output_expressions.len() > 2.
fn build_enum_match_long(
builder: CompiledInvocationBuilder<'_>,
variant_selector: CellRef,
output_expressions: impl ExactSizeIterator<
Item = impl ExactSizeIterator<Item = ReferenceExpression>,
>,
) -> Result<CompiledInvocation, InvocationError> {
let target_statement_ids = builder.invocation.branches[1..].iter().map(|b| match b {
BranchInfo { target: BranchTarget::Statement(stmnt_id), .. } => *stmnt_id,
_ => panic!("malformed invocation"),
});
// The first instruction is the jmp to the relevant index in the jmp table.
let mut ctx = casm! { jmp rel variant_selector; };
let mut relocations = Vec::new();
// Add a jump-table entry for all the branches but the first one (we directly jump to it from
// the first jump above).
for (i, stmnt_id) in target_statement_ids.rev().enumerate() {
// Add the jump instruction to the relevant target.
casm_extend!(ctx, jmp rel 0;);
relocations.push(RelocationEntry {
instruction_idx: i + 1,
relocation: Relocation::RelativeStatementId(stmnt_id),
});
}
Ok(builder.build(ctx.instructions, relocations, output_expressions))
}
/// A struct representing an actual enum value in the Sierra program.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EnumView {
/// This would be ReferenceExpression::Immediate after enum_init, and would be
/// ReferenceExpression::Deref after store_*.
pub variant_selector: CellExpression,
/// The inner value of the enum (a flat vector of cell expressions), padded with
/// CellExpression::Padding to match the size of the largest variant.
pub inner_value: Vec<CellExpression>,
}
impl ReferenceExpressionView for EnumView {
type Error = ReferencesError;
fn try_get_view(
expr: &ReferenceExpression,
program_info: &ProgramInfo<'_>,
enum_concrete_type: &ConcreteTypeId,
) -> Result<Self, Self::Error> {
let enum_size = get_enum_size(program_info, enum_concrete_type)
.ok_or(ReferencesError::InvalidReferenceTypeForArgument)?
as usize;
// Verify the size.
if expr.cells.len() != enum_size {
return Err(ReferencesError::InvalidReferenceTypeForArgument);
}
let mut expr_cells_iter = expr.cells.iter();
let variant_selector =
expr_cells_iter.next().ok_or(ReferencesError::InvalidReferenceTypeForArgument)?.clone();
Ok(EnumView { variant_selector, inner_value: expr_cells_iter.cloned().collect() })
}
fn to_reference_expression(self) -> ReferenceExpression {
ReferenceExpression {
cells: chain!(
// Variant selector
vec![self.variant_selector].into_iter(),
// actual value's cells
self.inner_value.into_iter(),
)
.collect(),
}
}
}
/// Gets the size of the given concrete enum type.
fn get_enum_size(
program_info: &ProgramInfo<'_>,
concrete_enum_type: &ConcreteTypeId,
) -> Option<i16> {
Some(program_info.type_sizes.get(concrete_enum_type)?.to_owned())
}