sway_ir/optimize/
dce.rs

1//! ## Dead Code Elimination
2//!
3//! This optimization removes unused definitions. The pass is a combination of:
4//!   1. A liveness analysis that keeps track of the uses of a definition,
5//!   2. At the time of inspecting a definition, if it has no uses, it is removed.
6//!
7//! This pass does not do CFG transformations. That is handled by `simplify_cfg`.
8
9use itertools::Itertools;
10use rustc_hash::FxHashSet;
11
12use crate::{
13    get_gep_referred_symbols, get_referred_symbols, memory_utils, AnalysisResults, Context,
14    EscapedSymbols, Function, GlobalVar, InstOp, Instruction, IrError, LocalVar, Module, Pass,
15    PassMutability, ReferredSymbols, ScopedPass, Symbol, Value, ValueDatum, ESCAPED_SYMBOLS_NAME,
16};
17
18use std::collections::{HashMap, HashSet};
19
20pub const DCE_NAME: &str = "dce";
21
22pub fn create_dce_pass() -> Pass {
23    Pass {
24        name: DCE_NAME,
25        descr: "Dead code elimination",
26        runner: ScopedPass::FunctionPass(PassMutability::Transform(dce)),
27        deps: vec![ESCAPED_SYMBOLS_NAME],
28    }
29}
30
31pub const GLOBALS_DCE_NAME: &str = "globals-dce";
32
33pub fn create_globals_dce_pass() -> Pass {
34    Pass {
35        name: GLOBALS_DCE_NAME,
36        descr: "Dead globals (functions and variables) elimination",
37        deps: vec![],
38        runner: ScopedPass::ModulePass(PassMutability::Transform(globals_dce)),
39    }
40}
41
42fn can_eliminate_value(
43    context: &Context,
44    val: Value,
45    num_symbol_loaded: &NumSymbolLoaded,
46    escaped_symbols: &EscapedSymbols,
47) -> bool {
48    let Some(inst) = val.get_instruction(context) else {
49        return true;
50    };
51
52    (!inst.op.is_terminator() && !inst.op.may_have_side_effect())
53        || is_removable_store(context, val, num_symbol_loaded, escaped_symbols)
54}
55
56fn is_removable_store(
57    context: &Context,
58    val: Value,
59    num_symbol_loaded: &NumSymbolLoaded,
60    escaped_symbols: &EscapedSymbols,
61) -> bool {
62    let escaped_symbols = match escaped_symbols {
63        EscapedSymbols::Complete(syms) => syms,
64        EscapedSymbols::Incomplete(_) => return false,
65    };
66
67    let num_symbol_loaded = match num_symbol_loaded {
68        NumSymbolLoaded::Unknown => return false,
69        NumSymbolLoaded::Known(known_num_symbol_loaded) => known_num_symbol_loaded,
70    };
71
72    match val.get_instruction(context).unwrap().op {
73        InstOp::MemCopyBytes { dst_val_ptr, .. }
74        | InstOp::MemCopyVal { dst_val_ptr, .. }
75        | InstOp::Store { dst_val_ptr, .. } => {
76            let syms = get_referred_symbols(context, dst_val_ptr);
77            match syms {
78                ReferredSymbols::Complete(syms) => syms.iter().all(|sym| {
79                    !escaped_symbols.contains(sym)
80                        && num_symbol_loaded.get(sym).map_or(0, |uses| *uses) == 0
81                }),
82                // We cannot guarantee that the destination is not used.
83                ReferredSymbols::Incomplete(_) => false,
84            }
85        }
86        _ => false,
87    }
88}
89
90/// How many times a [Symbol] gets loaded from, directly or indirectly.
91/// This number is either exactly `Known` for all the symbols loaded from, or is
92/// considered to be `Unknown` for all the symbols.
93enum NumSymbolLoaded {
94    Unknown,
95    Known(HashMap<Symbol, u32>),
96}
97
98/// Instructions that store to a [Symbol], directly or indirectly.
99/// These instructions are either exactly `Known` for all the symbols stored to, or is
100/// considered to be `Unknown` for all the symbols.
101enum StoresOfSymbol {
102    Unknown,
103    Known(HashMap<Symbol, Vec<Value>>),
104}
105
106fn get_operands(value: Value, context: &Context) -> Vec<Value> {
107    if let Some(inst) = value.get_instruction(context) {
108        inst.op.get_operands()
109    } else if let Some(arg) = value.get_argument(context) {
110        arg.block
111            .pred_iter(context)
112            .map(|pred| {
113                arg.get_val_coming_from(context, pred)
114                    .expect("Block arg doesn't have value passed from predecessor")
115            })
116            .collect()
117    } else {
118        vec![]
119    }
120}
121
122/// Perform dead code (if any) elimination and return true if the `function` is modified.
123pub fn dce(
124    context: &mut Context,
125    analyses: &AnalysisResults,
126    function: Function,
127) -> Result<bool, IrError> {
128    // For DCE, we need to proceed with the analysis even if we have
129    // incomplete list of escaped symbols, because we could have
130    // unused instructions in code. Removing unused instructions is
131    // independent of having any escaping symbols.
132    let escaped_symbols: &EscapedSymbols = analyses.get_analysis_result(function);
133
134    // Number of uses that an instruction / block arg has. This number is always known.
135    let mut num_ssa_uses: HashMap<Value, u32> = HashMap::new();
136    // Number of times a local is accessed via `get_local`. This number is always known.
137    let mut num_local_uses: HashMap<LocalVar, u32> = HashMap::new();
138    // Number of times a symbol, local or a function argument, is loaded, directly or indirectly. This number can be unknown.
139    let mut num_symbol_loaded: NumSymbolLoaded = NumSymbolLoaded::Known(HashMap::new());
140    // Instructions that store to a symbol, directly or indirectly. This information can be unknown.
141    let mut stores_of_sym: StoresOfSymbol = StoresOfSymbol::Known(HashMap::new());
142
143    // TODO-IG: Update this logic once `mut arg: T`s are implemented.
144    //          Currently, only `ref mut arg` arguments can be stored to,
145    //          which means they can be loaded from the caller.
146    //          Once we support `mut arg` in general, this will not be
147    //          the case anymore and we will need to distinguish between
148    //          `mut arg: T`, `arg: &mut T`, etc.
149    // Every argument is assumed to be loaded from (from the caller),
150    // so stores to it shouldn't be eliminated.
151    if let NumSymbolLoaded::Known(known_num_symbol_loaded) = &mut num_symbol_loaded {
152        for sym in function
153            .args_iter(context)
154            .flat_map(|arg| get_gep_referred_symbols(context, arg.1))
155        {
156            known_num_symbol_loaded
157                .entry(sym)
158                .and_modify(|count| *count += 1)
159                .or_insert(1);
160        }
161    }
162
163    // Go through each instruction and update use counters.
164    for (_block, inst) in function.instruction_iter(context) {
165        if let NumSymbolLoaded::Known(known_num_symbol_loaded) = &mut num_symbol_loaded {
166            match memory_utils::get_loaded_symbols(context, inst) {
167                ReferredSymbols::Complete(loaded_symbols) => {
168                    for sym in loaded_symbols {
169                        known_num_symbol_loaded
170                            .entry(sym)
171                            .and_modify(|count| *count += 1)
172                            .or_insert(1);
173                    }
174                }
175                ReferredSymbols::Incomplete(_) => num_symbol_loaded = NumSymbolLoaded::Unknown,
176            }
177        }
178
179        if let StoresOfSymbol::Known(known_stores_of_sym) = &mut stores_of_sym {
180            match memory_utils::get_stored_symbols(context, inst) {
181                ReferredSymbols::Complete(stored_symbols) => {
182                    for stored_sym in stored_symbols {
183                        known_stores_of_sym
184                            .entry(stored_sym)
185                            .and_modify(|stores| stores.push(inst))
186                            .or_insert(vec![inst]);
187                    }
188                }
189                ReferredSymbols::Incomplete(_) => stores_of_sym = StoresOfSymbol::Unknown,
190            }
191        }
192
193        // A local is used if it is accessed via `get_local`.
194        let inst = inst.get_instruction(context).unwrap();
195        if let InstOp::GetLocal(local) = inst.op {
196            num_local_uses
197                .entry(local)
198                .and_modify(|count| *count += 1)
199                .or_insert(1);
200        }
201
202        // An instruction or block-arg is used if it is an operand in another instruction.
203        let opds = inst.op.get_operands();
204        for opd in opds {
205            match context.values[opd.0].value {
206                ValueDatum::Instruction(_) | ValueDatum::Argument(_) => {
207                    num_ssa_uses
208                        .entry(opd)
209                        .and_modify(|count| *count += 1)
210                        .or_insert(1);
211                }
212                ValueDatum::Constant(_) => {}
213            }
214        }
215    }
216
217    // The list of all unused or `Store` instruction. Note that the `Store` instruction does
218    // not result in a value, and will, thus, always be treated as unused and will not
219    // have an entry in `num_inst_uses`. So, to collect unused or `Store` instructions it
220    // is sufficient to filter those that are not used.
221    let mut worklist = function
222        .instruction_iter(context)
223        .filter_map(|(_, inst)| (!num_ssa_uses.contains_key(&inst)).then_some(inst))
224        .collect::<Vec<_>>();
225    let dead_args = function
226        .block_iter(context)
227        .flat_map(|block| {
228            block
229                .arg_iter(context)
230                .filter_map(|arg| (!num_ssa_uses.contains_key(arg)).then_some(*arg))
231                .collect_vec()
232        })
233        .collect_vec();
234    worklist.extend(dead_args);
235
236    let mut modified = false;
237    let mut cemetery = FxHashSet::default();
238    while let Some(dead) = worklist.pop() {
239        if !can_eliminate_value(context, dead, &num_symbol_loaded, escaped_symbols)
240            || cemetery.contains(&dead)
241        {
242            continue;
243        }
244        // Process dead's operands.
245        let opds = get_operands(dead, context);
246        for opd in opds {
247            // Reduce the use count of the operand used in the dead instruction.
248            // If it reaches 0, add it to the worklist, since it is not used
249            // anywhere else.
250            match context.values[opd.0].value {
251                ValueDatum::Instruction(_) | ValueDatum::Argument(_) => {
252                    let nu = num_ssa_uses.get_mut(&opd).unwrap();
253                    *nu -= 1;
254                    if *nu == 0 {
255                        worklist.push(opd);
256                    }
257                }
258                ValueDatum::Constant(_) => {}
259            }
260        }
261
262        if dead.get_instruction(context).is_some() {
263            // If the `dead` instruction was the only instruction loading from a `sym`bol,
264            // after removing it, there will be no loads anymore, so all the stores to
265            // that `sym`bol can be added to the worklist.
266            if let ReferredSymbols::Complete(loaded_symbols) =
267                memory_utils::get_loaded_symbols(context, dead)
268            {
269                if let (
270                    NumSymbolLoaded::Known(known_num_symbol_loaded),
271                    StoresOfSymbol::Known(known_stores_of_sym),
272                ) = (&mut num_symbol_loaded, &mut stores_of_sym)
273                {
274                    for sym in loaded_symbols {
275                        let nu = known_num_symbol_loaded.get_mut(&sym).unwrap();
276                        *nu -= 1;
277                        if *nu == 0 {
278                            for store in known_stores_of_sym.get(&sym).unwrap_or(&vec![]) {
279                                worklist.push(*store);
280                            }
281                        }
282                    }
283                }
284            }
285        }
286
287        cemetery.insert(dead);
288
289        if let ValueDatum::Instruction(Instruction {
290            op: InstOp::GetLocal(local),
291            ..
292        }) = context.values[dead.0].value
293        {
294            let count = num_local_uses.get_mut(&local).unwrap();
295            *count -= 1;
296        }
297
298        modified = true;
299    }
300
301    // Remove all dead instructions and arguments.
302    // We collect here and below because we want &mut Context for modifications.
303    for block in function.block_iter(context).collect_vec() {
304        if block != function.get_entry_block(context) {
305            // dead_args[arg_idx] indicates whether the argument is dead.
306            let dead_args = block
307                .arg_iter(context)
308                .map(|arg| cemetery.contains(arg))
309                .collect_vec();
310            for pred in block.pred_iter(context).cloned().collect_vec() {
311                let params = pred
312                    .get_succ_params_mut(context, &block)
313                    .expect("Invalid IR");
314                let mut index = 0;
315                // Remove parameters passed to a dead argument.
316                params.retain(|_| {
317                    let retain = !dead_args[index];
318                    index += 1;
319                    retain
320                });
321            }
322            // Remove the dead argument itself.
323            let mut index = 0;
324            context.blocks[block.0].args.retain(|_| {
325                let retain = !dead_args[index];
326                index += 1;
327                retain
328            });
329            // Update the self-index stored in each arg.
330            for (arg_idx, arg) in block.arg_iter(context).cloned().enumerate().collect_vec() {
331                let arg = arg.get_argument_mut(context).unwrap();
332                arg.idx = arg_idx;
333            }
334        }
335        block.remove_instructions(context, |inst| cemetery.contains(&inst));
336    }
337
338    let local_removals: Vec<_> = function
339        .locals_iter(context)
340        .filter_map(|(name, local)| {
341            (num_local_uses.get(local).cloned().unwrap_or(0) == 0).then_some(name.clone())
342        })
343        .collect();
344    if !local_removals.is_empty() {
345        modified = true;
346        function.remove_locals(context, &local_removals);
347    }
348
349    Ok(modified)
350}
351
352/// Remove entire functions and globals from a module based on whether they are called / used or not,
353/// using a list of root 'entry' functions to perform a search.
354///
355/// Functions which are `pub` will not be removed and only functions within the passed [`Module`]
356/// are considered for removal.
357pub fn globals_dce(
358    context: &mut Context,
359    _: &AnalysisResults,
360    module: Module,
361) -> Result<bool, IrError> {
362    let mut called_fns: HashSet<Function> = HashSet::new();
363    let mut used_globals: HashSet<GlobalVar> = HashSet::new();
364
365    // config decode fns
366    for config in context.modules[module.0].configs.iter() {
367        if let crate::ConfigContent::V1 { decode_fn, .. } = config.1 {
368            grow_called_function_used_globals_set(
369                context,
370                decode_fn.get(),
371                &mut called_fns,
372                &mut used_globals,
373            );
374        }
375    }
376
377    // expand all called fns
378    for entry_fn in module
379        .function_iter(context)
380        .filter(|func| func.is_entry(context) || func.is_fallback(context))
381    {
382        grow_called_function_used_globals_set(
383            context,
384            entry_fn,
385            &mut called_fns,
386            &mut used_globals,
387        );
388    }
389
390    let mut modified = false;
391
392    // Remove dead globals
393    let m = &mut context.modules[module.0];
394    let cur_num_globals = m.global_variables.len();
395    m.global_variables.retain(|_, g| used_globals.contains(g));
396    modified |= cur_num_globals != m.global_variables.len();
397
398    // Gather the functions in the module which aren't called.  It's better to collect them
399    // separately first so as to avoid any issues with invalidating the function iterator.
400    let dead_fns = module
401        .function_iter(context)
402        .filter(|f| !called_fns.contains(f))
403        .collect::<Vec<_>>();
404    for dead_fn in &dead_fns {
405        module.remove_function(context, dead_fn);
406    }
407
408    modified |= !dead_fns.is_empty();
409
410    Ok(modified)
411}
412
413// Recursively find all the functions called by an entry function.
414fn grow_called_function_used_globals_set(
415    context: &Context,
416    caller: Function,
417    called_set: &mut HashSet<Function>,
418    used_globals: &mut HashSet<GlobalVar>,
419) {
420    if called_set.insert(caller) {
421        // We haven't seen caller before.  Iterate for all that it calls.
422        let mut callees = HashSet::new();
423        for (_block, value) in caller.instruction_iter(context) {
424            let inst = value.get_instruction(context).unwrap();
425            match &inst.op {
426                InstOp::Call(f, _args) => {
427                    callees.insert(*f);
428                }
429                InstOp::GetGlobal(g) => {
430                    used_globals.insert(*g);
431                }
432                _otherwise => (),
433            }
434        }
435        callees.into_iter().for_each(|func| {
436            grow_called_function_used_globals_set(context, func, called_set, used_globals);
437        });
438    }
439}