1use crate::{
4 constant::{ConstantContent, ConstantValue},
5 context::Context,
6 error::IrError,
7 function::Function,
8 instruction::InstOp,
9 value::ValueDatum,
10 AnalysisResults, BranchToWithArgs, Constant, Instruction, Pass, PassMutability, Predicate,
11 ScopedPass,
12};
13use rustc_hash::FxHashMap;
14
15pub const CONST_FOLDING_NAME: &str = "const-folding";
16
17pub fn create_const_folding_pass() -> Pass {
18 Pass {
19 name: CONST_FOLDING_NAME,
20 descr: "Constant folding",
21 deps: vec![],
22 runner: ScopedPass::FunctionPass(PassMutability::Transform(fold_constants)),
23 }
24}
25
26pub fn fold_constants(
28 context: &mut Context,
29 _: &AnalysisResults,
30 function: Function,
31) -> Result<bool, IrError> {
32 let mut modified = false;
33 loop {
34 if combine_cmp(context, &function) {
35 modified = true;
36 continue;
37 }
38
39 if combine_cbr(context, &function)? {
40 modified = true;
41 continue;
42 }
43
44 if combine_binary_op(context, &function) {
45 modified = true;
46 continue;
47 }
48
49 if remove_useless_binary_op(context, &function) {
50 modified = true;
51 continue;
52 }
53
54 if combine_unary_op(context, &function) {
55 modified = true;
56 continue;
57 }
58
59 break;
61 }
62
63 Ok(modified)
64}
65
66fn combine_cbr(context: &mut Context, function: &Function) -> Result<bool, IrError> {
67 let candidate = function
68 .instruction_iter(context)
69 .find_map(
70 |(in_block, inst_val)| match &context.values[inst_val.0].value {
71 ValueDatum::Instruction(Instruction {
72 op:
73 InstOp::ConditionalBranch {
74 cond_value,
75 true_block,
76 false_block,
77 },
78 ..
79 }) if cond_value.is_constant(context) => {
80 match &cond_value
81 .get_constant(context)
82 .unwrap()
83 .get_content(context)
84 .value
85 {
86 ConstantValue::Bool(true) => Some(Ok((
87 inst_val,
88 in_block,
89 true_block.clone(),
90 false_block.clone(),
91 ))),
92 ConstantValue::Bool(false) => Some(Ok((
93 inst_val,
94 in_block,
95 false_block.clone(),
96 true_block.clone(),
97 ))),
98 _ => Some(Err(IrError::VerifyConditionExprNotABool)),
99 }
100 }
101 _ => None,
102 },
103 )
104 .transpose()?;
105
106 candidate.map_or(
107 Ok(false),
108 |(
109 cbr,
110 from_block,
111 dest,
112 BranchToWithArgs {
113 block: no_more_dest,
114 ..
115 },
116 )| {
117 no_more_dest.remove_pred(context, &from_block);
119 dest.block.add_pred(context, &from_block);
123 cbr.replace(
124 context,
125 ValueDatum::Instruction(Instruction {
126 op: InstOp::Branch(dest),
127 parent: cbr.get_instruction(context).unwrap().parent,
128 }),
129 );
130 Ok(true)
131 },
132 )
133}
134
135fn combine_cmp(context: &mut Context, function: &Function) -> bool {
136 let candidate = function
137 .instruction_iter(context)
138 .find_map(
139 |(block, inst_val)| match &context.values[inst_val.0].value {
140 ValueDatum::Instruction(Instruction {
141 op: InstOp::Cmp(pred, val1, val2),
142 ..
143 }) if val1.is_constant(context) && val2.is_constant(context) => {
144 let val1 = val1.get_constant(context).unwrap();
145 let val2 = val2.get_constant(context).unwrap();
146
147 use ConstantValue::*;
148 match pred {
149 Predicate::Equal => Some((inst_val, block, val1 == val2)),
150 Predicate::GreaterThan => {
151 let r = match (
152 &val1.get_content(context).value,
153 &val2.get_content(context).value,
154 ) {
155 (Uint(val1), Uint(val2)) => val1 > val2,
156 (U256(val1), U256(val2)) => val1 > val2,
157 (B256(val1), B256(val2)) => val1 > val2,
158 _ => {
159 unreachable!(
160 "Type checker allowed non integer value for GreaterThan"
161 )
162 }
163 };
164 Some((inst_val, block, r))
165 }
166 Predicate::LessThan => {
167 let r = match (
168 &val1.get_content(context).value,
169 &val2.get_content(context).value,
170 ) {
171 (Uint(val1), Uint(val2)) => val1 < val2,
172 (U256(val1), U256(val2)) => val1 < val2,
173 (B256(val1), B256(val2)) => val1 < val2,
174 _ => {
175 unreachable!(
176 "Type checker allowed non integer value for GreaterThan"
177 )
178 }
179 };
180 Some((inst_val, block, r))
181 }
182 }
183 }
184 _ => None,
185 },
186 );
187
188 candidate.is_some_and(|(inst_val, block, cn_replace)| {
189 let const_content = ConstantContent::new_bool(context, cn_replace);
190 let constant = crate::Constant::unique(context, const_content);
191 inst_val.replace(context, ValueDatum::Constant(constant));
193 block.remove_instruction(context, inst_val);
194 true
195 })
196}
197
198fn combine_binary_op(context: &mut Context, function: &Function) -> bool {
199 let candidate = function
200 .instruction_iter(context)
201 .find_map(
202 |(block, inst_val)| match &context.values[inst_val.0].value {
203 ValueDatum::Instruction(Instruction {
204 op: InstOp::BinaryOp { op, arg1, arg2 },
205 ..
206 }) if arg1.is_constant(context) && arg2.is_constant(context) => {
207 let val1 = arg1.get_constant(context).unwrap().get_content(context);
208 let val2 = arg2.get_constant(context).unwrap().get_content(context);
209 use crate::BinaryOpKind::*;
210 use ConstantValue::*;
211 let v = match (op, &val1.value, &val2.value) {
212 (Add, Uint(l), Uint(r)) => l.checked_add(*r).map(Uint),
213 (Add, U256(l), U256(r)) => l.checked_add(r).map(U256),
214
215 (Sub, Uint(l), Uint(r)) => l.checked_sub(*r).map(Uint),
216 (Sub, U256(l), U256(r)) => l.checked_sub(r).map(U256),
217
218 (Mul, Uint(l), Uint(r)) => l.checked_mul(*r).map(Uint),
219 (Mul, U256(l), U256(r)) => l.checked_mul(r).map(U256),
220
221 (Div, Uint(l), Uint(r)) => l.checked_div(*r).map(Uint),
222 (Div, U256(l), U256(r)) => l.checked_div(r).map(U256),
223
224 (And, Uint(l), Uint(r)) => Some(Uint(l & r)),
225 (And, U256(l), U256(r)) => Some(U256(l & r)),
226
227 (Or, Uint(l), Uint(r)) => Some(Uint(l | r)),
228 (Or, U256(l), U256(r)) => Some(U256(l | r)),
229
230 (Xor, Uint(l), Uint(r)) => Some(Uint(l ^ r)),
231 (Xor, U256(l), U256(r)) => Some(U256(l ^ r)),
232
233 (Mod, Uint(l), Uint(r)) => l.checked_rem(*r).map(Uint),
234 (Mod, U256(l), U256(r)) => l.checked_rem(r).map(U256),
235
236 (Rsh, Uint(l), Uint(r)) => u32::try_from(*r)
237 .ok()
238 .and_then(|r| l.checked_shr(r).map(Uint)),
239 (Rsh, U256(l), Uint(r)) => Some(U256(l.shr(r))),
240
241 (Lsh, Uint(l), Uint(r)) => u32::try_from(*r)
242 .ok()
243 .and_then(|r| l.checked_shl(r).map(Uint)),
244 (Lsh, U256(l), Uint(r)) => l.checked_shl(r).map(U256),
245 _ => None,
246 };
247 v.map(|value| (inst_val, block, ConstantContent { ty: val1.ty, value }))
248 }
249 _ => None,
250 },
251 );
252
253 candidate.is_some_and(|(inst_val, block, new_value)| {
255 let new_value = Constant::unique(context, new_value);
256 inst_val.replace(context, ValueDatum::Constant(new_value));
257 block.remove_instruction(context, inst_val);
258 true
259 })
260}
261
262fn remove_useless_binary_op(context: &mut Context, function: &Function) -> bool {
263 let candidate =
264 function
265 .instruction_iter(context)
266 .find_map(
267 |(block, candidate)| match &context.values[candidate.0].value {
268 ValueDatum::Instruction(Instruction {
269 op: InstOp::BinaryOp { op, arg1, arg2 },
270 ..
271 }) if arg1.is_constant(context) || arg2.is_constant(context) => {
272 let val1 = arg1
273 .get_constant(context)
274 .map(|x| &x.get_content(context).value);
275 let val2 = arg2
276 .get_constant(context)
277 .map(|x| &x.get_content(context).value);
278
279 use crate::BinaryOpKind::*;
280 use ConstantValue::*;
281 match (op, val1, val2) {
282 (Add, Some(Uint(0)), _) => Some((block, candidate, *arg2)),
284 (Add, _, Some(Uint(0))) => Some((block, candidate, *arg1)),
286 (Mul, Some(Uint(1)), _) => Some((block, candidate, *arg2)),
288 (Mul, _, Some(Uint(1))) => Some((block, candidate, *arg1)),
290 (Div, _, Some(Uint(1))) => Some((block, candidate, *arg1)),
292 (Sub, _, Some(Uint(0))) => Some((block, candidate, *arg1)),
294 _ => None,
295 }
296 }
297 _ => None,
298 },
299 );
300
301 candidate.is_some_and(|(block, old_value, new_value)| {
302 let replace_map = FxHashMap::from_iter([(old_value, new_value)]);
303 function.replace_values(context, &replace_map, None);
304
305 block.remove_instruction(context, old_value);
306 true
307 })
308}
309
310fn combine_unary_op(context: &mut Context, function: &Function) -> bool {
311 let candidate = function
312 .instruction_iter(context)
313 .find_map(
314 |(block, inst_val)| match &context.values[inst_val.0].value {
315 ValueDatum::Instruction(Instruction {
316 op: InstOp::UnaryOp { op, arg },
317 ..
318 }) if arg.is_constant(context) => {
319 let val = arg.get_constant(context).unwrap();
320 use crate::UnaryOpKind::*;
321 use ConstantValue::*;
322 let v = match (op, &val.get_content(context).value) {
323 (Not, Uint(v)) => val
324 .get_content(context)
325 .ty
326 .get_uint_width(context)
327 .and_then(|width| {
328 let max = match width {
329 8 => u8::MAX as u64,
330 16 => u16::MAX as u64,
331 32 => u32::MAX as u64,
332 64 => u64::MAX,
333 _ => return None,
334 };
335 Some(Uint((!v) & max))
336 }),
337 (Not, U256(v)) => Some(U256(!v)),
338 _ => None,
339 };
340 v.map(|value| {
341 (
342 inst_val,
343 block,
344 ConstantContent {
345 ty: val.get_content(context).ty,
346 value,
347 },
348 )
349 })
350 }
351 _ => None,
352 },
353 );
354
355 candidate.is_some_and(|(inst_val, block, new_value)| {
357 let new_value = Constant::unique(context, new_value);
358 inst_val.replace(context, ValueDatum::Constant(new_value));
359 block.remove_instruction(context, inst_val);
360 true
361 })
362}
363
364#[cfg(test)]
365mod tests {
366 use crate::{optimize::tests::*, CONST_FOLDING_NAME};
367
368 fn assert_operator(t: &str, opcode: &str, l: &str, r: Option<&str>, result: Option<&str>) {
369 let expected = result.map(|result| format!("v0 = const {t} {result}"));
370 let expected = expected.as_ref().map(|x| vec![x.as_str()]);
371 let body = format!(
372 "
373 entry fn main() -> {t} {{
374 entry():
375 l = const {t} {l}
376 {r_inst}
377 result = {opcode} l, {result_inst} !0
378 ret {t} result
379 }}
380",
381 r_inst = r.map_or("".into(), |r| format!("r = const {t} {r}")),
382 result_inst = r.map_or("", |_| " r,")
383 );
384 assert_optimization(&[CONST_FOLDING_NAME], &body, expected);
385 }
386
387 #[test]
388 fn unary_op_are_optimized() {
389 assert_operator("u64", "not", &u64::MAX.to_string(), None, Some("0"));
390 }
391
392 #[test]
393 fn binary_op_are_optimized() {
394 assert_operator("u64", "add", "1", Some("1"), Some("2"));
396 assert_operator("u64", "sub", "1", Some("1"), Some("0"));
397 assert_operator("u64", "mul", "2", Some("2"), Some("4"));
398 assert_operator("u64", "div", "10", Some("5"), Some("2"));
399 assert_operator("u64", "mod", "12", Some("5"), Some("2"));
400 assert_operator("u64", "rsh", "16", Some("1"), Some("8"));
401 assert_operator("u64", "lsh", "16", Some("1"), Some("32"));
402
403 assert_operator(
404 "u64",
405 "and",
406 &0x00FFF.to_string(),
407 Some(&0xFFF00.to_string()),
408 Some(&0xF00.to_string()),
409 );
410 assert_operator(
411 "u64",
412 "or",
413 &0x00FFF.to_string(),
414 Some(&0xFFF00.to_string()),
415 Some(&0xFFFFF.to_string()),
416 );
417
418 assert_operator(
419 "u64",
420 "xor",
421 &0x00FFF.to_string(),
422 Some(&0xFFF00.to_string()),
423 Some(&0xFF0FF.to_string()),
424 );
425 }
426
427 #[test]
428 fn binary_op_are_not_optimized() {
429 assert_operator("u64", "add", &u64::MAX.to_string(), Some("1"), None);
430 assert_operator("u64", "sub", "0", Some("1"), None);
431 assert_operator("u64", "mul", &u64::MAX.to_string(), Some("2"), None);
432 assert_operator("u64", "div", "1", Some("0"), None);
433 assert_operator("u64", "mod", "1", Some("0"), None);
434
435 assert_operator("u64", "rsh", "1", Some("64"), None);
436 assert_operator("u64", "lsh", "1", Some("64"), None);
437 }
438
439 #[test]
440 fn ok_chain_optimization() {
441 assert_optimization(
445 &[CONST_FOLDING_NAME],
446 "
447 entry fn main() -> u64 {
448 entry():
449 a = const u64 18446744073709551615
450 b = not a, !0
451 c = not b, !0
452 d = const u64 1
453 result = sub c, d, !0
454 ret u64 result
455 }
456 ",
457 Some(["const u64 18446744073709551614"]),
458 );
459
460 assert_optimization(
462 &[CONST_FOLDING_NAME],
463 "
464 entry fn main() -> u64 {
465 entry():
466 l0 = const u64 1
467 r0 = const u64 2
468 l1 = add l0, r0, !0
469 r1 = const u64 3
470 result = add l1, r1, !0
471 ret u64 result
472 }
473 ",
474 Some(["const u64 6"]),
475 );
476 }
477
478 #[test]
479 fn ok_remove_useless_mul() {
480 assert_optimization(
481 &[CONST_FOLDING_NAME],
482 "entry fn main() -> u64 {
483 local u64 LOCAL
484 entry():
485 zero = const u64 0, !0
486 one = const u64 1, !0
487 l_ptr = get_local ptr u64, LOCAL, !0
488 l = load l_ptr, !0
489 result1 = mul l, one, !0
490 result2 = mul one, result1, !0
491 result3 = add result2, zero, !0
492 result4 = add zero, result3, !0
493 result5 = div result4, one, !0
494 result6 = sub result5, zero, !0
495 ret u64 result6, !0
496 }",
497 Some([
498 "v0 = get_local ptr u64, LOCAL",
499 "v1 = load v0",
500 "ret u64 v1",
501 ]),
502 );
503 }
504}