sway_ir/optimize/
simplify_cfg.rs

1//! ## Simplify Control Flow Graph
2//!
3//! The optimizations here aim to reduce the complexity in control flow by removing basic blocks.
4//! This may be done by removing 'dead' blocks which are no longer called (or in other words, have
5//! no predecessors) or by merging blocks which are linked by a single unconditional branch.
6//!
7//! Removing blocks will make the IR neater and more efficient but will also remove indirection of
8//! data flow via PHI instructions which in turn can make analyses for passes like constant folding
9//! much simpler.
10
11use rustc_hash::{FxHashMap, FxHashSet};
12
13use crate::{
14    block::Block, context::Context, error::IrError, function::Function, instruction::InstOp,
15    value::ValueDatum, AnalysisResults, BranchToWithArgs, Instruction, InstructionInserter, Pass,
16    PassMutability, ScopedPass, Value,
17};
18
19pub const SIMPLIFY_CFG_NAME: &str = "simplify-cfg";
20
21pub fn create_simplify_cfg_pass() -> Pass {
22    Pass {
23        name: SIMPLIFY_CFG_NAME,
24        descr: "Simplify the control flow graph (CFG)",
25        deps: vec![],
26        runner: ScopedPass::FunctionPass(PassMutability::Transform(simplify_cfg)),
27    }
28}
29
30pub fn simplify_cfg(
31    context: &mut Context,
32    _: &AnalysisResults,
33    function: Function,
34) -> Result<bool, IrError> {
35    let mut modified = false;
36    modified |= remove_dead_blocks(context, &function)?;
37    modified |= merge_blocks(context, &function)?;
38    modified |= unlink_empty_blocks(context, &function)?;
39    modified |= remove_dead_blocks(context, &function)?;
40    Ok(modified)
41}
42
43fn unlink_empty_blocks(context: &mut Context, function: &Function) -> Result<bool, IrError> {
44    let mut modified = false;
45    let candidates: Vec<_> = function
46        .block_iter(context)
47        .skip(1)
48        .filter_map(|block| {
49            match block.get_terminator(context) {
50                // Except for a branch, we don't want anything else.
51                // If the block has PHI nodes, then values merge here. Cannot remove the block.
52                Some(Instruction {
53                    op: InstOp::Branch(to_block),
54                    ..
55                }) if block.num_instructions(context) <= 1 && block.num_args(context) == 0 => {
56                    Some((block, to_block.clone()))
57                }
58                _ => None,
59            }
60        })
61        .collect();
62    for (
63        block,
64        BranchToWithArgs {
65            block: to_block,
66            args: cur_params,
67        },
68    ) in candidates
69    {
70        // If `to_block`'s predecessors and `block`'s predecessors intersect,
71        // AND `to_block` has an arg, then we have that pred branching to to_block
72        // with different args. While that's valid IR, it's harder to generate
73        // ASM for it, so let's just skip that for now.
74        if to_block.num_args(context) > 0
75            && to_block.pred_iter(context).any(|to_block_pred| {
76                block
77                    .pred_iter(context)
78                    .any(|block_pred| block_pred == to_block_pred)
79            })
80        {
81            // We cannot filter this out in candidates itself because this condition
82            // may get updated *during* this optimization (i.e., inside this loop).
83            continue;
84        }
85        let preds: Vec<_> = block.pred_iter(context).copied().collect();
86        for pred in preds {
87            // Whatever parameters "block" passed to "to_block", that
88            // should now go from "pred" to "to_block".
89            let params_from_pred = pred.get_succ_params(context, &block);
90            let new_params = cur_params
91                .iter()
92                .map(|cur_param| match &context.values[cur_param.0].value {
93                    ValueDatum::Argument(arg) if arg.block == block => {
94                        // An argument should map to the actual parameter passed.
95                        params_from_pred[arg.idx]
96                    }
97                    _ => *cur_param,
98                })
99                .collect();
100
101            pred.replace_successor(context, block, to_block, new_params);
102            modified = true;
103        }
104    }
105    Ok(modified)
106}
107
108fn remove_dead_blocks(context: &mut Context, function: &Function) -> Result<bool, IrError> {
109    let mut worklist = Vec::<Block>::new();
110    let mut reachable = std::collections::HashSet::<Block>::new();
111
112    // The entry is always reachable. Let's begin with that.
113    let entry_block = function.get_entry_block(context);
114    reachable.insert(entry_block);
115    worklist.push(entry_block);
116
117    // Mark reachable nodes.
118    while let Some(block) = worklist.pop() {
119        let succs = block.successors(context);
120        for BranchToWithArgs { block: succ, .. } in succs {
121            // If this isn't already marked reachable, we mark it and add to the worklist.
122            if !reachable.contains(&succ) {
123                reachable.insert(succ);
124                worklist.push(succ);
125            }
126        }
127    }
128
129    // Delete all unreachable nodes.
130    let mut modified = false;
131    for block in function.block_iter(context) {
132        if !reachable.contains(&block) {
133            modified = true;
134
135            for BranchToWithArgs { block: succ, .. } in block.successors(context) {
136                succ.remove_pred(context, &block);
137            }
138
139            function.remove_block(context, &block)?;
140        }
141    }
142
143    Ok(modified)
144}
145
146fn merge_blocks(context: &mut Context, function: &Function) -> Result<bool, IrError> {
147    // Check if block branches solely to another block B, and that B has exactly one predecessor.
148    fn check_candidate(context: &Context, from_block: Block) -> Option<(Block, Block)> {
149        from_block
150            .get_terminator(context)
151            .and_then(|term| match term {
152                Instruction {
153                    op:
154                        InstOp::Branch(BranchToWithArgs {
155                            block: to_block, ..
156                        }),
157                    ..
158                } if to_block.num_predecessors(context) == 1 => Some((from_block, *to_block)),
159                _ => None,
160            })
161    }
162
163    let blocks: Vec<_> = function.block_iter(context).collect();
164    let mut deleted_blocks = FxHashSet::<Block>::default();
165    let mut replace_map: FxHashMap<Value, Value> = FxHashMap::default();
166    let mut modified = false;
167
168    for from_block in blocks {
169        if deleted_blocks.contains(&from_block) {
170            continue;
171        }
172
173        // Find a block with an unconditional branch terminator which branches to a block with that
174        // single predecessor.
175        let twin_blocks = check_candidate(context, from_block);
176
177        // If not found then abort here.
178        let mut block_chain = match twin_blocks {
179            Some((from_block, to_block)) => vec![from_block, to_block],
180            None => continue,
181        };
182
183        // There may be more blocks which are also singly paired with these twins, so iteratively
184        // search for more blocks in a chain which can be all merged into one.
185        loop {
186            match check_candidate(context, block_chain.last().copied().unwrap()) {
187                None => {
188                    // There is no twin for this block.
189                    break;
190                }
191                Some(next_pair) => {
192                    block_chain.push(next_pair.1);
193                }
194            }
195        }
196
197        // Keep a copy of the final block in the chain so we can adjust the successors below.
198        let final_to_block = block_chain.last().copied().unwrap();
199        let final_to_block_succs = final_to_block.successors(context);
200
201        // The first block in the chain will be extended with the contents of the rest of the blocks in
202        // the chain, which we'll call `from_block` since we're branching from here to the next one.
203        let mut block_chain = block_chain.into_iter();
204        let from_block = block_chain.next().unwrap();
205
206        // Loop for the rest of the chain, to all the `to_block`s.
207        for to_block in block_chain {
208            let from_params = from_block.get_succ_params(context, &to_block);
209            // We collect here so that we can have &mut Context later on.
210            let to_blocks: Vec<_> = to_block.arg_iter(context).copied().enumerate().collect();
211            for (arg_idx, to_block_arg) in to_blocks {
212                // replace all uses of `to_block_arg` with the parameter from `from_block`.
213                replace_map.insert(to_block_arg, from_params[arg_idx]);
214            }
215
216            // Update the parent block field for every instruction
217            // in `to_block` to `from_block`.
218            for val in to_block.instruction_iter(context) {
219                let instr = val.get_instruction_mut(context).unwrap();
220                instr.parent = from_block;
221            }
222
223            // Drop the terminator from `from_block`.
224            from_block.remove_last_instruction(context);
225
226            // Move instructions from `to_block` to `from_block`.
227            let to_block_instructions = to_block.instruction_iter(context).collect::<Vec<_>>();
228            let mut inserter =
229                InstructionInserter::new(context, from_block, crate::InsertionPosition::End);
230            inserter.insert_slice(&to_block_instructions);
231
232            // Remove `to_block`.
233            function.remove_block(context, &to_block)?;
234            deleted_blocks.insert(to_block);
235        }
236
237        // Adjust the successors to the final `to_block` to now be successors of the fully merged
238        // `from_block`.
239        for BranchToWithArgs { block: succ, .. } in final_to_block_succs {
240            succ.replace_pred(context, &final_to_block, &from_block)
241        }
242        modified = true;
243    }
244
245    if !replace_map.is_empty() {
246        assert!(modified);
247        function.replace_values(context, &replace_map, None);
248    }
249
250    Ok(modified)
251}