1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9#[derive(Debug, Default)]
42pub struct Switch {
43 cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47 pub fn new() -> Self {
49 Self {
50 cases: HashMap::new(),
51 }
52 }
53
54 pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56 let prev = self.cases.insert(index, block);
57 assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58 }
59
60 pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62 &self.cases
63 }
64
65 fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74 log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75 let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76 cases.sort_by_key(|&(index, _)| index);
77
78 let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79 let mut last_index = None;
80 for (index, block) in cases {
81 match last_index {
82 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83 Some(last_index) => {
84 if index > last_index + 1 {
85 contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86 }
87 }
88 }
89 contiguous_case_ranges
90 .last_mut()
91 .unwrap()
92 .blocks
93 .push(block);
94 last_index = Some(index);
95 }
96
97 log::trace!(
98 "build_contiguous_case_ranges after: {:#?}",
99 contiguous_case_ranges
100 );
101
102 contiguous_case_ranges
103 }
104
105 fn build_search_tree<'a>(
107 bx: &mut FunctionBuilder,
108 val: Value,
109 otherwise: Block,
110 contiguous_case_ranges: &'a [ContiguousCaseRange],
111 ) {
112 if contiguous_case_ranges.is_empty() {
114 bx.ins().jump(otherwise, &[]);
115 return;
116 }
117
118 if contiguous_case_ranges.len() <= 3 {
120 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
121 return;
122 }
123
124 let mut stack = Vec::new();
125 stack.push((None, contiguous_case_ranges));
126
127 while let Some((block, contiguous_case_ranges)) = stack.pop() {
128 if let Some(block) = block {
129 bx.switch_to_block(block);
130 }
131
132 if contiguous_case_ranges.len() <= 3 {
133 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
134 } else {
135 let split_point = contiguous_case_ranges.len() / 2;
136 let (left, right) = contiguous_case_ranges.split_at(split_point);
137
138 let left_block = bx.create_block();
139 let right_block = bx.create_block();
140
141 let first_index = right[0].first_index;
142 let should_take_right_side =
143 icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
144 bx.ins()
145 .brif(should_take_right_side, right_block, &[], left_block, &[]);
146
147 bx.seal_block(left_block);
148 bx.seal_block(right_block);
149
150 stack.push((Some(left_block), left));
151 stack.push((Some(right_block), right));
152 }
153 }
154 }
155
156 fn build_search_branches<'a>(
158 bx: &mut FunctionBuilder,
159 val: Value,
160 otherwise: Block,
161 contiguous_case_ranges: &'a [ContiguousCaseRange],
162 ) {
163 for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
164 let alternate = if ix == 0 {
165 otherwise
166 } else {
167 bx.create_block()
168 };
169
170 if range.first_index == 0 {
171 assert_eq!(alternate, otherwise);
172
173 if let Some(block) = range.single_block() {
174 bx.ins().brif(val, otherwise, &[], block, &[]);
175 } else {
176 Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
177 }
178 } else {
179 if let Some(block) = range.single_block() {
180 let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
181 bx.ins().brif(is_good_val, block, &[], alternate, &[]);
182 } else {
183 let is_good_val = icmp_imm_u128(
184 bx,
185 IntCC::UnsignedGreaterThanOrEqual,
186 val,
187 range.first_index,
188 );
189 let jt_block = bx.create_block();
190 bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
191 bx.seal_block(jt_block);
192 bx.switch_to_block(jt_block);
193 Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
194 }
195 }
196
197 if alternate != otherwise {
198 bx.seal_block(alternate);
199 bx.switch_to_block(alternate);
200 }
201 }
202 }
203
204 fn build_jump_table(
205 bx: &mut FunctionBuilder,
206 val: Value,
207 otherwise: Block,
208 first_index: EntryIndex,
209 blocks: &[Block],
210 ) {
211 assert!(
214 u32::try_from(blocks.len()).is_ok(),
215 "Jump tables bigger than 2^32-1 are not yet supported"
216 );
217
218 let jt_data = JumpTableData::new(
219 bx.func.dfg.block_call(otherwise, &[]),
220 &blocks
221 .iter()
222 .map(|block| bx.func.dfg.block_call(*block, &[]))
223 .collect::<Vec<_>>(),
224 );
225 let jump_table = bx.create_jump_table(jt_data);
226
227 let discr = if first_index == 0 {
228 val
229 } else {
230 if let Ok(first_index) = u64::try_from(first_index) {
231 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
232 } else {
233 let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
234 let lsb = bx.ins().iconst(types::I64, lsb as i64);
235 let msb = bx.ins().iconst(types::I64, msb as i64);
236 let index = bx.ins().iconcat(lsb, msb);
237 bx.ins().isub(val, index)
238 }
239 };
240
241 let discr = match bx.func.dfg.value_type(discr).bits() {
242 bits if bits > 32 => {
243 let new_block = bx.create_block();
245 let bigger_than_u32 =
246 bx.ins()
247 .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
248 bx.ins()
249 .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
250 bx.seal_block(new_block);
251 bx.switch_to_block(new_block);
252
253 bx.ins().ireduce(types::I32, discr)
255 }
256 bits if bits < 32 => bx.ins().uextend(types::I32, discr),
257 _ => discr,
258 };
259
260 bx.ins().br_table(discr, jump_table);
261 }
262
263 pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
271 let max = self.cases.keys().max().copied().unwrap_or(0);
273 let val_ty = bx.func.dfg.value_type(val);
274 let val_ty_max = val_ty.bounds(false).1;
275 if max > val_ty_max {
276 panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
277 }
278
279 let contiguous_case_ranges = self.collect_contiguous_case_ranges();
280 Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
281 }
282}
283
284fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
285 if bx.func.dfg.value_type(x) != types::I128 {
286 assert!(u64::try_from(y).is_ok());
287 bx.ins().icmp_imm(cond, x, y as i64)
288 } else if let Ok(index) = i64::try_from(y) {
289 bx.ins().icmp_imm(cond, x, index)
290 } else {
291 let (lsb, msb) = (y as u64, (y >> 64) as u64);
292 let lsb = bx.ins().iconst(types::I64, lsb as i64);
293 let msb = bx.ins().iconst(types::I64, msb as i64);
294 let index = bx.ins().iconcat(lsb, msb);
295 bx.ins().icmp(cond, x, index)
296 }
297}
298
299#[derive(Debug)]
310struct ContiguousCaseRange {
311 first_index: EntryIndex,
313
314 blocks: Vec<Block>,
316}
317
318impl ContiguousCaseRange {
319 fn new(first_index: EntryIndex) -> Self {
320 Self {
321 first_index,
322 blocks: Vec::new(),
323 }
324 }
325
326 fn single_block(&self) -> Option<Block> {
328 if self.blocks.len() == 1 {
329 Some(self.blocks[0])
330 } else {
331 None
332 }
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::frontend::FunctionBuilderContext;
340 use alloc::string::ToString;
341
342 macro_rules! setup {
343 ($default:expr, [$($index:expr,)*]) => {{
344 let mut func = Function::new();
345 let mut func_ctx = FunctionBuilderContext::new();
346 {
347 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
348 let block = bx.create_block();
349 bx.switch_to_block(block);
350 let val = bx.ins().iconst(types::I8, 0);
351 let mut switch = Switch::new();
352 let _ = &mut switch;
353 $(
354 let block = bx.create_block();
355 switch.set_entry($index, block);
356 )*
357 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
358 }
359 func
360 .to_string()
361 .trim_start_matches("function u0:0() fast {\n")
362 .trim_end_matches("\n}\n")
363 .to_string()
364 }};
365 }
366
367 #[test]
368 fn switch_empty() {
369 let func = setup!(42, []);
370 assert_eq_output!(
371 func,
372 "block0:
373 v0 = iconst.i8 0
374 jump block42"
375 );
376 }
377
378 #[test]
379 fn switch_zero() {
380 let func = setup!(0, [0,]);
381 assert_eq_output!(
382 func,
383 "block0:
384 v0 = iconst.i8 0
385 brif v0, block0, block1 ; v0 = 0"
386 );
387 }
388
389 #[test]
390 fn switch_single() {
391 let func = setup!(0, [1,]);
392 assert_eq_output!(
393 func,
394 "block0:
395 v0 = iconst.i8 0
396 v1 = icmp_imm eq v0, 1 ; v0 = 0
397 brif v1, block1, block0"
398 );
399 }
400
401 #[test]
402 fn switch_bool() {
403 let func = setup!(0, [0, 1,]);
404 assert_eq_output!(
405 func,
406 "block0:
407 v0 = iconst.i8 0
408 v1 = uextend.i32 v0 ; v0 = 0
409 br_table v1, block0, [block1, block2]"
410 );
411 }
412
413 #[test]
414 fn switch_two_gap() {
415 let func = setup!(0, [0, 2,]);
416 assert_eq_output!(
417 func,
418 "block0:
419 v0 = iconst.i8 0
420 v1 = icmp_imm eq v0, 2 ; v0 = 0
421 brif v1, block2, block3
422
423block3:
424 brif.i8 v0, block0, block1 ; v0 = 0"
425 );
426 }
427
428 #[test]
429 fn switch_many() {
430 let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
431 assert_eq_output!(
432 func,
433 "block0:
434 v0 = iconst.i8 0
435 v1 = icmp_imm uge v0, 7 ; v0 = 0
436 brif v1, block9, block8
437
438block9:
439 v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
440 brif v2, block11, block10
441
442block11:
443 v3 = iadd_imm.i8 v0, -10 ; v0 = 0
444 v4 = uextend.i32 v3
445 br_table v4, block0, [block5, block6, block7]
446
447block10:
448 v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
449 brif v5, block4, block0
450
451block8:
452 v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
453 brif v6, block3, block12
454
455block12:
456 v7 = uextend.i32 v0 ; v0 = 0
457 br_table v7, block0, [block1, block2]"
458 );
459 }
460
461 #[test]
462 fn switch_min_index_value() {
463 let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
464 assert_eq_output!(
465 func,
466 "block0:
467 v0 = iconst.i8 0
468 v1 = icmp_imm eq v0, -128 ; v0 = 0
469 brif v1, block1, block3
470
471block3:
472 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
473 brif v2, block2, block0"
474 );
475 }
476
477 #[test]
478 fn switch_max_index_value() {
479 let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
480 assert_eq_output!(
481 func,
482 "block0:
483 v0 = iconst.i8 0
484 v1 = icmp_imm eq v0, 127 ; v0 = 0
485 brif v1, block1, block3
486
487block3:
488 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
489 brif v2, block2, block0"
490 )
491 }
492
493 #[test]
494 fn switch_optimal_codegen() {
495 let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
496 assert_eq_output!(
497 func,
498 "block0:
499 v0 = iconst.i8 0
500 v1 = icmp_imm eq v0, -1 ; v0 = 0
501 brif v1, block1, block4
502
503block4:
504 v2 = uextend.i32 v0 ; v0 = 0
505 br_table v2, block0, [block2, block3]"
506 );
507 }
508
509 #[test]
510 #[should_panic(
511 expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
512 )]
513 fn switch_rejects_small_inputs() {
514 setup!(1, [0x4100_0000_00bf_d470,]);
519 }
520
521 #[test]
522 fn switch_seal_generated_blocks() {
523 let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
524
525 for case in cases {
526 for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
527 eprintln!("Testing {typ:?} with keys: {case:?}");
528 do_case(case, *typ);
529 }
530 }
531
532 fn do_case(keys: &[u128], typ: Type) {
533 let mut func = Function::new();
534 let mut builder_ctx = FunctionBuilderContext::new();
535 let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
536
537 let root_block = builder.create_block();
538 let default_block = builder.create_block();
539 let mut switch = Switch::new();
540
541 let case_blocks = keys
542 .iter()
543 .map(|key| {
544 let block = builder.create_block();
545 switch.set_entry(*key, block);
546 block
547 })
548 .collect::<Vec<_>>();
549
550 builder.seal_block(root_block);
551 builder.switch_to_block(root_block);
552
553 let val = builder.ins().iconst(typ, 1);
554 switch.emit(&mut builder, val, default_block);
555
556 for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
557 builder.seal_block(block);
558 builder.switch_to_block(block);
559 builder.ins().return_(&[]);
560 }
561
562 builder.finalize(); }
564 }
565
566 #[test]
567 fn switch_64bit() {
568 let mut func = Function::new();
569 let mut func_ctx = FunctionBuilderContext::new();
570 {
571 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
572 let block0 = bx.create_block();
573 bx.switch_to_block(block0);
574 let val = bx.ins().iconst(types::I64, 0);
575 let mut switch = Switch::new();
576 let block1 = bx.create_block();
577 switch.set_entry(1, block1);
578 let block2 = bx.create_block();
579 switch.set_entry(0, block2);
580 let block3 = bx.create_block();
581 switch.emit(&mut bx, val, block3);
582 }
583 let func = func
584 .to_string()
585 .trim_start_matches("function u0:0() fast {\n")
586 .trim_end_matches("\n}\n")
587 .to_string();
588 assert_eq_output!(
589 func,
590 "block0:
591 v0 = iconst.i64 0
592 v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
593 brif v1, block3, block4
594
595block4:
596 v2 = ireduce.i32 v0 ; v0 = 0
597 br_table v2, block3, [block2, block1]"
598 );
599 }
600
601 #[test]
602 fn switch_128bit() {
603 let mut func = Function::new();
604 let mut func_ctx = FunctionBuilderContext::new();
605 {
606 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
607 let block0 = bx.create_block();
608 bx.switch_to_block(block0);
609 let val = bx.ins().iconst(types::I64, 0);
610 let val = bx.ins().uextend(types::I128, val);
611 let mut switch = Switch::new();
612 let block1 = bx.create_block();
613 switch.set_entry(1, block1);
614 let block2 = bx.create_block();
615 switch.set_entry(0, block2);
616 let block3 = bx.create_block();
617 switch.emit(&mut bx, val, block3);
618 }
619 let func = func
620 .to_string()
621 .trim_start_matches("function u0:0() fast {\n")
622 .trim_end_matches("\n}\n")
623 .to_string();
624 assert_eq_output!(
625 func,
626 "block0:
627 v0 = iconst.i64 0
628 v1 = uextend.i128 v0 ; v0 = 0
629 v2 = icmp_imm ugt v1, 0xffff_ffff
630 brif v2, block3, block4
631
632block4:
633 v3 = ireduce.i32 v1
634 br_table v3, block3, [block2, block1]"
635 );
636 }
637
638 #[test]
639 fn switch_128bit_max_u64() {
640 let mut func = Function::new();
641 let mut func_ctx = FunctionBuilderContext::new();
642 {
643 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
644 let block0 = bx.create_block();
645 bx.switch_to_block(block0);
646 let val = bx.ins().iconst(types::I64, 0);
647 let val = bx.ins().uextend(types::I128, val);
648 let mut switch = Switch::new();
649 let block1 = bx.create_block();
650 switch.set_entry(u64::MAX.into(), block1);
651 let block2 = bx.create_block();
652 switch.set_entry(0, block2);
653 let block3 = bx.create_block();
654 switch.emit(&mut bx, val, block3);
655 }
656 let func = func
657 .to_string()
658 .trim_start_matches("function u0:0() fast {\n")
659 .trim_end_matches("\n}\n")
660 .to_string();
661 assert_eq_output!(
662 func,
663 "block0:
664 v0 = iconst.i64 0
665 v1 = uextend.i128 v0 ; v0 = 0
666 v2 = iconst.i64 -1
667 v3 = iconst.i64 0
668 v4 = iconcat v2, v3 ; v2 = -1, v3 = 0
669 v5 = icmp eq v1, v4
670 brif v5, block1, block4
671
672block4:
673 brif.i128 v1, block3, block2"
674 );
675 }
676}