1use std::iter;
2
3use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
4use cairo_lang_sierra::edit_state::{put_results, take_args};
5use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
6use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
7use cairo_lang_sierra_type_size::TypeSizeMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::{chain, zip_eq};
10use thiserror::Error;
11
12use crate::environment::ap_tracking::update_ap_tracking;
13use crate::environment::frame_state::FrameStateError;
14use crate::environment::gas_wallet::{GasWallet, GasWalletError};
15use crate::environment::{
16 ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
17 validate_final_environment,
18};
19use crate::invocations::{ApTrackingChange, BranchChanges};
20use crate::metadata::Metadata;
21use crate::references::{
22 IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
23 ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
24};
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum AnnotationError {
28 #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
29 InconsistentReferencesAnnotation {
30 statement_idx: StatementIdx,
31 error: InconsistentReferenceError,
32 },
33 #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
34 AnnotationAlreadySet {
35 source_statement_idx: StatementIdx,
36 destination_statement_idx: StatementIdx,
37 },
38 #[error("#{statement_idx}: {error}")]
39 InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
40 #[error("#{statement_idx}: Belongs to two different functions.")]
41 InconsistentFunctionId { statement_idx: StatementIdx },
42 #[error("#{statement_idx}: Invalid convergence.")]
43 InvalidConvergence { statement_idx: StatementIdx },
44 #[error("InvalidStatementIdx")]
45 InvalidStatementIdx,
46 #[error("MissingAnnotationsForStatement")]
47 MissingAnnotationsForStatement(StatementIdx),
48 #[error("#{statement_idx}: {var_id} is undefined.")]
49 MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
50 #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
51 OverrideReferenceError {
52 source_statement_idx: StatementIdx,
53 destination_statement_idx: StatementIdx,
54 var_id: VarId,
55 },
56 #[error(transparent)]
57 FrameStateError(#[from] FrameStateError),
58 #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
59 GasWalletError {
60 source_statement_idx: StatementIdx,
61 destination_statement_idx: StatementIdx,
62 error: GasWalletError,
63 },
64 #[error("#{statement_idx}: {error}")]
65 ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
66 #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
67 ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
68 #[error(
69 "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
70 {var_id} introduced at {} .", {introduction_point}
71 )]
72 ApChangeError {
73 var_id: VarId,
74 source_statement_idx: StatementIdx,
75 destination_statement_idx: StatementIdx,
76 introduction_point: IntroductionPoint,
77 error: ApChangeError,
78 },
79 #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
80 ApTrackingError {
81 source_statement_idx: StatementIdx,
82 destination_statement_idx: StatementIdx,
83 error: ApChangeError,
84 },
85 #[error(
86 "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
87 {expected:?}, got: {actual:?}."
88 )]
89 InvalidFunctionApChange {
90 statement_idx: StatementIdx,
91 expected: ApTracking,
92 actual: ApTracking,
93 },
94}
95
96impl AnnotationError {
97 pub fn stmt_indices(&self) -> Vec<StatementIdx> {
98 match self {
99 AnnotationError::ApChangeError {
100 source_statement_idx,
101 destination_statement_idx,
102 introduction_point,
103 ..
104 } => chain!(
105 [source_statement_idx, destination_statement_idx],
106 &introduction_point.source_statement_idx,
107 [&introduction_point.destination_statement_idx]
108 )
109 .cloned()
110 .collect(),
111 _ => vec![],
112 }
113 }
114}
115
116#[derive(Error, Debug, Eq, PartialEq)]
118pub enum InconsistentReferenceError {
119 #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
120 TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
121 #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
122 ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
123 #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
124 StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
125 #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
126 IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
127 #[error("Variable count mismatch.")]
128 VariableCountMismatch,
129 #[error("Missing expected variable {0}.")]
130 VariableMissing(VarId),
131 #[error("Ap tracking is disabled while trying to merge {0}.")]
132 ApTrackingDisabled(VarId),
133}
134
135#[derive(Clone, Debug)]
137pub struct StatementAnnotations {
138 pub refs: StatementRefs,
139 pub function_id: FunctionId,
141 pub convergence_allowed: bool,
143 pub environment: Environment,
144}
145
146pub struct ProgramAnnotations {
149 per_statement_annotations: Vec<Option<StatementAnnotations>>,
151 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
153}
154impl ProgramAnnotations {
155 fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
156 ProgramAnnotations {
157 per_statement_annotations: iter::repeat_with(|| None).take(n_statements).collect(),
158 backwards_jump_indices,
159 }
160 }
161
162 pub fn create(
165 n_statements: usize,
166 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
167 functions: &[Function],
168 metadata: &Metadata,
169 gas_usage_check: bool,
170 type_sizes: &TypeSizeMap,
171 ) -> Result<Self, AnnotationError> {
172 let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
173 for func in functions {
174 annotations.set_or_assert(
175 func.entry_point,
176 StatementAnnotations {
177 refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
178 AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
179 })?,
180 function_id: func.id.clone(),
181 convergence_allowed: false,
182 environment: Environment::new(if gas_usage_check {
183 GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
184 } else {
185 GasWallet::Disabled
186 }),
187 },
188 )?
189 }
190
191 Ok(annotations)
192 }
193
194 pub fn set_or_assert(
199 &mut self,
200 statement_idx: StatementIdx,
201 annotations: StatementAnnotations,
202 ) -> Result<(), AnnotationError> {
203 let idx = statement_idx.0;
204 match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
205 None => self.per_statement_annotations[idx] = Some(annotations),
206 Some(expected_annotations) => {
207 if expected_annotations.function_id != annotations.function_id {
208 return Err(AnnotationError::InconsistentFunctionId { statement_idx });
209 }
210 validate_environment_equality(
211 &expected_annotations.environment,
212 &annotations.environment,
213 )
214 .map_err(|error| AnnotationError::InconsistentEnvironments {
215 statement_idx,
216 error,
217 })?;
218 self.test_references_consistency(&annotations, expected_annotations).map_err(
219 |error| AnnotationError::InconsistentReferencesAnnotation {
220 statement_idx,
221 error,
222 },
223 )?;
224
225 if !expected_annotations.convergence_allowed {
228 return Err(AnnotationError::InvalidConvergence { statement_idx });
229 }
230 }
231 };
232 Ok(())
233 }
234
235 fn test_references_consistency(
238 &self,
239 actual: &StatementAnnotations,
240 expected: &StatementAnnotations,
241 ) -> Result<(), InconsistentReferenceError> {
242 if actual.refs.len() != expected.refs.len() {
244 return Err(InconsistentReferenceError::VariableCountMismatch);
245 }
246 let ap_tracking_enabled =
247 matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
248 for (var_id, actual_ref) in actual.refs.iter() {
249 let Some(expected_ref) = expected.refs.get(var_id) else {
251 return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
252 };
253 if actual_ref.ty != expected_ref.ty {
255 return Err(InconsistentReferenceError::TypeMismatch {
256 var: var_id.clone(),
257 expected: expected_ref.ty.clone(),
258 actual: actual_ref.ty.clone(),
259 });
260 }
261 if actual_ref.expression != expected_ref.expression {
262 return Err(InconsistentReferenceError::ExpressionMismatch {
263 var: var_id.clone(),
264 expected: expected_ref.expression.clone(),
265 actual: actual_ref.expression.clone(),
266 });
267 }
268 if actual_ref.stack_idx != expected_ref.stack_idx {
269 return Err(InconsistentReferenceError::StackIndexMismatch {
270 var: var_id.clone(),
271 expected: expected_ref.stack_idx,
272 actual: actual_ref.stack_idx,
273 });
274 }
275 test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
276 }
277 Ok(())
278 }
279
280 pub fn get_annotations_after_take_args<'a>(
284 &mut self,
285 statement_idx: StatementIdx,
286 ref_ids: impl Iterator<Item = &'a VarId>,
287 ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
288 let existing = self.per_statement_annotations[statement_idx.0]
289 .as_mut()
290 .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
291 let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
292 existing.clone()
293 } else {
294 std::mem::replace(
295 existing,
296 StatementAnnotations {
297 refs: Default::default(),
298 function_id: existing.function_id.clone(),
299 convergence_allowed: false,
301 environment: existing.environment.clone(),
302 },
303 )
304 };
305 let refs = std::mem::take(&mut updated.refs);
306 let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
307 AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
308 })?;
309 updated.refs = statement_refs;
310 Ok((updated, taken_refs))
311 }
312
313 pub fn propagate_annotations(
319 &mut self,
320 source_statement_idx: StatementIdx,
321 destination_statement_idx: StatementIdx,
322 mut annotations: StatementAnnotations,
323 branch_info: &BranchInfo,
324 branch_changes: BranchChanges,
325 must_set: bool,
326 ) -> Result<(), AnnotationError> {
327 if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
328 return Err(AnnotationError::AnnotationAlreadySet {
329 source_statement_idx,
330 destination_statement_idx,
331 });
332 }
333
334 for (var_id, ref_value) in annotations.refs.iter_mut() {
335 if branch_changes.clear_old_stack {
336 ref_value.stack_idx = None;
337 }
338 ref_value.expression =
339 std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
340 .apply_ap_change(branch_changes.ap_change)
341 .map_err(|error| AnnotationError::ApChangeError {
342 var_id: var_id.clone(),
343 source_statement_idx,
344 destination_statement_idx,
345 introduction_point: ref_value.introduction_point.clone(),
346 error,
347 })?;
348 }
349 let mut refs = put_results(
350 annotations.refs,
351 zip_eq(
352 &branch_info.results,
353 branch_changes.refs.into_iter().map(|value| ReferenceValue {
354 expression: value.expression,
355 ty: value.ty,
356 stack_idx: value.stack_idx,
357 introduction_point: match value.introduction_point {
358 OutputReferenceValueIntroductionPoint::New(output_idx) => {
359 IntroductionPoint {
360 source_statement_idx: Some(source_statement_idx),
361 destination_statement_idx,
362 output_idx,
363 }
364 }
365 OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
366 introduction_point
367 }
368 },
369 }),
370 ),
371 )
372 .map_err(|error| AnnotationError::OverrideReferenceError {
373 source_statement_idx,
374 destination_statement_idx,
375 var_id: error.var_id(),
376 })?;
377
378 let available_stack_indices: UnorderedHashSet<_> =
382 refs.values().flat_map(|r| r.stack_idx).collect();
383 let new_stack_size_opt = (0..branch_changes.new_stack_size)
384 .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
385 let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
386 let stack_removal = branch_changes.new_stack_size - new_stack_size;
388 for (_, r) in refs.iter_mut() {
389 r.stack_idx =
392 r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
393 }
394 new_stack_size
395 } else {
396 branch_changes.new_stack_size
397 };
398
399 let ap_tracking = match branch_changes.ap_tracking_change {
400 ApTrackingChange::Disable => ApTracking::Disabled,
401 ApTrackingChange::Enable => {
402 if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
403 return Err(AnnotationError::ApTrackingAlreadyEnabled {
404 statement_idx: source_statement_idx,
405 });
406 }
407 ApTracking::Enabled {
408 ap_change: 0,
409 base: ApTrackingBase::Statement(destination_statement_idx),
410 }
411 }
412 ApTrackingChange::None => {
413 update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
414 .map_err(|error| AnnotationError::ApTrackingError {
415 source_statement_idx,
416 destination_statement_idx,
417 error,
418 })?
419 }
420 };
421
422 self.set_or_assert(
423 destination_statement_idx,
424 StatementAnnotations {
425 refs,
426 function_id: annotations.function_id,
427 convergence_allowed: !must_set,
428 environment: Environment {
429 ap_tracking,
430 stack_size,
431 frame_state: annotations.environment.frame_state,
432 gas_wallet: annotations
433 .environment
434 .gas_wallet
435 .update(branch_changes.gas_change)
436 .map_err(|error| AnnotationError::GasWalletError {
437 source_statement_idx,
438 destination_statement_idx,
439 error,
440 })?,
441 },
442 },
443 )
444 }
445
446 pub fn validate_return_properties(
448 &self,
449 statement_idx: StatementIdx,
450 annotations: &StatementAnnotations,
451 functions: &[Function],
452 metadata: &Metadata,
453 return_refs: &[ReferenceValue],
454 ) -> Result<(), AnnotationError> {
455 let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
457
458 let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
459 Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
460 None => ApTracking::Disabled,
461 };
462 if annotations.environment.ap_tracking != expected_ap_tracking {
463 return Err(AnnotationError::InvalidFunctionApChange {
464 statement_idx,
465 expected: expected_ap_tracking,
466 actual: annotations.environment.ap_tracking,
467 });
468 }
469
470 check_types_match(return_refs, &func.signature.ret_types)
472 .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
473 Ok(())
474 }
475
476 pub fn validate_final_annotations(
478 &self,
479 statement_idx: StatementIdx,
480 annotations: &StatementAnnotations,
481 functions: &[Function],
482 metadata: &Metadata,
483 return_refs: &[ReferenceValue],
484 ) -> Result<(), AnnotationError> {
485 self.validate_return_properties(
486 statement_idx,
487 annotations,
488 functions,
489 metadata,
490 return_refs,
491 )?;
492 validate_final_environment(&annotations.environment)
493 .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
494 }
495}
496
497fn test_var_consistency(
501 var_id: &VarId,
502 actual: &ReferenceValue,
503 expected: &ReferenceValue,
504 ap_tracking_enabled: bool,
505) -> Result<(), InconsistentReferenceError> {
506 if actual.stack_idx.is_some() {
508 return Ok(());
509 }
510 if actual.expression.can_apply_unknown() {
513 return Ok(());
514 }
515 if !ap_tracking_enabled {
517 return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
518 }
519 if actual.introduction_point == expected.introduction_point {
521 Ok(())
522 } else {
523 Err(InconsistentReferenceError::IntroductionPointMismatch {
524 var: var_id.clone(),
525 expected: expected.introduction_point.clone(),
526 actual: actual.introduction_point.clone(),
527 })
528 }
529}