use std::iter;
use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
use cairo_lang_sierra::edit_state::{put_results, take_args};
use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
use cairo_lang_sierra_type_size::TypeSizeMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::zip_eq;
use thiserror::Error;
use crate::environment::ap_tracking::update_ap_tracking;
use crate::environment::frame_state::FrameStateError;
use crate::environment::gas_wallet::{GasWallet, GasWalletError};
use crate::environment::{
validate_environment_equality, validate_final_environment, ApTracking, ApTrackingBase,
Environment, EnvironmentError,
};
use crate::invocations::{ApTrackingChange, BranchChanges};
use crate::metadata::Metadata;
use crate::references::{
build_function_parameters_refs, check_types_match, IntroductionPoint,
OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue, ReferencesError,
StatementRefs,
};
#[derive(Error, Debug, Eq, PartialEq)]
pub enum AnnotationError {
#[error("#{statement_idx}: Inconsistent references annotations: {error}")]
InconsistentReferencesAnnotation {
statement_idx: StatementIdx,
error: InconsistentReferenceError,
},
#[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
AnnotationAlreadySet {
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
},
#[error("#{statement_idx}: {error}")]
InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
#[error("#{statement_idx}: Belongs to two different functions.")]
InconsistentFunctionId { statement_idx: StatementIdx },
#[error("#{statement_idx}: Invalid convergence.")]
InvalidConvergence { statement_idx: StatementIdx },
#[error("InvalidStatementIdx")]
InvalidStatementIdx,
#[error("MissingAnnotationsForStatement")]
MissingAnnotationsForStatement(StatementIdx),
#[error("#{statement_idx}: {var_id} is undefined.")]
MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
#[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
OverrideReferenceError {
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
var_id: VarId,
},
#[error(transparent)]
FrameStateError(#[from] FrameStateError),
#[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
GasWalletError {
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
error: GasWalletError,
},
#[error("#{statement_idx}: {error}")]
ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
#[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
#[error(
"#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
{var_id}."
)]
ApChangeError {
var_id: VarId,
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
error: ApChangeError,
},
#[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
ApTrackingError {
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
error: ApChangeError,
},
#[error(
"#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
{expected:?}, got: {actual:?}."
)]
InvalidFunctionApChange {
statement_idx: StatementIdx,
expected: ApTracking,
actual: ApTracking,
},
}
#[derive(Error, Debug, Eq, PartialEq)]
pub enum InconsistentReferenceError {
#[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
#[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
#[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
#[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
#[error("Variable count mismatch.")]
VariableCountMismatch,
#[error("Missing expected variable {0}.")]
VariableMissing(VarId),
#[error("Ap tracking is disabled while trying to merge {0}.")]
ApTrackingDisabled(VarId),
}
#[derive(Clone, Debug)]
pub struct StatementAnnotations {
pub refs: StatementRefs,
pub function_id: FunctionId,
pub convergence_allowed: bool,
pub environment: Environment,
}
pub struct ProgramAnnotations {
per_statement_annotations: Vec<Option<StatementAnnotations>>,
backwards_jump_indices: UnorderedHashSet<StatementIdx>,
}
impl ProgramAnnotations {
fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
ProgramAnnotations {
per_statement_annotations: iter::repeat_with(|| None).take(n_statements).collect(),
backwards_jump_indices,
}
}
pub fn create(
n_statements: usize,
backwards_jump_indices: UnorderedHashSet<StatementIdx>,
functions: &[Function],
metadata: &Metadata,
gas_usage_check: bool,
type_sizes: &TypeSizeMap,
) -> Result<Self, AnnotationError> {
let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
for func in functions {
annotations.set_or_assert(
func.entry_point,
StatementAnnotations {
refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
})?,
function_id: func.id.clone(),
convergence_allowed: false,
environment: Environment::new(if gas_usage_check {
GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
} else {
GasWallet::Disabled
}),
},
)?
}
Ok(annotations)
}
pub fn set_or_assert(
&mut self,
statement_idx: StatementIdx,
annotations: StatementAnnotations,
) -> Result<(), AnnotationError> {
let idx = statement_idx.0;
match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
None => self.per_statement_annotations[idx] = Some(annotations),
Some(expected_annotations) => {
if expected_annotations.function_id != annotations.function_id {
return Err(AnnotationError::InconsistentFunctionId { statement_idx });
}
validate_environment_equality(
&expected_annotations.environment,
&annotations.environment,
)
.map_err(|error| AnnotationError::InconsistentEnvironments {
statement_idx,
error,
})?;
self.test_references_consistency(&annotations, expected_annotations).map_err(
|error| AnnotationError::InconsistentReferencesAnnotation {
statement_idx,
error,
},
)?;
if !expected_annotations.convergence_allowed {
return Err(AnnotationError::InvalidConvergence { statement_idx });
}
}
};
Ok(())
}
fn test_references_consistency(
&self,
actual: &StatementAnnotations,
expected: &StatementAnnotations,
) -> Result<(), InconsistentReferenceError> {
if actual.refs.len() != expected.refs.len() {
return Err(InconsistentReferenceError::VariableCountMismatch);
}
let ap_tracking_enabled =
matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
for (var_id, actual_ref) in actual.refs.iter() {
let Some(expected_ref) = expected.refs.get(var_id) else {
return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
};
if actual_ref.ty != expected_ref.ty {
return Err(InconsistentReferenceError::TypeMismatch {
var: var_id.clone(),
expected: expected_ref.ty.clone(),
actual: actual_ref.ty.clone(),
});
}
if actual_ref.expression != expected_ref.expression {
return Err(InconsistentReferenceError::ExpressionMismatch {
var: var_id.clone(),
expected: expected_ref.expression.clone(),
actual: actual_ref.expression.clone(),
});
}
if actual_ref.stack_idx != expected_ref.stack_idx {
return Err(InconsistentReferenceError::StackIndexMismatch {
var: var_id.clone(),
expected: expected_ref.stack_idx,
actual: actual_ref.stack_idx,
});
}
test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
}
Ok(())
}
pub fn get_annotations_after_take_args<'a>(
&mut self,
statement_idx: StatementIdx,
ref_ids: impl Iterator<Item = &'a VarId>,
) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
let existing = self.per_statement_annotations[statement_idx.0]
.as_mut()
.ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
existing.clone()
} else {
std::mem::replace(
existing,
StatementAnnotations {
refs: Default::default(),
function_id: existing.function_id.clone(),
convergence_allowed: false,
environment: existing.environment.clone(),
},
)
};
let refs = std::mem::take(&mut updated.refs);
let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
})?;
updated.refs = statement_refs;
Ok((updated, taken_refs))
}
pub fn propagate_annotations(
&mut self,
source_statement_idx: StatementIdx,
destination_statement_idx: StatementIdx,
mut annotations: StatementAnnotations,
branch_info: &BranchInfo,
branch_changes: BranchChanges,
must_set: bool,
) -> Result<(), AnnotationError> {
if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
return Err(AnnotationError::AnnotationAlreadySet {
source_statement_idx,
destination_statement_idx,
});
}
for (var_id, ref_value) in annotations.refs.iter_mut() {
if branch_changes.clear_old_stack {
ref_value.stack_idx = None;
}
ref_value.expression =
std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
.apply_ap_change(branch_changes.ap_change)
.map_err(|error| AnnotationError::ApChangeError {
var_id: var_id.clone(),
source_statement_idx,
destination_statement_idx,
error,
})?;
}
let mut refs = put_results(
annotations.refs,
zip_eq(
&branch_info.results,
branch_changes.refs.into_iter().map(|value| ReferenceValue {
expression: value.expression,
ty: value.ty,
stack_idx: value.stack_idx,
introduction_point: match value.introduction_point {
OutputReferenceValueIntroductionPoint::New(output_idx) => {
IntroductionPoint {
source_statement_idx: Some(source_statement_idx),
destination_statement_idx,
output_idx,
}
}
OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
introduction_point
}
},
}),
),
)
.map_err(|error| AnnotationError::OverrideReferenceError {
source_statement_idx,
destination_statement_idx,
var_id: error.var_id(),
})?;
let available_stack_indices: UnorderedHashSet<_> =
refs.values().flat_map(|r| r.stack_idx).collect();
let new_stack_size_opt = (0..branch_changes.new_stack_size)
.find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
let stack_removal = branch_changes.new_stack_size - new_stack_size;
for (_, r) in refs.iter_mut() {
r.stack_idx =
r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
}
new_stack_size
} else {
branch_changes.new_stack_size
};
let ap_tracking = match branch_changes.ap_tracking_change {
ApTrackingChange::Disable => ApTracking::Disabled,
ApTrackingChange::Enable => {
if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
return Err(AnnotationError::ApTrackingAlreadyEnabled {
statement_idx: source_statement_idx,
});
}
ApTracking::Enabled {
ap_change: 0,
base: ApTrackingBase::Statement(destination_statement_idx),
}
}
ApTrackingChange::None => {
update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
.map_err(|error| AnnotationError::ApTrackingError {
source_statement_idx,
destination_statement_idx,
error,
})?
}
};
self.set_or_assert(
destination_statement_idx,
StatementAnnotations {
refs,
function_id: annotations.function_id,
convergence_allowed: !must_set,
environment: Environment {
ap_tracking,
stack_size,
frame_state: annotations.environment.frame_state,
gas_wallet: annotations
.environment
.gas_wallet
.update(branch_changes.gas_change)
.map_err(|error| AnnotationError::GasWalletError {
source_statement_idx,
destination_statement_idx,
error,
})?,
},
},
)
}
pub fn validate_return_properties(
&self,
statement_idx: StatementIdx,
annotations: &StatementAnnotations,
functions: &[Function],
metadata: &Metadata,
return_refs: &[ReferenceValue],
) -> Result<(), AnnotationError> {
let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
None => ApTracking::Disabled,
};
if annotations.environment.ap_tracking != expected_ap_tracking {
return Err(AnnotationError::InvalidFunctionApChange {
statement_idx,
expected: expected_ap_tracking,
actual: annotations.environment.ap_tracking,
});
}
check_types_match(return_refs, &func.signature.ret_types)
.map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
Ok(())
}
pub fn validate_final_annotations(
&self,
statement_idx: StatementIdx,
annotations: &StatementAnnotations,
functions: &[Function],
metadata: &Metadata,
return_refs: &[ReferenceValue],
) -> Result<(), AnnotationError> {
self.validate_return_properties(
statement_idx,
annotations,
functions,
metadata,
return_refs,
)?;
validate_final_environment(&annotations.environment)
.map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
}
}
fn test_var_consistency(
var_id: &VarId,
actual: &ReferenceValue,
expected: &ReferenceValue,
ap_tracking_enabled: bool,
) -> Result<(), InconsistentReferenceError> {
if actual.stack_idx.is_some() {
return Ok(());
}
if actual.expression.can_apply_unknown() {
return Ok(());
}
if !ap_tracking_enabled {
return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
}
if actual.introduction_point == expected.introduction_point {
Ok(())
} else {
Err(InconsistentReferenceError::IntroductionPointMismatch {
var: var_id.clone(),
expected: expected.introduction_point.clone(),
actual: actual.introduction_point.clone(),
})
}
}