1use anyhow::{Context, bail};
2use cairo_lang_defs::ids::{
3 FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
4 NamedLanguageElementId, SubmoduleId,
5};
6use cairo_lang_diagnostics::ToOption;
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_semantic::Expr;
9use cairo_lang_semantic::db::SemanticGroup;
10use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
11use cairo_lang_semantic::expr::inference::InferenceId;
12use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
13use cairo_lang_semantic::items::functions::{
14 ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
15};
16use cairo_lang_semantic::items::us::SemanticUseEx;
17use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
18use cairo_lang_semantic::substitution::SemanticRewriter;
19use cairo_lang_sierra::ids::FunctionId;
20use cairo_lang_sierra_generator::db::SierraGenGroup;
21use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
22use cairo_lang_starknet_classes::keccak::starknet_keccak;
23use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
24use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
25use cairo_lang_utils::ordered_hash_map::{
26 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
27};
28use cairo_lang_utils::{Intern, extract_matches};
29use itertools::chain;
30use serde::{Deserialize, Serialize};
31use starknet_types_core::felt::Felt as Felt252;
32use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
33
34use crate::aliased::Aliased;
35use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
36use crate::plugin::aux_data::StarkNetContractAuxData;
37use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
38
39#[cfg(test)]
40#[path = "contract_test.rs"]
41mod test;
42
43#[derive(Clone)]
45pub struct ContractDeclaration {
46 pub submodule_id: SubmoduleId,
48}
49
50impl ContractDeclaration {
51 pub fn module_id(&self) -> ModuleId {
52 ModuleId::Submodule(self.submodule_id)
53 }
54}
55
56pub fn module_contract(db: &dyn SemanticGroup, module_id: ModuleId) -> Option<ContractDeclaration> {
58 let all_aux_data = db.module_generated_file_aux_data(module_id).ok()?;
59
60 all_aux_data.iter().skip(1).find_map(|aux_data| {
74 let StarkNetContractAuxData { contract_name } =
75 aux_data.as_ref()?.as_any().downcast_ref()?;
76 if let ModuleId::Submodule(submodule_id) = module_id {
77 Some(ContractDeclaration { submodule_id })
78 } else {
79 unreachable!("Contract `{contract_name}` was not found.");
80 }
81 })
82}
83
84pub fn find_contracts(db: &dyn SemanticGroup, crate_ids: &[CrateId]) -> Vec<ContractDeclaration> {
87 let mut contract_declarations = vec![];
88 for crate_id in crate_ids {
89 let modules = db.crate_modules(*crate_id);
90 for module_id in modules.iter() {
91 contract_declarations.extend(module_contract(db, *module_id));
92 }
93 }
94 contract_declarations
95}
96
97pub fn get_contract_abi_functions(
100 db: &dyn SemanticGroup,
101 contract: &ContractDeclaration,
102 module_name: &str,
103) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId>>> {
104 Ok(chain!(
105 get_contract_internal_module_abi_functions(db, contract, module_name)?,
106 get_impl_aliases_abi_functions(db, contract, module_name)?
107 )
108 .collect())
109}
110
111fn get_contract_internal_module_abi_functions(
113 db: &dyn SemanticGroup,
114 contract: &ContractDeclaration,
115 module_name: &str,
116) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
117 let generated_module_id = get_generated_contract_module(db, contract)?;
118 let module_id = get_submodule_id(db.upcast(), generated_module_id, module_name)?;
119 get_module_aliased_functions(db, module_id)?
120 .into_iter()
121 .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
122 .collect::<Option<Vec<_>>>()
123 .with_context(|| "Generics are not allowed in wrapper functions")
124}
125
126fn get_module_aliased_functions(
130 db: &dyn SemanticGroup,
131 module_id: ModuleId,
132) -> anyhow::Result<Vec<Aliased<FreeFunctionId>>> {
133 db.module_uses(module_id)
134 .to_option()
135 .with_context(|| "Failed to get external module uses.")?
136 .iter()
137 .map(|(use_id, leaf)| {
138 if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
139 .use_resolved_item(*use_id)
140 .to_option()
141 .with_context(|| "Failed to fetch used function.")?
142 {
143 Ok(Aliased {
144 value: function_id,
145 alias: leaf.stable_ptr().identifier(db.upcast()).to_string(),
146 })
147 } else {
148 bail!("Expected a free function.")
149 }
150 })
151 .collect::<Result<Vec<_>, _>>()
152}
153
154fn get_impl_aliases_abi_functions(
158 db: &dyn SemanticGroup,
159 contract: &ContractDeclaration,
160 module_prefix: &str,
161) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
162 let syntax_db = db.upcast();
163 let generated_module_id = get_generated_contract_module(db, contract)?;
164 let mut diagnostics = SemanticDiagnostics::default();
165 let mut all_abi_functions = vec![];
166 for (impl_alias_id, impl_alias) in db
167 .module_impl_aliases(generated_module_id)
168 .to_option()
169 .with_context(|| "Failed to get external module impl aliases.")?
170 .iter()
171 {
172 if !impl_alias.has_attr_with_arg(db.upcast(), ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
173 continue;
174 }
175 let resolver_data = db
176 .impl_alias_resolver_data(*impl_alias_id)
177 .to_option()
178 .with_context(|| "Internal error: Failed to get impl alias resolver data.")?;
179 let mut resolver = Resolver::with_data(
180 db,
181 resolver_data.clone_with_inference_id(
182 db,
183 InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(
184 ModuleItemId::ImplAlias(*impl_alias_id),
185 )),
186 ),
187 );
188
189 let impl_path_elements = impl_alias.impl_path(syntax_db).elements(syntax_db);
190 let Some((impl_final_part, impl_module)) = impl_path_elements.split_last() else {
191 unreachable!("impl_path should have at least one segment")
192 };
193 let impl_name = impl_final_part.identifier(syntax_db);
194 let generic_args = impl_final_part.generic_args(syntax_db).unwrap_or_default();
195 let ResolvedConcreteItem::Module(impl_module) = resolver
196 .resolve_concrete_path(
197 &mut diagnostics,
198 impl_module.to_vec(),
199 NotFoundItemType::Identifier,
200 )
201 .to_option()
202 .with_context(|| "Internal error: Failed to resolve impl module.")?
203 else {
204 bail!("Internal error: Impl alias pointed to an object with non module parent.");
205 };
206 let module_id = get_submodule_id(db, impl_module, &format!("{module_prefix}_{impl_name}"))?;
207 for abi_function in get_module_aliased_functions(db, module_id)? {
208 all_abi_functions.extend(abi_function.try_map(|f| {
209 let concrete_wrapper = resolver
210 .specialize_function(
211 &mut diagnostics,
212 impl_alias.stable_ptr().untyped(),
213 GenericFunctionId::Free(f),
214 &generic_args,
215 )
216 .to_option()?
217 .get_concrete(db)
218 .body(db)
219 .to_option()??;
220 let inference = &mut resolver.inference();
221 assert_eq!(
222 inference.finalize_without_reporting(),
223 Ok(()),
224 "All inferences should be solved at this point."
225 );
226 Some(inference.rewrite(concrete_wrapper).no_err())
227 }));
228 }
229 }
230 diagnostics
231 .build()
232 .expect_with_db(db.elongate(), "Internal error: Inference for wrappers generics failed.");
233 Ok(all_abi_functions)
234}
235
236fn get_generated_contract_module(
238 db: &dyn SemanticGroup,
239 contract: &ContractDeclaration,
240) -> anyhow::Result<ModuleId> {
241 let parent_module_id = contract.submodule_id.parent_module(db.upcast());
242 let contract_name = contract.submodule_id.name(db.upcast());
243
244 match db
245 .module_item_by_name(parent_module_id, contract_name.clone())
246 .to_option()
247 .with_context(|| "Failed to initiate a lookup in the root module.")?
248 {
249 Some(ModuleItemId::Submodule(generated_module_id)) => {
250 Ok(ModuleId::Submodule(generated_module_id))
251 }
252 _ => anyhow::bail!(format!("Failed to get generated module {contract_name}.")),
253 }
254}
255
256fn get_submodule_id(
258 db: &dyn SemanticGroup,
259 module_id: ModuleId,
260 submodule_name: &str,
261) -> anyhow::Result<ModuleId> {
262 match db
263 .module_item_by_name(module_id, submodule_name.into())
264 .to_option()
265 .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
266 {
267 Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
268 _ => anyhow::bail!(
269 "Failed to get the submodule `{submodule_name}` of `{}`.",
270 module_id.full_path(db.upcast())
271 ),
272 }
273}
274
275#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
277pub struct ContractInfo {
278 pub constructor: Option<FunctionId>,
280 #[serde(
282 serialize_with = "serialize_ordered_hashmap_vec",
283 deserialize_with = "deserialize_ordered_hashmap_vec"
284 )]
285 pub externals: OrderedHashMap<Felt252, FunctionId>,
286 #[serde(
288 serialize_with = "serialize_ordered_hashmap_vec",
289 deserialize_with = "deserialize_ordered_hashmap_vec"
290 )]
291 pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
292}
293
294pub fn get_contracts_info<T: SierraIdReplacer>(
296 db: &dyn SierraGenGroup,
297 contracts: Vec<ContractDeclaration>,
298 replacer: &T,
299) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
300 let mut contracts_info = OrderedHashMap::default();
301 for contract in contracts {
302 let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
303 contracts_info.insert(class_hash, contract_info);
304 }
305 Ok(contracts_info)
306}
307
308fn analyze_contract<T: SierraIdReplacer>(
310 db: &dyn SierraGenGroup,
311 contract: &ContractDeclaration,
312 replacer: &T,
313) -> anyhow::Result<(Felt252, ContractInfo)> {
314 let item =
316 db.module_item_by_name(contract.module_id(), "TEST_CLASS_HASH".into()).unwrap().unwrap();
317 let constant_id = extract_matches!(item, ModuleItemId::Constant);
318 let constant = db.constant_semantic_data(constant_id).unwrap();
319 let class_hash: Felt252 =
320 extract_matches!(&constant.exprs[constant.value], Expr::Literal).value.clone().into();
321
322 let SemanticEntryPoints { external, l1_handler, constructor } =
324 extract_semantic_entrypoints(db.upcast(), contract)?;
325 let externals =
326 external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
327 let l1_handlers = l1_handler
328 .into_iter()
329 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
330 .collect();
331 let constructors: Vec<_> = constructor
332 .into_iter()
333 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
334 .collect();
335
336 let contract_info = ContractInfo {
337 externals,
338 l1_handlers,
339 constructor: constructors.into_iter().next().map(|x| x.1),
340 };
341 Ok((class_hash, contract_info))
342}
343
344pub fn get_selector_and_sierra_function<T: SierraIdReplacer>(
347 db: &dyn SierraGenGroup,
348 function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId>,
349 replacer: &T,
350) -> (Felt252, FunctionId) {
351 let function_id = function_with_body.value.function_id(db.upcast()).expect("Function error.");
352 let sierra_id = replacer.replace_function_id(&function_id.intern(db));
353 let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
354 (selector, sierra_id)
355}