cairo_lang_lowering/implicits/
mod.rs1use std::collections::{HashMap, HashSet};
2
3use cairo_lang_defs::diagnostic_utils::StableLocation;
4use cairo_lang_defs::ids::LanguageElementId;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_semantic as semantic;
7use cairo_lang_semantic::db::SemanticGroup;
8use cairo_lang_utils::{LookupIntern, Upcast};
9use itertools::{Itertools, chain, zip_eq};
10use semantic::TypeId;
11
12use crate::blocks::Blocks;
13use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
14use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
15use crate::lower::context::{VarRequest, VariableAllocator};
16use crate::{
17 BlockId, DependencyType, FlatBlockEnd, FlatLowered, MatchArm, MatchInfo, Statement, VarUsage,
18};
19
20struct Context<'a> {
21 db: &'a dyn LoweringGroup,
22 variables: &'a mut VariableAllocator<'a>,
23 lowered: &'a mut FlatLowered,
24 implicit_index: HashMap<TypeId, usize>,
25 implicits_tys: Vec<TypeId>,
26 implicit_vars_for_block: HashMap<BlockId, Vec<VarUsage>>,
27 visited: HashSet<BlockId>,
28 location: LocationId,
29}
30
31pub fn lower_implicits(
33 db: &dyn LoweringGroup,
34 function_id: ConcreteFunctionWithBodyId,
35 lowered: &mut FlatLowered,
36) {
37 if let Err(diag_added) = inner_lower_implicits(db, function_id, lowered) {
38 lowered.blocks = Blocks::new_errored(diag_added);
39 }
40}
41
42pub fn inner_lower_implicits(
44 db: &dyn LoweringGroup,
45 function_id: ConcreteFunctionWithBodyId,
46 lowered: &mut FlatLowered,
47) -> Maybe<()> {
48 let semantic_function = function_id.function_with_body_id(db).base_semantic_function(db);
49 let location = LocationId::from_stable_location(
50 db,
51 StableLocation::new(semantic_function.untyped_stable_ptr(db.upcast())),
52 );
53 lowered.blocks.has_root()?;
54 let root_block_id = BlockId::root();
55
56 let mut variables = VariableAllocator::new(
57 db,
58 function_id.function_with_body_id(db).base_semantic_function(db),
59 lowered.variables.clone(),
60 )?;
61
62 let implicits_tys = db.function_with_body_implicits(function_id)?;
63
64 let implicit_index =
65 HashMap::from_iter(implicits_tys.iter().enumerate().map(|(i, ty)| (*ty, i)));
66 let mut ctx = Context {
67 db,
68 variables: &mut variables,
69 lowered,
70 implicit_index,
71 implicits_tys,
72 implicit_vars_for_block: Default::default(),
73 visited: Default::default(),
74 location,
75 };
76
77 lower_function_blocks_implicits(&mut ctx, root_block_id)?;
79
80 let implicit_vars = &ctx.implicit_vars_for_block[&root_block_id];
82 ctx.lowered.parameters.splice(0..0, implicit_vars.iter().map(|var_usage| var_usage.var_id));
83
84 lowered.variables = std::mem::take(&mut ctx.variables.variables);
85
86 Ok(())
87}
88
89fn alloc_implicits(
92 ctx: &mut VariableAllocator<'_>,
93 implicits_tys: &[TypeId],
94 location: LocationId,
95) -> Vec<VarUsage> {
96 implicits_tys
97 .iter()
98 .copied()
99 .map(|ty| VarUsage { var_id: ctx.new_var(VarRequest { ty, location }), location })
100 .collect_vec()
101}
102
103fn block_body_implicits(
105 ctx: &mut Context<'_>,
106 block_id: BlockId,
107) -> Result<Vec<VarUsage>, cairo_lang_diagnostics::DiagnosticAdded> {
108 let mut implicits = ctx
109 .implicit_vars_for_block
110 .entry(block_id)
111 .or_insert_with(|| {
112 alloc_implicits(
113 ctx.variables,
114 &ctx.implicits_tys,
115 ctx.location.with_auto_generation_note(ctx.db, "implicits"),
116 )
117 })
118 .clone();
119 let require_implicits_libfunc_id =
120 semantic::corelib::internal_require_implicit(ctx.db.upcast());
121 let mut remove = vec![];
122 for (i, statement) in ctx.lowered.blocks[block_id].statements.iter_mut().enumerate() {
123 if let Statement::Call(stmt) = statement {
124 if matches!(
125 stmt.function.lookup_intern(ctx.db),
126 FunctionLongId::Semantic(func_id)
127 if func_id.get_concrete(ctx.db.upcast()).generic_function == require_implicits_libfunc_id
128 ) {
129 remove.push(i);
130 continue;
131 }
132 let callee_implicits = ctx.db.function_implicits(stmt.function)?;
133 let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
134
135 let indices = callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
136
137 let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
138 stmt.inputs.splice(0..0, implicit_input_vars);
139 let implicit_output_vars = callee_implicits
140 .iter()
141 .copied()
142 .map(|ty| ctx.variables.new_var(VarRequest { ty, location }))
143 .collect_vec();
144 for (i, var) in zip_eq(indices, implicit_output_vars.iter()) {
145 implicits[i] = VarUsage { var_id: *var, location: ctx.variables[*var].location };
146 }
147 stmt.outputs.splice(0..0, implicit_output_vars);
148 }
149 }
150 for i in remove.into_iter().rev() {
151 ctx.lowered.blocks[block_id].statements.remove(i);
152 }
153 Ok(implicits)
154}
155
156fn lower_function_blocks_implicits(ctx: &mut Context<'_>, root_block_id: BlockId) -> Maybe<()> {
158 let mut blocks_to_visit = vec![root_block_id];
159 while let Some(block_id) = blocks_to_visit.pop() {
160 if !ctx.visited.insert(block_id) {
161 continue;
162 }
163 let implicits = block_body_implicits(ctx, block_id)?;
164 match &mut ctx.lowered.blocks[block_id].end {
166 FlatBlockEnd::Return(rets, _location) => {
167 rets.splice(0..0, implicits.iter().cloned());
168 }
169 FlatBlockEnd::Panic(_) => {
170 unreachable!("Panics should have been stripped in a previous phase.")
171 }
172 FlatBlockEnd::Goto(block_id, remapping) => {
173 let target_implicits = ctx
174 .implicit_vars_for_block
175 .entry(*block_id)
176 .or_insert_with(|| {
177 alloc_implicits(ctx.variables, &ctx.implicits_tys, ctx.location)
178 })
179 .clone();
180 let old_remapping = std::mem::take(&mut remapping.remapping);
181 remapping.remapping = chain!(
182 zip_eq(
183 target_implicits.into_iter().map(|var_usage| var_usage.var_id),
184 implicits
185 ),
186 old_remapping
187 )
188 .collect();
189 blocks_to_visit.push(*block_id);
190 }
191 FlatBlockEnd::Match { info } => {
192 blocks_to_visit.extend(info.arms().iter().rev().map(|a| a.block_id));
193 match info {
194 MatchInfo::Enum(_) | MatchInfo::Value(_) => {
195 for MatchArm { arm_selector: _, block_id, var_ids: _ } in info.arms() {
196 assert!(
197 ctx.implicit_vars_for_block
198 .insert(*block_id, implicits.clone())
199 .is_none(),
200 "Multiple jumps to arm blocks are not allowed."
201 );
202 }
203 }
204 MatchInfo::Extern(stmt) => {
205 let callee_implicits = ctx.db.function_implicits(stmt.function)?;
206
207 let indices =
208 callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
209
210 let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
211 stmt.inputs.splice(0..0, implicit_input_vars);
212 let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
213
214 for MatchArm { arm_selector: _, block_id, var_ids } in stmt.arms.iter_mut()
215 {
216 let mut arm_implicits = implicits.clone();
217 let mut implicit_input_vars = vec![];
218 for ty in callee_implicits.iter().copied() {
219 let var = ctx.variables.new_var(VarRequest { ty, location });
220 implicit_input_vars.push(var);
221 let implicit_index = ctx.implicit_index[&ty];
222 arm_implicits[implicit_index] = VarUsage { var_id: var, location };
223 }
224 assert!(
225 ctx.implicit_vars_for_block
226 .insert(*block_id, arm_implicits)
227 .is_none(),
228 "Multiple jumps to arm blocks are not allowed."
229 );
230
231 var_ids.splice(0..0, implicit_input_vars);
232 }
233 }
234 }
235 }
236 FlatBlockEnd::NotSet => unreachable!(),
237 }
238 }
239 Ok(())
240}
241
242pub fn function_implicits(db: &dyn LoweringGroup, function: FunctionId) -> Maybe<Vec<TypeId>> {
246 if let Some(body) = function.body(db.upcast())? {
247 return db.function_with_body_implicits(body);
248 }
249 Ok(function.signature(db)?.implicits)
250}
251
252pub trait FunctionImplicitsTrait<'a>: Upcast<dyn LoweringGroup + 'a> {
254 fn function_with_body_implicits(
256 &self,
257 function: ConcreteFunctionWithBodyId,
258 ) -> Maybe<Vec<TypeId>> {
259 let db: &dyn LoweringGroup = self.upcast();
260 let semantic_db: &dyn SemanticGroup = db.upcast();
261 let scc_representative = db
262 .concrete_function_with_body_scc_inlined_representative(function, DependencyType::Call);
263 let mut implicits = db.scc_implicits(scc_representative)?;
264
265 let precedence = db.function_declaration_implicit_precedence(
266 function.function_with_body_id(db).base_semantic_function(db),
267 )?;
268 precedence.apply(&mut implicits, semantic_db);
269
270 Ok(implicits)
271 }
272}
273impl<'a, T: Upcast<dyn LoweringGroup + 'a> + ?Sized> FunctionImplicitsTrait<'a> for T {}
274
275pub fn scc_implicits(db: &dyn LoweringGroup, scc: ConcreteSCCRepresentative) -> Maybe<Vec<TypeId>> {
277 let scc_functions = db.concrete_function_with_body_inlined_scc(scc.0, DependencyType::Call);
278 let mut all_implicits = HashSet::new();
279 for function in scc_functions {
280 all_implicits.extend(function.function_id(db)?.signature(db)?.implicits);
282 let direct_callees =
284 db.concrete_function_with_body_inlined_direct_callees(function, DependencyType::Call)?;
285 for direct_callee in direct_callees {
286 if let Some(callee_body) = direct_callee.body(db.upcast())? {
287 let callee_scc = db.concrete_function_with_body_scc_inlined_representative(
288 callee_body,
289 DependencyType::Call,
290 );
291 if callee_scc != scc {
292 all_implicits.extend(db.scc_implicits(callee_scc)?);
293 }
294 } else {
295 all_implicits.extend(direct_callee.signature(db)?.implicits);
296 }
297 }
298 }
299 Ok(all_implicits.into_iter().collect())
300}