cairo_lang_lowering/optimizations/
reorder_statements.rs

1#[cfg(test)]
2#[path = "reorder_statements_test.rs"]
3mod test;
4
5use std::cmp::Reverse;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
9use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
10use itertools::{Itertools, zip_eq};
11
12use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::db::LoweringGroup;
14use crate::ids::FunctionId;
15use crate::{
16    BlockId, FlatLowered, MatchInfo, Statement, StatementCall, VarRemapping, VarUsage, VariableId,
17};
18
19/// Reorder the statements in the lowering in order to move variable definitions closer to their
20/// usage. Statement with no side effects and unused outputs are removed.
21///
22/// The list of call statements that can be moved is currently hardcoded.
23///
24/// Removing unnecessary remapping before this optimization will result in better code.
25pub fn reorder_statements(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
26    if lowered.blocks.is_empty() {
27        return;
28    }
29    let ctx = ReorderStatementsContext {
30        lowered: &*lowered,
31        moveable_functions: &db.priv_movable_function_ids(),
32        statement_to_move: vec![],
33    };
34    let mut analysis = BackAnalysis::new(lowered, ctx);
35    analysis.get_root_info();
36    let ctx = analysis.analyzer;
37
38    let mut changes_by_block =
39        OrderedHashMap::<BlockId, Vec<(usize, Option<Statement>)>>::default();
40
41    for (src, opt_dst) in ctx.statement_to_move.into_iter() {
42        changes_by_block.entry(src.0).or_insert_with(Vec::new).push((src.1, None));
43
44        if let Some(dst) = opt_dst {
45            let statement = lowered.blocks[src.0].statements[src.1].clone();
46            changes_by_block.entry(dst.0).or_insert_with(Vec::new).push((dst.1, Some(statement)));
47        }
48    }
49
50    for (block_id, block_changes) in changes_by_block.into_iter() {
51        let statements = &mut lowered.blocks[block_id].statements;
52        let block_len = statements.len();
53
54        // Apply block changes in reverse order to prevent a change from invalidating the
55        // indices of the other changes.
56        for (index, opt_statement) in
57            block_changes.into_iter().sorted_by_key(|(index, _)| Reverse(*index))
58        {
59            match opt_statement {
60                Some(stmt) => {
61                    // If index > block_len, we insert the statement at the end of the block.
62                    statements.insert(std::cmp::min(index, block_len), stmt)
63                }
64                None => {
65                    statements.remove(index);
66                }
67            }
68        }
69    }
70}
71
72#[derive(Clone, Default)]
73pub struct ReorderStatementsInfo {
74    // A mapping from var_id to a candidate location that it can be moved to.
75    // If the variable is used in multiple match arms we define the next use to be
76    // the match.
77
78    // Note that StatementLocation.0 might >= block.len() and it means that
79    // the variable should be inserted at the end of the block.
80    next_use: UnorderedHashMap<VariableId, StatementLocation>,
81}
82
83pub struct ReorderStatementsContext<'a> {
84    lowered: &'a FlatLowered,
85    // A list of function that can be moved.
86    moveable_functions: &'a UnorderedHashSet<FunctionId>,
87    statement_to_move: Vec<(StatementLocation, Option<StatementLocation>)>,
88}
89impl ReorderStatementsContext<'_> {
90    fn call_can_be_moved(&mut self, stmt: &StatementCall) -> bool {
91        self.moveable_functions.contains(&stmt.function)
92    }
93}
94impl Analyzer<'_> for ReorderStatementsContext<'_> {
95    type Info = ReorderStatementsInfo;
96
97    fn visit_stmt(
98        &mut self,
99        info: &mut Self::Info,
100        statement_location: StatementLocation,
101        stmt: &Statement,
102    ) {
103        let mut immovable = matches!(stmt, Statement::Call(stmt) if !self.call_can_be_moved(stmt));
104        let mut optional_target_location = None;
105        for var_to_move in stmt.outputs() {
106            let Some((block_id, index)) = info.next_use.remove(var_to_move) else { continue };
107            if let Some((target_block_id, target_index)) = &mut optional_target_location {
108                *target_index = std::cmp::min(*target_index, index);
109                // If the output is used in multiple places we can't move their creation point.
110                immovable |= target_block_id != &block_id;
111            } else {
112                optional_target_location = Some((block_id, index));
113            }
114        }
115        if immovable {
116            for var_usage in stmt.inputs() {
117                info.next_use.insert(var_usage.var_id, statement_location);
118            }
119            return;
120        }
121
122        if let Some(target_location) = optional_target_location {
123            // If the statement is not removed add demand for its inputs.
124            for var_usage in stmt.inputs() {
125                match info.next_use.entry(var_usage.var_id) {
126                    Entry::Occupied(mut e) => {
127                        // Since we don't know where `e.get()` and `target_location` converge
128                        // we use `statement_location` as a conservative estimate.
129                        &e.insert(statement_location)
130                    }
131                    Entry::Vacant(e) => e.insert(target_location),
132                };
133            }
134
135            self.statement_to_move.push((statement_location, Some(target_location)))
136        } else if stmt.inputs().iter().all(|v| self.lowered.variables[v.var_id].droppable.is_ok()) {
137            // If a movable statement is unused, and all its inputs are droppable removing it is
138            // valid.
139            self.statement_to_move.push((statement_location, None))
140        } else {
141            // Statement is unused but can't be removed.
142            for var_usage in stmt.inputs() {
143                info.next_use.insert(var_usage.var_id, statement_location);
144            }
145        }
146    }
147
148    fn visit_goto(
149        &mut self,
150        info: &mut Self::Info,
151        statement_location: StatementLocation,
152        _target_block_id: BlockId,
153        remapping: &VarRemapping,
154    ) {
155        for VarUsage { var_id, .. } in remapping.values() {
156            info.next_use.insert(*var_id, statement_location);
157        }
158    }
159
160    fn merge_match(
161        &mut self,
162        statement_location: StatementLocation,
163        match_info: &MatchInfo,
164        infos: impl Iterator<Item = Self::Info>,
165    ) -> Self::Info {
166        let mut infos = zip_eq(infos, match_info.arms()).map(|(mut info, arm)| {
167            for var_id in &arm.var_ids {
168                info.next_use.remove(var_id);
169            }
170            info
171        });
172        let mut info = infos.next().unwrap_or_default();
173        for arm_info in infos {
174            info.next_use.merge(&arm_info.next_use, |e, _| {
175                *e.into_mut() = statement_location;
176            });
177        }
178
179        for var_usage in match_info.inputs() {
180            // Make sure we insert the match inputs after the variables that are used in the arms.
181            info.next_use
182                .insert(var_usage.var_id, (statement_location.0, statement_location.1 + 1));
183        }
184
185        info
186    }
187
188    fn info_from_return(
189        &mut self,
190        statement_location: StatementLocation,
191        vars: &[VarUsage],
192    ) -> Self::Info {
193        let mut info = Self::Info::default();
194        for var_usage in vars {
195            info.next_use.insert(var_usage.var_id, statement_location);
196        }
197        info
198    }
199}