sway_ir/optimize/
constants.rs

1//! Optimization passes for manipulating constant values.
2
3use crate::{
4    constant::{ConstantContent, ConstantValue},
5    context::Context,
6    error::IrError,
7    function::Function,
8    instruction::InstOp,
9    value::ValueDatum,
10    AnalysisResults, BranchToWithArgs, Constant, Instruction, Pass, PassMutability, Predicate,
11    ScopedPass,
12};
13use rustc_hash::FxHashMap;
14
15pub const CONST_FOLDING_NAME: &str = "const-folding";
16
17pub fn create_const_folding_pass() -> Pass {
18    Pass {
19        name: CONST_FOLDING_NAME,
20        descr: "Constant folding",
21        deps: vec![],
22        runner: ScopedPass::FunctionPass(PassMutability::Transform(fold_constants)),
23    }
24}
25
26/// Find constant expressions which can be reduced to fewer operations.
27pub fn fold_constants(
28    context: &mut Context,
29    _: &AnalysisResults,
30    function: Function,
31) -> Result<bool, IrError> {
32    let mut modified = false;
33    loop {
34        if combine_cmp(context, &function) {
35            modified = true;
36            continue;
37        }
38
39        if combine_cbr(context, &function)? {
40            modified = true;
41            continue;
42        }
43
44        if combine_binary_op(context, &function) {
45            modified = true;
46            continue;
47        }
48
49        if remove_useless_binary_op(context, &function) {
50            modified = true;
51            continue;
52        }
53
54        if combine_unary_op(context, &function) {
55            modified = true;
56            continue;
57        }
58
59        // Other passes here... always continue to the top if pass returns true.
60        break;
61    }
62
63    Ok(modified)
64}
65
66fn combine_cbr(context: &mut Context, function: &Function) -> Result<bool, IrError> {
67    let candidate = function
68        .instruction_iter(context)
69        .find_map(
70            |(in_block, inst_val)| match &context.values[inst_val.0].value {
71                ValueDatum::Instruction(Instruction {
72                    op:
73                        InstOp::ConditionalBranch {
74                            cond_value,
75                            true_block,
76                            false_block,
77                        },
78                    ..
79                }) if cond_value.is_constant(context) => {
80                    match &cond_value
81                        .get_constant(context)
82                        .unwrap()
83                        .get_content(context)
84                        .value
85                    {
86                        ConstantValue::Bool(true) => Some(Ok((
87                            inst_val,
88                            in_block,
89                            true_block.clone(),
90                            false_block.clone(),
91                        ))),
92                        ConstantValue::Bool(false) => Some(Ok((
93                            inst_val,
94                            in_block,
95                            false_block.clone(),
96                            true_block.clone(),
97                        ))),
98                        _ => Some(Err(IrError::VerifyConditionExprNotABool)),
99                    }
100                }
101                _ => None,
102            },
103        )
104        .transpose()?;
105
106    candidate.map_or(
107        Ok(false),
108        |(
109            cbr,
110            from_block,
111            dest,
112            BranchToWithArgs {
113                block: no_more_dest,
114                ..
115            },
116        )| {
117            // `no_more_dest` will no longer have from_block as a predecessor.
118            no_more_dest.remove_pred(context, &from_block);
119            // Although our cbr already branched to `dest`, in case
120            // `no_more_dest` and `dest` are the same, we'll need to re-add
121            // `from_block` as a predecessor for `dest`.
122            dest.block.add_pred(context, &from_block);
123            cbr.replace(
124                context,
125                ValueDatum::Instruction(Instruction {
126                    op: InstOp::Branch(dest),
127                    parent: cbr.get_instruction(context).unwrap().parent,
128                }),
129            );
130            Ok(true)
131        },
132    )
133}
134
135fn combine_cmp(context: &mut Context, function: &Function) -> bool {
136    let candidate = function
137        .instruction_iter(context)
138        .find_map(
139            |(block, inst_val)| match &context.values[inst_val.0].value {
140                ValueDatum::Instruction(Instruction {
141                    op: InstOp::Cmp(pred, val1, val2),
142                    ..
143                }) if val1.is_constant(context) && val2.is_constant(context) => {
144                    let val1 = val1.get_constant(context).unwrap();
145                    let val2 = val2.get_constant(context).unwrap();
146
147                    use ConstantValue::*;
148                    match pred {
149                        Predicate::Equal => Some((inst_val, block, val1 == val2)),
150                        Predicate::GreaterThan => {
151                            let r = match (
152                                &val1.get_content(context).value,
153                                &val2.get_content(context).value,
154                            ) {
155                                (Uint(val1), Uint(val2)) => val1 > val2,
156                                (U256(val1), U256(val2)) => val1 > val2,
157                                (B256(val1), B256(val2)) => val1 > val2,
158                                _ => {
159                                    unreachable!(
160                                        "Type checker allowed non integer value for GreaterThan"
161                                    )
162                                }
163                            };
164                            Some((inst_val, block, r))
165                        }
166                        Predicate::LessThan => {
167                            let r = match (
168                                &val1.get_content(context).value,
169                                &val2.get_content(context).value,
170                            ) {
171                                (Uint(val1), Uint(val2)) => val1 < val2,
172                                (U256(val1), U256(val2)) => val1 < val2,
173                                (B256(val1), B256(val2)) => val1 < val2,
174                                _ => {
175                                    unreachable!(
176                                        "Type checker allowed non integer value for GreaterThan"
177                                    )
178                                }
179                            };
180                            Some((inst_val, block, r))
181                        }
182                    }
183                }
184                _ => None,
185            },
186        );
187
188    candidate.is_some_and(|(inst_val, block, cn_replace)| {
189        let const_content = ConstantContent::new_bool(context, cn_replace);
190        let constant = crate::Constant::unique(context, const_content);
191        // Replace this `cmp` instruction with a constant.
192        inst_val.replace(context, ValueDatum::Constant(constant));
193        block.remove_instruction(context, inst_val);
194        true
195    })
196}
197
198fn combine_binary_op(context: &mut Context, function: &Function) -> bool {
199    let candidate = function
200        .instruction_iter(context)
201        .find_map(
202            |(block, inst_val)| match &context.values[inst_val.0].value {
203                ValueDatum::Instruction(Instruction {
204                    op: InstOp::BinaryOp { op, arg1, arg2 },
205                    ..
206                }) if arg1.is_constant(context) && arg2.is_constant(context) => {
207                    let val1 = arg1.get_constant(context).unwrap().get_content(context);
208                    let val2 = arg2.get_constant(context).unwrap().get_content(context);
209                    use crate::BinaryOpKind::*;
210                    use ConstantValue::*;
211                    let v = match (op, &val1.value, &val2.value) {
212                        (Add, Uint(l), Uint(r)) => l.checked_add(*r).map(Uint),
213                        (Add, U256(l), U256(r)) => l.checked_add(r).map(U256),
214
215                        (Sub, Uint(l), Uint(r)) => l.checked_sub(*r).map(Uint),
216                        (Sub, U256(l), U256(r)) => l.checked_sub(r).map(U256),
217
218                        (Mul, Uint(l), Uint(r)) => l.checked_mul(*r).map(Uint),
219                        (Mul, U256(l), U256(r)) => l.checked_mul(r).map(U256),
220
221                        (Div, Uint(l), Uint(r)) => l.checked_div(*r).map(Uint),
222                        (Div, U256(l), U256(r)) => l.checked_div(r).map(U256),
223
224                        (And, Uint(l), Uint(r)) => Some(Uint(l & r)),
225                        (And, U256(l), U256(r)) => Some(U256(l & r)),
226
227                        (Or, Uint(l), Uint(r)) => Some(Uint(l | r)),
228                        (Or, U256(l), U256(r)) => Some(U256(l | r)),
229
230                        (Xor, Uint(l), Uint(r)) => Some(Uint(l ^ r)),
231                        (Xor, U256(l), U256(r)) => Some(U256(l ^ r)),
232
233                        (Mod, Uint(l), Uint(r)) => l.checked_rem(*r).map(Uint),
234                        (Mod, U256(l), U256(r)) => l.checked_rem(r).map(U256),
235
236                        (Rsh, Uint(l), Uint(r)) => u32::try_from(*r)
237                            .ok()
238                            .and_then(|r| l.checked_shr(r).map(Uint)),
239                        (Rsh, U256(l), Uint(r)) => Some(U256(l.shr(r))),
240
241                        (Lsh, Uint(l), Uint(r)) => u32::try_from(*r)
242                            .ok()
243                            .and_then(|r| l.checked_shl(r).map(Uint)),
244                        (Lsh, U256(l), Uint(r)) => l.checked_shl(r).map(U256),
245                        _ => None,
246                    };
247                    v.map(|value| (inst_val, block, ConstantContent { ty: val1.ty, value }))
248                }
249                _ => None,
250            },
251        );
252
253    // Replace this binary op instruction with a constant.
254    candidate.is_some_and(|(inst_val, block, new_value)| {
255        let new_value = Constant::unique(context, new_value);
256        inst_val.replace(context, ValueDatum::Constant(new_value));
257        block.remove_instruction(context, inst_val);
258        true
259    })
260}
261
262fn remove_useless_binary_op(context: &mut Context, function: &Function) -> bool {
263    let candidate =
264        function
265            .instruction_iter(context)
266            .find_map(
267                |(block, candidate)| match &context.values[candidate.0].value {
268                    ValueDatum::Instruction(Instruction {
269                        op: InstOp::BinaryOp { op, arg1, arg2 },
270                        ..
271                    }) if arg1.is_constant(context) || arg2.is_constant(context) => {
272                        let val1 = arg1
273                            .get_constant(context)
274                            .map(|x| &x.get_content(context).value);
275                        let val2 = arg2
276                            .get_constant(context)
277                            .map(|x| &x.get_content(context).value);
278
279                        use crate::BinaryOpKind::*;
280                        use ConstantValue::*;
281                        match (op, val1, val2) {
282                            // 0 + arg2
283                            (Add, Some(Uint(0)), _) => Some((block, candidate, *arg2)),
284                            // arg1 + 0
285                            (Add, _, Some(Uint(0))) => Some((block, candidate, *arg1)),
286                            // 1 * arg2
287                            (Mul, Some(Uint(1)), _) => Some((block, candidate, *arg2)),
288                            // arg1 * 1
289                            (Mul, _, Some(Uint(1))) => Some((block, candidate, *arg1)),
290                            // arg1 / 1
291                            (Div, _, Some(Uint(1))) => Some((block, candidate, *arg1)),
292                            // arg1 - 0
293                            (Sub, _, Some(Uint(0))) => Some((block, candidate, *arg1)),
294                            _ => None,
295                        }
296                    }
297                    _ => None,
298                },
299            );
300
301    candidate.is_some_and(|(block, old_value, new_value)| {
302        let replace_map = FxHashMap::from_iter([(old_value, new_value)]);
303        function.replace_values(context, &replace_map, None);
304
305        block.remove_instruction(context, old_value);
306        true
307    })
308}
309
310fn combine_unary_op(context: &mut Context, function: &Function) -> bool {
311    let candidate = function
312        .instruction_iter(context)
313        .find_map(
314            |(block, inst_val)| match &context.values[inst_val.0].value {
315                ValueDatum::Instruction(Instruction {
316                    op: InstOp::UnaryOp { op, arg },
317                    ..
318                }) if arg.is_constant(context) => {
319                    let val = arg.get_constant(context).unwrap();
320                    use crate::UnaryOpKind::*;
321                    use ConstantValue::*;
322                    let v = match (op, &val.get_content(context).value) {
323                        (Not, Uint(v)) => val
324                            .get_content(context)
325                            .ty
326                            .get_uint_width(context)
327                            .and_then(|width| {
328                                let max = match width {
329                                    8 => u8::MAX as u64,
330                                    16 => u16::MAX as u64,
331                                    32 => u32::MAX as u64,
332                                    64 => u64::MAX,
333                                    _ => return None,
334                                };
335                                Some(Uint((!v) & max))
336                            }),
337                        (Not, U256(v)) => Some(U256(!v)),
338                        _ => None,
339                    };
340                    v.map(|value| {
341                        (
342                            inst_val,
343                            block,
344                            ConstantContent {
345                                ty: val.get_content(context).ty,
346                                value,
347                            },
348                        )
349                    })
350                }
351                _ => None,
352            },
353        );
354
355    // Replace this unary op instruction with a constant.
356    candidate.is_some_and(|(inst_val, block, new_value)| {
357        let new_value = Constant::unique(context, new_value);
358        inst_val.replace(context, ValueDatum::Constant(new_value));
359        block.remove_instruction(context, inst_val);
360        true
361    })
362}
363
364#[cfg(test)]
365mod tests {
366    use crate::{optimize::tests::*, CONST_FOLDING_NAME};
367
368    fn assert_operator(t: &str, opcode: &str, l: &str, r: Option<&str>, result: Option<&str>) {
369        let expected = result.map(|result| format!("v0 = const {t} {result}"));
370        let expected = expected.as_ref().map(|x| vec![x.as_str()]);
371        let body = format!(
372            "
373    entry fn main() -> {t} {{
374        entry():
375        l = const {t} {l}
376        {r_inst}
377        result = {opcode} l, {result_inst} !0
378        ret {t} result
379    }}
380",
381            r_inst = r.map_or("".into(), |r| format!("r = const {t} {r}")),
382            result_inst = r.map_or("", |_| " r,")
383        );
384        assert_optimization(&[CONST_FOLDING_NAME], &body, expected);
385    }
386
387    #[test]
388    fn unary_op_are_optimized() {
389        assert_operator("u64", "not", &u64::MAX.to_string(), None, Some("0"));
390    }
391
392    #[test]
393    fn binary_op_are_optimized() {
394        // u64
395        assert_operator("u64", "add", "1", Some("1"), Some("2"));
396        assert_operator("u64", "sub", "1", Some("1"), Some("0"));
397        assert_operator("u64", "mul", "2", Some("2"), Some("4"));
398        assert_operator("u64", "div", "10", Some("5"), Some("2"));
399        assert_operator("u64", "mod", "12", Some("5"), Some("2"));
400        assert_operator("u64", "rsh", "16", Some("1"), Some("8"));
401        assert_operator("u64", "lsh", "16", Some("1"), Some("32"));
402
403        assert_operator(
404            "u64",
405            "and",
406            &0x00FFF.to_string(),
407            Some(&0xFFF00.to_string()),
408            Some(&0xF00.to_string()),
409        );
410        assert_operator(
411            "u64",
412            "or",
413            &0x00FFF.to_string(),
414            Some(&0xFFF00.to_string()),
415            Some(&0xFFFFF.to_string()),
416        );
417
418        assert_operator(
419            "u64",
420            "xor",
421            &0x00FFF.to_string(),
422            Some(&0xFFF00.to_string()),
423            Some(&0xFF0FF.to_string()),
424        );
425    }
426
427    #[test]
428    fn binary_op_are_not_optimized() {
429        assert_operator("u64", "add", &u64::MAX.to_string(), Some("1"), None);
430        assert_operator("u64", "sub", "0", Some("1"), None);
431        assert_operator("u64", "mul", &u64::MAX.to_string(), Some("2"), None);
432        assert_operator("u64", "div", "1", Some("0"), None);
433        assert_operator("u64", "mod", "1", Some("0"), None);
434
435        assert_operator("u64", "rsh", "1", Some("64"), None);
436        assert_operator("u64", "lsh", "1", Some("64"), None);
437    }
438
439    #[test]
440    fn ok_chain_optimization() {
441        // Unary operator
442
443        // `sub 1` is used to guarantee that the assert string is unique
444        assert_optimization(
445            &[CONST_FOLDING_NAME],
446            "
447        entry fn main() -> u64 {
448            entry():
449            a = const u64 18446744073709551615
450            b = not a, !0
451            c = not b, !0
452            d = const u64 1
453            result = sub c, d, !0
454            ret u64 result
455        }
456    ",
457            Some(["const u64 18446744073709551614"]),
458        );
459
460        // Binary Operators
461        assert_optimization(
462            &[CONST_FOLDING_NAME],
463            "
464        entry fn main() -> u64 {
465            entry():
466            l0 = const u64 1
467            r0 = const u64 2
468            l1 = add l0, r0, !0
469            r1 = const u64 3
470            result = add l1, r1, !0
471            ret u64 result
472        }
473    ",
474            Some(["const u64 6"]),
475        );
476    }
477
478    #[test]
479    fn ok_remove_useless_mul() {
480        assert_optimization(
481            &[CONST_FOLDING_NAME],
482            "entry fn main() -> u64 {
483                local u64 LOCAL
484            entry():
485                zero = const u64 0, !0
486                one = const u64 1, !0
487                l_ptr = get_local ptr u64, LOCAL, !0
488                l = load l_ptr, !0
489                result1 = mul l, one, !0
490                result2 = mul one, result1, !0
491                result3 = add result2, zero, !0
492                result4 = add zero, result3, !0
493                result5 = div result4, one, !0
494                result6 = sub result5, zero, !0
495                ret u64 result6, !0
496         }",
497            Some([
498                "v0 = get_local ptr u64, LOCAL",
499                "v1 = load v0",
500                "ret u64 v1",
501            ]),
502        );
503    }
504}