cairo_lang_lowering/optimizations/
remappings.rs1use std::collections::{HashMap, HashSet};
12
13use itertools::Itertools;
14
15use crate::utils::{Rebuilder, RebuilderEx};
16use crate::{BlockId, FlatBlockEnd, FlatLowered, VarRemapping, VariableId};
17
18pub(crate) fn visit_remappings<F: FnMut(&VarRemapping)>(
21 lowered: &mut FlatLowered,
22 mut f: F,
23) -> Vec<bool> {
24 let mut stack = vec![BlockId::root()];
25 let mut visited = vec![false; lowered.blocks.len()];
26 while let Some(block_id) = stack.pop() {
27 if visited[block_id.0] {
28 continue;
29 }
30 visited[block_id.0] = true;
31 match &lowered.blocks[block_id].end {
32 FlatBlockEnd::Goto(target_block_id, remapping) => {
33 stack.push(*target_block_id);
34 f(remapping)
35 }
36 FlatBlockEnd::Match { info } => {
37 stack.extend(info.arms().iter().map(|arm| arm.block_id));
38 }
39 FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(_) => {}
40 FlatBlockEnd::NotSet => unreachable!(),
41 }
42 }
43
44 visited
45}
46
47#[derive(Default)]
49pub(crate) struct Context {
50 pub dest_to_srcs: HashMap<VariableId, Vec<VariableId>>,
52 var_representatives: HashMap<VariableId, VariableId>,
55 variable_used: HashSet<VariableId>,
57}
58impl Context {
59 pub fn set_used(&mut self, var: VariableId) {
61 let var = self.map_var_id(var);
62 if self.variable_used.insert(var) {
63 for src in self.dest_to_srcs.get(&var).cloned().unwrap_or_default() {
64 self.set_used(src);
65 }
66 }
67 }
68}
69
70impl Rebuilder for Context {
71 fn map_var_id(&mut self, var: VariableId) -> VariableId {
72 if let Some(res) = self.var_representatives.get(&var) {
73 *res
74 } else {
75 let srcs = self.dest_to_srcs.get(&var).cloned().unwrap_or_default();
76 let src_representatives: HashSet<_> =
77 srcs.iter().map(|src| self.map_var_id(*src)).collect();
78 let src_representatives = src_representatives.into_iter().collect_vec();
79 let new_var =
80 if let [single_var] = &src_representatives[..] { *single_var } else { var };
81 self.var_representatives.insert(var, new_var);
82 new_var
83 }
84 }
85
86 fn transform_remapping(&mut self, remapping: &mut VarRemapping) {
87 let mut new_remapping = VarRemapping::default();
88 for (dst, src) in remapping.iter() {
89 if dst != &src.var_id && self.variable_used.contains(dst) {
90 new_remapping.insert(*dst, *src);
91 }
92 }
93 *remapping = new_remapping;
94 }
95}
96
97pub fn optimize_remappings(lowered: &mut FlatLowered) {
98 if lowered.blocks.has_root().is_err() {
99 return;
100 }
101
102 let mut ctx = Context::default();
104 let reachable = visit_remappings(lowered, |remapping| {
105 for (dst, src) in remapping.iter() {
106 ctx.dest_to_srcs.entry(*dst).or_default().push(src.var_id);
107 }
108 });
109
110 for ((_, block), is_reachable) in lowered.blocks.iter().zip(reachable) {
112 if !is_reachable {
113 continue;
114 }
115
116 for stmt in &block.statements {
117 for var_usage in stmt.inputs() {
118 ctx.set_used(var_usage.var_id);
119 }
120 }
121 match &block.end {
122 FlatBlockEnd::Return(returns, _location) => {
123 for var_usage in returns {
124 ctx.set_used(var_usage.var_id);
125 }
126 }
127 FlatBlockEnd::Panic(data) => {
128 ctx.set_used(data.var_id);
129 }
130 FlatBlockEnd::Goto(_, _) => {}
131 FlatBlockEnd::Match { info } => {
132 for var_usage in info.inputs() {
133 ctx.set_used(var_usage.var_id);
134 }
135 }
136 FlatBlockEnd::NotSet => unreachable!(),
137 }
138 }
139
140 for block in lowered.blocks.iter_mut() {
142 *block = ctx.rebuild_block(block);
143 }
144}