sway_ir/optimize/
mem2reg.rs

1use indexmap::IndexMap;
2/// Promote local memory to SSA registers.
3/// This pass is essentially SSA construction. A good readable reference is:
4/// https://www.cs.princeton.edu/~appel/modern/c/
5/// We use block arguments instead of explicit PHI nodes. Conceptually,
6/// they are both the same.
7use rustc_hash::FxHashMap;
8use std::collections::HashSet;
9use sway_utils::mapped_stack::MappedStack;
10
11use crate::{
12    AnalysisResults, Block, BranchToWithArgs, Constant, Context, DomFronts, DomTree, Function,
13    InstOp, Instruction, IrError, LocalVar, Pass, PassMutability, PostOrder, ScopedPass, Type,
14    Value, ValueDatum, DOMINATORS_NAME, DOM_FRONTS_NAME, POSTORDER_NAME,
15};
16
17pub const MEM2REG_NAME: &str = "mem2reg";
18
19pub fn create_mem2reg_pass() -> Pass {
20    Pass {
21        name: MEM2REG_NAME,
22        descr: "Promotion of memory to SSA registers",
23        deps: vec![POSTORDER_NAME, DOMINATORS_NAME, DOM_FRONTS_NAME],
24        runner: ScopedPass::FunctionPass(PassMutability::Transform(promote_to_registers)),
25    }
26}
27
28// Check if a value is a valid (for our optimization) local pointer
29fn get_validate_local_var(
30    context: &Context,
31    function: &Function,
32    val: &Value,
33) -> Option<(String, LocalVar)> {
34    match context.values[val.0].value {
35        ValueDatum::Instruction(Instruction {
36            op: InstOp::GetLocal(local_var),
37            ..
38        }) => {
39            let name = function.lookup_local_name(context, &local_var);
40            name.map(|name| (name.clone(), local_var))
41        }
42        _ => None,
43    }
44}
45
46fn is_promotable_type(context: &Context, ty: Type) -> bool {
47    ty.is_unit(context)
48        || ty.is_bool(context)
49        || (ty.is_uint(context) && ty.get_uint_width(context).unwrap() <= 64)
50}
51
52// Returns those locals that can be promoted to SSA registers.
53fn filter_usable_locals(context: &mut Context, function: &Function) -> HashSet<String> {
54    // The size of an SSA register is target specific.  Here we're going to just stick with atomic
55    // types which can fit in 64-bits.
56    let mut locals: HashSet<String> = function
57        .locals_iter(context)
58        .filter_map(|(name, var)| {
59            let ty = var.get_inner_type(context);
60            is_promotable_type(context, ty).then_some(name.clone())
61        })
62        .collect();
63
64    for (_, inst) in function.instruction_iter(context) {
65        match context.values[inst.0].value {
66            ValueDatum::Instruction(Instruction {
67                op: InstOp::Load(_),
68                ..
69            })
70            | ValueDatum::Instruction(Instruction {
71                op: InstOp::Store { .. },
72                ..
73            }) => {
74                // We understand load and store, so no problem.
75            }
76            _ => {
77                // Make sure that no local escapes into instructions we don't understand.
78                let operands = inst.get_instruction(context).unwrap().op.get_operands();
79                for opd in operands {
80                    if let Some((local, ..)) = get_validate_local_var(context, function, &opd) {
81                        locals.remove(&local);
82                    }
83                }
84            }
85        }
86    }
87    locals
88}
89
90// For each block, compute the set of locals that are live-in.
91// TODO: Use rustc_index::bit_set::ChunkedBitSet by mapping local names to indices.
92//       This will allow more efficient set operations.
93pub fn compute_livein(
94    context: &mut Context,
95    function: &Function,
96    po: &PostOrder,
97    locals: &HashSet<String>,
98) -> FxHashMap<Block, HashSet<String>> {
99    let mut result = FxHashMap::<Block, HashSet<String>>::default();
100    for block in &po.po_to_block {
101        result.insert(*block, HashSet::<String>::default());
102    }
103
104    let mut changed = true;
105    while changed {
106        changed = false;
107        for block in &po.po_to_block {
108            // we begin by unioning the liveins at successor blocks.
109            let mut cur_live = HashSet::<String>::default();
110            for BranchToWithArgs { block: succ, .. } in block.successors(context) {
111                let succ_livein = &result[&succ];
112                cur_live.extend(succ_livein.iter().cloned());
113            }
114            // Scan the instructions, in reverse.
115            for inst in block.instruction_iter(context).rev() {
116                match context.values[inst.0].value {
117                    ValueDatum::Instruction(Instruction {
118                        op: InstOp::Load(ptr),
119                        ..
120                    }) => {
121                        let local_var = get_validate_local_var(context, function, &ptr);
122                        match local_var {
123                            Some((local, ..)) if locals.contains(&local) => {
124                                cur_live.insert(local);
125                            }
126                            _ => {}
127                        }
128                    }
129                    ValueDatum::Instruction(Instruction {
130                        op: InstOp::Store { dst_val_ptr, .. },
131                        ..
132                    }) => {
133                        let local_var = get_validate_local_var(context, function, &dst_val_ptr);
134                        match local_var {
135                            Some((local, _)) if locals.contains(&local) => {
136                                cur_live.remove(&local);
137                            }
138                            _ => (),
139                        }
140                    }
141                    _ => (),
142                }
143            }
144            if result[block] != cur_live {
145                // Whatever's live now, is the live-in for the block.
146                result.get_mut(block).unwrap().extend(cur_live);
147                changed = true;
148            }
149        }
150    }
151    result
152}
153
154/// Promote loads of globals constants to SSA registers
155/// We promote only non-mutable globals of copy types
156fn promote_globals(context: &mut Context, function: &Function) -> Result<bool, IrError> {
157    let mut replacements = FxHashMap::<Value, Constant>::default();
158    for (_, inst) in function.instruction_iter(context) {
159        if let ValueDatum::Instruction(Instruction {
160            op: InstOp::Load(ptr),
161            ..
162        }) = context.values[inst.0].value
163        {
164            if let ValueDatum::Instruction(Instruction {
165                op: InstOp::GetGlobal(global_var),
166                ..
167            }) = context.values[ptr.0].value
168            {
169                if !global_var.is_mutable(context)
170                    && is_promotable_type(context, global_var.get_inner_type(context))
171                {
172                    let constant = *global_var
173                        .get_initializer(context)
174                        .expect("`global_var` is not mutable so it must be initialized");
175                    replacements.insert(inst, constant);
176                }
177            }
178        }
179    }
180
181    if replacements.is_empty() {
182        return Ok(false);
183    }
184
185    let replacements = replacements
186        .into_iter()
187        .map(|(k, v)| (k, Value::new_constant(context, v)))
188        .collect::<FxHashMap<_, _>>();
189
190    function.replace_values(context, &replacements, None);
191
192    Ok(true)
193}
194
195/// Promote memory values that are accessed via load/store to SSA registers.
196pub fn promote_to_registers(
197    context: &mut Context,
198    analyses: &AnalysisResults,
199    function: Function,
200) -> Result<bool, IrError> {
201    let mut modified = false;
202    modified |= promote_globals(context, &function)?;
203    modified |= promote_locals(context, analyses, function)?;
204    Ok(modified)
205}
206
207/// Promote locals to registers. We promote only locals of copy types,
208/// whose every use is in a `get_local` without offsets, and the result of
209/// such a `get_local` is used only in a load or a store.
210pub fn promote_locals(
211    context: &mut Context,
212    analyses: &AnalysisResults,
213    function: Function,
214) -> Result<bool, IrError> {
215    let safe_locals = filter_usable_locals(context, &function);
216    if safe_locals.is_empty() {
217        return Ok(false);
218    }
219
220    let po: &PostOrder = analyses.get_analysis_result(function);
221    let dom_tree: &DomTree = analyses.get_analysis_result(function);
222    let dom_fronts: &DomFronts = analyses.get_analysis_result(function);
223    let liveins = compute_livein(context, &function, po, &safe_locals);
224
225    // A list of the PHIs we insert in this transform.
226    let mut new_phi_tracker = HashSet::<(String, Block)>::new();
227    // A map from newly inserted block args to the Local that it's a PHI for.
228    let mut worklist = Vec::<(String, Type, Block)>::new();
229    let mut phi_to_local = FxHashMap::<Value, String>::default();
230    // Insert PHIs for each definition (store) at its dominance frontiers.
231    // Start by adding the existing definitions (stores) to a worklist,
232    // in program order (reverse post order). This is for faster convergence (or maybe not).
233    for (block, inst) in po
234        .po_to_block
235        .iter()
236        .rev()
237        .flat_map(|b| b.instruction_iter(context).map(|i| (*b, i)))
238    {
239        if let ValueDatum::Instruction(Instruction {
240            op: InstOp::Store { dst_val_ptr, .. },
241            ..
242        }) = context.values[inst.0].value
243        {
244            match get_validate_local_var(context, &function, &dst_val_ptr) {
245                Some((local, var)) if safe_locals.contains(&local) => {
246                    worklist.push((local, var.get_inner_type(context), block));
247                }
248                _ => (),
249            }
250        }
251    }
252    // Transitively add PHIs, till nothing more to do.
253    while let Some((local, ty, known_def)) = worklist.pop() {
254        for df in dom_fronts[&known_def].iter() {
255            if !new_phi_tracker.contains(&(local.clone(), *df)) && liveins[df].contains(&local) {
256                // Insert PHI for this local at block df.
257                let index = df.new_arg(context, ty);
258                phi_to_local.insert(df.get_arg(context, index).unwrap(), local.clone());
259                new_phi_tracker.insert((local.clone(), *df));
260                // Add df to the worklist.
261                worklist.push((local.clone(), ty, *df));
262            }
263        }
264    }
265
266    // We're just left with rewriting the loads and stores into SSA.
267    // For efficiency, we first collect the rewrites
268    // and then apply them all together in the next step.
269    #[allow(clippy::too_many_arguments)]
270    fn record_rewrites(
271        context: &mut Context,
272        function: &Function,
273        dom_tree: &DomTree,
274        node: Block,
275        safe_locals: &HashSet<String>,
276        phi_to_local: &FxHashMap<Value, String>,
277        name_stack: &mut MappedStack<String, Value>,
278        rewrites: &mut FxHashMap<Value, Value>,
279        deletes: &mut Vec<(Block, Value)>,
280    ) {
281        // Whatever new definitions we find in this block, they must be popped
282        // when we're done. So let's keep track of that locally as a count.
283        let mut num_local_pushes = IndexMap::<String, u32>::new();
284
285        // Start with relevant block args, they are new definitions.
286        for arg in node.arg_iter(context) {
287            if let Some(local) = phi_to_local.get(arg) {
288                name_stack.push(local.clone(), *arg);
289                num_local_pushes
290                    .entry(local.clone())
291                    .and_modify(|count| *count += 1)
292                    .or_insert(1);
293            }
294        }
295
296        for inst in node.instruction_iter(context) {
297            match context.values[inst.0].value {
298                ValueDatum::Instruction(Instruction {
299                    op: InstOp::Load(ptr),
300                    ..
301                }) => {
302                    let local_var = get_validate_local_var(context, function, &ptr);
303                    match local_var {
304                        Some((local, var)) if safe_locals.contains(&local) => {
305                            // We should replace all uses of inst with new_stack[local].
306                            let new_val = match name_stack.get(&local) {
307                                Some(val) => *val,
308                                None => {
309                                    // Nothing on the stack, let's attempt to get the initializer
310                                    let constant = *var
311                                        .get_initializer(context)
312                                        .expect("We're dealing with an uninitialized value");
313                                    Value::new_constant(context, constant)
314                                }
315                            };
316                            rewrites.insert(inst, new_val);
317                            deletes.push((node, inst));
318                        }
319                        _ => (),
320                    }
321                }
322                ValueDatum::Instruction(Instruction {
323                    op:
324                        InstOp::Store {
325                            dst_val_ptr,
326                            stored_val,
327                        },
328                    ..
329                }) => {
330                    let local_var = get_validate_local_var(context, function, &dst_val_ptr);
331                    match local_var {
332                        Some((local, _)) if safe_locals.contains(&local) => {
333                            // Henceforth, everything that's dominated by this inst must use stored_val
334                            // instead of loading from dst_val.
335                            name_stack.push(local.clone(), stored_val);
336                            num_local_pushes
337                                .entry(local)
338                                .and_modify(|count| *count += 1)
339                                .or_insert(1);
340                            deletes.push((node, inst));
341                        }
342                        _ => (),
343                    }
344                }
345                _ => (),
346            }
347        }
348
349        // Update arguments to successor blocks (i.e., PHI args).
350        for BranchToWithArgs { block: succ, .. } in node.successors(context) {
351            let args: Vec<_> = succ.arg_iter(context).copied().collect();
352            // For every arg of succ, if it's in phi_to_local,
353            // we pass, as arg, the top value of local
354            for arg in args {
355                if let Some(local) = phi_to_local.get(&arg) {
356                    let ptr = function.get_local_var(context, local).unwrap();
357                    let new_val = match name_stack.get(local) {
358                        Some(val) => *val,
359                        None => {
360                            // Nothing on the stack, let's attempt to get the initializer
361                            let constant = *ptr
362                                .get_initializer(context)
363                                .expect("We're dealing with an uninitialized value");
364                            Value::new_constant(context, constant)
365                        }
366                    };
367                    let params = node.get_succ_params_mut(context, &succ).unwrap();
368                    params.push(new_val);
369                }
370            }
371        }
372
373        // Process dominator children.
374        for child in dom_tree.children(node) {
375            record_rewrites(
376                context,
377                function,
378                dom_tree,
379                child,
380                safe_locals,
381                phi_to_local,
382                name_stack,
383                rewrites,
384                deletes,
385            );
386        }
387
388        // Pop from the names stack.
389        for (local, pushes) in num_local_pushes.iter() {
390            for _ in 0..*pushes {
391                name_stack.pop(local);
392            }
393        }
394    }
395
396    let mut name_stack = MappedStack::<String, Value>::default();
397    let mut value_replacement = FxHashMap::<Value, Value>::default();
398    let mut delete_insts = Vec::<(Block, Value)>::new();
399    record_rewrites(
400        context,
401        &function,
402        dom_tree,
403        function.get_entry_block(context),
404        &safe_locals,
405        &phi_to_local,
406        &mut name_stack,
407        &mut value_replacement,
408        &mut delete_insts,
409    );
410
411    // Apply the rewrites.
412    function.replace_values(context, &value_replacement, None);
413    // Delete the loads and stores.
414    for (block, inst) in delete_insts {
415        block.remove_instruction(context, inst);
416    }
417
418    Ok(true)
419}