cairo_lang_starknet/
contract.rs

1use anyhow::{Context, bail};
2use cairo_lang_defs::ids::{
3    FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
4    NamedLanguageElementId, SubmoduleId,
5};
6use cairo_lang_diagnostics::ToOption;
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_semantic::db::SemanticGroup;
9use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
10use cairo_lang_semantic::expr::inference::InferenceId;
11use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
12use cairo_lang_semantic::items::functions::{
13    ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
14};
15use cairo_lang_semantic::items::us::SemanticUseEx;
16use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
17use cairo_lang_semantic::substitution::SemanticRewriter;
18use cairo_lang_sierra::ids::FunctionId;
19use cairo_lang_sierra_generator::db::SierraGenGroup;
20use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
21use cairo_lang_starknet_classes::keccak::starknet_keccak;
22use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
23use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
24use cairo_lang_utils::ordered_hash_map::{
25    OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
26};
27use cairo_lang_utils::{Intern, LookupIntern, extract_matches};
28use itertools::chain;
29use serde::{Deserialize, Serialize};
30use starknet_types_core::felt::Felt as Felt252;
31use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
32
33use crate::aliased::Aliased;
34use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
35use crate::plugin::aux_data::StarknetContractAuxData;
36use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
37
38#[cfg(test)]
39#[path = "contract_test.rs"]
40mod test;
41
42/// Represents a declaration of a contract.
43#[derive(Clone)]
44pub struct ContractDeclaration {
45    /// The id of the module that defines the contract.
46    pub submodule_id: SubmoduleId,
47}
48
49impl ContractDeclaration {
50    pub fn module_id(&self) -> ModuleId {
51        ModuleId::Submodule(self.submodule_id)
52    }
53}
54
55/// Returns the contract declaration of a given module if it is a contract module.
56pub fn module_contract(db: &dyn SemanticGroup, module_id: ModuleId) -> Option<ContractDeclaration> {
57    let all_aux_data = db.module_generated_file_aux_data(module_id).ok()?;
58
59    // When a module is generated by a plugin the same aux data appears in two
60    // places:
61    //   1. db.module_generated_file_aux_data(*original_module_id)?[k] (with k > 0).
62    //   2. db.module_generated_file_aux_data(*generated_module_id)?[0].
63    // We are interested in modules that the plugin acted on and not modules that were
64    // created by the plugin, so we skip all_aux_data[0].
65    // For example if we have
66    // mod a {
67    //    #[starknet::contract]
68    //    mod b {
69    //    }
70    // }
71    // Then we want lookup b inside a and not inside b.
72    all_aux_data.iter().skip(1).find_map(|aux_data| {
73        let StarknetContractAuxData { contract_name } =
74            aux_data.as_ref()?.as_any().downcast_ref()?;
75        if let ModuleId::Submodule(submodule_id) = module_id {
76            Some(ContractDeclaration { submodule_id })
77        } else {
78            unreachable!("Contract `{contract_name}` was not found.");
79        }
80    })
81}
82
83/// Finds the inline modules annotated as contracts in the given crate_ids and
84/// returns the corresponding ContractDeclarations.
85pub fn find_contracts(db: &dyn SemanticGroup, crate_ids: &[CrateId]) -> Vec<ContractDeclaration> {
86    let mut contract_declarations = vec![];
87    for crate_id in crate_ids {
88        let modules = db.crate_modules(*crate_id);
89        for module_id in modules.iter() {
90            contract_declarations.extend(module_contract(db, *module_id));
91        }
92    }
93    contract_declarations
94}
95
96/// Returns the ABI functions of a given contract.
97/// Assumes the given module is a contract module.
98pub fn get_contract_abi_functions(
99    db: &dyn SemanticGroup,
100    contract: &ContractDeclaration,
101    module_name: &str,
102) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId>>> {
103    Ok(chain!(
104        get_contract_internal_module_abi_functions(db, contract, module_name)?,
105        get_impl_aliases_abi_functions(db, contract, module_name)?
106    )
107    .collect())
108}
109
110/// Returns the ABI functions in a given internal module in the contract.
111fn get_contract_internal_module_abi_functions(
112    db: &dyn SemanticGroup,
113    contract: &ContractDeclaration,
114    module_name: &str,
115) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
116    let generated_module_id = get_generated_contract_module(db, contract)?;
117    let module_id = get_submodule_id(db.upcast(), generated_module_id, module_name)?;
118    get_module_aliased_functions(db, module_id)?
119        .into_iter()
120        .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
121        .collect::<Option<Vec<_>>>()
122        .with_context(|| "Generics are not allowed in wrapper functions")
123}
124
125/// Returns the list of functions in a given module with their aliases.
126/// Assumes the given module is a generated module containing `use` items pointing to wrapper ABI
127/// functions.
128fn get_module_aliased_functions(
129    db: &dyn SemanticGroup,
130    module_id: ModuleId,
131) -> anyhow::Result<Vec<Aliased<FreeFunctionId>>> {
132    db.module_uses(module_id)
133        .to_option()
134        .with_context(|| "Failed to get external module uses.")?
135        .iter()
136        .map(|(use_id, leaf)| {
137            if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
138                .use_resolved_item(*use_id)
139                .to_option()
140                .with_context(|| "Failed to fetch used function.")?
141            {
142                Ok(Aliased {
143                    value: function_id,
144                    alias: leaf.stable_ptr().identifier(db.upcast()).to_string(),
145                })
146            } else {
147                bail!("Expected a free function.")
148            }
149        })
150        .collect::<Result<Vec<_>, _>>()
151}
152
153/// Returns the abi functions of the impl aliases embedded in the given contract.
154/// `module_prefix` is the prefix of the generated module name outside of the contract, the rest of
155/// the name is defined by the name of the aliased impl.
156fn get_impl_aliases_abi_functions(
157    db: &dyn SemanticGroup,
158    contract: &ContractDeclaration,
159    module_prefix: &str,
160) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
161    let syntax_db = db.upcast();
162    let generated_module_id = get_generated_contract_module(db, contract)?;
163    let mut diagnostics = SemanticDiagnostics::default();
164    let mut all_abi_functions = vec![];
165    for (impl_alias_id, impl_alias) in db
166        .module_impl_aliases(generated_module_id)
167        .to_option()
168        .with_context(|| "Failed to get external module impl aliases.")?
169        .iter()
170    {
171        if !impl_alias.has_attr_with_arg(db.upcast(), ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
172            continue;
173        }
174        let resolver_data = db
175            .impl_alias_resolver_data(*impl_alias_id)
176            .to_option()
177            .with_context(|| "Internal error: Failed to get impl alias resolver data.")?;
178        let mut resolver = Resolver::with_data(
179            db,
180            resolver_data.clone_with_inference_id(
181                db,
182                InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(
183                    ModuleItemId::ImplAlias(*impl_alias_id),
184                )),
185            ),
186        );
187
188        let impl_path_elements = impl_alias.impl_path(syntax_db).elements(syntax_db);
189        let Some((impl_final_part, impl_module)) = impl_path_elements.split_last() else {
190            unreachable!("impl_path should have at least one segment")
191        };
192        let impl_name = impl_final_part.identifier(syntax_db);
193        let generic_args = impl_final_part.generic_args(syntax_db).unwrap_or_default();
194        let ResolvedConcreteItem::Module(impl_module) = resolver
195            .resolve_concrete_path(
196                &mut diagnostics,
197                impl_module.to_vec(),
198                NotFoundItemType::Identifier,
199            )
200            .to_option()
201            .with_context(|| "Internal error: Failed to resolve impl module.")?
202        else {
203            bail!("Internal error: Impl alias pointed to an object with non module parent.");
204        };
205        let module_id = get_submodule_id(db, impl_module, &format!("{module_prefix}_{impl_name}"))?;
206        for abi_function in get_module_aliased_functions(db, module_id)? {
207            all_abi_functions.extend(abi_function.try_map(|f| {
208                let concrete_wrapper = resolver
209                    .specialize_function(
210                        &mut diagnostics,
211                        impl_alias.stable_ptr().untyped(),
212                        GenericFunctionId::Free(f),
213                        &generic_args,
214                    )
215                    .to_option()?
216                    .get_concrete(db)
217                    .body(db)
218                    .to_option()??;
219                let inference = &mut resolver.inference();
220                assert_eq!(
221                    inference.finalize_without_reporting(),
222                    Ok(()),
223                    "All inferences should be solved at this point."
224                );
225                Some(inference.rewrite(concrete_wrapper).no_err())
226            }));
227        }
228    }
229    diagnostics
230        .build()
231        .expect_with_db(db.elongate(), "Internal error: Inference for wrappers generics failed.");
232    Ok(all_abi_functions)
233}
234
235/// Returns the generated contract module.
236fn get_generated_contract_module(
237    db: &dyn SemanticGroup,
238    contract: &ContractDeclaration,
239) -> anyhow::Result<ModuleId> {
240    let parent_module_id = contract.submodule_id.parent_module(db.upcast());
241    let contract_name = contract.submodule_id.name(db.upcast());
242
243    match db
244        .module_item_by_name(parent_module_id, contract_name.clone())
245        .to_option()
246        .with_context(|| "Failed to initiate a lookup in the root module.")?
247    {
248        Some(ModuleItemId::Submodule(generated_module_id)) => {
249            Ok(ModuleId::Submodule(generated_module_id))
250        }
251        _ => anyhow::bail!(format!("Failed to get generated module {contract_name}.")),
252    }
253}
254
255/// Returns the module id of the submodule of a module.
256fn get_submodule_id(
257    db: &dyn SemanticGroup,
258    module_id: ModuleId,
259    submodule_name: &str,
260) -> anyhow::Result<ModuleId> {
261    match db
262        .module_item_by_name(module_id, submodule_name.into())
263        .to_option()
264        .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
265    {
266        Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
267        _ => anyhow::bail!(
268            "Failed to get the submodule `{submodule_name}` of `{}`.",
269            module_id.full_path(db.upcast())
270        ),
271    }
272}
273
274/// Sierra information of a contract.
275#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
276pub struct ContractInfo {
277    /// Sierra function of the constructor.
278    pub constructor: Option<FunctionId>,
279    /// Sierra functions of the external functions.
280    #[serde(
281        serialize_with = "serialize_ordered_hashmap_vec",
282        deserialize_with = "deserialize_ordered_hashmap_vec"
283    )]
284    pub externals: OrderedHashMap<Felt252, FunctionId>,
285    /// Sierra functions of the l1 handler functions.
286    #[serde(
287        serialize_with = "serialize_ordered_hashmap_vec",
288        deserialize_with = "deserialize_ordered_hashmap_vec"
289    )]
290    pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
291}
292
293/// Returns the list of functions in a given module.
294pub fn get_contracts_info<T: SierraIdReplacer>(
295    db: &dyn SierraGenGroup,
296    contracts: Vec<ContractDeclaration>,
297    replacer: &T,
298) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
299    let mut contracts_info = OrderedHashMap::default();
300    for contract in contracts {
301        let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
302        contracts_info.insert(class_hash, contract_info);
303    }
304    Ok(contracts_info)
305}
306
307/// Analyzes a contract and returns its class hash and a list of its functions.
308fn analyze_contract<T: SierraIdReplacer>(
309    db: &dyn SierraGenGroup,
310    contract: &ContractDeclaration,
311    replacer: &T,
312) -> anyhow::Result<(Felt252, ContractInfo)> {
313    // Extract class hash.
314    let item =
315        db.module_item_by_name(contract.module_id(), "TEST_CLASS_HASH".into()).unwrap().unwrap();
316    let constant_id = extract_matches!(item, ModuleItemId::Constant);
317    let class_hash: Felt252 =
318        db.constant_const_value(constant_id).unwrap().lookup_intern(db).into_int().unwrap().into();
319
320    // Extract functions.
321    let SemanticEntryPoints { external, l1_handler, constructor } =
322        extract_semantic_entrypoints(db.upcast(), contract)?;
323    let externals =
324        external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
325    let l1_handlers = l1_handler
326        .into_iter()
327        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
328        .collect();
329    let constructors: Vec<_> = constructor
330        .into_iter()
331        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
332        .collect();
333
334    let contract_info = ContractInfo {
335        externals,
336        l1_handlers,
337        constructor: constructors.into_iter().next().map(|x| x.1),
338    };
339    Ok((class_hash, contract_info))
340}
341
342/// Converts a function to a Sierra function.
343/// Returns the selector and the sierra function id.
344pub fn get_selector_and_sierra_function<T: SierraIdReplacer>(
345    db: &dyn SierraGenGroup,
346    function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId>,
347    replacer: &T,
348) -> (Felt252, FunctionId) {
349    let function_id = function_with_body.value.function_id(db.upcast()).expect("Function error.");
350    let sierra_id = replacer.replace_function_id(&function_id.intern(db));
351    let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
352    (selector, sierra_id)
353}