1use std::iter;
2
3use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
4use cairo_lang_sierra::edit_state::{put_results, take_args};
5use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
6use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
7use cairo_lang_sierra_type_size::TypeSizeMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::zip_eq;
10use thiserror::Error;
11
12use crate::environment::ap_tracking::update_ap_tracking;
13use crate::environment::frame_state::FrameStateError;
14use crate::environment::gas_wallet::{GasWallet, GasWalletError};
15use crate::environment::{
16 ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
17 validate_final_environment,
18};
19use crate::invocations::{ApTrackingChange, BranchChanges};
20use crate::metadata::Metadata;
21use crate::references::{
22 IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
23 ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
24};
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum AnnotationError {
28 #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
29 InconsistentReferencesAnnotation {
30 statement_idx: StatementIdx,
31 error: InconsistentReferenceError,
32 },
33 #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
34 AnnotationAlreadySet {
35 source_statement_idx: StatementIdx,
36 destination_statement_idx: StatementIdx,
37 },
38 #[error("#{statement_idx}: {error}")]
39 InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
40 #[error("#{statement_idx}: Belongs to two different functions.")]
41 InconsistentFunctionId { statement_idx: StatementIdx },
42 #[error("#{statement_idx}: Invalid convergence.")]
43 InvalidConvergence { statement_idx: StatementIdx },
44 #[error("InvalidStatementIdx")]
45 InvalidStatementIdx,
46 #[error("MissingAnnotationsForStatement")]
47 MissingAnnotationsForStatement(StatementIdx),
48 #[error("#{statement_idx}: {var_id} is undefined.")]
49 MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
50 #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
51 OverrideReferenceError {
52 source_statement_idx: StatementIdx,
53 destination_statement_idx: StatementIdx,
54 var_id: VarId,
55 },
56 #[error(transparent)]
57 FrameStateError(#[from] FrameStateError),
58 #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
59 GasWalletError {
60 source_statement_idx: StatementIdx,
61 destination_statement_idx: StatementIdx,
62 error: GasWalletError,
63 },
64 #[error("#{statement_idx}: {error}")]
65 ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
66 #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
67 ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
68 #[error(
69 "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
70 {var_id}."
71 )]
72 ApChangeError {
73 var_id: VarId,
74 source_statement_idx: StatementIdx,
75 destination_statement_idx: StatementIdx,
76 error: ApChangeError,
77 },
78 #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
79 ApTrackingError {
80 source_statement_idx: StatementIdx,
81 destination_statement_idx: StatementIdx,
82 error: ApChangeError,
83 },
84 #[error(
85 "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
86 {expected:?}, got: {actual:?}."
87 )]
88 InvalidFunctionApChange {
89 statement_idx: StatementIdx,
90 expected: ApTracking,
91 actual: ApTracking,
92 },
93}
94
95#[derive(Error, Debug, Eq, PartialEq)]
97pub enum InconsistentReferenceError {
98 #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
99 TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
100 #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
101 ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
102 #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
103 StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
104 #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
105 IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
106 #[error("Variable count mismatch.")]
107 VariableCountMismatch,
108 #[error("Missing expected variable {0}.")]
109 VariableMissing(VarId),
110 #[error("Ap tracking is disabled while trying to merge {0}.")]
111 ApTrackingDisabled(VarId),
112}
113
114#[derive(Clone, Debug)]
116pub struct StatementAnnotations {
117 pub refs: StatementRefs,
118 pub function_id: FunctionId,
120 pub convergence_allowed: bool,
122 pub environment: Environment,
123}
124
125pub struct ProgramAnnotations {
128 per_statement_annotations: Vec<Option<StatementAnnotations>>,
130 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
132}
133impl ProgramAnnotations {
134 fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
135 ProgramAnnotations {
136 per_statement_annotations: iter::repeat_with(|| None).take(n_statements).collect(),
137 backwards_jump_indices,
138 }
139 }
140
141 pub fn create(
144 n_statements: usize,
145 backwards_jump_indices: UnorderedHashSet<StatementIdx>,
146 functions: &[Function],
147 metadata: &Metadata,
148 gas_usage_check: bool,
149 type_sizes: &TypeSizeMap,
150 ) -> Result<Self, AnnotationError> {
151 let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
152 for func in functions {
153 annotations.set_or_assert(func.entry_point, StatementAnnotations {
154 refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
155 AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
156 })?,
157 function_id: func.id.clone(),
158 convergence_allowed: false,
159 environment: Environment::new(if gas_usage_check {
160 GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
161 } else {
162 GasWallet::Disabled
163 }),
164 })?
165 }
166
167 Ok(annotations)
168 }
169
170 pub fn set_or_assert(
175 &mut self,
176 statement_idx: StatementIdx,
177 annotations: StatementAnnotations,
178 ) -> Result<(), AnnotationError> {
179 let idx = statement_idx.0;
180 match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
181 None => self.per_statement_annotations[idx] = Some(annotations),
182 Some(expected_annotations) => {
183 if expected_annotations.function_id != annotations.function_id {
184 return Err(AnnotationError::InconsistentFunctionId { statement_idx });
185 }
186 validate_environment_equality(
187 &expected_annotations.environment,
188 &annotations.environment,
189 )
190 .map_err(|error| AnnotationError::InconsistentEnvironments {
191 statement_idx,
192 error,
193 })?;
194 self.test_references_consistency(&annotations, expected_annotations).map_err(
195 |error| AnnotationError::InconsistentReferencesAnnotation {
196 statement_idx,
197 error,
198 },
199 )?;
200
201 if !expected_annotations.convergence_allowed {
204 return Err(AnnotationError::InvalidConvergence { statement_idx });
205 }
206 }
207 };
208 Ok(())
209 }
210
211 fn test_references_consistency(
214 &self,
215 actual: &StatementAnnotations,
216 expected: &StatementAnnotations,
217 ) -> Result<(), InconsistentReferenceError> {
218 if actual.refs.len() != expected.refs.len() {
220 return Err(InconsistentReferenceError::VariableCountMismatch);
221 }
222 let ap_tracking_enabled =
223 matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
224 for (var_id, actual_ref) in actual.refs.iter() {
225 let Some(expected_ref) = expected.refs.get(var_id) else {
227 return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
228 };
229 if actual_ref.ty != expected_ref.ty {
231 return Err(InconsistentReferenceError::TypeMismatch {
232 var: var_id.clone(),
233 expected: expected_ref.ty.clone(),
234 actual: actual_ref.ty.clone(),
235 });
236 }
237 if actual_ref.expression != expected_ref.expression {
238 return Err(InconsistentReferenceError::ExpressionMismatch {
239 var: var_id.clone(),
240 expected: expected_ref.expression.clone(),
241 actual: actual_ref.expression.clone(),
242 });
243 }
244 if actual_ref.stack_idx != expected_ref.stack_idx {
245 return Err(InconsistentReferenceError::StackIndexMismatch {
246 var: var_id.clone(),
247 expected: expected_ref.stack_idx,
248 actual: actual_ref.stack_idx,
249 });
250 }
251 test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
252 }
253 Ok(())
254 }
255
256 pub fn get_annotations_after_take_args<'a>(
260 &mut self,
261 statement_idx: StatementIdx,
262 ref_ids: impl Iterator<Item = &'a VarId>,
263 ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
264 let existing = self.per_statement_annotations[statement_idx.0]
265 .as_mut()
266 .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
267 let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
268 existing.clone()
269 } else {
270 std::mem::replace(existing, StatementAnnotations {
271 refs: Default::default(),
272 function_id: existing.function_id.clone(),
273 convergence_allowed: false,
275 environment: existing.environment.clone(),
276 })
277 };
278 let refs = std::mem::take(&mut updated.refs);
279 let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
280 AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
281 })?;
282 updated.refs = statement_refs;
283 Ok((updated, taken_refs))
284 }
285
286 pub fn propagate_annotations(
292 &mut self,
293 source_statement_idx: StatementIdx,
294 destination_statement_idx: StatementIdx,
295 mut annotations: StatementAnnotations,
296 branch_info: &BranchInfo,
297 branch_changes: BranchChanges,
298 must_set: bool,
299 ) -> Result<(), AnnotationError> {
300 if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
301 return Err(AnnotationError::AnnotationAlreadySet {
302 source_statement_idx,
303 destination_statement_idx,
304 });
305 }
306
307 for (var_id, ref_value) in annotations.refs.iter_mut() {
308 if branch_changes.clear_old_stack {
309 ref_value.stack_idx = None;
310 }
311 ref_value.expression =
312 std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
313 .apply_ap_change(branch_changes.ap_change)
314 .map_err(|error| AnnotationError::ApChangeError {
315 var_id: var_id.clone(),
316 source_statement_idx,
317 destination_statement_idx,
318 error,
319 })?;
320 }
321 let mut refs = put_results(
322 annotations.refs,
323 zip_eq(
324 &branch_info.results,
325 branch_changes.refs.into_iter().map(|value| ReferenceValue {
326 expression: value.expression,
327 ty: value.ty,
328 stack_idx: value.stack_idx,
329 introduction_point: match value.introduction_point {
330 OutputReferenceValueIntroductionPoint::New(output_idx) => {
331 IntroductionPoint {
332 source_statement_idx: Some(source_statement_idx),
333 destination_statement_idx,
334 output_idx,
335 }
336 }
337 OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
338 introduction_point
339 }
340 },
341 }),
342 ),
343 )
344 .map_err(|error| AnnotationError::OverrideReferenceError {
345 source_statement_idx,
346 destination_statement_idx,
347 var_id: error.var_id(),
348 })?;
349
350 let available_stack_indices: UnorderedHashSet<_> =
354 refs.values().flat_map(|r| r.stack_idx).collect();
355 let new_stack_size_opt = (0..branch_changes.new_stack_size)
356 .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
357 let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
358 let stack_removal = branch_changes.new_stack_size - new_stack_size;
360 for (_, r) in refs.iter_mut() {
361 r.stack_idx =
364 r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
365 }
366 new_stack_size
367 } else {
368 branch_changes.new_stack_size
369 };
370
371 let ap_tracking = match branch_changes.ap_tracking_change {
372 ApTrackingChange::Disable => ApTracking::Disabled,
373 ApTrackingChange::Enable => {
374 if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
375 return Err(AnnotationError::ApTrackingAlreadyEnabled {
376 statement_idx: source_statement_idx,
377 });
378 }
379 ApTracking::Enabled {
380 ap_change: 0,
381 base: ApTrackingBase::Statement(destination_statement_idx),
382 }
383 }
384 ApTrackingChange::None => {
385 update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
386 .map_err(|error| AnnotationError::ApTrackingError {
387 source_statement_idx,
388 destination_statement_idx,
389 error,
390 })?
391 }
392 };
393
394 self.set_or_assert(destination_statement_idx, StatementAnnotations {
395 refs,
396 function_id: annotations.function_id,
397 convergence_allowed: !must_set,
398 environment: Environment {
399 ap_tracking,
400 stack_size,
401 frame_state: annotations.environment.frame_state,
402 gas_wallet: annotations
403 .environment
404 .gas_wallet
405 .update(branch_changes.gas_change)
406 .map_err(|error| AnnotationError::GasWalletError {
407 source_statement_idx,
408 destination_statement_idx,
409 error,
410 })?,
411 },
412 })
413 }
414
415 pub fn validate_return_properties(
417 &self,
418 statement_idx: StatementIdx,
419 annotations: &StatementAnnotations,
420 functions: &[Function],
421 metadata: &Metadata,
422 return_refs: &[ReferenceValue],
423 ) -> Result<(), AnnotationError> {
424 let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
426
427 let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
428 Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
429 None => ApTracking::Disabled,
430 };
431 if annotations.environment.ap_tracking != expected_ap_tracking {
432 return Err(AnnotationError::InvalidFunctionApChange {
433 statement_idx,
434 expected: expected_ap_tracking,
435 actual: annotations.environment.ap_tracking,
436 });
437 }
438
439 check_types_match(return_refs, &func.signature.ret_types)
441 .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
442 Ok(())
443 }
444
445 pub fn validate_final_annotations(
447 &self,
448 statement_idx: StatementIdx,
449 annotations: &StatementAnnotations,
450 functions: &[Function],
451 metadata: &Metadata,
452 return_refs: &[ReferenceValue],
453 ) -> Result<(), AnnotationError> {
454 self.validate_return_properties(
455 statement_idx,
456 annotations,
457 functions,
458 metadata,
459 return_refs,
460 )?;
461 validate_final_environment(&annotations.environment)
462 .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
463 }
464}
465
466fn test_var_consistency(
470 var_id: &VarId,
471 actual: &ReferenceValue,
472 expected: &ReferenceValue,
473 ap_tracking_enabled: bool,
474) -> Result<(), InconsistentReferenceError> {
475 if actual.stack_idx.is_some() {
477 return Ok(());
478 }
479 if actual.expression.can_apply_unknown() {
482 return Ok(());
483 }
484 if !ap_tracking_enabled {
486 return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
487 }
488 if actual.introduction_point == expected.introduction_point {
490 Ok(())
491 } else {
492 Err(InconsistentReferenceError::IntroductionPointMismatch {
493 var: var_id.clone(),
494 expected: expected.introduction_point.clone(),
495 actual: actual.introduction_point.clone(),
496 })
497 }
498}