cairo_lang_lowering/optimizations/
reorder_statements.rs

1#[cfg(test)]
2#[path = "reorder_statements_test.rs"]
3mod test;
4
5use std::cmp::Reverse;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
9use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
10use itertools::{Itertools, zip_eq};
11
12use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::db::LoweringGroup;
14use crate::ids::FunctionId;
15use crate::{
16    BlockId, FlatLowered, MatchInfo, Statement, StatementCall, VarRemapping, VarUsage, VariableId,
17};
18
19/// Reorder the statements in the lowering in order to move variable definitions closer to their
20/// usage. Statement with no side effects and unused outputs are removed.
21///
22/// The list of call statements that can be moved is currently hardcoded.
23///
24/// Removing unnecessary remapping before this optimization will result in better code.
25pub fn reorder_statements(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
26    if lowered.blocks.is_empty() {
27        return;
28    }
29    let ctx = ReorderStatementsContext {
30        lowered: &*lowered,
31        moveable_functions: &db.priv_movable_function_ids(),
32        statement_to_move: vec![],
33    };
34    let mut analysis = BackAnalysis::new(lowered, ctx);
35    analysis.get_root_info();
36    let ctx = analysis.analyzer;
37
38    let mut changes_by_block =
39        OrderedHashMap::<BlockId, Vec<(usize, Option<Statement>)>>::default();
40
41    for (src, opt_dst) in ctx.statement_to_move.into_iter() {
42        changes_by_block.entry(src.0).or_insert_with(Vec::new).push((src.1, None));
43
44        if let Some(dst) = opt_dst {
45            let statement = lowered.blocks[src.0].statements[src.1].clone();
46            changes_by_block.entry(dst.0).or_insert_with(Vec::new).push((dst.1, Some(statement)));
47        }
48    }
49
50    for (block_id, block_changes) in changes_by_block.into_iter() {
51        let statements = &mut lowered.blocks[block_id].statements;
52
53        // Apply block changes in reverse order to prevent a change from invalidating the
54        // indices of the other changes.
55        for (index, opt_statement) in
56            block_changes.into_iter().sorted_by_key(|(index, _)| Reverse(*index))
57        {
58            match opt_statement {
59                Some(stmt) => statements.insert(index, stmt),
60                None => {
61                    statements.remove(index);
62                }
63            }
64        }
65    }
66}
67
68#[derive(Clone, Default)]
69pub struct ReorderStatementsInfo {
70    // A mapping from var_id to a candidate location that it can be moved to.
71    // If the variable is used in multiple match arms we define the next use to be
72    // the match.
73    next_use: UnorderedHashMap<VariableId, StatementLocation>,
74}
75
76pub struct ReorderStatementsContext<'a> {
77    lowered: &'a FlatLowered,
78    // A list of function that can be moved.
79    moveable_functions: &'a UnorderedHashSet<FunctionId>,
80    statement_to_move: Vec<(StatementLocation, Option<StatementLocation>)>,
81}
82impl ReorderStatementsContext<'_> {
83    fn call_can_be_moved(&mut self, stmt: &StatementCall) -> bool {
84        self.moveable_functions.contains(&stmt.function)
85    }
86}
87impl Analyzer<'_> for ReorderStatementsContext<'_> {
88    type Info = ReorderStatementsInfo;
89
90    fn visit_stmt(
91        &mut self,
92        info: &mut Self::Info,
93        statement_location: StatementLocation,
94        stmt: &Statement,
95    ) {
96        let mut immovable = matches!(stmt, Statement::Call(stmt) if !self.call_can_be_moved(stmt));
97        let mut optional_target_location = None;
98        for var_to_move in stmt.outputs() {
99            let Some((block_id, index)) = info.next_use.remove(var_to_move) else { continue };
100            if let Some((target_block_id, target_index)) = &mut optional_target_location {
101                *target_index = std::cmp::min(*target_index, index);
102                // If the output is used in multiple places we can't move their creation point.
103                immovable |= target_block_id != &block_id;
104            } else {
105                optional_target_location = Some((block_id, index));
106            }
107        }
108        if immovable {
109            for var_usage in stmt.inputs() {
110                info.next_use.insert(var_usage.var_id, statement_location);
111            }
112            return;
113        }
114
115        if let Some(target_location) = optional_target_location {
116            // If the statement is not removed add demand for its inputs.
117            for var_usage in stmt.inputs() {
118                match info.next_use.entry(var_usage.var_id) {
119                    Entry::Occupied(mut e) => {
120                        // Since we don't know where `e.get()` and `target_location` converge
121                        // we use `statement_location` as a conservative estimate.
122                        &e.insert(statement_location)
123                    }
124                    Entry::Vacant(e) => e.insert(target_location),
125                };
126            }
127
128            self.statement_to_move.push((statement_location, Some(target_location)))
129        } else if stmt.inputs().iter().all(|v| self.lowered.variables[v.var_id].droppable.is_ok()) {
130            // If a movable statement is unused, and all its inputs are droppable removing it is
131            // valid.
132            self.statement_to_move.push((statement_location, None))
133        } else {
134            // Statement is unused but can't be removed.
135            for var_usage in stmt.inputs() {
136                info.next_use.insert(var_usage.var_id, statement_location);
137            }
138        }
139    }
140
141    fn visit_goto(
142        &mut self,
143        info: &mut Self::Info,
144        statement_location: StatementLocation,
145        _target_block_id: BlockId,
146        remapping: &VarRemapping,
147    ) {
148        for VarUsage { var_id, .. } in remapping.values() {
149            info.next_use.insert(*var_id, statement_location);
150        }
151    }
152
153    fn merge_match(
154        &mut self,
155        statement_location: StatementLocation,
156        match_info: &MatchInfo,
157        infos: impl Iterator<Item = Self::Info>,
158    ) -> Self::Info {
159        let mut infos = zip_eq(infos, match_info.arms()).map(|(mut info, arm)| {
160            for var_id in &arm.var_ids {
161                info.next_use.remove(var_id);
162            }
163            info
164        });
165        let mut info = infos.next().unwrap_or_default();
166        for arm_info in infos {
167            info.next_use.merge(&arm_info.next_use, |e, _| {
168                *e.into_mut() = statement_location;
169            });
170        }
171
172        for var_usage in match_info.inputs() {
173            info.next_use.insert(var_usage.var_id, statement_location);
174        }
175
176        info
177    }
178
179    fn info_from_return(
180        &mut self,
181        statement_location: StatementLocation,
182        vars: &[VarUsage],
183    ) -> Self::Info {
184        let mut info = Self::Info::default();
185        for var_usage in vars {
186            info.next_use.insert(var_usage.var_id, statement_location);
187        }
188        info
189    }
190}