1use anyhow::{Context, bail};
2use cairo_lang_defs::ids::{
3 FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
4 NamedLanguageElementId, SubmoduleId,
5};
6use cairo_lang_diagnostics::ToOption;
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_semantic::db::SemanticGroup;
9use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
10use cairo_lang_semantic::expr::inference::InferenceId;
11use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
12use cairo_lang_semantic::items::functions::{
13 ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
14};
15use cairo_lang_semantic::items::us::SemanticUseEx;
16use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
17use cairo_lang_semantic::substitution::SemanticRewriter;
18use cairo_lang_sierra::ids::FunctionId;
19use cairo_lang_sierra_generator::db::SierraGenGroup;
20use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
21use cairo_lang_starknet_classes::keccak::starknet_keccak;
22use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
23use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
24use cairo_lang_utils::ordered_hash_map::{
25 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
26};
27use cairo_lang_utils::{Intern, LookupIntern, extract_matches};
28use itertools::chain;
29use serde::{Deserialize, Serialize};
30use starknet_types_core::felt::Felt as Felt252;
31use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
32
33use crate::aliased::Aliased;
34use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
35use crate::plugin::aux_data::StarknetContractAuxData;
36use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
37
38#[cfg(test)]
39#[path = "contract_test.rs"]
40mod test;
41
42#[derive(Clone)]
44pub struct ContractDeclaration {
45 pub submodule_id: SubmoduleId,
47}
48
49impl ContractDeclaration {
50 pub fn module_id(&self) -> ModuleId {
51 ModuleId::Submodule(self.submodule_id)
52 }
53}
54
55pub fn module_contract(db: &dyn SemanticGroup, module_id: ModuleId) -> Option<ContractDeclaration> {
57 let all_aux_data = db.module_generated_file_aux_data(module_id).ok()?;
58
59 all_aux_data.iter().skip(1).find_map(|aux_data| {
73 let StarknetContractAuxData { contract_name } =
74 aux_data.as_ref()?.as_any().downcast_ref()?;
75 if let ModuleId::Submodule(submodule_id) = module_id {
76 Some(ContractDeclaration { submodule_id })
77 } else {
78 unreachable!("Contract `{contract_name}` was not found.");
79 }
80 })
81}
82
83pub fn find_contracts(db: &dyn SemanticGroup, crate_ids: &[CrateId]) -> Vec<ContractDeclaration> {
86 let mut contract_declarations = vec![];
87 for crate_id in crate_ids {
88 let modules = db.crate_modules(*crate_id);
89 for module_id in modules.iter() {
90 contract_declarations.extend(module_contract(db, *module_id));
91 }
92 }
93 contract_declarations
94}
95
96pub fn get_contract_abi_functions(
99 db: &dyn SemanticGroup,
100 contract: &ContractDeclaration,
101 module_name: &str,
102) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId>>> {
103 Ok(chain!(
104 get_contract_internal_module_abi_functions(db, contract, module_name)?,
105 get_impl_aliases_abi_functions(db, contract, module_name)?
106 )
107 .collect())
108}
109
110fn get_contract_internal_module_abi_functions(
112 db: &dyn SemanticGroup,
113 contract: &ContractDeclaration,
114 module_name: &str,
115) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
116 let generated_module_id = get_generated_contract_module(db, contract)?;
117 let module_id = get_submodule_id(db.upcast(), generated_module_id, module_name)?;
118 get_module_aliased_functions(db, module_id)?
119 .into_iter()
120 .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
121 .collect::<Option<Vec<_>>>()
122 .with_context(|| "Generics are not allowed in wrapper functions")
123}
124
125fn get_module_aliased_functions(
129 db: &dyn SemanticGroup,
130 module_id: ModuleId,
131) -> anyhow::Result<Vec<Aliased<FreeFunctionId>>> {
132 db.module_uses(module_id)
133 .to_option()
134 .with_context(|| "Failed to get external module uses.")?
135 .iter()
136 .map(|(use_id, leaf)| {
137 if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
138 .use_resolved_item(*use_id)
139 .to_option()
140 .with_context(|| "Failed to fetch used function.")?
141 {
142 Ok(Aliased {
143 value: function_id,
144 alias: leaf.stable_ptr().identifier(db.upcast()).to_string(),
145 })
146 } else {
147 bail!("Expected a free function.")
148 }
149 })
150 .collect::<Result<Vec<_>, _>>()
151}
152
153fn get_impl_aliases_abi_functions(
157 db: &dyn SemanticGroup,
158 contract: &ContractDeclaration,
159 module_prefix: &str,
160) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId>>> {
161 let syntax_db = db.upcast();
162 let generated_module_id = get_generated_contract_module(db, contract)?;
163 let mut diagnostics = SemanticDiagnostics::default();
164 let mut all_abi_functions = vec![];
165 for (impl_alias_id, impl_alias) in db
166 .module_impl_aliases(generated_module_id)
167 .to_option()
168 .with_context(|| "Failed to get external module impl aliases.")?
169 .iter()
170 {
171 if !impl_alias.has_attr_with_arg(db.upcast(), ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
172 continue;
173 }
174 let resolver_data = db
175 .impl_alias_resolver_data(*impl_alias_id)
176 .to_option()
177 .with_context(|| "Internal error: Failed to get impl alias resolver data.")?;
178 let mut resolver = Resolver::with_data(
179 db,
180 resolver_data.clone_with_inference_id(
181 db,
182 InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(
183 ModuleItemId::ImplAlias(*impl_alias_id),
184 )),
185 ),
186 );
187
188 let impl_path_elements = impl_alias.impl_path(syntax_db).elements(syntax_db);
189 let Some((impl_final_part, impl_module)) = impl_path_elements.split_last() else {
190 unreachable!("impl_path should have at least one segment")
191 };
192 let impl_name = impl_final_part.identifier(syntax_db);
193 let generic_args = impl_final_part.generic_args(syntax_db).unwrap_or_default();
194 let ResolvedConcreteItem::Module(impl_module) = resolver
195 .resolve_concrete_path(
196 &mut diagnostics,
197 impl_module.to_vec(),
198 NotFoundItemType::Identifier,
199 )
200 .to_option()
201 .with_context(|| "Internal error: Failed to resolve impl module.")?
202 else {
203 bail!("Internal error: Impl alias pointed to an object with non module parent.");
204 };
205 let module_id = get_submodule_id(db, impl_module, &format!("{module_prefix}_{impl_name}"))?;
206 for abi_function in get_module_aliased_functions(db, module_id)? {
207 all_abi_functions.extend(abi_function.try_map(|f| {
208 let concrete_wrapper = resolver
209 .specialize_function(
210 &mut diagnostics,
211 impl_alias.stable_ptr().untyped(),
212 GenericFunctionId::Free(f),
213 &generic_args,
214 )
215 .to_option()?
216 .get_concrete(db)
217 .body(db)
218 .to_option()??;
219 let inference = &mut resolver.inference();
220 assert_eq!(
221 inference.finalize_without_reporting(),
222 Ok(()),
223 "All inferences should be solved at this point."
224 );
225 Some(inference.rewrite(concrete_wrapper).no_err())
226 }));
227 }
228 }
229 diagnostics
230 .build()
231 .expect_with_db(db.elongate(), "Internal error: Inference for wrappers generics failed.");
232 Ok(all_abi_functions)
233}
234
235fn get_generated_contract_module(
237 db: &dyn SemanticGroup,
238 contract: &ContractDeclaration,
239) -> anyhow::Result<ModuleId> {
240 let parent_module_id = contract.submodule_id.parent_module(db.upcast());
241 let contract_name = contract.submodule_id.name(db.upcast());
242
243 match db
244 .module_item_by_name(parent_module_id, contract_name.clone())
245 .to_option()
246 .with_context(|| "Failed to initiate a lookup in the root module.")?
247 {
248 Some(ModuleItemId::Submodule(generated_module_id)) => {
249 Ok(ModuleId::Submodule(generated_module_id))
250 }
251 _ => anyhow::bail!(format!("Failed to get generated module {contract_name}.")),
252 }
253}
254
255fn get_submodule_id(
257 db: &dyn SemanticGroup,
258 module_id: ModuleId,
259 submodule_name: &str,
260) -> anyhow::Result<ModuleId> {
261 match db
262 .module_item_by_name(module_id, submodule_name.into())
263 .to_option()
264 .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
265 {
266 Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
267 _ => anyhow::bail!(
268 "Failed to get the submodule `{submodule_name}` of `{}`.",
269 module_id.full_path(db.upcast())
270 ),
271 }
272}
273
274#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
276pub struct ContractInfo {
277 pub constructor: Option<FunctionId>,
279 #[serde(
281 serialize_with = "serialize_ordered_hashmap_vec",
282 deserialize_with = "deserialize_ordered_hashmap_vec"
283 )]
284 pub externals: OrderedHashMap<Felt252, FunctionId>,
285 #[serde(
287 serialize_with = "serialize_ordered_hashmap_vec",
288 deserialize_with = "deserialize_ordered_hashmap_vec"
289 )]
290 pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
291}
292
293pub fn get_contracts_info<T: SierraIdReplacer>(
295 db: &dyn SierraGenGroup,
296 contracts: Vec<ContractDeclaration>,
297 replacer: &T,
298) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
299 let mut contracts_info = OrderedHashMap::default();
300 for contract in contracts {
301 let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
302 contracts_info.insert(class_hash, contract_info);
303 }
304 Ok(contracts_info)
305}
306
307fn analyze_contract<T: SierraIdReplacer>(
309 db: &dyn SierraGenGroup,
310 contract: &ContractDeclaration,
311 replacer: &T,
312) -> anyhow::Result<(Felt252, ContractInfo)> {
313 let item =
315 db.module_item_by_name(contract.module_id(), "TEST_CLASS_HASH".into()).unwrap().unwrap();
316 let constant_id = extract_matches!(item, ModuleItemId::Constant);
317 let class_hash: Felt252 =
318 db.constant_const_value(constant_id).unwrap().lookup_intern(db).into_int().unwrap().into();
319
320 let SemanticEntryPoints { external, l1_handler, constructor } =
322 extract_semantic_entrypoints(db.upcast(), contract)?;
323 let externals =
324 external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
325 let l1_handlers = l1_handler
326 .into_iter()
327 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
328 .collect();
329 let constructors: Vec<_> = constructor
330 .into_iter()
331 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
332 .collect();
333
334 let contract_info = ContractInfo {
335 externals,
336 l1_handlers,
337 constructor: constructors.into_iter().next().map(|x| x.1),
338 };
339 Ok((class_hash, contract_info))
340}
341
342pub fn get_selector_and_sierra_function<T: SierraIdReplacer>(
345 db: &dyn SierraGenGroup,
346 function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId>,
347 replacer: &T,
348) -> (Felt252, FunctionId) {
349 let function_id = function_with_body.value.function_id(db.upcast()).expect("Function error.");
350 let sierra_id = replacer.replace_function_id(&function_id.intern(db));
351 let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
352 (selector, sierra_id)
353}