use super::HashMap;
use crate::frontend::FunctionBuilder;
use alloc::vec::Vec;
use core::convert::TryFrom;
use cranelift_codegen::ir::condcodes::IntCC;
use cranelift_codegen::ir::*;
use log::debug;
type EntryIndex = u128;
#[derive(Debug, Default)]
pub struct Switch {
cases: HashMap<EntryIndex, Block>,
}
impl Switch {
pub fn new() -> Self {
Self {
cases: HashMap::new(),
}
}
pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
let prev = self.cases.insert(index, block);
assert!(
prev.is_none(),
"Tried to set the same entry {} twice",
index
);
}
pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
&self.cases
}
fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
debug!("build_contiguous_case_ranges before: {:#?}", self.cases);
let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
cases.sort_by_key(|&(index, _)| index);
let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
let mut last_index = None;
for (index, block) in cases {
match last_index {
None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
Some(last_index) => {
if index > last_index + 1 {
contiguous_case_ranges.push(ContiguousCaseRange::new(index));
}
}
}
contiguous_case_ranges
.last_mut()
.unwrap()
.blocks
.push(block);
last_index = Some(index);
}
debug!(
"build_contiguous_case_ranges after: {:#?}",
contiguous_case_ranges
);
contiguous_case_ranges
}
fn build_search_tree(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
) -> Vec<(EntryIndex, Block, Vec<Block>)> {
let mut cases_and_jt_blocks = Vec::new();
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
return cases_and_jt_blocks;
}
let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
stack.push((None, contiguous_case_ranges));
while let Some((block, contiguous_case_ranges)) = stack.pop() {
if let Some(block) = block {
bx.switch_to_block(block);
}
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
} else {
let split_point = contiguous_case_ranges.len() / 2;
let mut left = contiguous_case_ranges;
let right = left.split_off(split_point);
let left_block = bx.create_block();
let right_block = bx.create_block();
let first_index = right[0].first_index;
let should_take_right_side =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(should_take_right_side, right_block, &[]);
bx.ins().jump(left_block, &[]);
bx.seal_block(left_block);
bx.seal_block(right_block);
stack.push((Some(left_block), left));
stack.push((Some(right_block), right));
}
}
cases_and_jt_blocks
}
fn build_search_branches(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
) {
let mut was_branch = false;
let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
if was_branch {
let block = bx.create_block();
bx.ins().jump(block, &[]);
bx.seal_block(block);
bx.switch_to_block(block);
}
};
for ContiguousCaseRange {
first_index,
blocks,
} in contiguous_case_ranges.into_iter().rev()
{
match (blocks.len(), first_index) {
(1, 0) => {
ins_fallthrough_jump(was_branch, bx);
bx.ins().brz(val, blocks[0], &[]);
}
(1, _) => {
ins_fallthrough_jump(was_branch, bx);
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index);
bx.ins().brnz(is_good_val, blocks[0], &[]);
}
(_, 0) => {
let jt_block = bx.create_block();
bx.ins().jump(jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
return;
}
(_, _) => {
ins_fallthrough_jump(was_branch, bx);
let jt_block = bx.create_block();
let is_good_val =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(is_good_val, jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
}
}
was_branch = true;
}
bx.ins().jump(otherwise, &[]);
}
fn build_jump_tables(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
) {
for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
assert!(
u64::try_from(blocks.len()).is_ok(),
"Jump tables bigger than 2^64-1 are not yet supported"
);
let mut jt_data = JumpTableData::new();
for block in blocks {
jt_data.push_entry(block);
}
let jump_table = bx.create_jump_table(jt_data);
bx.switch_to_block(jt_block);
let discr = if first_index == 0 {
val
} else {
if let Ok(first_index) = u64::try_from(first_index) {
bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
} else {
let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().isub(val, index)
}
};
let discr = if bx.func.dfg.value_type(discr).bits() > 64 {
let new_block = bx.create_block();
let bigger_than_u64 =
bx.ins()
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u64::max_value() as i64);
bx.ins().brnz(bigger_than_u64, otherwise, &[]);
bx.ins().jump(new_block, &[]);
bx.switch_to_block(new_block);
bx.ins().ireduce(types::I64, discr)
} else {
discr
};
bx.ins().br_table(discr, otherwise, jump_table);
}
}
pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
let val = match bx.func.dfg.value_type(val) {
types::I8 | types::I16 => bx.ins().uextend(types::I32, val),
_ => val,
};
let contiguous_case_ranges = self.collect_contiguous_case_ranges();
let cases_and_jt_blocks =
Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
}
}
fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
if let Ok(index) = u64::try_from(y) {
bx.ins().icmp_imm(cond, x, index as i64)
} else {
let (lsb, msb) = (y as u64, (y >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().icmp(cond, x, index)
}
}
#[derive(Debug)]
struct ContiguousCaseRange {
first_index: EntryIndex,
blocks: Vec<Block>,
}
impl ContiguousCaseRange {
fn new(first_index: EntryIndex) -> Self {
Self {
first_index,
blocks: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frontend::FunctionBuilderContext;
use alloc::string::ToString;
use cranelift_codegen::ir::Function;
macro_rules! setup {
($default:expr, [$($index:expr,)*]) => {{
let mut func = Function::new();
let mut func_ctx = FunctionBuilderContext::new();
{
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
let block = bx.create_block();
bx.switch_to_block(block);
let val = bx.ins().iconst(types::I8, 0);
let mut switch = Switch::new();
$(
let block = bx.create_block();
switch.set_entry($index, block);
)*
switch.emit(&mut bx, val, Block::with_number($default).unwrap());
}
func
.to_string()
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string()
}};
}
#[test]
fn switch_zero() {
let func = setup!(0, [0,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
brz v1, block1
jump block0"
);
}
#[test]
fn switch_single() {
let func = setup!(0, [1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm eq v1, 1
brnz v2, block1
jump block0"
);
}
#[test]
fn switch_bool() {
let func = setup!(0, [0, 1,]);
assert_eq!(
func,
" jt0 = jump_table [block1, block2]
block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
jump block3
block3:
br_table.i32 v1, block0, jt0"
);
}
#[test]
fn switch_two_gap() {
let func = setup!(0, [0, 2,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm eq v1, 2
brnz v2, block2
jump block3
block3:
brz.i32 v1, block1
jump block0"
);
}
#[test]
fn switch_many() {
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
assert_eq!(
func,
" jt0 = jump_table [block1, block2]
jt1 = jump_table [block5, block6, block7]
block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm uge v1, 7
brnz v2, block9
jump block8
block9:
v3 = icmp_imm.i32 uge v1, 10
brnz v3, block10
jump block11
block11:
v4 = icmp_imm.i32 eq v1, 7
brnz v4, block4
jump block0
block8:
v5 = icmp_imm.i32 eq v1, 5
brnz v5, block3
jump block12
block12:
br_table.i32 v1, block0, jt0
block10:
v6 = iadd_imm.i32 v1, -10
br_table v6, block0, jt1"
);
}
#[test]
fn switch_min_index_value() {
let func = setup!(0, [::core::i64::MIN as u64 as u128, 1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm eq v1, 0x8000_0000_0000_0000
brnz v2, block1
jump block3
block3:
v3 = icmp_imm.i32 eq v1, 1
brnz v3, block2
jump block0"
);
}
#[test]
fn switch_max_index_value() {
let func = setup!(0, [::core::i64::MAX as u64 as u128, 1,]);
assert_eq!(
func,
"block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm eq v1, 0x7fff_ffff_ffff_ffff
brnz v2, block1
jump block3
block3:
v3 = icmp_imm.i32 eq v1, 1
brnz v3, block2
jump block0"
)
}
#[test]
fn switch_optimal_codegen() {
let func = setup!(0, [-1i64 as u64 as u128, 0, 1,]);
assert_eq!(
func,
" jt0 = jump_table [block2, block3]
block0:
v0 = iconst.i8 0
v1 = uextend.i32 v0
v2 = icmp_imm eq v1, -1
brnz v2, block1
jump block4
block4:
br_table.i32 v1, block0, jt0"
);
}
#[test]
fn switch_seal_generated_blocks() {
let keys = [0, 1, 2, 10, 11, 12, 20, 30, 40, 50];
let mut func = Function::new();
let mut builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
let root_block = builder.create_block();
let default_block = builder.create_block();
let mut switch = Switch::new();
let case_blocks = keys
.iter()
.map(|key| {
let block = builder.create_block();
switch.set_entry(*key, block);
block
})
.collect::<Vec<_>>();
builder.seal_block(root_block);
builder.switch_to_block(root_block);
let val = builder.ins().iconst(types::I32, 1);
switch.emit(&mut builder, val, default_block);
for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
builder.seal_block(block);
builder.switch_to_block(block);
builder.ins().return_(&[]);
}
builder.finalize();
}
#[test]
fn switch_128bit() {
let mut func = Function::new();
let mut func_ctx = FunctionBuilderContext::new();
{
let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
let block0 = bx.create_block();
bx.switch_to_block(block0);
let val = bx.ins().iconst(types::I128, 0);
let mut switch = Switch::new();
let block1 = bx.create_block();
switch.set_entry(1, block1);
let block2 = bx.create_block();
switch.set_entry(0, block2);
let block3 = bx.create_block();
switch.emit(&mut bx, val, block3);
}
let func = func
.to_string()
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string();
assert_eq!(
func,
" jt0 = jump_table [block2, block1]
block0:
v0 = iconst.i128 0
jump block4
block4:
v1 = icmp_imm.i128 ugt v0, -1
brnz v1, block3
jump block5
block5:
v2 = ireduce.i64 v0
br_table v2, block3, jt0"
);
}
}