cranelift_frontend/
switch.rs

1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
10/// They emit efficient code using branches, jump tables, or a combination of both.
11///
12/// # Example
13///
14/// ```rust
15/// # use cranelift_codegen::ir::types::*;
16/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
17/// # use cranelift_codegen::isa::CallConv;
18/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
19/// #
20/// # let mut sig = Signature::new(CallConv::SystemV);
21/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
22/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
23/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
24/// #
25/// # let entry = builder.create_block();
26/// # builder.switch_to_block(entry);
27/// #
28/// let block0 = builder.create_block();
29/// let block1 = builder.create_block();
30/// let block2 = builder.create_block();
31/// let fallback = builder.create_block();
32///
33/// let val = builder.ins().iconst(I32, 1);
34///
35/// let mut switch = Switch::new();
36/// switch.set_entry(0, block0);
37/// switch.set_entry(1, block1);
38/// switch.set_entry(7, block2);
39/// switch.emit(&mut builder, val, fallback);
40/// ```
41#[derive(Debug, Default)]
42pub struct Switch {
43    cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47    /// Create a new empty switch
48    pub fn new() -> Self {
49        Self {
50            cases: HashMap::new(),
51        }
52    }
53
54    /// Set a switch entry
55    pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56        let prev = self.cases.insert(index, block);
57        assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58    }
59
60    /// Get a reference to all existing entries
61    pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62        &self.cases
63    }
64
65    /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
66    ///
67    /// # Postconditions
68    ///
69    /// * Every entry will be represented.
70    /// * The `ContiguousCaseRange`s will not overlap.
71    /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
72    /// * No `ContiguousCaseRange`s will be empty.
73    fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74        log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75        let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76        cases.sort_by_key(|&(index, _)| index);
77
78        let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79        let mut last_index = None;
80        for (index, block) in cases {
81            match last_index {
82                None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83                Some(last_index) => {
84                    if index > last_index + 1 {
85                        contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86                    }
87                }
88            }
89            contiguous_case_ranges
90                .last_mut()
91                .unwrap()
92                .blocks
93                .push(block);
94            last_index = Some(index);
95        }
96
97        log::trace!(
98            "build_contiguous_case_ranges after: {:#?}",
99            contiguous_case_ranges
100        );
101
102        contiguous_case_ranges
103    }
104
105    /// Binary search for the right `ContiguousCaseRange`.
106    fn build_search_tree<'a>(
107        bx: &mut FunctionBuilder,
108        val: Value,
109        otherwise: Block,
110        contiguous_case_ranges: &'a [ContiguousCaseRange],
111    ) {
112        // If no switch cases were added to begin with, we can just emit `jump otherwise`.
113        if contiguous_case_ranges.is_empty() {
114            bx.ins().jump(otherwise, &[]);
115            return;
116        }
117
118        // Avoid allocation in the common case
119        if contiguous_case_ranges.len() <= 3 {
120            Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
121            return;
122        }
123
124        let mut stack = Vec::new();
125        stack.push((None, contiguous_case_ranges));
126
127        while let Some((block, contiguous_case_ranges)) = stack.pop() {
128            if let Some(block) = block {
129                bx.switch_to_block(block);
130            }
131
132            if contiguous_case_ranges.len() <= 3 {
133                Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
134            } else {
135                let split_point = contiguous_case_ranges.len() / 2;
136                let (left, right) = contiguous_case_ranges.split_at(split_point);
137
138                let left_block = bx.create_block();
139                let right_block = bx.create_block();
140
141                let first_index = right[0].first_index;
142                let should_take_right_side =
143                    icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
144                bx.ins()
145                    .brif(should_take_right_side, right_block, &[], left_block, &[]);
146
147                bx.seal_block(left_block);
148                bx.seal_block(right_block);
149
150                stack.push((Some(left_block), left));
151                stack.push((Some(right_block), right));
152            }
153        }
154    }
155
156    /// Linear search for the right `ContiguousCaseRange`.
157    fn build_search_branches<'a>(
158        bx: &mut FunctionBuilder,
159        val: Value,
160        otherwise: Block,
161        contiguous_case_ranges: &'a [ContiguousCaseRange],
162    ) {
163        for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
164            let alternate = if ix == 0 {
165                otherwise
166            } else {
167                bx.create_block()
168            };
169
170            if range.first_index == 0 {
171                assert_eq!(alternate, otherwise);
172
173                if let Some(block) = range.single_block() {
174                    bx.ins().brif(val, otherwise, &[], block, &[]);
175                } else {
176                    Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
177                }
178            } else {
179                if let Some(block) = range.single_block() {
180                    let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
181                    bx.ins().brif(is_good_val, block, &[], alternate, &[]);
182                } else {
183                    let is_good_val = icmp_imm_u128(
184                        bx,
185                        IntCC::UnsignedGreaterThanOrEqual,
186                        val,
187                        range.first_index,
188                    );
189                    let jt_block = bx.create_block();
190                    bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
191                    bx.seal_block(jt_block);
192                    bx.switch_to_block(jt_block);
193                    Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
194                }
195            }
196
197            if alternate != otherwise {
198                bx.seal_block(alternate);
199                bx.switch_to_block(alternate);
200            }
201        }
202    }
203
204    fn build_jump_table(
205        bx: &mut FunctionBuilder,
206        val: Value,
207        otherwise: Block,
208        first_index: EntryIndex,
209        blocks: &[Block],
210    ) {
211        // There are currently no 128bit systems supported by rustc, but once we do ensure that
212        // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
213        assert!(
214            u32::try_from(blocks.len()).is_ok(),
215            "Jump tables bigger than 2^32-1 are not yet supported"
216        );
217
218        let jt_data = JumpTableData::new(
219            bx.func.dfg.block_call(otherwise, &[]),
220            &blocks
221                .iter()
222                .map(|block| bx.func.dfg.block_call(*block, &[]))
223                .collect::<Vec<_>>(),
224        );
225        let jump_table = bx.create_jump_table(jt_data);
226
227        let discr = if first_index == 0 {
228            val
229        } else {
230            if let Ok(first_index) = u64::try_from(first_index) {
231                bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
232            } else {
233                let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
234                let lsb = bx.ins().iconst(types::I64, lsb as i64);
235                let msb = bx.ins().iconst(types::I64, msb as i64);
236                let index = bx.ins().iconcat(lsb, msb);
237                bx.ins().isub(val, index)
238            }
239        };
240
241        let discr = match bx.func.dfg.value_type(discr).bits() {
242            bits if bits > 32 => {
243                // Check for overflow of cast to u32. This is the max supported jump table entries.
244                let new_block = bx.create_block();
245                let bigger_than_u32 =
246                    bx.ins()
247                        .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
248                bx.ins()
249                    .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
250                bx.seal_block(new_block);
251                bx.switch_to_block(new_block);
252
253                // Cast to i32, as br_table is not implemented for i64/i128
254                bx.ins().ireduce(types::I32, discr)
255            }
256            bits if bits < 32 => bx.ins().uextend(types::I32, discr),
257            _ => discr,
258        };
259
260        bx.ins().br_table(discr, jump_table);
261    }
262
263    /// Build the switch
264    ///
265    /// # Arguments
266    ///
267    /// * The function builder to emit to
268    /// * The value to switch on
269    /// * The default block
270    pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
271        // Validate that the type of `val` is sufficiently wide to address all cases.
272        let max = self.cases.keys().max().copied().unwrap_or(0);
273        let val_ty = bx.func.dfg.value_type(val);
274        let val_ty_max = val_ty.bounds(false).1;
275        if max > val_ty_max {
276            panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
277        }
278
279        let contiguous_case_ranges = self.collect_contiguous_case_ranges();
280        Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
281    }
282}
283
284fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
285    if bx.func.dfg.value_type(x) != types::I128 {
286        assert!(u64::try_from(y).is_ok());
287        bx.ins().icmp_imm(cond, x, y as i64)
288    } else if let Ok(index) = i64::try_from(y) {
289        bx.ins().icmp_imm(cond, x, index)
290    } else {
291        let (lsb, msb) = (y as u64, (y >> 64) as u64);
292        let lsb = bx.ins().iconst(types::I64, lsb as i64);
293        let msb = bx.ins().iconst(types::I64, msb as i64);
294        let index = bx.ins().iconcat(lsb, msb);
295        bx.ins().icmp(cond, x, index)
296    }
297}
298
299/// This represents a contiguous range of cases to switch on.
300///
301/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
302///
303/// ```plain
304/// ContiguousCaseRange {
305///     first_index: 10,
306///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
307/// }
308/// ```
309#[derive(Debug)]
310struct ContiguousCaseRange {
311    /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
312    first_index: EntryIndex,
313
314    /// The blocks to jump to sorted in ascending order of entry index.
315    blocks: Vec<Block>,
316}
317
318impl ContiguousCaseRange {
319    fn new(first_index: EntryIndex) -> Self {
320        Self {
321            first_index,
322            blocks: Vec::new(),
323        }
324    }
325
326    /// Returns `Some` block when there is only a single block in this range.
327    fn single_block(&self) -> Option<Block> {
328        if self.blocks.len() == 1 {
329            Some(self.blocks[0])
330        } else {
331            None
332        }
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::frontend::FunctionBuilderContext;
340    use alloc::string::ToString;
341
342    macro_rules! setup {
343        ($default:expr, [$($index:expr,)*]) => {{
344            let mut func = Function::new();
345            let mut func_ctx = FunctionBuilderContext::new();
346            {
347                let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
348                let block = bx.create_block();
349                bx.switch_to_block(block);
350                let val = bx.ins().iconst(types::I8, 0);
351                let mut switch = Switch::new();
352                let _ = &mut switch;
353                $(
354                    let block = bx.create_block();
355                    switch.set_entry($index, block);
356                )*
357                switch.emit(&mut bx, val, Block::with_number($default).unwrap());
358            }
359            func
360                .to_string()
361                .trim_start_matches("function u0:0() fast {\n")
362                .trim_end_matches("\n}\n")
363                .to_string()
364        }};
365    }
366
367    #[test]
368    fn switch_empty() {
369        let func = setup!(42, []);
370        assert_eq_output!(
371            func,
372            "block0:
373    v0 = iconst.i8 0
374    jump block42"
375        );
376    }
377
378    #[test]
379    fn switch_zero() {
380        let func = setup!(0, [0,]);
381        assert_eq_output!(
382            func,
383            "block0:
384    v0 = iconst.i8 0
385    brif v0, block0, block1  ; v0 = 0"
386        );
387    }
388
389    #[test]
390    fn switch_single() {
391        let func = setup!(0, [1,]);
392        assert_eq_output!(
393            func,
394            "block0:
395    v0 = iconst.i8 0
396    v1 = icmp_imm eq v0, 1  ; v0 = 0
397    brif v1, block1, block0"
398        );
399    }
400
401    #[test]
402    fn switch_bool() {
403        let func = setup!(0, [0, 1,]);
404        assert_eq_output!(
405            func,
406            "block0:
407    v0 = iconst.i8 0
408    v1 = uextend.i32 v0  ; v0 = 0
409    br_table v1, block0, [block1, block2]"
410        );
411    }
412
413    #[test]
414    fn switch_two_gap() {
415        let func = setup!(0, [0, 2,]);
416        assert_eq_output!(
417            func,
418            "block0:
419    v0 = iconst.i8 0
420    v1 = icmp_imm eq v0, 2  ; v0 = 0
421    brif v1, block2, block3
422
423block3:
424    brif.i8 v0, block0, block1  ; v0 = 0"
425        );
426    }
427
428    #[test]
429    fn switch_many() {
430        let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
431        assert_eq_output!(
432            func,
433            "block0:
434    v0 = iconst.i8 0
435    v1 = icmp_imm uge v0, 7  ; v0 = 0
436    brif v1, block9, block8
437
438block9:
439    v2 = icmp_imm.i8 uge v0, 10  ; v0 = 0
440    brif v2, block11, block10
441
442block11:
443    v3 = iadd_imm.i8 v0, -10  ; v0 = 0
444    v4 = uextend.i32 v3
445    br_table v4, block0, [block5, block6, block7]
446
447block10:
448    v5 = icmp_imm.i8 eq v0, 7  ; v0 = 0
449    brif v5, block4, block0
450
451block8:
452    v6 = icmp_imm.i8 eq v0, 5  ; v0 = 0
453    brif v6, block3, block12
454
455block12:
456    v7 = uextend.i32 v0  ; v0 = 0
457    br_table v7, block0, [block1, block2]"
458        );
459    }
460
461    #[test]
462    fn switch_min_index_value() {
463        let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
464        assert_eq_output!(
465            func,
466            "block0:
467    v0 = iconst.i8 0
468    v1 = icmp_imm eq v0, -128  ; v0 = 0
469    brif v1, block1, block3
470
471block3:
472    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
473    brif v2, block2, block0"
474        );
475    }
476
477    #[test]
478    fn switch_max_index_value() {
479        let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
480        assert_eq_output!(
481            func,
482            "block0:
483    v0 = iconst.i8 0
484    v1 = icmp_imm eq v0, 127  ; v0 = 0
485    brif v1, block1, block3
486
487block3:
488    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
489    brif v2, block2, block0"
490        )
491    }
492
493    #[test]
494    fn switch_optimal_codegen() {
495        let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
496        assert_eq_output!(
497            func,
498            "block0:
499    v0 = iconst.i8 0
500    v1 = icmp_imm eq v0, -1  ; v0 = 0
501    brif v1, block1, block4
502
503block4:
504    v2 = uextend.i32 v0  ; v0 = 0
505    br_table v2, block0, [block2, block3]"
506        );
507    }
508
509    #[test]
510    #[should_panic(
511        expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
512    )]
513    fn switch_rejects_small_inputs() {
514        // This is a regression test for a bug that we found where we would emit a cmp
515        // with a type that was not able to fully represent a large index.
516        //
517        // See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
518        setup!(1, [0x4100_0000_00bf_d470,]);
519    }
520
521    #[test]
522    fn switch_seal_generated_blocks() {
523        let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
524
525        for case in cases {
526            for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
527                eprintln!("Testing {typ:?} with keys: {case:?}");
528                do_case(case, *typ);
529            }
530        }
531
532        fn do_case(keys: &[u128], typ: Type) {
533            let mut func = Function::new();
534            let mut builder_ctx = FunctionBuilderContext::new();
535            let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
536
537            let root_block = builder.create_block();
538            let default_block = builder.create_block();
539            let mut switch = Switch::new();
540
541            let case_blocks = keys
542                .iter()
543                .map(|key| {
544                    let block = builder.create_block();
545                    switch.set_entry(*key, block);
546                    block
547                })
548                .collect::<Vec<_>>();
549
550            builder.seal_block(root_block);
551            builder.switch_to_block(root_block);
552
553            let val = builder.ins().iconst(typ, 1);
554            switch.emit(&mut builder, val, default_block);
555
556            for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
557                builder.seal_block(block);
558                builder.switch_to_block(block);
559                builder.ins().return_(&[]);
560            }
561
562            builder.finalize(); // Will panic if some blocks are not sealed
563        }
564    }
565
566    #[test]
567    fn switch_64bit() {
568        let mut func = Function::new();
569        let mut func_ctx = FunctionBuilderContext::new();
570        {
571            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
572            let block0 = bx.create_block();
573            bx.switch_to_block(block0);
574            let val = bx.ins().iconst(types::I64, 0);
575            let mut switch = Switch::new();
576            let block1 = bx.create_block();
577            switch.set_entry(1, block1);
578            let block2 = bx.create_block();
579            switch.set_entry(0, block2);
580            let block3 = bx.create_block();
581            switch.emit(&mut bx, val, block3);
582        }
583        let func = func
584            .to_string()
585            .trim_start_matches("function u0:0() fast {\n")
586            .trim_end_matches("\n}\n")
587            .to_string();
588        assert_eq_output!(
589            func,
590            "block0:
591    v0 = iconst.i64 0
592    v1 = icmp_imm ugt v0, 0xffff_ffff  ; v0 = 0
593    brif v1, block3, block4
594
595block4:
596    v2 = ireduce.i32 v0  ; v0 = 0
597    br_table v2, block3, [block2, block1]"
598        );
599    }
600
601    #[test]
602    fn switch_128bit() {
603        let mut func = Function::new();
604        let mut func_ctx = FunctionBuilderContext::new();
605        {
606            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
607            let block0 = bx.create_block();
608            bx.switch_to_block(block0);
609            let val = bx.ins().iconst(types::I64, 0);
610            let val = bx.ins().uextend(types::I128, val);
611            let mut switch = Switch::new();
612            let block1 = bx.create_block();
613            switch.set_entry(1, block1);
614            let block2 = bx.create_block();
615            switch.set_entry(0, block2);
616            let block3 = bx.create_block();
617            switch.emit(&mut bx, val, block3);
618        }
619        let func = func
620            .to_string()
621            .trim_start_matches("function u0:0() fast {\n")
622            .trim_end_matches("\n}\n")
623            .to_string();
624        assert_eq_output!(
625            func,
626            "block0:
627    v0 = iconst.i64 0
628    v1 = uextend.i128 v0  ; v0 = 0
629    v2 = icmp_imm ugt v1, 0xffff_ffff
630    brif v2, block3, block4
631
632block4:
633    v3 = ireduce.i32 v1
634    br_table v3, block3, [block2, block1]"
635        );
636    }
637
638    #[test]
639    fn switch_128bit_max_u64() {
640        let mut func = Function::new();
641        let mut func_ctx = FunctionBuilderContext::new();
642        {
643            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
644            let block0 = bx.create_block();
645            bx.switch_to_block(block0);
646            let val = bx.ins().iconst(types::I64, 0);
647            let val = bx.ins().uextend(types::I128, val);
648            let mut switch = Switch::new();
649            let block1 = bx.create_block();
650            switch.set_entry(u64::MAX.into(), block1);
651            let block2 = bx.create_block();
652            switch.set_entry(0, block2);
653            let block3 = bx.create_block();
654            switch.emit(&mut bx, val, block3);
655        }
656        let func = func
657            .to_string()
658            .trim_start_matches("function u0:0() fast {\n")
659            .trim_end_matches("\n}\n")
660            .to_string();
661        assert_eq_output!(
662            func,
663            "block0:
664    v0 = iconst.i64 0
665    v1 = uextend.i128 v0  ; v0 = 0
666    v2 = iconst.i64 -1
667    v3 = iconst.i64 0
668    v4 = iconcat v2, v3  ; v2 = -1, v3 = 0
669    v5 = icmp eq v1, v4
670    brif v5, block1, block4
671
672block4:
673    brif.i128 v1, block3, block2"
674        );
675    }
676}