cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
8use cairo_lang_semantic::items::constant::ConstValue;
9use cairo_lang_semantic::items::imp::ImplLookupContext;
10use cairo_lang_semantic::{GenericArgumentId, MatchArmSelector, TypeId, corelib};
11use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
12use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use cairo_lang_utils::{Intern, LookupIntern, extract_matches, try_extract_matches};
15use id_arena::Arena;
16use itertools::{chain, zip_eq};
17use num_bigint::BigInt;
18use num_integer::Integer;
19use num_traits::Zero;
20use smol_str::SmolStr;
21
22use crate::db::LoweringGroup;
23use crate::ids::{FunctionId, FunctionLongId};
24use crate::{
25    BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
26    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
27    StatementStructConstruct, StatementStructDestructure, VarUsage, Variable, VariableId,
28};
29
30/// Keeps track of equivalent values that a variables might be replaced with.
31/// Note: We don't keep track of types as we assume the usage is always correct.
32#[derive(Debug, Clone)]
33enum VarInfo {
34    /// The variable is a const value.
35    Const(ConstValue),
36    /// The variable can be replaced by another variable.
37    Var(VarUsage),
38    /// The variable is a snapshot of another variable.
39    Snapshot(Box<VarInfo>),
40    /// The variable is a struct of other variables.
41    /// `None` values represent variables that are not tracked.
42    Struct(Vec<Option<VarInfo>>),
43}
44
45/// Performs constant folding on the lowered program.
46/// The optimization works better when the blocks are topologically sorted.
47pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
48    if db.optimization_config().skip_const_folding || lowered.blocks.is_empty() {
49        return;
50    }
51    let libfunc_info = priv_const_folding_info(db);
52    // Note that we can keep the var_info across blocks because the lowering
53    // is in static single assignment form.
54    let mut ctx = ConstFoldingContext {
55        db,
56        var_info: UnorderedHashMap::default(),
57        variables: &mut lowered.variables,
58        libfunc_info: &libfunc_info,
59    };
60    let mut stack = vec![BlockId::root()];
61    let mut visited = vec![false; lowered.blocks.len()];
62    while let Some(block_id) = stack.pop() {
63        if visited[block_id.0] {
64            continue;
65        }
66        visited[block_id.0] = true;
67
68        let block = &mut lowered.blocks[block_id];
69        let mut additional_consts = vec![];
70        for stmt in block.statements.iter_mut() {
71            ctx.maybe_replace_inputs(stmt.inputs_mut());
72            match stmt {
73                Statement::Const(StatementConst { value, output }) => {
74                    // Preventing the insertion of non-member consts values (such as a `Box` of a
75                    // const).
76                    if matches!(
77                        value,
78                        ConstValue::Int(..)
79                            | ConstValue::Struct(..)
80                            | ConstValue::Enum(..)
81                            | ConstValue::NonZero(..)
82                    ) {
83                        ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
84                    }
85                }
86                Statement::Snapshot(stmt) => {
87                    if let Some(info) = ctx.var_info.get(&stmt.input.var_id).cloned() {
88                        ctx.var_info.insert(stmt.original(), info.clone());
89                        ctx.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
90                    }
91                }
92                Statement::Desnap(StatementDesnap { input, output }) => {
93                    if let Some(VarInfo::Snapshot(info)) = ctx.var_info.get(&input.var_id) {
94                        ctx.var_info.insert(*output, info.as_ref().clone());
95                    }
96                }
97                Statement::Call(call_stmt) => {
98                    if let Some(updated_stmt) =
99                        ctx.handle_statement_call(call_stmt, &mut additional_consts)
100                    {
101                        *stmt = Statement::Const(updated_stmt);
102                    }
103                }
104                Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
105                    let mut const_args = vec![];
106                    let mut all_args = vec![];
107                    let mut contains_info = false;
108                    for input in inputs.iter() {
109                        let Some(info) = ctx.var_info.get(&input.var_id) else {
110                            all_args.push(
111                                ctx.variables[input.var_id]
112                                    .copyable
113                                    .is_ok()
114                                    .then_some(VarInfo::Var(*input)),
115                            );
116                            continue;
117                        };
118                        contains_info = true;
119                        if let VarInfo::Const(value) = info {
120                            const_args.push(value.clone());
121                        }
122                        all_args.push(Some(info.clone()));
123                    }
124                    if const_args.len() == inputs.len() {
125                        let value = ConstValue::Struct(const_args, ctx.variables[*output].ty);
126                        ctx.var_info.insert(*output, VarInfo::Const(value));
127                    } else if contains_info {
128                        ctx.var_info.insert(*output, VarInfo::Struct(all_args));
129                    }
130                }
131                Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
132                    if let Some(mut info) = ctx.var_info.get(&input.var_id) {
133                        let mut n_snapshot = 0;
134                        while let VarInfo::Snapshot(inner) = info {
135                            info = inner.as_ref();
136                            n_snapshot += 1;
137                        }
138                        let wrap_with_snapshots = |mut info| {
139                            for _ in 0..n_snapshot {
140                                info = VarInfo::Snapshot(Box::new(info));
141                            }
142                            info
143                        };
144                        match info {
145                            VarInfo::Const(ConstValue::Struct(member_values, _)) => {
146                                for (output, value) in zip_eq(outputs, member_values.clone()) {
147                                    ctx.var_info.insert(
148                                        *output,
149                                        wrap_with_snapshots(VarInfo::Const(value)),
150                                    );
151                                }
152                            }
153                            VarInfo::Struct(members) => {
154                                for (output, member) in zip_eq(outputs, members.clone()) {
155                                    if let Some(member) = member {
156                                        ctx.var_info.insert(*output, wrap_with_snapshots(member));
157                                    }
158                                }
159                            }
160                            _ => {}
161                        }
162                    }
163                }
164                Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
165                    if let Some(VarInfo::Const(val)) = ctx.var_info.get(&input.var_id) {
166                        let value = ConstValue::Enum(variant.clone(), val.clone().into());
167                        ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
168                    }
169                }
170            }
171        }
172        block.statements.splice(0..0, additional_consts.into_iter().map(Statement::Const));
173
174        match &mut block.end {
175            FlatBlockEnd::Goto(block_id, remappings) => {
176                stack.push(*block_id);
177                for (_, v) in remappings.iter_mut() {
178                    ctx.maybe_replace_input(v);
179                }
180            }
181            FlatBlockEnd::Match { info } => {
182                stack.extend(info.arms().iter().map(|arm| arm.block_id));
183                ctx.maybe_replace_inputs(info.inputs_mut());
184                match info {
185                    MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) => {
186                        if let Some(VarInfo::Const(ConstValue::Enum(variant, value))) =
187                            ctx.var_info.get(&input.var_id)
188                        {
189                            let arm = &arms[variant.idx];
190                            ctx.var_info
191                                .insert(arm.var_ids[0], VarInfo::Const(value.as_ref().clone()));
192                        }
193                    }
194                    MatchInfo::Extern(info) => {
195                        if let Some((extra_stmt, updated_end)) = ctx.handle_extern_block_end(info) {
196                            if let Some(stmt) = extra_stmt {
197                                block.statements.push(Statement::Const(stmt));
198                            }
199                            block.end = updated_end;
200                        }
201                    }
202                    MatchInfo::Value(..) => {}
203                }
204            }
205            FlatBlockEnd::Return(ref mut inputs, _) => ctx.maybe_replace_inputs(inputs),
206            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
207        }
208    }
209}
210
211struct ConstFoldingContext<'a> {
212    /// The used database.
213    db: &'a dyn LoweringGroup,
214    /// The variables arena, mostly used to get the type of variables.
215    variables: &'a mut Arena<Variable>,
216    /// The accumulated information about the const values of variables.
217    var_info: UnorderedHashMap<VariableId, VarInfo>,
218    /// The libfunc information.
219    libfunc_info: &'a ConstFoldingLibfuncInfo,
220}
221
222impl ConstFoldingContext<'_> {
223    /// Handles a statement call.
224    ///
225    /// Returns None if no additional changes are required.
226    /// If changes are required, returns an updated const-statement (to override the current
227    /// statement).
228    /// May add an additional const to `additional_consts` if just replacing the current statement
229    /// is not enough.
230    fn handle_statement_call(
231        &mut self,
232        stmt: &mut StatementCall,
233        additional_consts: &mut Vec<StatementConst>,
234    ) -> Option<StatementConst> {
235        let (id, _generic_args) = stmt.function.get_extern(self.db)?;
236        if id == self.felt_sub {
237            // (a - 0) can be replaced by a.
238            let val = self.as_int(stmt.inputs[1].var_id)?;
239            if val.is_zero() {
240                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
241            }
242            None
243        } else if self.wide_mul_fns.contains(&id) {
244            let lhs = self.as_int_ex(stmt.inputs[0].var_id);
245            let rhs = self.as_int(stmt.inputs[1].var_id);
246            let output = stmt.outputs[0];
247            if lhs.map(|(v, _)| v.is_zero()).unwrap_or_default()
248                || rhs.map(Zero::is_zero).unwrap_or_default()
249            {
250                return Some(self.propagate_zero_and_get_statement(output));
251            }
252            let (lhs, nz_ty) = lhs?;
253            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0], nz_ty))
254        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
255            let lhs = self.as_int(stmt.inputs[0].var_id)?;
256            let rhs = self.as_int(stmt.inputs[1].var_id)?;
257            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
258            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
259        } else if self.div_rem_fns.contains(&id) {
260            let lhs = self.as_int(stmt.inputs[0].var_id);
261            if lhs.map(Zero::is_zero).unwrap_or_default() {
262                additional_consts.push(self.propagate_zero_and_get_statement(stmt.outputs[1]));
263                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
264            }
265            let rhs = self.as_int(stmt.inputs[1].var_id)?;
266            let (q, r) = lhs?.div_rem(rhs);
267            let q_output = stmt.outputs[0];
268            let q_value = ConstValue::Int(q, self.variables[q_output].ty);
269            self.var_info.insert(q_output, VarInfo::Const(q_value.clone()));
270            let r_output = stmt.outputs[1];
271            let r_value = ConstValue::Int(r, self.variables[r_output].ty);
272            self.var_info.insert(r_output, VarInfo::Const(r_value.clone()));
273            additional_consts.push(StatementConst { value: r_value, output: r_output });
274            Some(StatementConst { value: q_value, output: q_output })
275        } else if id == self.storage_base_address_from_felt252 {
276            let input_var = stmt.inputs[0].var_id;
277            if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
278                stmt.inputs.clear();
279                stmt.function = ModuleHelper { db: self.db, id: self.storage_access_module }
280                    .function_id("storage_base_address_const", vec![GenericArgumentId::Constant(
281                        ConstValue::Int(val.clone(), *ty).intern(self.db),
282                    )]);
283            }
284            None
285        } else if id == self.into_box {
286            let const_value = match self.var_info.get(&stmt.inputs[0].var_id)? {
287                VarInfo::Const(val) => val,
288                VarInfo::Snapshot(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)?,
289                _ => return None,
290            };
291            let value = ConstValue::Boxed(const_value.clone().into());
292            // Not inserting the value into the `var_info` map because the
293            // resulting box isn't an actual const at the Sierra level.
294            Some(StatementConst { value, output: stmt.outputs[0] })
295        } else if id == self.upcast {
296            let int_value = self.as_int(stmt.inputs[0].var_id)?;
297            let output = stmt.outputs[0];
298            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty);
299            self.var_info.insert(output, VarInfo::Const(value.clone()));
300            Some(StatementConst { value, output })
301        } else {
302            None
303        }
304    }
305
306    /// Adds `value` as a const to `var_info` and return a const statement for it.
307    fn propagate_const_and_get_statement(
308        &mut self,
309        value: BigInt,
310        output: VariableId,
311        nz_ty: bool,
312    ) -> StatementConst {
313        let mut value = ConstValue::Int(value, self.variables[output].ty);
314        if nz_ty {
315            value = ConstValue::NonZero(Box::new(value));
316        }
317        self.var_info.insert(output, VarInfo::Const(value.clone()));
318        StatementConst { value, output }
319    }
320
321    /// Adds 0 const to `var_info` and return a const statement for it.
322    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> StatementConst {
323        self.propagate_const_and_get_statement(BigInt::zero(), output, false)
324    }
325
326    /// Handles the end of an extern block.
327    /// Returns None if no additional changes are required.
328    /// If changes are required, returns a possible additional const-statement to the block, as well
329    /// as an updated block end.
330    fn handle_extern_block_end(
331        &mut self,
332        info: &mut MatchExternInfo,
333    ) -> Option<(Option<StatementConst>, FlatBlockEnd)> {
334        let (id, generic_args) = info.function.get_extern(self.db)?;
335        if self.nz_fns.contains(&id) {
336            let val = self.as_const(info.inputs[0].var_id)?;
337            let is_zero = match val {
338                ConstValue::Int(v, _) => v.is_zero(),
339                ConstValue::Struct(s, _) => s.iter().all(|v| {
340                    v.clone().into_int().expect("Expected ConstValue::Int for size").is_zero()
341                }),
342                _ => unreachable!(),
343            };
344            Some(if is_zero {
345                (None, FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()))
346            } else {
347                let arm = &info.arms[1];
348                let nz_var = arm.var_ids[0];
349                let nz_val = ConstValue::NonZero(Box::new(val.clone()));
350                self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
351                (
352                    Some(StatementConst { value: nz_val, output: nz_var }),
353                    FlatBlockEnd::Goto(arm.block_id, Default::default()),
354                )
355            })
356        } else if self.eq_fns.contains(&id) {
357            let lhs = self.as_int(info.inputs[0].var_id);
358            let rhs = self.as_int(info.inputs[1].var_id);
359            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
360                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
361            {
362                let db = self.db.upcast();
363                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
364                let var = &self.variables[nz_input.var_id].clone();
365                let function = self.type_value_ranges.get(&var.ty)?.is_zero;
366                let unused_nz_var = Variable::new(
367                    self.db,
368                    ImplLookupContext::default(),
369                    corelib::core_nonzero_ty(db, var.ty),
370                    var.location,
371                );
372                let unused_nz_var = self.variables.alloc(unused_nz_var);
373                return Some((None, FlatBlockEnd::Match {
374                    info: MatchInfo::Extern(MatchExternInfo {
375                        function,
376                        inputs: vec![nz_input],
377                        arms: vec![
378                            MatchArm {
379                                arm_selector: MatchArmSelector::VariantId(
380                                    corelib::jump_nz_zero_variant(db, var.ty),
381                                ),
382                                block_id: info.arms[1].block_id,
383                                var_ids: vec![],
384                            },
385                            MatchArm {
386                                arm_selector: MatchArmSelector::VariantId(
387                                    corelib::jump_nz_nonzero_variant(db, var.ty),
388                                ),
389                                block_id: info.arms[0].block_id,
390                                var_ids: vec![unused_nz_var],
391                            },
392                        ],
393                        location: info.location,
394                    }),
395                }));
396            }
397            Some((
398                None,
399                FlatBlockEnd::Goto(
400                    info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
401                    Default::default(),
402                ),
403            ))
404        } else if self.uadd_fns.contains(&id)
405            || self.usub_fns.contains(&id)
406            || self.diff_fns.contains(&id)
407            || self.iadd_fns.contains(&id)
408            || self.isub_fns.contains(&id)
409        {
410            let rhs = self.as_int(info.inputs[1].var_id);
411            if rhs.map(Zero::is_zero).unwrap_or_default() && !self.diff_fns.contains(&id) {
412                let arm = &info.arms[0];
413                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
414                return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
415            }
416            let lhs = self.as_int(info.inputs[0].var_id);
417            let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
418                if lhs.map(Zero::is_zero).unwrap_or_default() {
419                    let arm = &info.arms[0];
420                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
421                    return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
422                }
423                lhs? + rhs?
424            } else {
425                lhs? - rhs?
426            };
427            let ty = self.variables[info.arms[0].var_ids[0]].ty;
428            let range = self.type_value_ranges.get(&ty)?;
429            let (arm_index, value) = match range.normalized(value) {
430                NormalizedResult::InRange(value) => (0, value),
431                NormalizedResult::Under(value) => (1, value),
432                NormalizedResult::Over(value) => (
433                    if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) { 2 } else { 1 },
434                    value,
435                ),
436            };
437            let arm = &info.arms[arm_index];
438            let actual_output = arm.var_ids[0];
439            let value = ConstValue::Int(value, ty);
440            self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
441            Some((
442                Some(StatementConst { value, output: actual_output }),
443                FlatBlockEnd::Goto(arm.block_id, Default::default()),
444            ))
445        } else if id == self.downcast {
446            let input_var = info.inputs[0].var_id;
447            let value = self.as_int(input_var)?;
448            let success_output = info.arms[0].var_ids[0];
449            let ty = self.variables[success_output].ty;
450            let range = self.type_value_ranges.get(&ty)?;
451            Some(if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
452                let value = ConstValue::Int(value, ty);
453                self.var_info.insert(success_output, VarInfo::Const(value.clone()));
454                (
455                    Some(StatementConst { value, output: success_output }),
456                    FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
457                )
458            } else {
459                (None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
460            })
461        } else if id == self.bounded_int_constrain {
462            let input_var = info.inputs[0].var_id;
463            let (value, nz_ty) = self.as_int_ex(input_var)?;
464            let generic_arg = generic_args[1];
465            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
466                .lookup_intern(self.db)
467                .into_int()
468                .unwrap();
469            let arm_idx = if value < &constrain_value { 0 } else { 1 };
470            let output = info.arms[arm_idx].var_ids[0];
471            Some((
472                Some(self.propagate_const_and_get_statement(value.clone(), output, nz_ty)),
473                FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
474            ))
475        } else if id == self.array_get {
476            if self.as_int(info.inputs[1].var_id)?.is_zero() {
477                if let [success, failure] = info.arms.as_mut_slice() {
478                    let arr = info.inputs[0].var_id;
479                    let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
480                    let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
481                    info.inputs.truncate(1);
482                    info.function = ModuleHelper { db: self.db, id: self.array_module }
483                        .function_id("array_snapshot_pop_front", generic_args);
484                    success.var_ids.insert(0, unused_arr_output0);
485                    failure.var_ids.insert(0, unused_arr_output1);
486                }
487            }
488            None
489        } else {
490            None
491        }
492    }
493
494    /// Returns the const value of a variable if it exists.
495    fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
496        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
497    }
498
499    /// Return the const value as an int if it exists and is an integer, additionally, if it is of a
500    /// non-zero type.
501    fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
502        match self.as_const(var_id)? {
503            ConstValue::Int(value, _) => Some((value, false)),
504            ConstValue::NonZero(const_value) => {
505                if let ConstValue::Int(value, _) = const_value.as_ref() {
506                    Some((value, true))
507                } else {
508                    None
509                }
510            }
511            _ => None,
512        }
513    }
514
515    /// Return the const value as a int if it exists and is an integer.
516    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
517        Some(self.as_int_ex(var_id)?.0)
518    }
519
520    /// Replaces the inputs in place if they are in the var_info map.
521    fn maybe_replace_inputs(&mut self, inputs: &mut [VarUsage]) {
522        for input in inputs {
523            self.maybe_replace_input(input);
524        }
525    }
526
527    /// Replaces the input in place if it is in the var_info map.
528    fn maybe_replace_input(&mut self, input: &mut VarUsage) {
529        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
530            *input = *new_var;
531        }
532    }
533}
534
535/// Query implementation of [LoweringGroup::priv_const_folding_info].
536pub fn priv_const_folding_info(
537    db: &dyn LoweringGroup,
538) -> Arc<crate::optimizations::const_folding::ConstFoldingLibfuncInfo> {
539    Arc::new(ConstFoldingLibfuncInfo::new(db))
540}
541
542/// Helper for getting functions in the corelib.
543struct ModuleHelper<'a> {
544    /// The db.
545    db: &'a dyn LoweringGroup,
546    /// The current module id.
547    id: ModuleId,
548}
549impl<'a> ModuleHelper<'a> {
550    /// Returns a helper for the core module.
551    fn core(db: &'a dyn LoweringGroup) -> Self {
552        Self { db, id: corelib::core_module(db.upcast()) }
553    }
554    /// Returns a helper for a submodule named `name` of the current module.
555    fn submodule(&self, name: &str) -> Self {
556        let id = corelib::get_submodule(self.db.upcast(), self.id, name).unwrap_or_else(|| {
557            panic!("`{name}` missing in `{}`.", self.id.full_path(self.db.upcast()))
558        });
559        Self { db: self.db, id }
560    }
561    /// Returns the id of an extern function named `name` in the current module.
562    fn extern_function_id(&self, name: impl Into<SmolStr>) -> ExternFunctionId {
563        let name = name.into();
564        let Ok(Some(ModuleItemId::ExternFunction(id))) =
565            self.db.module_item_by_name(self.id, name.clone())
566        else {
567            panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast()));
568        };
569        id
570    }
571    /// Returns the id of a function named `name` in the current module, with the given
572    /// `generic_args`.
573    fn function_id(
574        &self,
575        name: impl Into<SmolStr>,
576        generic_args: Vec<GenericArgumentId>,
577    ) -> FunctionId {
578        FunctionLongId::Semantic(corelib::get_function_id(
579            self.db.upcast(),
580            self.id,
581            name.into(),
582            generic_args,
583        ))
584        .intern(self.db)
585    }
586}
587
588/// Holds static information about libfuncs required for the optimization.
589#[derive(Debug, PartialEq, Eq)]
590pub struct ConstFoldingLibfuncInfo {
591    /// The `felt252_sub` libfunc.
592    felt_sub: ExternFunctionId,
593    /// The `into_box` libfunc.
594    into_box: ExternFunctionId,
595    /// The `upcast` libfunc.
596    upcast: ExternFunctionId,
597    /// The `downcast` libfunc.
598    downcast: ExternFunctionId,
599    /// The set of functions that check if a number is zero.
600    nz_fns: OrderedHashSet<ExternFunctionId>,
601    /// The set of functions that check if numbers are equal.
602    eq_fns: OrderedHashSet<ExternFunctionId>,
603    /// The set of functions to add unsigned ints.
604    uadd_fns: OrderedHashSet<ExternFunctionId>,
605    /// The set of functions to subtract unsigned ints.
606    usub_fns: OrderedHashSet<ExternFunctionId>,
607    /// The set of functions to get the difference of signed ints.
608    diff_fns: OrderedHashSet<ExternFunctionId>,
609    /// The set of functions to add signed ints.
610    iadd_fns: OrderedHashSet<ExternFunctionId>,
611    /// The set of functions to subtract signed ints.
612    isub_fns: OrderedHashSet<ExternFunctionId>,
613    /// The set of functions to multiply integers.
614    wide_mul_fns: OrderedHashSet<ExternFunctionId>,
615    /// The set of functions to divide and get the remainder of integers.
616    div_rem_fns: OrderedHashSet<ExternFunctionId>,
617    /// The `bounded_int_add` libfunc.
618    bounded_int_add: ExternFunctionId,
619    /// The `bounded_int_sub` libfunc.
620    bounded_int_sub: ExternFunctionId,
621    /// The `bounded_int_constrain` libfunc.
622    bounded_int_constrain: ExternFunctionId,
623    /// The array module.
624    array_module: ModuleId,
625    /// The `array_get` libfunc.
626    array_get: ExternFunctionId,
627    /// The storage access module.
628    storage_access_module: ModuleId,
629    /// The `storage_base_address_from_felt252` libfunc.
630    storage_base_address_from_felt252: ExternFunctionId,
631    /// Type ranges.
632    type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
633}
634impl ConstFoldingLibfuncInfo {
635    fn new(db: &dyn LoweringGroup) -> Self {
636        let core = ModuleHelper::core(db);
637        let felt_sub = core.extern_function_id("felt252_sub");
638        let box_module = core.submodule("box");
639        let into_box = box_module.extern_function_id("into_box");
640        let integer_module = core.submodule("integer");
641        let bounded_int_module = core.submodule("internal").submodule("bounded_int");
642        let upcast = integer_module.extern_function_id("upcast");
643        let downcast = integer_module.extern_function_id("downcast");
644        let array_module = core.submodule("array");
645        let array_get = array_module.extern_function_id("array_get");
646        let starknet_module = core.submodule("starknet");
647        let storage_access_module = starknet_module.submodule("storage_access");
648        let storage_base_address_from_felt252 =
649            storage_access_module.extern_function_id("storage_base_address_from_felt252");
650        let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
651            [
652                core.extern_function_id("felt252_is_zero"),
653                bounded_int_module.extern_function_id("bounded_int_is_zero")
654            ],
655            ["u8", "u16", "u32", "u64", "u128", "u256", "i8", "i16", "i32", "i64", "i128"]
656                .map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
657        ));
658        let utypes = ["u8", "u16", "u32", "u64", "u128"];
659        let itypes = ["i8", "i16", "i32", "i64", "i128"];
660        let eq_fns = OrderedHashSet::<_>::from_iter(
661            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))),
662        );
663        let uadd_fns = OrderedHashSet::<_>::from_iter(
664            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
665        );
666        let usub_fns = OrderedHashSet::<_>::from_iter(
667            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub"))),
668        );
669        let diff_fns = OrderedHashSet::<_>::from_iter(
670            itypes.map(|ty| integer_module.extern_function_id(format!("{ty}_diff"))),
671        );
672        let iadd_fns = OrderedHashSet::<_>::from_iter(
673            itypes
674                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add_impl"))),
675        );
676        let isub_fns = OrderedHashSet::<_>::from_iter(
677            itypes
678                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
679        );
680        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
681            [bounded_int_module.extern_function_id("bounded_int_mul")],
682            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
683                .map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
684        ));
685        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
686            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
687            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))),
688        ));
689        let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add");
690        let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub");
691        let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
692        let type_value_ranges = OrderedHashMap::from_iter(
693            [
694                ("u8", BigInt::ZERO, u8::MAX.into()),
695                ("u16", BigInt::ZERO, u16::MAX.into()),
696                ("u32", BigInt::ZERO, u32::MAX.into()),
697                ("u64", BigInt::ZERO, u64::MAX.into()),
698                ("u128", BigInt::ZERO, u128::MAX.into()),
699                ("u256", BigInt::ZERO, BigInt::from(1) << 256),
700                ("i8", i8::MIN.into(), i8::MAX.into()),
701                ("i16", i16::MIN.into(), i16::MAX.into()),
702                ("i32", i32::MIN.into(), i32::MAX.into()),
703                ("i64", i64::MIN.into(), i64::MAX.into()),
704                ("i128", i128::MIN.into(), i128::MAX.into()),
705            ]
706            .map(|(ty, min, max): (&str, BigInt, BigInt)| {
707                let info = TypeInfo {
708                    min,
709                    max,
710                    is_zero: integer_module.function_id(format!("{ty}_is_zero"), vec![]),
711                };
712                (corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), info)
713            }),
714        );
715        Self {
716            felt_sub,
717            into_box,
718            upcast,
719            downcast,
720            nz_fns,
721            eq_fns,
722            uadd_fns,
723            usub_fns,
724            diff_fns,
725            iadd_fns,
726            isub_fns,
727            wide_mul_fns,
728            div_rem_fns,
729            bounded_int_add,
730            bounded_int_sub,
731            bounded_int_constrain,
732            array_module: array_module.id,
733            array_get,
734            storage_access_module: storage_access_module.id,
735            storage_base_address_from_felt252,
736            type_value_ranges,
737        }
738    }
739}
740
741impl std::ops::Deref for ConstFoldingContext<'_> {
742    type Target = ConstFoldingLibfuncInfo;
743    fn deref(&self) -> &ConstFoldingLibfuncInfo {
744        self.libfunc_info
745    }
746}
747
748/// The information of a type required for const foldings.
749#[derive(Debug, PartialEq, Eq)]
750struct TypeInfo {
751    /// The minimum value of the type.
752    min: BigInt,
753    /// The maximum value of the type.
754    max: BigInt,
755    /// The function to check if the value is zero for the type.
756    is_zero: FunctionId,
757}
758impl TypeInfo {
759    /// Normalizes the value to the range.
760    /// Assumes the value is within size of range of the range.
761    fn normalized(&self, value: BigInt) -> NormalizedResult {
762        if value < self.min {
763            NormalizedResult::Under(value - &self.min + &self.max + 1)
764        } else if value > self.max {
765            NormalizedResult::Over(value + &self.min - &self.max - 1)
766        } else {
767            NormalizedResult::InRange(value)
768        }
769    }
770}
771
772/// The result of normalizing a value to a range.
773enum NormalizedResult {
774    /// The original value is in the range, carries the value, or an equivalent value.
775    InRange(BigInt),
776    /// The original value is larger than range max, carries the normalized value.
777    Over(BigInt),
778    /// The original value is smaller than range min, carries the normalized value.
779    Under(BigInt),
780}