use rustc_hash::FxHashSet;
use crate::{
get_symbols, AnalysisResults, Context, EscapedSymbols, FuelVmInstruction, Function,
Instruction, IrError, LocalVar, Module, Pass, PassMutability, ScopedPass, Symbol, Value,
ValueDatum, ESCAPED_SYMBOLS_NAME,
};
use std::collections::{HashMap, HashSet};
pub const DCE_NAME: &str = "dce";
pub fn create_dce_pass() -> Pass {
Pass {
name: DCE_NAME,
descr: "Dead code elimination.",
runner: ScopedPass::FunctionPass(PassMutability::Transform(dce)),
deps: vec![ESCAPED_SYMBOLS_NAME],
}
}
pub const FUNC_DCE_NAME: &str = "func_dce";
pub fn create_func_dce_pass() -> Pass {
Pass {
name: FUNC_DCE_NAME,
descr: "Dead function elimination.",
deps: vec![],
runner: ScopedPass::ModulePass(PassMutability::Transform(func_dce)),
}
}
fn can_eliminate_instruction(
context: &Context,
val: Value,
num_symbol_uses: &HashMap<Symbol, u32>,
escaped_symbols: &EscapedSymbols,
) -> bool {
let inst = val.get_instruction(context).unwrap();
(!inst.is_terminator() && !inst.may_have_side_effect())
|| is_removable_store(context, val, num_symbol_uses, escaped_symbols)
}
fn get_loaded_symbols(context: &Context, val: Value) -> Vec<Symbol> {
match val.get_instruction(context).unwrap() {
Instruction::UnaryOp { .. }
| Instruction::BinaryOp { .. }
| Instruction::BitCast(_, _)
| Instruction::Branch(_)
| Instruction::ConditionalBranch { .. }
| Instruction::Cmp(_, _, _)
| Instruction::Nop
| Instruction::CastPtr(_, _)
| Instruction::GetLocal(_)
| Instruction::GetElemPtr { .. }
| Instruction::IntToPtr(_, _) => vec![],
Instruction::PtrToInt(src_val_ptr, _) => get_symbols(context, *src_val_ptr).to_vec(),
Instruction::ContractCall {
params,
coins,
asset_id,
..
} => [*params, *coins, *asset_id]
.iter()
.flat_map(|val| get_symbols(context, *val).to_vec())
.collect(),
Instruction::Call(_, args) => args
.iter()
.flat_map(|val| get_symbols(context, *val).to_vec())
.collect(),
Instruction::AsmBlock(_, args) => args
.iter()
.filter_map(|val| {
val.initializer
.map(|val| get_symbols(context, val).to_vec())
})
.flatten()
.collect(),
Instruction::MemCopyBytes { src_val_ptr, .. }
| Instruction::MemCopyVal { src_val_ptr, .. }
| Instruction::Ret(src_val_ptr, _)
| Instruction::Load(src_val_ptr)
| Instruction::FuelVm(FuelVmInstruction::Log {
log_val: src_val_ptr,
..
})
| Instruction::FuelVm(FuelVmInstruction::StateLoadWord(src_val_ptr))
| Instruction::FuelVm(FuelVmInstruction::StateStoreWord {
key: src_val_ptr, ..
})
| Instruction::FuelVm(FuelVmInstruction::StateLoadQuadWord {
key: src_val_ptr, ..
})
| Instruction::FuelVm(FuelVmInstruction::StateClear {
key: src_val_ptr, ..
}) => get_symbols(context, *src_val_ptr).to_vec(),
Instruction::FuelVm(FuelVmInstruction::StateStoreQuadWord {
stored_val: memopd1,
key: memopd2,
..
})
| Instruction::FuelVm(FuelVmInstruction::Smo {
recipient: memopd1,
message: memopd2,
..
}) => get_symbols(context, *memopd1)
.iter()
.cloned()
.chain(get_symbols(context, *memopd2).iter().cloned())
.collect(),
Instruction::Store { dst_val_ptr: _, .. } => vec![],
Instruction::FuelVm(FuelVmInstruction::Gtf { .. })
| Instruction::FuelVm(FuelVmInstruction::ReadRegister(_))
| Instruction::FuelVm(FuelVmInstruction::Revert(_)) => vec![],
Instruction::FuelVm(FuelVmInstruction::WideUnaryOp { arg, .. }) => {
get_symbols(context, *arg).to_vec()
}
Instruction::FuelVm(FuelVmInstruction::WideBinaryOp { arg1, arg2, .. })
| Instruction::FuelVm(FuelVmInstruction::WideCmpOp { arg1, arg2, .. }) => {
get_symbols(context, *arg1)
.iter()
.cloned()
.chain(get_symbols(context, *arg2).iter().cloned())
.collect()
}
Instruction::FuelVm(FuelVmInstruction::WideModularOp {
arg1, arg2, arg3, ..
}) => get_symbols(context, *arg1)
.iter()
.cloned()
.chain(get_symbols(context, *arg2).iter().cloned())
.chain(get_symbols(context, *arg3).iter().cloned())
.collect(),
}
}
fn get_stored_symbols(context: &Context, val: Value) -> Vec<Symbol> {
match val.get_instruction(context).unwrap() {
Instruction::UnaryOp { .. }
| Instruction::BinaryOp { .. }
| Instruction::BitCast(_, _)
| Instruction::Branch(_)
| Instruction::ConditionalBranch { .. }
| Instruction::Cmp(_, _, _)
| Instruction::Nop
| Instruction::PtrToInt(_, _)
| Instruction::Ret(_, _)
| Instruction::CastPtr(_, _)
| Instruction::GetLocal(_)
| Instruction::GetElemPtr { .. }
| Instruction::IntToPtr(_, _) => vec![],
Instruction::ContractCall { params, .. } => get_symbols(context, *params),
Instruction::Call(_, args) => args
.iter()
.flat_map(|val| get_symbols(context, *val).to_vec())
.collect(),
Instruction::AsmBlock(_, args) => args
.iter()
.filter_map(|val| {
val.initializer
.map(|val| get_symbols(context, val).to_vec())
})
.flatten()
.collect(),
Instruction::MemCopyBytes { dst_val_ptr, .. }
| Instruction::MemCopyVal { dst_val_ptr, .. }
| Instruction::Store { dst_val_ptr, .. } => get_symbols(context, *dst_val_ptr).to_vec(),
Instruction::Load(_) => vec![],
Instruction::FuelVm(vmop) => match vmop {
FuelVmInstruction::Gtf { .. }
| FuelVmInstruction::Log { .. }
| FuelVmInstruction::ReadRegister(_)
| FuelVmInstruction::Revert(_)
| FuelVmInstruction::Smo { .. }
| FuelVmInstruction::StateClear { .. } => vec![],
FuelVmInstruction::StateLoadQuadWord { load_val, .. } => {
get_symbols(context, *load_val).to_vec()
}
FuelVmInstruction::StateLoadWord(_) | FuelVmInstruction::StateStoreWord { .. } => {
vec![]
}
FuelVmInstruction::StateStoreQuadWord { stored_val: _, .. } => vec![],
FuelVmInstruction::WideUnaryOp { result, .. } => get_symbols(context, *result).to_vec(),
FuelVmInstruction::WideBinaryOp { result, .. } => {
get_symbols(context, *result).to_vec()
}
FuelVmInstruction::WideModularOp { result, .. } => {
get_symbols(context, *result).to_vec()
}
FuelVmInstruction::WideCmpOp { .. } => vec![],
},
}
}
fn is_removable_store(
context: &Context,
val: Value,
num_symbol_uses: &HashMap<Symbol, u32>,
escaped_symbols: &EscapedSymbols,
) -> bool {
match val.get_instruction(context).unwrap() {
Instruction::MemCopyBytes { dst_val_ptr, .. }
| Instruction::MemCopyVal { dst_val_ptr, .. }
| Instruction::Store { dst_val_ptr, .. } => {
let syms = get_symbols(context, *dst_val_ptr);
syms.iter().all(|sym| {
!escaped_symbols.contains(sym)
&& num_symbol_uses.get(sym).map_or(0, |uses| *uses) == 0
})
}
_ => false,
}
}
pub fn dce(
context: &mut Context,
analyses: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
let escaped_symbols: &EscapedSymbols = analyses.get_analysis_result(function);
let mut num_uses: HashMap<Value, u32> = HashMap::new();
let mut num_local_uses: HashMap<LocalVar, u32> = HashMap::new();
let mut num_symbol_uses: HashMap<Symbol, u32> = HashMap::new();
let mut stores_of_sym: HashMap<Symbol, Vec<Value>> = HashMap::new();
for sym in function
.args_iter(context)
.flat_map(|arg| get_symbols(context, arg.1))
{
num_symbol_uses
.entry(sym)
.and_modify(|count| *count += 1)
.or_insert(1);
}
for (_block, inst) in function.instruction_iter(context) {
for sym in get_loaded_symbols(context, inst) {
num_symbol_uses
.entry(sym)
.and_modify(|count| *count += 1)
.or_insert(1);
}
for stored_sym in get_stored_symbols(context, inst) {
stores_of_sym
.entry(stored_sym)
.and_modify(|stores| stores.push(inst))
.or_insert(vec![inst]);
}
let inst = inst.get_instruction(context).unwrap();
if let Instruction::GetLocal(local) = inst {
num_local_uses
.entry(*local)
.and_modify(|count| *count += 1)
.or_insert(1);
}
let opds = inst.get_operands();
for v in opds {
match context.values[v.0].value {
ValueDatum::Instruction(_) => {
num_uses
.entry(v)
.and_modify(|count| *count += 1)
.or_insert(1);
}
ValueDatum::Configurable(_) | ValueDatum::Constant(_) | ValueDatum::Argument(_) => {
}
}
}
}
let mut worklist = function
.instruction_iter(context)
.filter_map(|(_block, inst)| {
(num_uses.get(&inst).is_none()
|| is_removable_store(context, inst, &num_symbol_uses, escaped_symbols))
.then_some(inst)
})
.collect::<Vec<_>>();
let mut modified = false;
let mut cemetery = FxHashSet::default();
while let Some(dead) = worklist.pop() {
if !can_eliminate_instruction(context, dead, &num_symbol_uses, escaped_symbols)
|| cemetery.contains(&dead)
{
continue;
}
let opds = dead.get_instruction(context).unwrap().get_operands();
for v in opds {
match context.values[v.0].value {
ValueDatum::Instruction(_) => {
let nu = num_uses.get_mut(&v).unwrap();
*nu -= 1;
if *nu == 0 {
worklist.push(v);
}
}
ValueDatum::Configurable(_) | ValueDatum::Constant(_) | ValueDatum::Argument(_) => {
}
}
}
for sym in get_loaded_symbols(context, dead) {
let nu = num_symbol_uses.get_mut(&sym).unwrap();
*nu -= 1;
if *nu == 0 {
for store in stores_of_sym.get(&sym).unwrap_or(&vec![]) {
worklist.push(*store);
}
}
}
cemetery.insert(dead);
if let ValueDatum::Instruction(Instruction::GetLocal(local)) = context.values[dead.0].value
{
let count = num_local_uses.get_mut(&local).unwrap();
*count -= 1;
}
modified = true;
}
for block in function.block_iter(context) {
block.remove_instructions(context, |inst| cemetery.contains(&inst));
}
let local_removals: Vec<_> = function
.locals_iter(context)
.filter_map(|(name, local)| {
(num_local_uses.get(local).cloned().unwrap_or(0) == 0).then_some(name.clone())
})
.collect();
if !local_removals.is_empty() {
modified = true;
function.remove_locals(context, &local_removals);
}
Ok(modified)
}
pub fn func_dce(
context: &mut Context,
_: &AnalysisResults,
module: Module,
) -> Result<bool, IrError> {
let entry_fns = module
.function_iter(context)
.filter(|func| func.is_entry(context))
.collect::<Vec<_>>();
fn grow_called_function_set(
context: &Context,
caller: Function,
called_set: &mut HashSet<Function>,
) {
if called_set.insert(caller) {
for func in caller
.instruction_iter(context)
.filter_map(|(_block, ins_value)| {
ins_value
.get_instruction(context)
.and_then(|ins| match ins {
Instruction::Call(f, _args) => Some(f),
_otherwise => None,
})
})
{
grow_called_function_set(context, *func, called_set);
}
}
}
let mut called_fns: HashSet<Function> = HashSet::new();
for entry_fn in entry_fns {
grow_called_function_set(context, entry_fn, &mut called_fns);
}
let dead_fns = module
.function_iter(context)
.filter(|f| !called_fns.contains(f))
.collect::<Vec<_>>();
let modified = !dead_fns.is_empty();
for dead_fn in dead_fns {
module.remove_function(context, &dead_fn);
}
Ok(modified)
}