1use itertools::Itertools;
10use rustc_hash::FxHashSet;
11
12use crate::{
13 get_gep_referred_symbols, get_referred_symbols, memory_utils, AnalysisResults, Context,
14 EscapedSymbols, Function, GlobalVar, InstOp, Instruction, IrError, LocalVar, Module, Pass,
15 PassMutability, ReferredSymbols, ScopedPass, Symbol, Value, ValueDatum, ESCAPED_SYMBOLS_NAME,
16};
17
18use std::collections::{HashMap, HashSet};
19
20pub const DCE_NAME: &str = "dce";
21
22pub fn create_dce_pass() -> Pass {
23 Pass {
24 name: DCE_NAME,
25 descr: "Dead code elimination",
26 runner: ScopedPass::FunctionPass(PassMutability::Transform(dce)),
27 deps: vec![ESCAPED_SYMBOLS_NAME],
28 }
29}
30
31pub const GLOBALS_DCE_NAME: &str = "globals-dce";
32
33pub fn create_globals_dce_pass() -> Pass {
34 Pass {
35 name: GLOBALS_DCE_NAME,
36 descr: "Dead globals (functions and variables) elimination",
37 deps: vec![],
38 runner: ScopedPass::ModulePass(PassMutability::Transform(globals_dce)),
39 }
40}
41
42fn can_eliminate_value(
43 context: &Context,
44 val: Value,
45 num_symbol_loaded: &NumSymbolLoaded,
46 escaped_symbols: &EscapedSymbols,
47) -> bool {
48 let Some(inst) = val.get_instruction(context) else {
49 return true;
50 };
51
52 (!inst.op.is_terminator() && !inst.op.may_have_side_effect())
53 || is_removable_store(context, val, num_symbol_loaded, escaped_symbols)
54}
55
56fn is_removable_store(
57 context: &Context,
58 val: Value,
59 num_symbol_loaded: &NumSymbolLoaded,
60 escaped_symbols: &EscapedSymbols,
61) -> bool {
62 let escaped_symbols = match escaped_symbols {
63 EscapedSymbols::Complete(syms) => syms,
64 EscapedSymbols::Incomplete(_) => return false,
65 };
66
67 let num_symbol_loaded = match num_symbol_loaded {
68 NumSymbolLoaded::Unknown => return false,
69 NumSymbolLoaded::Known(known_num_symbol_loaded) => known_num_symbol_loaded,
70 };
71
72 match val.get_instruction(context).unwrap().op {
73 InstOp::MemCopyBytes { dst_val_ptr, .. }
74 | InstOp::MemCopyVal { dst_val_ptr, .. }
75 | InstOp::Store { dst_val_ptr, .. } => {
76 let syms = get_referred_symbols(context, dst_val_ptr);
77 match syms {
78 ReferredSymbols::Complete(syms) => syms.iter().all(|sym| {
79 !escaped_symbols.contains(sym)
80 && num_symbol_loaded.get(sym).map_or(0, |uses| *uses) == 0
81 }),
82 ReferredSymbols::Incomplete(_) => false,
84 }
85 }
86 _ => false,
87 }
88}
89
90enum NumSymbolLoaded {
94 Unknown,
95 Known(HashMap<Symbol, u32>),
96}
97
98enum StoresOfSymbol {
102 Unknown,
103 Known(HashMap<Symbol, Vec<Value>>),
104}
105
106fn get_operands(value: Value, context: &Context) -> Vec<Value> {
107 if let Some(inst) = value.get_instruction(context) {
108 inst.op.get_operands()
109 } else if let Some(arg) = value.get_argument(context) {
110 arg.block
111 .pred_iter(context)
112 .map(|pred| {
113 arg.get_val_coming_from(context, pred)
114 .expect("Block arg doesn't have value passed from predecessor")
115 })
116 .collect()
117 } else {
118 vec![]
119 }
120}
121
122pub fn dce(
124 context: &mut Context,
125 analyses: &AnalysisResults,
126 function: Function,
127) -> Result<bool, IrError> {
128 let escaped_symbols: &EscapedSymbols = analyses.get_analysis_result(function);
133
134 let mut num_ssa_uses: HashMap<Value, u32> = HashMap::new();
136 let mut num_local_uses: HashMap<LocalVar, u32> = HashMap::new();
138 let mut num_symbol_loaded: NumSymbolLoaded = NumSymbolLoaded::Known(HashMap::new());
140 let mut stores_of_sym: StoresOfSymbol = StoresOfSymbol::Known(HashMap::new());
142
143 if let NumSymbolLoaded::Known(known_num_symbol_loaded) = &mut num_symbol_loaded {
152 for sym in function
153 .args_iter(context)
154 .flat_map(|arg| get_gep_referred_symbols(context, arg.1))
155 {
156 known_num_symbol_loaded
157 .entry(sym)
158 .and_modify(|count| *count += 1)
159 .or_insert(1);
160 }
161 }
162
163 for (_block, inst) in function.instruction_iter(context) {
165 if let NumSymbolLoaded::Known(known_num_symbol_loaded) = &mut num_symbol_loaded {
166 match memory_utils::get_loaded_symbols(context, inst) {
167 ReferredSymbols::Complete(loaded_symbols) => {
168 for sym in loaded_symbols {
169 known_num_symbol_loaded
170 .entry(sym)
171 .and_modify(|count| *count += 1)
172 .or_insert(1);
173 }
174 }
175 ReferredSymbols::Incomplete(_) => num_symbol_loaded = NumSymbolLoaded::Unknown,
176 }
177 }
178
179 if let StoresOfSymbol::Known(known_stores_of_sym) = &mut stores_of_sym {
180 match memory_utils::get_stored_symbols(context, inst) {
181 ReferredSymbols::Complete(stored_symbols) => {
182 for stored_sym in stored_symbols {
183 known_stores_of_sym
184 .entry(stored_sym)
185 .and_modify(|stores| stores.push(inst))
186 .or_insert(vec![inst]);
187 }
188 }
189 ReferredSymbols::Incomplete(_) => stores_of_sym = StoresOfSymbol::Unknown,
190 }
191 }
192
193 let inst = inst.get_instruction(context).unwrap();
195 if let InstOp::GetLocal(local) = inst.op {
196 num_local_uses
197 .entry(local)
198 .and_modify(|count| *count += 1)
199 .or_insert(1);
200 }
201
202 let opds = inst.op.get_operands();
204 for opd in opds {
205 match context.values[opd.0].value {
206 ValueDatum::Instruction(_) | ValueDatum::Argument(_) => {
207 num_ssa_uses
208 .entry(opd)
209 .and_modify(|count| *count += 1)
210 .or_insert(1);
211 }
212 ValueDatum::Constant(_) => {}
213 }
214 }
215 }
216
217 let mut worklist = function
222 .instruction_iter(context)
223 .filter_map(|(_, inst)| (!num_ssa_uses.contains_key(&inst)).then_some(inst))
224 .collect::<Vec<_>>();
225 let dead_args = function
226 .block_iter(context)
227 .flat_map(|block| {
228 block
229 .arg_iter(context)
230 .filter_map(|arg| (!num_ssa_uses.contains_key(arg)).then_some(*arg))
231 .collect_vec()
232 })
233 .collect_vec();
234 worklist.extend(dead_args);
235
236 let mut modified = false;
237 let mut cemetery = FxHashSet::default();
238 while let Some(dead) = worklist.pop() {
239 if !can_eliminate_value(context, dead, &num_symbol_loaded, escaped_symbols)
240 || cemetery.contains(&dead)
241 {
242 continue;
243 }
244 let opds = get_operands(dead, context);
246 for opd in opds {
247 match context.values[opd.0].value {
251 ValueDatum::Instruction(_) | ValueDatum::Argument(_) => {
252 let nu = num_ssa_uses.get_mut(&opd).unwrap();
253 *nu -= 1;
254 if *nu == 0 {
255 worklist.push(opd);
256 }
257 }
258 ValueDatum::Constant(_) => {}
259 }
260 }
261
262 if dead.get_instruction(context).is_some() {
263 if let ReferredSymbols::Complete(loaded_symbols) =
267 memory_utils::get_loaded_symbols(context, dead)
268 {
269 if let (
270 NumSymbolLoaded::Known(known_num_symbol_loaded),
271 StoresOfSymbol::Known(known_stores_of_sym),
272 ) = (&mut num_symbol_loaded, &mut stores_of_sym)
273 {
274 for sym in loaded_symbols {
275 let nu = known_num_symbol_loaded.get_mut(&sym).unwrap();
276 *nu -= 1;
277 if *nu == 0 {
278 for store in known_stores_of_sym.get(&sym).unwrap_or(&vec![]) {
279 worklist.push(*store);
280 }
281 }
282 }
283 }
284 }
285 }
286
287 cemetery.insert(dead);
288
289 if let ValueDatum::Instruction(Instruction {
290 op: InstOp::GetLocal(local),
291 ..
292 }) = context.values[dead.0].value
293 {
294 let count = num_local_uses.get_mut(&local).unwrap();
295 *count -= 1;
296 }
297
298 modified = true;
299 }
300
301 for block in function.block_iter(context).collect_vec() {
304 if block != function.get_entry_block(context) {
305 let dead_args = block
307 .arg_iter(context)
308 .map(|arg| cemetery.contains(arg))
309 .collect_vec();
310 for pred in block.pred_iter(context).cloned().collect_vec() {
311 let params = pred
312 .get_succ_params_mut(context, &block)
313 .expect("Invalid IR");
314 let mut index = 0;
315 params.retain(|_| {
317 let retain = !dead_args[index];
318 index += 1;
319 retain
320 });
321 }
322 let mut index = 0;
324 context.blocks[block.0].args.retain(|_| {
325 let retain = !dead_args[index];
326 index += 1;
327 retain
328 });
329 for (arg_idx, arg) in block.arg_iter(context).cloned().enumerate().collect_vec() {
331 let arg = arg.get_argument_mut(context).unwrap();
332 arg.idx = arg_idx;
333 }
334 }
335 block.remove_instructions(context, |inst| cemetery.contains(&inst));
336 }
337
338 let local_removals: Vec<_> = function
339 .locals_iter(context)
340 .filter_map(|(name, local)| {
341 (num_local_uses.get(local).cloned().unwrap_or(0) == 0).then_some(name.clone())
342 })
343 .collect();
344 if !local_removals.is_empty() {
345 modified = true;
346 function.remove_locals(context, &local_removals);
347 }
348
349 Ok(modified)
350}
351
352pub fn globals_dce(
358 context: &mut Context,
359 _: &AnalysisResults,
360 module: Module,
361) -> Result<bool, IrError> {
362 let mut called_fns: HashSet<Function> = HashSet::new();
363 let mut used_globals: HashSet<GlobalVar> = HashSet::new();
364
365 for config in context.modules[module.0].configs.iter() {
367 if let crate::ConfigContent::V1 { decode_fn, .. } = config.1 {
368 grow_called_function_used_globals_set(
369 context,
370 decode_fn.get(),
371 &mut called_fns,
372 &mut used_globals,
373 );
374 }
375 }
376
377 for entry_fn in module
379 .function_iter(context)
380 .filter(|func| func.is_entry(context) || func.is_fallback(context))
381 {
382 grow_called_function_used_globals_set(
383 context,
384 entry_fn,
385 &mut called_fns,
386 &mut used_globals,
387 );
388 }
389
390 let mut modified = false;
391
392 let m = &mut context.modules[module.0];
394 let cur_num_globals = m.global_variables.len();
395 m.global_variables.retain(|_, g| used_globals.contains(g));
396 modified |= cur_num_globals != m.global_variables.len();
397
398 let dead_fns = module
401 .function_iter(context)
402 .filter(|f| !called_fns.contains(f))
403 .collect::<Vec<_>>();
404 for dead_fn in &dead_fns {
405 module.remove_function(context, dead_fn);
406 }
407
408 modified |= !dead_fns.is_empty();
409
410 Ok(modified)
411}
412
413fn grow_called_function_used_globals_set(
415 context: &Context,
416 caller: Function,
417 called_set: &mut HashSet<Function>,
418 used_globals: &mut HashSet<GlobalVar>,
419) {
420 if called_set.insert(caller) {
421 let mut callees = HashSet::new();
423 for (_block, value) in caller.instruction_iter(context) {
424 let inst = value.get_instruction(context).unwrap();
425 match &inst.op {
426 InstOp::Call(f, _args) => {
427 callees.insert(*f);
428 }
429 InstOp::GetGlobal(g) => {
430 used_globals.insert(*g);
431 }
432 _otherwise => (),
433 }
434 }
435 callees.into_iter().for_each(|func| {
436 grow_called_function_used_globals_set(context, func, called_set, used_globals);
437 });
438 }
439}