use cairo_lang_utils::try_extract_matches;
use num_bigint::ToBigInt;
use num_traits::Signed;
use super::snapshot::snapshot_ty;
use crate::define_libfunc_hierarchy;
use crate::extensions::lib_func::{
BranchSignature, DeferredOutputKind, LibfuncSignature, OutputVarInfo, ParamSignature,
SierraApChange, SignatureOnlyGenericLibfunc, SignatureSpecializationContext,
SpecializationContext,
};
use crate::extensions::type_specialization_context::TypeSpecializationContext;
use crate::extensions::types::TypeInfo;
use crate::extensions::{
args_as_single_type, ConcreteType, NamedLibfunc, NamedType, OutputVarReferenceInfo,
SignatureBasedConcreteLibfunc, SpecializationError,
};
use crate::ids::{ConcreteTypeId, GenericTypeId};
use crate::program::{ConcreteTypeLongId, GenericArg};
#[derive(Default)]
pub struct EnumType {}
impl NamedType for EnumType {
type Concrete = EnumConcreteType;
const ID: GenericTypeId = GenericTypeId::new_inline("Enum");
fn specialize(
&self,
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
Self::Concrete::new(context, args)
}
}
pub struct EnumConcreteType {
pub info: TypeInfo,
pub variants: Vec<ConcreteTypeId>,
}
impl EnumConcreteType {
fn new(
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self, SpecializationError> {
let mut args_iter = args.iter();
args_iter
.next()
.and_then(|arg| try_extract_matches!(arg, GenericArg::UserType))
.ok_or(SpecializationError::UnsupportedGenericArg)?;
let mut duplicatable = true;
let mut droppable = true;
let mut variants: Vec<ConcreteTypeId> = Vec::new();
for arg in args_iter {
let ty = try_extract_matches!(arg, GenericArg::Type)
.ok_or(SpecializationError::UnsupportedGenericArg)?
.clone();
let info = context.get_type_info(ty.clone())?;
if !info.storable {
return Err(SpecializationError::UnsupportedGenericArg);
}
if !info.duplicatable {
duplicatable = false;
}
if !info.droppable {
droppable = false;
}
variants.push(ty);
}
Ok(EnumConcreteType {
info: TypeInfo {
long_id: ConcreteTypeLongId {
generic_id: "Enum".into(),
generic_args: args.to_vec(),
},
duplicatable,
droppable,
storable: true,
zero_sized: false,
},
variants,
})
}
}
impl ConcreteType for EnumConcreteType {
fn info(&self) -> &TypeInfo {
&self.info
}
}
define_libfunc_hierarchy! {
pub enum EnumLibfunc {
Init(EnumInitLibfunc),
Match(EnumMatchLibfunc),
SnapshotMatch(EnumSnapshotMatchLibfunc),
}, EnumConcreteLibfunc
}
pub struct EnumInitConcreteLibfunc {
pub signature: LibfuncSignature,
pub num_variants: usize,
pub index: usize,
}
impl SignatureBasedConcreteLibfunc for EnumInitConcreteLibfunc {
fn signature(&self) -> &LibfuncSignature {
&self.signature
}
}
#[derive(Default)]
pub struct EnumInitLibfunc {}
impl EnumInitLibfunc {
fn specialize_concrete_lib_func(
&self,
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<EnumInitConcreteLibfunc, SpecializationError> {
let (enum_type, index) = match args {
[GenericArg::Type(enum_type), GenericArg::Value(index)] => {
(enum_type.clone(), index.clone())
}
[_, _] => return Err(SpecializationError::UnsupportedGenericArg),
_ => return Err(SpecializationError::WrongNumberOfGenericArgs),
};
let generic_args = context.get_type_info(enum_type.clone())?.long_id.generic_args;
let variant_types =
EnumConcreteType::new(context.as_type_specialization_context(), &generic_args)?
.variants;
let num_variants = variant_types.len();
if index.is_negative() || index >= num_variants.to_bigint().unwrap() {
return Err(SpecializationError::IndexOutOfRange { index, range_size: num_variants });
}
let index: usize = index.try_into().unwrap();
let variant_type = variant_types[index].clone();
Ok(EnumInitConcreteLibfunc {
signature: LibfuncSignature::new_non_branch_ex(
vec![ParamSignature {
ty: variant_type,
allow_deferred: true,
allow_add_const: true,
allow_const: true,
}],
vec![OutputVarInfo {
ty: enum_type,
ref_info: OutputVarReferenceInfo::Deferred(DeferredOutputKind::Generic),
}],
SierraApChange::Known { new_vars_only: true },
),
num_variants,
index,
})
}
}
impl NamedLibfunc for EnumInitLibfunc {
type Concrete = EnumInitConcreteLibfunc;
const STR_ID: &'static str = "enum_init";
fn specialize_signature(
&self,
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<LibfuncSignature, SpecializationError> {
Ok(self.specialize_concrete_lib_func(context, args)?.signature)
}
fn specialize(
&self,
context: &dyn SpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
self.specialize_concrete_lib_func(context.upcast(), args)
}
}
#[derive(Default)]
pub struct EnumMatchLibfunc {}
impl SignatureOnlyGenericLibfunc for EnumMatchLibfunc {
const STR_ID: &'static str = "enum_match";
fn specialize_signature(
&self,
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<LibfuncSignature, SpecializationError> {
let enum_type = args_as_single_type(args)?;
let generic_args = context.get_type_info(enum_type.clone())?.long_id.generic_args;
let variant_types =
EnumConcreteType::new(context.as_type_specialization_context(), &generic_args)?
.variants;
let is_empty = variant_types.is_empty();
let branch_signatures = variant_types
.into_iter()
.map(|ty| BranchSignature {
vars: vec![OutputVarInfo {
ty,
ref_info: OutputVarReferenceInfo::PartialParam { param_idx: 0 },
}],
ap_change: SierraApChange::Known { new_vars_only: true },
})
.collect();
Ok(LibfuncSignature {
param_signatures: vec![enum_type.into()],
branch_signatures,
fallthrough: if is_empty { None } else { Some(0) },
})
}
}
#[derive(Default)]
pub struct EnumSnapshotMatchLibfunc {}
impl SignatureOnlyGenericLibfunc for EnumSnapshotMatchLibfunc {
const STR_ID: &'static str = "enum_snapshot_match";
fn specialize_signature(
&self,
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<LibfuncSignature, SpecializationError> {
let enum_type = args_as_single_type(args)?;
let generic_args = context.get_type_info(enum_type.clone())?.long_id.generic_args;
let variant_types =
EnumConcreteType::new(context.as_type_specialization_context(), &generic_args)?
.variants;
let branch_signatures = variant_types
.into_iter()
.map(|ty| {
Ok(BranchSignature {
vars: vec![OutputVarInfo {
ty: snapshot_ty(context, ty)?,
ref_info: OutputVarReferenceInfo::PartialParam { param_idx: 0 },
}],
ap_change: SierraApChange::Known { new_vars_only: true },
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(LibfuncSignature {
param_signatures: vec![snapshot_ty(context, enum_type)?.into()],
branch_signatures,
fallthrough: Some(0),
})
}
}