cairo_lang_lowering/optimizations/
remappings.rs

1//! Remove unnecessary remapping of variables optimization.
2//!
3//! At each convergence, we have one or more branches with remappings of variables.
4//! A destination variable `dest` introduced by the remappings must be remapped at every branch
5//! `b_i` by mapping a source variable `src_i->dest`.
6//! We require that every use of `dest` refers to the correct `src_i`.
7//! This means that the remappings to `dest` are not necessary in these cases:
8//! 1. There is no flow that uses the "value" of `dest` after the convergence.
9//! 2. All the `src_i` variables get the same "value".
10
11use std::collections::{HashMap, HashSet};
12
13use itertools::Itertools;
14
15use crate::utils::{Rebuilder, RebuilderEx};
16use crate::{BlockId, FlatBlockEnd, FlatLowered, VarRemapping, VariableId};
17
18/// Visits all the reachable remappings in the function, calls `f` on each one and returns a vector
19/// indicating which blocks are reachable.
20pub(crate) fn visit_remappings<F: FnMut(&VarRemapping)>(
21    lowered: &mut FlatLowered,
22    mut f: F,
23) -> Vec<bool> {
24    let mut stack = vec![BlockId::root()];
25    let mut visited = vec![false; lowered.blocks.len()];
26    while let Some(block_id) = stack.pop() {
27        if visited[block_id.0] {
28            continue;
29        }
30        visited[block_id.0] = true;
31        match &lowered.blocks[block_id].end {
32            FlatBlockEnd::Goto(target_block_id, remapping) => {
33                stack.push(*target_block_id);
34                f(remapping)
35            }
36            FlatBlockEnd::Match { info } => {
37                stack.extend(info.arms().iter().map(|arm| arm.block_id));
38            }
39            FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(_) => {}
40            FlatBlockEnd::NotSet => unreachable!(),
41        }
42    }
43
44    visited
45}
46
47/// Context for the optimize remappings optimization.
48#[derive(Default)]
49pub(crate) struct Context {
50    /// Maps a destination variable to the source variables that are remapped to it.
51    pub dest_to_srcs: HashMap<VariableId, Vec<VariableId>>,
52    /// Cache of a mapping from variable id in the old lowering to variable id in the new lowering.
53    /// This mapping is built on demand.
54    var_representatives: HashMap<VariableId, VariableId>,
55    /// The set of variables that is used by a reachable blocks.
56    variable_used: HashSet<VariableId>,
57}
58impl Context {
59    /// Find the `canonical` variable that `var` maps to and mark it as used.
60    pub fn set_used(&mut self, var: VariableId) {
61        let var = self.map_var_id(var);
62        if self.variable_used.insert(var) {
63            for src in self.dest_to_srcs.get(&var).cloned().unwrap_or_default() {
64                self.set_used(src);
65            }
66        }
67    }
68}
69
70impl Rebuilder for Context {
71    fn map_var_id(&mut self, var: VariableId) -> VariableId {
72        if let Some(res) = self.var_representatives.get(&var) {
73            *res
74        } else {
75            let srcs = self.dest_to_srcs.get(&var).cloned().unwrap_or_default();
76            let src_representatives: HashSet<_> =
77                srcs.iter().map(|src| self.map_var_id(*src)).collect();
78            let src_representatives = src_representatives.into_iter().collect_vec();
79            let new_var =
80                if let [single_var] = &src_representatives[..] { *single_var } else { var };
81            self.var_representatives.insert(var, new_var);
82            new_var
83        }
84    }
85
86    fn transform_remapping(&mut self, remapping: &mut VarRemapping) {
87        let mut new_remapping = VarRemapping::default();
88        for (dst, src) in remapping.iter() {
89            if dst != &src.var_id && self.variable_used.contains(dst) {
90                new_remapping.insert(*dst, *src);
91            }
92        }
93        *remapping = new_remapping;
94    }
95}
96
97pub fn optimize_remappings(lowered: &mut FlatLowered) {
98    if lowered.blocks.has_root().is_err() {
99        return;
100    }
101
102    // Find condition 1 (see module doc).
103    let mut ctx = Context::default();
104    let reachable = visit_remappings(lowered, |remapping| {
105        for (dst, src) in remapping.iter() {
106            ctx.dest_to_srcs.entry(*dst).or_default().push(src.var_id);
107        }
108    });
109
110    // Find condition 2 (see module doc).
111    for ((_, block), is_reachable) in lowered.blocks.iter().zip(reachable) {
112        if !is_reachable {
113            continue;
114        }
115
116        for stmt in &block.statements {
117            for var_usage in stmt.inputs() {
118                ctx.set_used(var_usage.var_id);
119            }
120        }
121        match &block.end {
122            FlatBlockEnd::Return(returns, _location) => {
123                for var_usage in returns {
124                    ctx.set_used(var_usage.var_id);
125                }
126            }
127            FlatBlockEnd::Panic(data) => {
128                ctx.set_used(data.var_id);
129            }
130            FlatBlockEnd::Goto(_, _) => {}
131            FlatBlockEnd::Match { info } => {
132                for var_usage in info.inputs() {
133                    ctx.set_used(var_usage.var_id);
134                }
135            }
136            FlatBlockEnd::NotSet => unreachable!(),
137        }
138    }
139
140    // Rebuild the blocks without unnecessary remappings.
141    for block in lowered.blocks.iter_mut() {
142        *block = ctx.rebuild_block(block);
143    }
144}