cairo_lang_sierra/
program_registry.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
5use itertools::{chain, izip};
6use thiserror::Error;
7
8use crate::extensions::lib_func::{
9    SierraApChange, SignatureSpecializationContext, SpecializationContext,
10};
11use crate::extensions::type_specialization_context::TypeSpecializationContext;
12use crate::extensions::types::TypeInfo;
13use crate::extensions::{
14    ConcreteLibfunc, ConcreteType, ExtensionError, GenericLibfunc, GenericLibfuncEx, GenericType,
15    GenericTypeEx,
16};
17use crate::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId};
18use crate::program::{
19    BranchTarget, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Program, Statement,
20    StatementIdx, TypeDeclaration,
21};
22
23#[cfg(test)]
24#[path = "program_registry_test.rs"]
25mod test;
26
27/// Errors encountered in the program registry.
28#[derive(Error, Debug, Eq, PartialEq)]
29pub enum ProgramRegistryError {
30    #[error("used the same function id twice `{0}`.")]
31    FunctionIdAlreadyExists(FunctionId),
32    #[error("Could not find the requested function `{0}`.")]
33    MissingFunction(FunctionId),
34    #[error("Error during type specialization of `{concrete_id}`: {error}")]
35    TypeSpecialization { concrete_id: ConcreteTypeId, error: ExtensionError },
36    #[error("Used concrete type id `{0}` twice")]
37    TypeConcreteIdAlreadyExists(ConcreteTypeId),
38    #[error("Declared concrete type `{0}` twice")]
39    TypeAlreadyDeclared(Box<TypeDeclaration>),
40    #[error("Could not find requested type `{0}`.")]
41    MissingType(ConcreteTypeId),
42    #[error("Error during libfunc specialization of {concrete_id}: {error}")]
43    LibfuncSpecialization { concrete_id: ConcreteLibfuncId, error: ExtensionError },
44    #[error("Used concrete libfunc id `{0}` twice.")]
45    LibfuncConcreteIdAlreadyExists(ConcreteLibfuncId),
46    #[error("Could not find requested libfunc `{0}`.")]
47    MissingLibfunc(ConcreteLibfuncId),
48    #[error("Type info declaration mismatch for `{0}`.")]
49    TypeInfoDeclarationMismatch(ConcreteTypeId),
50    #[error("Function `{func_id}`'s parameter type `{ty}` is not storable.")]
51    FunctionWithUnstorableType { func_id: FunctionId, ty: ConcreteTypeId },
52    #[error("Function `{0}` points to non existing entry point statement.")]
53    FunctionNonExistingEntryPoint(FunctionId),
54    #[error("#{0}: Libfunc invocation input count mismatch")]
55    LibfuncInvocationInputCountMismatch(StatementIdx),
56    #[error("#{0}: Libfunc invocation branch count mismatch")]
57    LibfuncInvocationBranchCountMismatch(StatementIdx),
58    #[error("#{0}: Libfunc invocation branch #{1} result count mismatch")]
59    LibfuncInvocationBranchResultCountMismatch(StatementIdx, usize),
60    #[error("#{0}: Libfunc invocation branch #{1} target mismatch")]
61    LibfuncInvocationBranchTargetMismatch(StatementIdx, usize),
62    #[error("#{src}: Branch jump backwards to {dst}")]
63    BranchBackwards { src: StatementIdx, dst: StatementIdx },
64    #[error("#{src}: Branch jump to a non-branch align statement #{dst}")]
65    BranchNotToBranchAlign { src: StatementIdx, dst: StatementIdx },
66    #[error("#{src1}, #{src2}: Jump to the same statement #{dst}")]
67    MultipleJumpsToSameStatement { src1: StatementIdx, src2: StatementIdx, dst: StatementIdx },
68    #[error("#{0}: Jump out of range")]
69    JumpOutOfRange(StatementIdx),
70}
71
72type TypeMap<TType> = HashMap<ConcreteTypeId, TType>;
73type LibfuncMap<TLibfunc> = HashMap<ConcreteLibfuncId, TLibfunc>;
74type FunctionMap = HashMap<FunctionId, Function>;
75/// Mapping from the arguments for generating a concrete type (the generic-id and the arguments) to
76/// the concrete-id that points to it.
77type ConcreteTypeIdMap<'a> = HashMap<(GenericTypeId, &'a [GenericArg]), ConcreteTypeId>;
78
79/// Registry for the data of the compiler, for all program specific data.
80pub struct ProgramRegistry<TType: GenericType, TLibfunc: GenericLibfunc> {
81    /// Mapping ids to the corresponding user function declaration from the program.
82    functions: FunctionMap,
83    /// Mapping ids to the concrete types represented by them.
84    concrete_types: TypeMap<TType::Concrete>,
85    /// Mapping ids to the concrete libfuncs represented by them.
86    concrete_libfuncs: LibfuncMap<TLibfunc::Concrete>,
87}
88impl<TType: GenericType, TLibfunc: GenericLibfunc> ProgramRegistry<TType, TLibfunc> {
89    /// Create a registry for the program.
90    pub fn new_with_ap_change(
91        program: &Program,
92        function_ap_change: OrderedHashMap<FunctionId, usize>,
93    ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
94        let functions = get_functions(program)?;
95        let (concrete_types, concrete_type_ids) = get_concrete_types_maps::<TType>(program)?;
96        let concrete_libfuncs =
97            get_concrete_libfuncs::<TType, TLibfunc>(program, &SpecializationContextForRegistry {
98                functions: &functions,
99                concrete_type_ids: &concrete_type_ids,
100                concrete_types: &concrete_types,
101                function_ap_change,
102            })?;
103        let registry = ProgramRegistry { functions, concrete_types, concrete_libfuncs };
104        registry.validate(program)?;
105        Ok(registry)
106    }
107
108    pub fn new(
109        program: &Program,
110    ) -> Result<ProgramRegistry<TType, TLibfunc>, Box<ProgramRegistryError>> {
111        Self::new_with_ap_change(program, Default::default())
112    }
113    /// Gets a function from the input program.
114    pub fn get_function<'a>(
115        &'a self,
116        id: &FunctionId,
117    ) -> Result<&'a Function, Box<ProgramRegistryError>> {
118        self.functions
119            .get(id)
120            .ok_or_else(|| Box::new(ProgramRegistryError::MissingFunction(id.clone())))
121    }
122    /// Gets a type from the input program.
123    pub fn get_type<'a>(
124        &'a self,
125        id: &ConcreteTypeId,
126    ) -> Result<&'a TType::Concrete, Box<ProgramRegistryError>> {
127        self.concrete_types
128            .get(id)
129            .ok_or_else(|| Box::new(ProgramRegistryError::MissingType(id.clone())))
130    }
131    /// Gets a libfunc from the input program.
132    pub fn get_libfunc<'a>(
133        &'a self,
134        id: &ConcreteLibfuncId,
135    ) -> Result<&'a TLibfunc::Concrete, Box<ProgramRegistryError>> {
136        self.concrete_libfuncs
137            .get(id)
138            .ok_or_else(|| Box::new(ProgramRegistryError::MissingLibfunc(id.clone())))
139    }
140
141    /// Checks the validity of the [ProgramRegistry] and runs validations on the program.
142    ///
143    /// Later compilation stages may perform more validations as well as repeat these validations.
144    fn validate(&self, program: &Program) -> Result<(), Box<ProgramRegistryError>> {
145        // Check that all the parameter and return types are storable.
146        for func in self.functions.values() {
147            for ty in chain!(func.signature.param_types.iter(), func.signature.ret_types.iter()) {
148                if !self.get_type(ty)?.info().storable {
149                    return Err(Box::new(ProgramRegistryError::FunctionWithUnstorableType {
150                        func_id: func.id.clone(),
151                        ty: ty.clone(),
152                    }));
153                }
154            }
155            if func.entry_point.0 >= program.statements.len() {
156                return Err(Box::new(ProgramRegistryError::FunctionNonExistingEntryPoint(
157                    func.id.clone(),
158                )));
159            }
160        }
161        // A branch map, mapping from a destination statement to the statement that jumps to it.
162        // A branch is considered a branch only if it has more than one target.
163        // Assuming branches into branch alignments only, this should be a bijection.
164        let mut branches: HashMap<StatementIdx, StatementIdx> =
165            HashMap::<StatementIdx, StatementIdx>::default();
166        for (i, statement) in program.statements.iter().enumerate() {
167            self.validate_statement(program, StatementIdx(i), statement, &mut branches)?;
168        }
169        Ok(())
170    }
171
172    /// Checks the validity of a statement.
173    fn validate_statement(
174        &self,
175        program: &Program,
176        index: StatementIdx,
177        statement: &Statement,
178        branches: &mut HashMap<StatementIdx, StatementIdx>,
179    ) -> Result<(), Box<ProgramRegistryError>> {
180        let Statement::Invocation(invocation) = statement else {
181            return Ok(());
182        };
183        let libfunc = self.get_libfunc(&invocation.libfunc_id)?;
184        if invocation.args.len() != libfunc.param_signatures().len() {
185            return Err(Box::new(ProgramRegistryError::LibfuncInvocationInputCountMismatch(index)));
186        }
187        let libfunc_branches = libfunc.branch_signatures();
188        if invocation.branches.len() != libfunc_branches.len() {
189            return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchCountMismatch(
190                index,
191            )));
192        }
193        let libfunc_fallthrough = libfunc.fallthrough();
194        for (branch_index, (invocation_branch, libfunc_branch)) in
195            izip!(&invocation.branches, libfunc_branches).enumerate()
196        {
197            if invocation_branch.results.len() != libfunc_branch.vars.len() {
198                return Err(Box::new(
199                    ProgramRegistryError::LibfuncInvocationBranchResultCountMismatch(
200                        index,
201                        branch_index,
202                    ),
203                ));
204            }
205            if matches!(libfunc_fallthrough, Some(target) if target == branch_index)
206                != (invocation_branch.target == BranchTarget::Fallthrough)
207            {
208                return Err(Box::new(ProgramRegistryError::LibfuncInvocationBranchTargetMismatch(
209                    index,
210                    branch_index,
211                )));
212            }
213            if !matches!(libfunc_branch.ap_change, SierraApChange::BranchAlign) {
214                if let Some(prev) = branches.get(&index) {
215                    return Err(Box::new(ProgramRegistryError::BranchNotToBranchAlign {
216                        src: *prev,
217                        dst: index,
218                    }));
219                }
220            }
221            let next = index.next(&invocation_branch.target);
222            if next.0 >= program.statements.len() {
223                return Err(Box::new(ProgramRegistryError::JumpOutOfRange(index)));
224            }
225            if libfunc_branches.len() > 1 {
226                if next.0 < index.0 {
227                    return Err(Box::new(ProgramRegistryError::BranchBackwards {
228                        src: index,
229                        dst: next,
230                    }));
231                }
232                match branches.entry(next) {
233                    Entry::Occupied(e) => {
234                        return Err(Box::new(ProgramRegistryError::MultipleJumpsToSameStatement {
235                            src1: *e.get(),
236                            src2: index,
237                            dst: next,
238                        }));
239                    }
240                    Entry::Vacant(e) => {
241                        e.insert(index);
242                    }
243                }
244            }
245        }
246        Ok(())
247    }
248}
249
250/// Creates the functions map.
251fn get_functions(program: &Program) -> Result<FunctionMap, Box<ProgramRegistryError>> {
252    let mut functions = FunctionMap::new();
253    for func in &program.funcs {
254        match functions.entry(func.id.clone()) {
255            Entry::Occupied(_) => {
256                Err(ProgramRegistryError::FunctionIdAlreadyExists(func.id.clone()))
257            }
258            Entry::Vacant(entry) => Ok(entry.insert(func.clone())),
259        }?;
260    }
261    Ok(functions)
262}
263
264struct TypeSpecializationContextForRegistry<'a, TType: GenericType> {
265    pub concrete_types: &'a TypeMap<TType::Concrete>,
266    pub declared_type_info: &'a TypeMap<TypeInfo>,
267}
268impl<TType: GenericType> TypeSpecializationContext
269    for TypeSpecializationContextForRegistry<'_, TType>
270{
271    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
272        self.declared_type_info
273            .get(&id)
274            .or_else(|| self.concrete_types.get(&id).map(|ty| ty.info()))
275            .cloned()
276    }
277}
278
279/// Creates the type-id to concrete type map, and the reverse map from generic-id and arguments to
280/// concrete-id.
281fn get_concrete_types_maps<TType: GenericType>(
282    program: &Program,
283) -> Result<(TypeMap<TType::Concrete>, ConcreteTypeIdMap<'_>), Box<ProgramRegistryError>> {
284    let mut concrete_types = HashMap::new();
285    let mut concrete_type_ids = HashMap::<(GenericTypeId, &[GenericArg]), ConcreteTypeId>::new();
286    let declared_type_info = program
287        .type_declarations
288        .iter()
289        .filter_map(|declaration| {
290            let TypeDeclaration { id, long_id, declared_type_info } = declaration;
291            let DeclaredTypeInfo { storable, droppable, duplicatable, zero_sized } =
292                declared_type_info.as_ref().cloned()?;
293            Some((id.clone(), TypeInfo {
294                long_id: long_id.clone(),
295                storable,
296                droppable,
297                duplicatable,
298                zero_sized,
299            }))
300        })
301        .collect();
302    for declaration in &program.type_declarations {
303        let concrete_type = TType::specialize_by_id(
304            &TypeSpecializationContextForRegistry::<TType> {
305                concrete_types: &concrete_types,
306                declared_type_info: &declared_type_info,
307            },
308            &declaration.long_id.generic_id,
309            &declaration.long_id.generic_args,
310        )
311        .map_err(|error| {
312            Box::new(ProgramRegistryError::TypeSpecialization {
313                concrete_id: declaration.id.clone(),
314                error,
315            })
316        })?;
317        // Check that the info is consistent with declaration.
318        if let Some(declared_info) = declared_type_info.get(&declaration.id) {
319            if concrete_type.info() != declared_info {
320                return Err(Box::new(ProgramRegistryError::TypeInfoDeclarationMismatch(
321                    declaration.id.clone(),
322                )));
323            }
324        }
325
326        match concrete_types.entry(declaration.id.clone()) {
327            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeConcreteIdAlreadyExists(
328                declaration.id.clone(),
329            ))),
330            Entry::Vacant(entry) => Ok(entry.insert(concrete_type)),
331        }?;
332        match concrete_type_ids
333            .entry((declaration.long_id.generic_id.clone(), &declaration.long_id.generic_args[..]))
334        {
335            Entry::Occupied(_) => Err(Box::new(ProgramRegistryError::TypeAlreadyDeclared(
336                Box::new(declaration.clone()),
337            ))),
338            Entry::Vacant(entry) => Ok(entry.insert(declaration.id.clone())),
339        }?;
340    }
341    Ok((concrete_types, concrete_type_ids))
342}
343
344/// Context required for specialization process.
345pub struct SpecializationContextForRegistry<'a, TType: GenericType> {
346    pub functions: &'a FunctionMap,
347    pub concrete_type_ids: &'a ConcreteTypeIdMap<'a>,
348    pub concrete_types: &'a TypeMap<TType::Concrete>,
349    /// AP changes information for Sierra user functions.
350    pub function_ap_change: OrderedHashMap<FunctionId, usize>,
351}
352impl<TType: GenericType> TypeSpecializationContext for SpecializationContextForRegistry<'_, TType> {
353    fn try_get_type_info(&self, id: ConcreteTypeId) -> Option<TypeInfo> {
354        self.concrete_types.get(&id).map(|ty| ty.info().clone())
355    }
356}
357impl<TType: GenericType> SignatureSpecializationContext
358    for SpecializationContextForRegistry<'_, TType>
359{
360    fn try_get_concrete_type(
361        &self,
362        id: GenericTypeId,
363        generic_args: &[GenericArg],
364    ) -> Option<ConcreteTypeId> {
365        self.concrete_type_ids.get(&(id, generic_args)).cloned()
366    }
367
368    fn try_get_function_signature(&self, function_id: &FunctionId) -> Option<FunctionSignature> {
369        self.try_get_function(function_id).map(|f| f.signature)
370    }
371
372    fn as_type_specialization_context(&self) -> &dyn TypeSpecializationContext {
373        self
374    }
375
376    fn try_get_function_ap_change(&self, function_id: &FunctionId) -> Option<SierraApChange> {
377        Some(if self.function_ap_change.contains_key(function_id) {
378            SierraApChange::Known { new_vars_only: false }
379        } else {
380            SierraApChange::Unknown
381        })
382    }
383}
384impl<TType: GenericType> SpecializationContext for SpecializationContextForRegistry<'_, TType> {
385    fn try_get_function(&self, function_id: &FunctionId) -> Option<Function> {
386        self.functions.get(function_id).cloned()
387    }
388
389    fn upcast(&self) -> &dyn SignatureSpecializationContext {
390        self
391    }
392}
393
394/// Creates the libfuncs map.
395fn get_concrete_libfuncs<TType: GenericType, TLibfunc: GenericLibfunc>(
396    program: &Program,
397    context: &SpecializationContextForRegistry<'_, TType>,
398) -> Result<LibfuncMap<TLibfunc::Concrete>, Box<ProgramRegistryError>> {
399    let mut concrete_libfuncs = HashMap::new();
400    for declaration in &program.libfunc_declarations {
401        let concrete_libfunc = TLibfunc::specialize_by_id(
402            context,
403            &declaration.long_id.generic_id,
404            &declaration.long_id.generic_args,
405        )
406        .map_err(|error| ProgramRegistryError::LibfuncSpecialization {
407            concrete_id: declaration.id.clone(),
408            error,
409        })?;
410        match concrete_libfuncs.entry(declaration.id.clone()) {
411            Entry::Occupied(_) => {
412                Err(ProgramRegistryError::LibfuncConcreteIdAlreadyExists(declaration.id.clone()))
413            }
414            Entry::Vacant(entry) => Ok(entry.insert(concrete_libfunc)),
415        }?;
416    }
417    Ok(concrete_libfuncs)
418}