use crate::cfg::CFGInfo;
use crate::entity::PerEntity;
use crate::ir::{Block, FunctionBody, Value, ValueDef};
use std::collections::{BTreeSet, HashMap, HashSet};
pub(crate) fn run(body: &mut FunctionBody, cut_blocks: Option<HashSet<Block>>, cfg: &CFGInfo) {
MaxSSAPass::new(cut_blocks).run(body, cfg);
}
struct MaxSSAPass {
cut_blocks: Option<HashSet<Block>>,
new_args: PerEntity<Block, Vec<Value>>,
value_map: HashMap<(Block, Value), Value>,
}
impl MaxSSAPass {
fn new(cut_blocks: Option<HashSet<Block>>) -> Self {
Self {
cut_blocks,
new_args: PerEntity::default(),
value_map: HashMap::new(),
}
}
fn run(mut self, body: &mut FunctionBody, cfg: &CFGInfo) {
for block in body.blocks.iter() {
self.visit(body, cfg, block);
}
self.update(body);
}
fn visit(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block) {
let mut uses = BTreeSet::default();
for &inst in &body.blocks[block].insts {
match &body.values[inst] {
&ValueDef::Operator(_, args, _) => {
for &arg in &body.arg_pool[args] {
let arg = body.resolve_alias(arg);
uses.insert(arg);
}
}
&ValueDef::PickOutput(value, ..) => {
let value = body.resolve_alias(value);
uses.insert(value);
}
_ => {}
}
}
body.blocks[block].terminator.visit_uses(|u| {
let u = body.resolve_alias(u);
uses.insert(u);
});
for u in uses {
self.visit_use(body, cfg, block, u);
}
}
fn visit_use(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block, value: Value) {
if self.value_map.contains_key(&(block, value)) {
return;
}
if cfg.def_block[value] == block {
return;
}
self.new_args[block].push(value);
let ty = body.values[value].ty(&body.type_pool).unwrap();
let blockparam = body.add_blockparam(block, ty);
self.value_map.insert((block, value), blockparam);
for i in 0..body.blocks[block].preds.len() {
let pred = body.blocks[block].preds[i];
self.visit_use(body, cfg, pred, value);
}
if !self.is_cut_block(block) {
if let Some(pred_value) = iter_all_same(
body.blocks[block]
.preds
.iter()
.map(|&pred| *self.value_map.get(&(pred, value)).unwrap_or(&value))
.filter(|&val| val != blockparam),
) {
body.blocks[block].params.pop();
self.new_args[block].pop();
body.values[blockparam] = ValueDef::Alias(pred_value);
self.value_map.insert((block, value), pred_value);
}
}
}
fn is_cut_block(&self, block: Block) -> bool {
self.cut_blocks
.as_ref()
.map(|cut_blocks| cut_blocks.contains(&block))
.unwrap_or(true)
}
fn update_branch_args(&mut self, body: &mut FunctionBody) {
for (block, blockdata) in body.blocks.entries_mut() {
blockdata.terminator.update_targets(|target| {
for &new_arg in &self.new_args[target.block] {
let actual_value = self
.value_map
.get(&(block, new_arg))
.copied()
.unwrap_or(new_arg);
target.args.push(actual_value);
}
});
}
}
fn update_uses(&mut self, body: &mut FunctionBody, block: Block) {
let resolve = |body: &FunctionBody, value: Value| {
let value = body.resolve_alias(value);
self.value_map
.get(&(block, value))
.copied()
.unwrap_or(value)
};
for i in 0..body.blocks[block].insts.len() {
let inst = body.blocks[block].insts[i];
let mut def = std::mem::take(&mut body.values[inst]);
match &mut def {
ValueDef::Operator(_, args, _) => {
for i in 0..args.len() {
let val = body.arg_pool[*args][i];
let val = resolve(body, val);
body.arg_pool[*args][i] = val;
}
}
ValueDef::PickOutput(value, ..) => {
*value = resolve(body, *value);
}
ValueDef::Alias(_) => {
def = ValueDef::None;
}
_ => {}
}
body.values[inst] = def;
}
let mut term = std::mem::take(&mut body.blocks[block].terminator);
term.update_uses(|u| {
*u = resolve(body, *u);
});
body.blocks[block].terminator = term;
}
fn update(&mut self, body: &mut FunctionBody) {
self.update_branch_args(body);
for block in body.blocks.iter() {
self.update_uses(body, block);
}
}
}
fn iter_all_same<Item: PartialEq + Eq + Copy, I: Iterator<Item = Item>>(iter: I) -> Option<Item> {
let mut item = None;
for val in iter {
if *item.get_or_insert(val) != val {
return None;
}
}
item
}