cairo_lang_sierra/extensions/modules/
enm.rs

1//! Sierra example:
2//! ```ignore
3//! type felt252_ty = felt252;
4//! type unit_ty = Tuple;
5//! type Option = Enum<felt252_ty, unit_ty>;
6//! libfunc init_option_some = enum_init<Option, 0>;
7//! libfunc init_option_none = enum_init<Option, 1>;
8//! libfunc match_option = enum_match<Option>;
9//! ...
10//! felt252_const<0>() -> (felt0);
11//! tuple_const() -> (unit);
12//! init_option_some(felt0) -> (some_id);
13//! init_option_none(unit) -> (none_id);
14//! match_option(some_id) {1000(some), 2000(none)};
15//! match_option(none_id) {1000(some), 2000(none)};
16//! ```
17
18use cairo_lang_utils::try_extract_matches;
19use num_bigint::ToBigInt;
20use num_traits::Signed;
21
22use super::snapshot::snapshot_ty;
23use super::structure::StructType;
24use super::utils::reinterpret_cast_signature;
25use crate::define_libfunc_hierarchy;
26use crate::extensions::bounded_int::bounded_int_ty;
27use crate::extensions::lib_func::{
28    BranchSignature, DeferredOutputKind, LibfuncSignature, OutputVarInfo, ParamSignature,
29    SierraApChange, SignatureOnlyGenericLibfunc, SignatureSpecializationContext,
30    SpecializationContext,
31};
32use crate::extensions::type_specialization_context::TypeSpecializationContext;
33use crate::extensions::types::TypeInfo;
34use crate::extensions::{
35    ConcreteType, NamedLibfunc, NamedType, OutputVarReferenceInfo, SignatureBasedConcreteLibfunc,
36    SpecializationError, args_as_single_type,
37};
38use crate::ids::{ConcreteTypeId, GenericTypeId};
39use crate::program::{ConcreteTypeLongId, GenericArg};
40
41/// Type representing an enum.
42#[derive(Default)]
43pub struct EnumType {}
44impl NamedType for EnumType {
45    type Concrete = EnumConcreteType;
46    const ID: GenericTypeId = GenericTypeId::new_inline("Enum");
47
48    fn specialize(
49        &self,
50        context: &dyn TypeSpecializationContext,
51        args: &[GenericArg],
52    ) -> Result<Self::Concrete, SpecializationError> {
53        Self::Concrete::new(context, args)
54    }
55}
56
57pub struct EnumConcreteType {
58    pub info: TypeInfo,
59    pub variants: Vec<ConcreteTypeId>,
60}
61impl EnumConcreteType {
62    fn new(
63        context: &dyn TypeSpecializationContext,
64        args: &[GenericArg],
65    ) -> Result<Self, SpecializationError> {
66        let mut args_iter = args.iter();
67        args_iter
68            .next()
69            .and_then(|arg| try_extract_matches!(arg, GenericArg::UserType))
70            .ok_or(SpecializationError::UnsupportedGenericArg)?;
71        let mut duplicatable = true;
72        let mut droppable = true;
73        let mut variants: Vec<ConcreteTypeId> = Vec::new();
74        for arg in args_iter {
75            let ty = try_extract_matches!(arg, GenericArg::Type)
76                .ok_or(SpecializationError::UnsupportedGenericArg)?
77                .clone();
78            let info = context.get_type_info(ty.clone())?;
79            if !info.storable {
80                return Err(SpecializationError::UnsupportedGenericArg);
81            }
82            if !info.duplicatable {
83                duplicatable = false;
84            }
85            if !info.droppable {
86                droppable = false;
87            }
88            variants.push(ty);
89        }
90        Ok(EnumConcreteType {
91            info: TypeInfo {
92                long_id: ConcreteTypeLongId {
93                    generic_id: "Enum".into(),
94                    generic_args: args.to_vec(),
95                },
96                duplicatable,
97                droppable,
98                storable: true,
99                zero_sized: false,
100            },
101            variants,
102        })
103    }
104
105    /// Returns the EnumConcreteType of the given type, or a specialization error if not possible.
106    fn try_from_concrete_type(
107        context: &dyn SignatureSpecializationContext,
108        ty: &ConcreteTypeId,
109    ) -> Result<Self, SpecializationError> {
110        let long_id = context.get_type_info(ty.clone())?.long_id;
111        if long_id.generic_id != EnumType::ID {
112            return Err(SpecializationError::UnsupportedGenericArg);
113        }
114        Self::new(context.as_type_specialization_context(), &long_id.generic_args)
115    }
116}
117
118impl ConcreteType for EnumConcreteType {
119    fn info(&self) -> &TypeInfo {
120        &self.info
121    }
122}
123
124define_libfunc_hierarchy! {
125    pub enum EnumLibfunc {
126        Init(EnumInitLibfunc),
127        FromBoundedInt(EnumFromBoundedIntLibfunc),
128        Match(EnumMatchLibfunc),
129        SnapshotMatch(EnumSnapshotMatchLibfunc),
130    }, EnumConcreteLibfunc
131}
132
133pub struct EnumInitConcreteLibfunc {
134    pub signature: LibfuncSignature,
135    /// The number of variants of the enum.
136    pub n_variants: usize,
137    /// The index of the relevant variant from the enum.
138    pub index: usize,
139}
140impl SignatureBasedConcreteLibfunc for EnumInitConcreteLibfunc {
141    fn signature(&self) -> &LibfuncSignature {
142        &self.signature
143    }
144}
145
146/// Libfunc for setting a value to an enum.
147#[derive(Default)]
148pub struct EnumInitLibfunc {}
149impl EnumInitLibfunc {
150    /// Creates the specialization of the enum-init libfunc with the given template arguments.
151    fn specialize_concrete_lib_func(
152        &self,
153        context: &dyn SignatureSpecializationContext,
154        args: &[GenericArg],
155    ) -> Result<EnumInitConcreteLibfunc, SpecializationError> {
156        let (enum_type, index) = match args {
157            [GenericArg::Type(enum_type), GenericArg::Value(index)] => {
158                (enum_type.clone(), index.clone())
159            }
160            [_, _] => return Err(SpecializationError::UnsupportedGenericArg),
161            _ => return Err(SpecializationError::WrongNumberOfGenericArgs),
162        };
163        let variant_types = EnumConcreteType::try_from_concrete_type(context, &enum_type)?.variants;
164        let n_variants = variant_types.len();
165        if index.is_negative() || index >= n_variants.to_bigint().unwrap() {
166            return Err(SpecializationError::IndexOutOfRange { index, range_size: n_variants });
167        }
168        let index: usize = index.try_into().unwrap();
169        let variant_type = variant_types[index].clone();
170        Ok(EnumInitConcreteLibfunc {
171            signature: LibfuncSignature::new_non_branch_ex(
172                vec![ParamSignature {
173                    ty: variant_type,
174                    allow_deferred: true,
175                    allow_add_const: true,
176                    allow_const: true,
177                }],
178                vec![OutputVarInfo {
179                    ty: enum_type,
180                    ref_info: OutputVarReferenceInfo::Deferred(DeferredOutputKind::Generic),
181                }],
182                SierraApChange::Known { new_vars_only: true },
183            ),
184            n_variants,
185            index,
186        })
187    }
188}
189impl NamedLibfunc for EnumInitLibfunc {
190    type Concrete = EnumInitConcreteLibfunc;
191    const STR_ID: &'static str = "enum_init";
192
193    fn specialize_signature(
194        &self,
195        context: &dyn SignatureSpecializationContext,
196        args: &[GenericArg],
197    ) -> Result<LibfuncSignature, SpecializationError> {
198        Ok(self.specialize_concrete_lib_func(context, args)?.signature)
199    }
200
201    fn specialize(
202        &self,
203        context: &dyn SpecializationContext,
204        args: &[GenericArg],
205    ) -> Result<Self::Concrete, SpecializationError> {
206        self.specialize_concrete_lib_func(context.upcast(), args)
207    }
208}
209
210pub struct EnumFromBoundedIntConcreteLibfunc {
211    pub signature: LibfuncSignature,
212    /// The number of variants of the enum.
213    pub n_variants: usize,
214}
215impl SignatureBasedConcreteLibfunc for EnumFromBoundedIntConcreteLibfunc {
216    fn signature(&self) -> &LibfuncSignature {
217        &self.signature
218    }
219}
220
221/// Libfunc for creating an enum from a `BoundedInt` type.
222/// Will only work where there are the same number of empty variants as in the range of the
223/// `BoundedInt` type, and the range starts from 0.
224#[derive(Default)]
225pub struct EnumFromBoundedIntLibfunc {}
226impl EnumFromBoundedIntLibfunc {
227    /// Creates the specialization of the enum-from-bounded-int libfunc with the given template
228    /// arguments.
229    fn specialize_concrete_lib_func(
230        &self,
231        context: &dyn SignatureSpecializationContext,
232        args: &[GenericArg],
233    ) -> Result<EnumFromBoundedIntConcreteLibfunc, SpecializationError> {
234        let enum_type = args_as_single_type(args)?;
235        let variant_types = EnumConcreteType::try_from_concrete_type(context, &enum_type)?.variants;
236        let n_variants = variant_types.len();
237        if n_variants == 0 {
238            return Err(SpecializationError::UnsupportedGenericArg);
239        }
240
241        for v in variant_types {
242            let long_id = context.get_type_info(v)?.long_id;
243            // Only trivial empty structs are allowed as variant types.
244            if !(long_id.generic_id == StructType::ID && long_id.generic_args.len() == 1) {
245                return Err(SpecializationError::UnsupportedGenericArg);
246            }
247        }
248        let input_ty = bounded_int_ty(context, 0.into(), (n_variants - 1).into())?;
249        if n_variants <= 2 {
250            Ok(EnumFromBoundedIntConcreteLibfunc {
251                signature: reinterpret_cast_signature(input_ty, enum_type),
252                n_variants,
253            })
254        } else {
255            Ok(EnumFromBoundedIntConcreteLibfunc {
256                signature: LibfuncSignature::new_non_branch_ex(
257                    vec![ParamSignature::new(input_ty)],
258                    vec![OutputVarInfo {
259                        ty: enum_type,
260                        ref_info: OutputVarReferenceInfo::Deferred(DeferredOutputKind::Generic),
261                    }],
262                    SierraApChange::Known { new_vars_only: false },
263                ),
264                n_variants,
265            })
266        }
267    }
268}
269impl NamedLibfunc for EnumFromBoundedIntLibfunc {
270    type Concrete = EnumFromBoundedIntConcreteLibfunc;
271    const STR_ID: &'static str = "enum_from_bounded_int";
272
273    fn specialize_signature(
274        &self,
275        context: &dyn SignatureSpecializationContext,
276        args: &[GenericArg],
277    ) -> Result<LibfuncSignature, SpecializationError> {
278        Ok(self.specialize_concrete_lib_func(context, args)?.signature)
279    }
280
281    fn specialize(
282        &self,
283        context: &dyn SpecializationContext,
284        args: &[GenericArg],
285    ) -> Result<Self::Concrete, SpecializationError> {
286        self.specialize_concrete_lib_func(context.upcast(), args)
287    }
288}
289
290/// Libfunc for matching an enum.
291#[derive(Default)]
292pub struct EnumMatchLibfunc {}
293impl SignatureOnlyGenericLibfunc for EnumMatchLibfunc {
294    const STR_ID: &'static str = "enum_match";
295
296    fn specialize_signature(
297        &self,
298        context: &dyn SignatureSpecializationContext,
299        args: &[GenericArg],
300    ) -> Result<LibfuncSignature, SpecializationError> {
301        let enum_type = args_as_single_type(args)?;
302        let variant_types = EnumConcreteType::try_from_concrete_type(context, &enum_type)?.variants;
303        let is_empty = variant_types.is_empty();
304        let branch_signatures = variant_types
305            .into_iter()
306            .map(|ty| {
307                Ok(BranchSignature {
308                    vars: vec![OutputVarInfo {
309                        ty: ty.clone(),
310                        ref_info: if context.get_type_info(ty)?.zero_sized {
311                            OutputVarReferenceInfo::ZeroSized
312                        } else {
313                            OutputVarReferenceInfo::PartialParam { param_idx: 0 }
314                        },
315                    }],
316                    ap_change: SierraApChange::Known { new_vars_only: true },
317                })
318            })
319            .collect::<Result<Vec<_>, _>>()?;
320
321        Ok(LibfuncSignature {
322            param_signatures: vec![enum_type.into()],
323            branch_signatures,
324            fallthrough: if is_empty { None } else { Some(0) },
325        })
326    }
327}
328
329/// Libfunc for matching an enum snapshot.
330#[derive(Default)]
331pub struct EnumSnapshotMatchLibfunc {}
332impl SignatureOnlyGenericLibfunc for EnumSnapshotMatchLibfunc {
333    const STR_ID: &'static str = "enum_snapshot_match";
334
335    fn specialize_signature(
336        &self,
337        context: &dyn SignatureSpecializationContext,
338        args: &[GenericArg],
339    ) -> Result<LibfuncSignature, SpecializationError> {
340        let enum_type = args_as_single_type(args)?;
341        let variant_types = EnumConcreteType::try_from_concrete_type(context, &enum_type)?.variants;
342        let branch_signatures = variant_types
343            .into_iter()
344            .map(|ty| {
345                Ok(BranchSignature {
346                    vars: vec![OutputVarInfo {
347                        ty: snapshot_ty(context, ty.clone())?,
348                        ref_info: if context.get_type_info(ty)?.zero_sized {
349                            OutputVarReferenceInfo::ZeroSized
350                        } else {
351                            // All memory of the deconstruction would have the same lifetime as the
352                            // first param - as it is its deconstruction.
353                            OutputVarReferenceInfo::PartialParam { param_idx: 0 }
354                        },
355                    }],
356                    ap_change: SierraApChange::Known { new_vars_only: true },
357                })
358            })
359            .collect::<Result<Vec<_>, _>>()?;
360
361        Ok(LibfuncSignature {
362            param_signatures: vec![snapshot_ty(context, enum_type)?.into()],
363            branch_signatures,
364            fallthrough: Some(0),
365        })
366    }
367}