cairo_lang_starknet/
contract.rs

1use anyhow::{Context, bail};
2use cairo_lang_defs::ids::{
3    FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
4    NamedLanguageElementId, SubmoduleId,
5};
6use cairo_lang_diagnostics::ToOption;
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_semantic::Expr;
9use cairo_lang_semantic::db::SemanticGroup;
10use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
11use cairo_lang_semantic::expr::inference::InferenceId;
12use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
13use cairo_lang_semantic::items::functions::{
14    ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
15};
16use cairo_lang_semantic::items::us::SemanticUseEx;
17use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
18use cairo_lang_semantic::substitution::SemanticRewriter;
19use cairo_lang_sierra::ids::FunctionId;
20use cairo_lang_sierra_generator::db::SierraGenGroup;
21use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
22use cairo_lang_starknet_classes::keccak::starknet_keccak;
23use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
24use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
25use cairo_lang_utils::ordered_hash_map::{
26    OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
27};
28use cairo_lang_utils::{Intern, extract_matches};
29use itertools::chain;
30use serde::{Deserialize, Serialize};
31use starknet_types_core::felt::Felt as Felt252;
32use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
33
34use crate::aliased::Aliased;
35use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
36use crate::plugin::aux_data::StarkNetContractAuxData;
37use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
38
39#[cfg(test)]
40#[path = "contract_test.rs"]
41mod test;
42
43/// Represents a declaration of a contract.
44#[derive(Clone)]
45pub struct ContractDeclaration {
46    /// The id of the module that defines the contract.
47    pub submodule_id: SubmoduleId,
48}
49
50impl ContractDeclaration {
51    pub fn module_id(&self) -> ModuleId {
52        ModuleId::Submodule(self.submodule_id)
53    }
54}
55
56/// Returns the contract declaration of a given module if it is a contract module.
57pub fn module_contract(db: &dyn SemanticGroup, module_id: ModuleId) -> Option<ContractDeclaration> {
58    let all_aux_data = db.module_generated_file_aux_data(module_id).ok()?;
59
60    // When a module is generated by a plugin the same aux data appears in two
61    // places:
62    //   1. db.module_generated_file_aux_data(*original_module_id)?[k] (with k > 0).
63    //   2. db.module_generated_file_aux_data(*generated_module_id)?[0].
64    // We are interested in modules that the plugin acted on and not modules that were
65    // created by the plugin, so we skip all_aux_data[0].
66    // For example if we have
67    // mod a {
68    //    #[starknet::contract]
69    //    mod b {
70    //    }
71    // }
72    // Then we want lookup b inside a and not inside b.
73    all_aux_data.iter().skip(1).find_map(|aux_data| {
74        let StarkNetContractAuxData { contract_name } =
75            aux_data.as_ref()?.as_any().downcast_ref()?;
76        if let ModuleId::Submodule(submodule_id) = module_id {
77            Some(ContractDeclaration { submodule_id })
78        } else {
79            unreachable!("Contract `{contract_name}` was not found.");
80        }
81    })
82}
83
84/// Finds the inline modules annotated as contracts in the given crate_ids and
85/// returns the corresponding ContractDeclarations.
86pub fn find_contracts(db: &dyn SemanticGroup, crate_ids: &[CrateId]) -> Vec<ContractDeclaration> {
87    let mut contract_declarations = vec![];
88    for crate_id in crate_ids {
89        let modules = db.crate_modules(*crate_id);
90        for module_id in modules.iter() {
91            contract_declarations.extend(module_contract(db, *module_id));
92        }
93    }
94    contract_declarations
95}
96
97/// Returns the ABI functions of a given contract.
98/// Assumes the given module is a contract module.
99pub fn get_contract_abi_functions(
100    db: &dyn SemanticGroup,
101    contract: &ContractDeclaration,
102    module_name: &str,
103) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId>>> {
104    Ok(chain!(
105        get_contract_internal_module_abi_functions(db, contract, module_name)?,
106        get_impl_aliases_abi_functions(db, contract, module_name)?
107    )
108    .collect())
109}
110
111/// Returns the ABI functions in a given internal module in the contract.
112fn get_contract_internal_module_abi_functions(
113    db: &dyn SemanticGroup,
114    contract: &ContractDeclaration,
115    module_name: &str,
116) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
117    let generated_module_id = get_generated_contract_module(db, contract)?;
118    let module_id = get_submodule_id(db.upcast(), generated_module_id, module_name)?;
119    get_module_aliased_functions(db, module_id)?
120        .into_iter()
121        .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
122        .collect::<Option<Vec<_>>>()
123        .with_context(|| "Generics are not allowed in wrapper functions")
124}
125
126/// Returns the list of functions in a given module with their aliases.
127/// Assumes the given module is a generated module containing `use` items pointing to wrapper ABI
128/// functions.
129fn get_module_aliased_functions(
130    db: &dyn SemanticGroup,
131    module_id: ModuleId,
132) -> anyhow::Result<Vec<Aliased<FreeFunctionId>>> {
133    db.module_uses(module_id)
134        .to_option()
135        .with_context(|| "Failed to get external module uses.")?
136        .iter()
137        .map(|(use_id, leaf)| {
138            if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
139                .use_resolved_item(*use_id)
140                .to_option()
141                .with_context(|| "Failed to fetch used function.")?
142            {
143                Ok(Aliased {
144                    value: function_id,
145                    alias: leaf.stable_ptr().identifier(db.upcast()).to_string(),
146                })
147            } else {
148                bail!("Expected a free function.")
149            }
150        })
151        .collect::<Result<Vec<_>, _>>()
152}
153
154/// Returns the abi functions of the impl aliases embedded in the given contract.
155/// `module_prefix` is the prefix of the generated module name outside of the contract, the rest of
156/// the name is defined by the name of the aliased impl.
157fn get_impl_aliases_abi_functions(
158    db: &dyn SemanticGroup,
159    contract: &ContractDeclaration,
160    module_prefix: &str,
161) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
162    let syntax_db = db.upcast();
163    let generated_module_id = get_generated_contract_module(db, contract)?;
164    let mut diagnostics = SemanticDiagnostics::default();
165    let mut all_abi_functions = vec![];
166    for (impl_alias_id, impl_alias) in db
167        .module_impl_aliases(generated_module_id)
168        .to_option()
169        .with_context(|| "Failed to get external module impl aliases.")?
170        .iter()
171    {
172        if !impl_alias.has_attr_with_arg(db.upcast(), ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
173            continue;
174        }
175        let resolver_data = db
176            .impl_alias_resolver_data(*impl_alias_id)
177            .to_option()
178            .with_context(|| "Internal error: Failed to get impl alias resolver data.")?;
179        let mut resolver = Resolver::with_data(
180            db,
181            resolver_data.clone_with_inference_id(
182                db,
183                InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(
184                    ModuleItemId::ImplAlias(*impl_alias_id),
185                )),
186            ),
187        );
188
189        let impl_path_elements = impl_alias.impl_path(syntax_db).elements(syntax_db);
190        let Some((impl_final_part, impl_module)) = impl_path_elements.split_last() else {
191            unreachable!("impl_path should have at least one segment")
192        };
193        let impl_name = impl_final_part.identifier(syntax_db);
194        let generic_args = impl_final_part.generic_args(syntax_db).unwrap_or_default();
195        let ResolvedConcreteItem::Module(impl_module) = resolver
196            .resolve_concrete_path(
197                &mut diagnostics,
198                impl_module.to_vec(),
199                NotFoundItemType::Identifier,
200            )
201            .to_option()
202            .with_context(|| "Internal error: Failed to resolve impl module.")?
203        else {
204            bail!("Internal error: Impl alias pointed to an object with non module parent.");
205        };
206        let module_id = get_submodule_id(db, impl_module, &format!("{module_prefix}_{impl_name}"))?;
207        for abi_function in get_module_aliased_functions(db, module_id)? {
208            all_abi_functions.extend(abi_function.try_map(|f| {
209                let concrete_wrapper = resolver
210                    .specialize_function(
211                        &mut diagnostics,
212                        impl_alias.stable_ptr().untyped(),
213                        GenericFunctionId::Free(f),
214                        &generic_args,
215                    )
216                    .to_option()?
217                    .get_concrete(db)
218                    .body(db)
219                    .to_option()??;
220                let inference = &mut resolver.inference();
221                assert_eq!(
222                    inference.finalize_without_reporting(),
223                    Ok(()),
224                    "All inferences should be solved at this point."
225                );
226                Some(inference.rewrite(concrete_wrapper).no_err())
227            }));
228        }
229    }
230    diagnostics
231        .build()
232        .expect_with_db(db.elongate(), "Internal error: Inference for wrappers generics failed.");
233    Ok(all_abi_functions)
234}
235
236/// Returns the generated contract module.
237fn get_generated_contract_module(
238    db: &dyn SemanticGroup,
239    contract: &ContractDeclaration,
240) -> anyhow::Result<ModuleId> {
241    let parent_module_id = contract.submodule_id.parent_module(db.upcast());
242    let contract_name = contract.submodule_id.name(db.upcast());
243
244    match db
245        .module_item_by_name(parent_module_id, contract_name.clone())
246        .to_option()
247        .with_context(|| "Failed to initiate a lookup in the root module.")?
248    {
249        Some(ModuleItemId::Submodule(generated_module_id)) => {
250            Ok(ModuleId::Submodule(generated_module_id))
251        }
252        _ => anyhow::bail!(format!("Failed to get generated module {contract_name}.")),
253    }
254}
255
256/// Returns the module id of the submodule of a module.
257fn get_submodule_id(
258    db: &dyn SemanticGroup,
259    module_id: ModuleId,
260    submodule_name: &str,
261) -> anyhow::Result<ModuleId> {
262    match db
263        .module_item_by_name(module_id, submodule_name.into())
264        .to_option()
265        .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
266    {
267        Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
268        _ => anyhow::bail!(
269            "Failed to get the submodule `{submodule_name}` of `{}`.",
270            module_id.full_path(db.upcast())
271        ),
272    }
273}
274
275/// Sierra information of a contract.
276#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
277pub struct ContractInfo {
278    /// Sierra function of the constructor.
279    pub constructor: Option<FunctionId>,
280    /// Sierra functions of the external functions.
281    #[serde(
282        serialize_with = "serialize_ordered_hashmap_vec",
283        deserialize_with = "deserialize_ordered_hashmap_vec"
284    )]
285    pub externals: OrderedHashMap<Felt252, FunctionId>,
286    /// Sierra functions of the l1 handler functions.
287    #[serde(
288        serialize_with = "serialize_ordered_hashmap_vec",
289        deserialize_with = "deserialize_ordered_hashmap_vec"
290    )]
291    pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
292}
293
294/// Returns the list of functions in a given module.
295pub fn get_contracts_info<T: SierraIdReplacer>(
296    db: &dyn SierraGenGroup,
297    contracts: Vec<ContractDeclaration>,
298    replacer: &T,
299) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
300    let mut contracts_info = OrderedHashMap::default();
301    for contract in contracts {
302        let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
303        contracts_info.insert(class_hash, contract_info);
304    }
305    Ok(contracts_info)
306}
307
308/// Analyzes a contract and returns its class hash and a list of its functions.
309fn analyze_contract<T: SierraIdReplacer>(
310    db: &dyn SierraGenGroup,
311    contract: &ContractDeclaration,
312    replacer: &T,
313) -> anyhow::Result<(Felt252, ContractInfo)> {
314    // Extract class hash.
315    let item =
316        db.module_item_by_name(contract.module_id(), "TEST_CLASS_HASH".into()).unwrap().unwrap();
317    let constant_id = extract_matches!(item, ModuleItemId::Constant);
318    let constant = db.constant_semantic_data(constant_id).unwrap();
319    let class_hash: Felt252 =
320        extract_matches!(&constant.exprs[constant.value], Expr::Literal).value.clone().into();
321
322    // Extract functions.
323    let SemanticEntryPoints { external, l1_handler, constructor } =
324        extract_semantic_entrypoints(db.upcast(), contract)?;
325    let externals =
326        external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
327    let l1_handlers = l1_handler
328        .into_iter()
329        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
330        .collect();
331    let constructors: Vec<_> = constructor
332        .into_iter()
333        .map(|f| get_selector_and_sierra_function(db, &f, replacer))
334        .collect();
335
336    let contract_info = ContractInfo {
337        externals,
338        l1_handlers,
339        constructor: constructors.into_iter().next().map(|x| x.1),
340    };
341    Ok((class_hash, contract_info))
342}
343
344/// Converts a function to a Sierra function.
345/// Returns the selector and the sierra function id.
346pub fn get_selector_and_sierra_function<T: SierraIdReplacer>(
347    db: &dyn SierraGenGroup,
348    function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId>,
349    replacer: &T,
350) -> (Felt252, FunctionId) {
351    let function_id = function_with_body.value.function_id(db.upcast()).expect("Function error.");
352    let sierra_id = replacer.replace_function_id(&function_id.intern(db));
353    let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
354    (selector, sierra_id)
355}