cairo_lang_lowering/implicits/
mod.rs

1use std::collections::{HashMap, HashSet};
2
3use cairo_lang_defs::diagnostic_utils::StableLocation;
4use cairo_lang_defs::ids::LanguageElementId;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_semantic as semantic;
7use cairo_lang_semantic::db::SemanticGroup;
8use cairo_lang_utils::{LookupIntern, Upcast};
9use itertools::{Itertools, chain, zip_eq};
10use semantic::TypeId;
11
12use crate::blocks::Blocks;
13use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
14use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
15use crate::lower::context::{VarRequest, VariableAllocator};
16use crate::{
17    BlockId, DependencyType, FlatBlockEnd, FlatLowered, MatchArm, MatchInfo, Statement, VarUsage,
18};
19
20struct Context<'a> {
21    db: &'a dyn LoweringGroup,
22    variables: &'a mut VariableAllocator<'a>,
23    lowered: &'a mut FlatLowered,
24    implicit_index: HashMap<TypeId, usize>,
25    implicits_tys: Vec<TypeId>,
26    implicit_vars_for_block: HashMap<BlockId, Vec<VarUsage>>,
27    visited: HashSet<BlockId>,
28    location: LocationId,
29}
30
31/// Lowering phase that adds implicits.
32pub fn lower_implicits(
33    db: &dyn LoweringGroup,
34    function_id: ConcreteFunctionWithBodyId,
35    lowered: &mut FlatLowered,
36) {
37    if let Err(diag_added) = inner_lower_implicits(db, function_id, lowered) {
38        lowered.blocks = Blocks::new_errored(diag_added);
39    }
40}
41
42/// Similar to lower_implicits, but uses Maybe<> for convenience.
43pub fn inner_lower_implicits(
44    db: &dyn LoweringGroup,
45    function_id: ConcreteFunctionWithBodyId,
46    lowered: &mut FlatLowered,
47) -> Maybe<()> {
48    let semantic_function = function_id.function_with_body_id(db).base_semantic_function(db);
49    let location = LocationId::from_stable_location(
50        db,
51        StableLocation::new(semantic_function.untyped_stable_ptr(db.upcast())),
52    );
53    lowered.blocks.has_root()?;
54    let root_block_id = BlockId::root();
55
56    let mut variables = VariableAllocator::new(
57        db,
58        function_id.function_with_body_id(db).base_semantic_function(db),
59        lowered.variables.clone(),
60    )?;
61
62    let implicits_tys = db.function_with_body_implicits(function_id)?;
63
64    let implicit_index =
65        HashMap::from_iter(implicits_tys.iter().enumerate().map(|(i, ty)| (*ty, i)));
66    let mut ctx = Context {
67        db,
68        variables: &mut variables,
69        lowered,
70        implicit_index,
71        implicits_tys,
72        implicit_vars_for_block: Default::default(),
73        visited: Default::default(),
74        location,
75    };
76
77    // Start from root block.
78    lower_function_blocks_implicits(&mut ctx, root_block_id)?;
79
80    // Introduce new input variables in the root block.
81    let implicit_vars = &ctx.implicit_vars_for_block[&root_block_id];
82    ctx.lowered.parameters.splice(0..0, implicit_vars.iter().map(|var_usage| var_usage.var_id));
83
84    lowered.variables = std::mem::take(&mut ctx.variables.variables);
85
86    Ok(())
87}
88
89/// Allocates and returns new variables with usage location for each of the current function's
90/// implicits.
91fn alloc_implicits(
92    ctx: &mut VariableAllocator<'_>,
93    implicits_tys: &[TypeId],
94    location: LocationId,
95) -> Vec<VarUsage> {
96    implicits_tys
97        .iter()
98        .copied()
99        .map(|ty| VarUsage { var_id: ctx.new_var(VarRequest { ty, location }), location })
100        .collect_vec()
101}
102
103/// Returns the implicits that are used in the statements of a block.
104fn block_body_implicits(
105    ctx: &mut Context<'_>,
106    block_id: BlockId,
107) -> Result<Vec<VarUsage>, cairo_lang_diagnostics::DiagnosticAdded> {
108    let mut implicits = ctx
109        .implicit_vars_for_block
110        .entry(block_id)
111        .or_insert_with(|| {
112            alloc_implicits(
113                ctx.variables,
114                &ctx.implicits_tys,
115                ctx.location.with_auto_generation_note(ctx.db, "implicits"),
116            )
117        })
118        .clone();
119    let require_implicits_libfunc_id =
120        semantic::corelib::internal_require_implicit(ctx.db.upcast());
121    let mut remove = vec![];
122    for (i, statement) in ctx.lowered.blocks[block_id].statements.iter_mut().enumerate() {
123        if let Statement::Call(stmt) = statement {
124            if matches!(
125                stmt.function.lookup_intern(ctx.db),
126                FunctionLongId::Semantic(func_id)
127                    if func_id.get_concrete(ctx.db.upcast()).generic_function == require_implicits_libfunc_id
128            ) {
129                remove.push(i);
130                continue;
131            }
132            let callee_implicits = ctx.db.function_implicits(stmt.function)?;
133            let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
134
135            let indices = callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
136
137            let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
138            stmt.inputs.splice(0..0, implicit_input_vars);
139            let implicit_output_vars = callee_implicits
140                .iter()
141                .copied()
142                .map(|ty| ctx.variables.new_var(VarRequest { ty, location }))
143                .collect_vec();
144            for (i, var) in zip_eq(indices, implicit_output_vars.iter()) {
145                implicits[i] = VarUsage { var_id: *var, location: ctx.variables[*var].location };
146            }
147            stmt.outputs.splice(0..0, implicit_output_vars);
148        }
149    }
150    for i in remove.into_iter().rev() {
151        ctx.lowered.blocks[block_id].statements.remove(i);
152    }
153    Ok(implicits)
154}
155
156/// Finds the implicits for a function's blocks starting from the root.
157fn lower_function_blocks_implicits(ctx: &mut Context<'_>, root_block_id: BlockId) -> Maybe<()> {
158    let mut blocks_to_visit = vec![root_block_id];
159    while let Some(block_id) = blocks_to_visit.pop() {
160        if !ctx.visited.insert(block_id) {
161            continue;
162        }
163        let implicits = block_body_implicits(ctx, block_id)?;
164        // End.
165        match &mut ctx.lowered.blocks[block_id].end {
166            FlatBlockEnd::Return(rets, _location) => {
167                rets.splice(0..0, implicits.iter().cloned());
168            }
169            FlatBlockEnd::Panic(_) => {
170                unreachable!("Panics should have been stripped in a previous phase.")
171            }
172            FlatBlockEnd::Goto(block_id, remapping) => {
173                let target_implicits = ctx
174                    .implicit_vars_for_block
175                    .entry(*block_id)
176                    .or_insert_with(|| {
177                        alloc_implicits(ctx.variables, &ctx.implicits_tys, ctx.location)
178                    })
179                    .clone();
180                let old_remapping = std::mem::take(&mut remapping.remapping);
181                remapping.remapping = chain!(
182                    zip_eq(
183                        target_implicits.into_iter().map(|var_usage| var_usage.var_id),
184                        implicits
185                    ),
186                    old_remapping
187                )
188                .collect();
189                blocks_to_visit.push(*block_id);
190            }
191            FlatBlockEnd::Match { info } => {
192                blocks_to_visit.extend(info.arms().iter().rev().map(|a| a.block_id));
193                match info {
194                    MatchInfo::Enum(_) | MatchInfo::Value(_) => {
195                        for MatchArm { arm_selector: _, block_id, var_ids: _ } in info.arms() {
196                            assert!(
197                                ctx.implicit_vars_for_block
198                                    .insert(*block_id, implicits.clone())
199                                    .is_none(),
200                                "Multiple jumps to arm blocks are not allowed."
201                            );
202                        }
203                    }
204                    MatchInfo::Extern(stmt) => {
205                        let callee_implicits = ctx.db.function_implicits(stmt.function)?;
206
207                        let indices =
208                            callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
209
210                        let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
211                        stmt.inputs.splice(0..0, implicit_input_vars);
212                        let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
213
214                        for MatchArm { arm_selector: _, block_id, var_ids } in stmt.arms.iter_mut()
215                        {
216                            let mut arm_implicits = implicits.clone();
217                            let mut implicit_input_vars = vec![];
218                            for ty in callee_implicits.iter().copied() {
219                                let var = ctx.variables.new_var(VarRequest { ty, location });
220                                implicit_input_vars.push(var);
221                                let implicit_index = ctx.implicit_index[&ty];
222                                arm_implicits[implicit_index] = VarUsage { var_id: var, location };
223                            }
224                            assert!(
225                                ctx.implicit_vars_for_block
226                                    .insert(*block_id, arm_implicits)
227                                    .is_none(),
228                                "Multiple jumps to arm blocks are not allowed."
229                            );
230
231                            var_ids.splice(0..0, implicit_input_vars);
232                        }
233                    }
234                }
235            }
236            FlatBlockEnd::NotSet => unreachable!(),
237        }
238    }
239    Ok(())
240}
241
242// =========== Query implementations ===========
243
244/// Query implementation of [crate::db::LoweringGroup::function_implicits].
245pub fn function_implicits(db: &dyn LoweringGroup, function: FunctionId) -> Maybe<Vec<TypeId>> {
246    if let Some(body) = function.body(db.upcast())? {
247        return db.function_with_body_implicits(body);
248    }
249    Ok(function.signature(db)?.implicits)
250}
251
252/// A trait to add helper methods in [LoweringGroup].
253pub trait FunctionImplicitsTrait<'a>: Upcast<dyn LoweringGroup + 'a> {
254    /// Returns all the implicits used by a [ConcreteFunctionWithBodyId].
255    fn function_with_body_implicits(
256        &self,
257        function: ConcreteFunctionWithBodyId,
258    ) -> Maybe<Vec<TypeId>> {
259        let db: &dyn LoweringGroup = self.upcast();
260        let semantic_db: &dyn SemanticGroup = db.upcast();
261        let scc_representative = db
262            .concrete_function_with_body_scc_inlined_representative(function, DependencyType::Call);
263        let mut implicits = db.scc_implicits(scc_representative)?;
264
265        let precedence = db.function_declaration_implicit_precedence(
266            function.function_with_body_id(db).base_semantic_function(db),
267        )?;
268        precedence.apply(&mut implicits, semantic_db);
269
270        Ok(implicits)
271    }
272}
273impl<'a, T: Upcast<dyn LoweringGroup + 'a> + ?Sized> FunctionImplicitsTrait<'a> for T {}
274
275/// Query implementation of [LoweringGroup::scc_implicits].
276pub fn scc_implicits(db: &dyn LoweringGroup, scc: ConcreteSCCRepresentative) -> Maybe<Vec<TypeId>> {
277    let scc_functions = db.concrete_function_with_body_inlined_scc(scc.0, DependencyType::Call);
278    let mut all_implicits = HashSet::new();
279    for function in scc_functions {
280        // Add the function's explicit implicits.
281        all_implicits.extend(function.function_id(db)?.signature(db)?.implicits);
282        // For each direct callee, add its implicits.
283        let direct_callees =
284            db.concrete_function_with_body_inlined_direct_callees(function, DependencyType::Call)?;
285        for direct_callee in direct_callees {
286            if let Some(callee_body) = direct_callee.body(db.upcast())? {
287                let callee_scc = db.concrete_function_with_body_scc_inlined_representative(
288                    callee_body,
289                    DependencyType::Call,
290                );
291                if callee_scc != scc {
292                    all_implicits.extend(db.scc_implicits(callee_scc)?);
293                }
294            } else {
295                all_implicits.extend(direct_callee.signature(db)?.implicits);
296            }
297        }
298    }
299    Ok(all_implicits.into_iter().collect())
300}