1use indexmap::IndexMap;
2use rustc_hash::FxHashMap;
8use std::collections::HashSet;
9use sway_utils::mapped_stack::MappedStack;
10
11use crate::{
12 AnalysisResults, Block, BranchToWithArgs, Constant, Context, DomFronts, DomTree, Function,
13 InstOp, Instruction, IrError, LocalVar, Pass, PassMutability, PostOrder, ScopedPass, Type,
14 Value, ValueDatum, DOMINATORS_NAME, DOM_FRONTS_NAME, POSTORDER_NAME,
15};
16
17pub const MEM2REG_NAME: &str = "mem2reg";
18
19pub fn create_mem2reg_pass() -> Pass {
20 Pass {
21 name: MEM2REG_NAME,
22 descr: "Promotion of memory to SSA registers",
23 deps: vec![POSTORDER_NAME, DOMINATORS_NAME, DOM_FRONTS_NAME],
24 runner: ScopedPass::FunctionPass(PassMutability::Transform(promote_to_registers)),
25 }
26}
27
28fn get_validate_local_var(
30 context: &Context,
31 function: &Function,
32 val: &Value,
33) -> Option<(String, LocalVar)> {
34 match context.values[val.0].value {
35 ValueDatum::Instruction(Instruction {
36 op: InstOp::GetLocal(local_var),
37 ..
38 }) => {
39 let name = function.lookup_local_name(context, &local_var);
40 name.map(|name| (name.clone(), local_var))
41 }
42 _ => None,
43 }
44}
45
46fn is_promotable_type(context: &Context, ty: Type) -> bool {
47 ty.is_unit(context)
48 || ty.is_bool(context)
49 || (ty.is_uint(context) && ty.get_uint_width(context).unwrap() <= 64)
50}
51
52fn filter_usable_locals(context: &mut Context, function: &Function) -> HashSet<String> {
54 let mut locals: HashSet<String> = function
57 .locals_iter(context)
58 .filter_map(|(name, var)| {
59 let ty = var.get_inner_type(context);
60 is_promotable_type(context, ty).then_some(name.clone())
61 })
62 .collect();
63
64 for (_, inst) in function.instruction_iter(context) {
65 match context.values[inst.0].value {
66 ValueDatum::Instruction(Instruction {
67 op: InstOp::Load(_),
68 ..
69 })
70 | ValueDatum::Instruction(Instruction {
71 op: InstOp::Store { .. },
72 ..
73 }) => {
74 }
76 _ => {
77 let operands = inst.get_instruction(context).unwrap().op.get_operands();
79 for opd in operands {
80 if let Some((local, ..)) = get_validate_local_var(context, function, &opd) {
81 locals.remove(&local);
82 }
83 }
84 }
85 }
86 }
87 locals
88}
89
90pub fn compute_livein(
94 context: &mut Context,
95 function: &Function,
96 po: &PostOrder,
97 locals: &HashSet<String>,
98) -> FxHashMap<Block, HashSet<String>> {
99 let mut result = FxHashMap::<Block, HashSet<String>>::default();
100 for block in &po.po_to_block {
101 result.insert(*block, HashSet::<String>::default());
102 }
103
104 let mut changed = true;
105 while changed {
106 changed = false;
107 for block in &po.po_to_block {
108 let mut cur_live = HashSet::<String>::default();
110 for BranchToWithArgs { block: succ, .. } in block.successors(context) {
111 let succ_livein = &result[&succ];
112 cur_live.extend(succ_livein.iter().cloned());
113 }
114 for inst in block.instruction_iter(context).rev() {
116 match context.values[inst.0].value {
117 ValueDatum::Instruction(Instruction {
118 op: InstOp::Load(ptr),
119 ..
120 }) => {
121 let local_var = get_validate_local_var(context, function, &ptr);
122 match local_var {
123 Some((local, ..)) if locals.contains(&local) => {
124 cur_live.insert(local);
125 }
126 _ => {}
127 }
128 }
129 ValueDatum::Instruction(Instruction {
130 op: InstOp::Store { dst_val_ptr, .. },
131 ..
132 }) => {
133 let local_var = get_validate_local_var(context, function, &dst_val_ptr);
134 match local_var {
135 Some((local, _)) if locals.contains(&local) => {
136 cur_live.remove(&local);
137 }
138 _ => (),
139 }
140 }
141 _ => (),
142 }
143 }
144 if result[block] != cur_live {
145 result.get_mut(block).unwrap().extend(cur_live);
147 changed = true;
148 }
149 }
150 }
151 result
152}
153
154fn promote_globals(context: &mut Context, function: &Function) -> Result<bool, IrError> {
157 let mut replacements = FxHashMap::<Value, Constant>::default();
158 for (_, inst) in function.instruction_iter(context) {
159 if let ValueDatum::Instruction(Instruction {
160 op: InstOp::Load(ptr),
161 ..
162 }) = context.values[inst.0].value
163 {
164 if let ValueDatum::Instruction(Instruction {
165 op: InstOp::GetGlobal(global_var),
166 ..
167 }) = context.values[ptr.0].value
168 {
169 if !global_var.is_mutable(context)
170 && is_promotable_type(context, global_var.get_inner_type(context))
171 {
172 let constant = *global_var
173 .get_initializer(context)
174 .expect("`global_var` is not mutable so it must be initialized");
175 replacements.insert(inst, constant);
176 }
177 }
178 }
179 }
180
181 if replacements.is_empty() {
182 return Ok(false);
183 }
184
185 let replacements = replacements
186 .into_iter()
187 .map(|(k, v)| (k, Value::new_constant(context, v)))
188 .collect::<FxHashMap<_, _>>();
189
190 function.replace_values(context, &replacements, None);
191
192 Ok(true)
193}
194
195pub fn promote_to_registers(
197 context: &mut Context,
198 analyses: &AnalysisResults,
199 function: Function,
200) -> Result<bool, IrError> {
201 let mut modified = false;
202 modified |= promote_globals(context, &function)?;
203 modified |= promote_locals(context, analyses, function)?;
204 Ok(modified)
205}
206
207pub fn promote_locals(
211 context: &mut Context,
212 analyses: &AnalysisResults,
213 function: Function,
214) -> Result<bool, IrError> {
215 let safe_locals = filter_usable_locals(context, &function);
216 if safe_locals.is_empty() {
217 return Ok(false);
218 }
219
220 let po: &PostOrder = analyses.get_analysis_result(function);
221 let dom_tree: &DomTree = analyses.get_analysis_result(function);
222 let dom_fronts: &DomFronts = analyses.get_analysis_result(function);
223 let liveins = compute_livein(context, &function, po, &safe_locals);
224
225 let mut new_phi_tracker = HashSet::<(String, Block)>::new();
227 let mut worklist = Vec::<(String, Type, Block)>::new();
229 let mut phi_to_local = FxHashMap::<Value, String>::default();
230 for (block, inst) in po
234 .po_to_block
235 .iter()
236 .rev()
237 .flat_map(|b| b.instruction_iter(context).map(|i| (*b, i)))
238 {
239 if let ValueDatum::Instruction(Instruction {
240 op: InstOp::Store { dst_val_ptr, .. },
241 ..
242 }) = context.values[inst.0].value
243 {
244 match get_validate_local_var(context, &function, &dst_val_ptr) {
245 Some((local, var)) if safe_locals.contains(&local) => {
246 worklist.push((local, var.get_inner_type(context), block));
247 }
248 _ => (),
249 }
250 }
251 }
252 while let Some((local, ty, known_def)) = worklist.pop() {
254 for df in dom_fronts[&known_def].iter() {
255 if !new_phi_tracker.contains(&(local.clone(), *df)) && liveins[df].contains(&local) {
256 let index = df.new_arg(context, ty);
258 phi_to_local.insert(df.get_arg(context, index).unwrap(), local.clone());
259 new_phi_tracker.insert((local.clone(), *df));
260 worklist.push((local.clone(), ty, *df));
262 }
263 }
264 }
265
266 #[allow(clippy::too_many_arguments)]
270 fn record_rewrites(
271 context: &mut Context,
272 function: &Function,
273 dom_tree: &DomTree,
274 node: Block,
275 safe_locals: &HashSet<String>,
276 phi_to_local: &FxHashMap<Value, String>,
277 name_stack: &mut MappedStack<String, Value>,
278 rewrites: &mut FxHashMap<Value, Value>,
279 deletes: &mut Vec<(Block, Value)>,
280 ) {
281 let mut num_local_pushes = IndexMap::<String, u32>::new();
284
285 for arg in node.arg_iter(context) {
287 if let Some(local) = phi_to_local.get(arg) {
288 name_stack.push(local.clone(), *arg);
289 num_local_pushes
290 .entry(local.clone())
291 .and_modify(|count| *count += 1)
292 .or_insert(1);
293 }
294 }
295
296 for inst in node.instruction_iter(context) {
297 match context.values[inst.0].value {
298 ValueDatum::Instruction(Instruction {
299 op: InstOp::Load(ptr),
300 ..
301 }) => {
302 let local_var = get_validate_local_var(context, function, &ptr);
303 match local_var {
304 Some((local, var)) if safe_locals.contains(&local) => {
305 let new_val = match name_stack.get(&local) {
307 Some(val) => *val,
308 None => {
309 let constant = *var
311 .get_initializer(context)
312 .expect("We're dealing with an uninitialized value");
313 Value::new_constant(context, constant)
314 }
315 };
316 rewrites.insert(inst, new_val);
317 deletes.push((node, inst));
318 }
319 _ => (),
320 }
321 }
322 ValueDatum::Instruction(Instruction {
323 op:
324 InstOp::Store {
325 dst_val_ptr,
326 stored_val,
327 },
328 ..
329 }) => {
330 let local_var = get_validate_local_var(context, function, &dst_val_ptr);
331 match local_var {
332 Some((local, _)) if safe_locals.contains(&local) => {
333 name_stack.push(local.clone(), stored_val);
336 num_local_pushes
337 .entry(local)
338 .and_modify(|count| *count += 1)
339 .or_insert(1);
340 deletes.push((node, inst));
341 }
342 _ => (),
343 }
344 }
345 _ => (),
346 }
347 }
348
349 for BranchToWithArgs { block: succ, .. } in node.successors(context) {
351 let args: Vec<_> = succ.arg_iter(context).copied().collect();
352 for arg in args {
355 if let Some(local) = phi_to_local.get(&arg) {
356 let ptr = function.get_local_var(context, local).unwrap();
357 let new_val = match name_stack.get(local) {
358 Some(val) => *val,
359 None => {
360 let constant = *ptr
362 .get_initializer(context)
363 .expect("We're dealing with an uninitialized value");
364 Value::new_constant(context, constant)
365 }
366 };
367 let params = node.get_succ_params_mut(context, &succ).unwrap();
368 params.push(new_val);
369 }
370 }
371 }
372
373 for child in dom_tree.children(node) {
375 record_rewrites(
376 context,
377 function,
378 dom_tree,
379 child,
380 safe_locals,
381 phi_to_local,
382 name_stack,
383 rewrites,
384 deletes,
385 );
386 }
387
388 for (local, pushes) in num_local_pushes.iter() {
390 for _ in 0..*pushes {
391 name_stack.pop(local);
392 }
393 }
394 }
395
396 let mut name_stack = MappedStack::<String, Value>::default();
397 let mut value_replacement = FxHashMap::<Value, Value>::default();
398 let mut delete_insts = Vec::<(Block, Value)>::new();
399 record_rewrites(
400 context,
401 &function,
402 dom_tree,
403 function.get_entry_block(context),
404 &safe_locals,
405 &phi_to_local,
406 &mut name_stack,
407 &mut value_replacement,
408 &mut delete_insts,
409 );
410
411 function.replace_values(context, &value_replacement, None);
413 for (block, inst) in delete_insts {
415 block.remove_instruction(context, inst);
416 }
417
418 Ok(true)
419}