cairo_lang_lowering/optimizations/
return_optimization.rs

1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic as semantic;
6use cairo_lang_utils::{extract_matches, require};
7use itertools::Itertools;
8use semantic::MatchArmSelector;
9
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::db::LoweringGroup;
12use crate::ids::LocationId;
13use crate::{
14    BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
15    StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
16    VarUsage, VariableId,
17};
18
19/// Adds early returns when applicable.
20///
21/// This optimization does backward analysis from return statement and keeps track of
22/// each returned value (see `ValueInfo`), whenever all the returned values are available at a block
23/// end and there was no side effects later, the end is replaced with a return statement.
24pub fn return_optimization(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
25    if lowered.blocks.is_empty() {
26        return;
27    }
28    let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
29    let mut analysis = BackAnalysis::new(lowered, ctx);
30    let info = analysis.get_root_info();
31    let mut ctx = analysis.analyzer;
32
33    if info.early_return_possible() {
34        ctx.fixes.push(FixInfo {
35            location: (BlockId::root(), 0),
36            return_info: info.opt_return_info.clone().unwrap(),
37        });
38    }
39
40    for FixInfo { location: (block_id, statement_idx), return_info } in ctx.fixes.into_iter() {
41        let block = &mut lowered.blocks[block_id];
42        block.statements.truncate(statement_idx);
43        block.end = FlatBlockEnd::Return(
44            return_info
45                .returned_vars
46                .iter()
47                .map(|var_info| *extract_matches!(var_info, ValueInfo::Var))
48                .collect_vec(),
49            return_info.location,
50        )
51    }
52}
53
54pub struct ReturnOptimizerContext<'a> {
55    db: &'a dyn LoweringGroup,
56    lowered: &'a FlatLowered,
57
58    /// The list of fixes that should be applied.
59    fixes: Vec<FixInfo>,
60}
61impl ReturnOptimizerContext<'_> {
62    /// Given a VarUsage, returns the ValueInfo that corresponds to it.
63    fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
64        let var_ty = &self.lowered.variables[var_usage.var_id].ty;
65        if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
66            ValueInfo::Interchangeable(*var_ty)
67        } else {
68            ValueInfo::Var(*var_usage)
69        }
70    }
71
72    /// Returns true if the variable is droppable
73    fn is_droppable(&self, var_id: VariableId) -> bool {
74        self.lowered.variables[var_id].droppable.is_ok()
75    }
76
77    /// Helper function for `merge_match`.
78    /// Returns `Option<ReturnInfo>` rather then `AnalyzerInfo` to simplify early return.
79    fn try_merge_match(
80        &mut self,
81        match_info: &MatchInfo,
82        infos: &[AnalyzerInfo],
83    ) -> Option<ReturnInfo> {
84        let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
85            return None;
86        };
87        require(!arms.is_empty())?;
88
89        let input_info = self.get_var_info(input);
90        let mut opt_last_info = None;
91        for (arm, info) in arms.iter().zip(infos) {
92            let mut curr_info = info.clone();
93            curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
94
95            require(curr_info.early_return_possible())?;
96
97            match curr_info.opt_return_info {
98                Some(return_info)
99                    if opt_last_info
100                        .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
101                        .unwrap_or(true) =>
102                {
103                    // If this is the first iteration or the returned var are the same as the
104                    // previous iteration, then the optimization is still applicable.
105                    opt_last_info = Some(return_info)
106                }
107                _ => return None,
108            }
109        }
110
111        Some(opt_last_info.unwrap())
112    }
113}
114
115/// Information about a fix that should be applied to the lowering.
116pub struct FixInfo {
117    /// A location where we `return_vars` can be returned.
118    location: StatementLocation,
119    /// The return info at the fix location.
120    return_info: ReturnInfo,
121}
122
123/// Information about the value that should be returned from the function.
124#[derive(Clone, Debug, PartialEq, Eq)]
125pub enum ValueInfo {
126    /// The value is available through the given var usage.
127    Var(VarUsage),
128    /// The can be replaced with other values of the same type.
129    Interchangeable(semantic::TypeId),
130    /// The value is the result of a StructConstruct statement.
131    StructConstruct {
132        /// The type of the struct.
133        ty: semantic::TypeId,
134        /// The inputs to the StructConstruct statement.
135        var_infos: Vec<ValueInfo>,
136    },
137    /// The value is the result of an EnumConstruct statement.
138    EnumConstruct {
139        /// The input to the EnumConstruct.
140        var_info: Box<ValueInfo>,
141        /// The constructed variant.
142        variant: semantic::ConcreteVariant,
143    },
144}
145
146/// The result of applying an operation to a ValueInfo.
147enum OpResult {
148    /// The input of the operation was consumed.
149    InputConsumed,
150    /// One of the value is produced operation and therefore it is invalid before the operation.
151    ValueInvalidated,
152    /// The operation did not change the value info.
153    NoChange,
154}
155
156impl ValueInfo {
157    /// Applies the given function to the value info.
158    fn apply<F>(&mut self, f: &F)
159    where
160        F: Fn(&VarUsage) -> ValueInfo,
161    {
162        match self {
163            ValueInfo::Var(var_usage) => *self = f(var_usage),
164            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
165                for var_info in var_infos.iter_mut() {
166                    var_info.apply(f);
167                }
168            }
169            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
170                var_info.apply(f);
171            }
172            ValueInfo::Interchangeable(_) => {}
173        }
174    }
175
176    /// Updates the value to the state before the StructDeconstruct statement.
177    /// Returns OpResult.
178    fn apply_deconstruct(
179        &mut self,
180        ctx: &ReturnOptimizerContext<'_>,
181        stmt: &StatementStructDestructure,
182    ) -> OpResult {
183        match self {
184            ValueInfo::Var(var_usage) => {
185                if stmt.outputs.contains(&var_usage.var_id) {
186                    OpResult::ValueInvalidated
187                } else {
188                    OpResult::NoChange
189                }
190            }
191            ValueInfo::StructConstruct { ty, var_infos } => {
192                let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
193                    && var_infos.len() == stmt.outputs.len();
194                for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
195                    if !cancels_out {
196                        break;
197                    }
198
199                    match var_info {
200                        ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
201                        ValueInfo::Interchangeable(ty)
202                            if &ctx.lowered.variables[*output].ty == ty => {}
203                        _ => cancels_out = false,
204                    }
205                }
206
207                if cancels_out {
208                    // If the StructDeconstruct cancels out the StructConstruct, then we don't need
209                    // to `apply_deconstruct` to the inner var infos.
210                    *self = ValueInfo::Var(stmt.input);
211                    return OpResult::InputConsumed;
212                }
213
214                let mut input_consumed = false;
215                for var_info in var_infos.iter_mut() {
216                    match var_info.apply_deconstruct(ctx, stmt) {
217                        OpResult::InputConsumed => {
218                            input_consumed = true;
219                        }
220                        OpResult::ValueInvalidated => {
221                            // If one of the values is invalidated the optimization is no longer
222                            // applicable.
223                            return OpResult::ValueInvalidated;
224                        }
225                        OpResult::NoChange => {}
226                    }
227                }
228
229                match input_consumed {
230                    true => OpResult::InputConsumed,
231                    false => OpResult::NoChange,
232                }
233            }
234            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
235                var_info.apply_deconstruct(ctx, stmt)
236            }
237            ValueInfo::Interchangeable(_) => OpResult::NoChange,
238        }
239    }
240
241    /// Updates the value to the expected value before the match arm.
242    /// Returns OpResult.
243    fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
244        match self {
245            ValueInfo::Var(var_usage) => {
246                if arm.var_ids == [var_usage.var_id] {
247                    OpResult::ValueInvalidated
248                } else {
249                    OpResult::NoChange
250                }
251            }
252            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
253                let mut input_consumed = false;
254                for var_info in var_infos.iter_mut() {
255                    match var_info.apply_match_arm(input, arm) {
256                        OpResult::InputConsumed => {
257                            input_consumed = true;
258                        }
259                        OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
260                        OpResult::NoChange => {}
261                    }
262                }
263
264                if input_consumed {
265                    return OpResult::InputConsumed;
266                }
267                OpResult::NoChange
268            }
269            ValueInfo::EnumConstruct { ref mut var_info, variant } => {
270                let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
271                    panic!("Enum construct should not appear in value match");
272                };
273
274                if *variant == *arm_variant {
275                    let cancels_out = match **var_info {
276                        ValueInfo::Interchangeable(_) => true,
277                        ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
278                        _ => false,
279                    };
280
281                    if cancels_out {
282                        // If the arm recreates the relevant enum variant, then the arm
283                        // assuming the other arms also cancel out.
284                        *self = input.clone();
285                        return OpResult::InputConsumed;
286                    }
287                }
288
289                var_info.apply_match_arm(input, arm)
290            }
291            ValueInfo::Interchangeable(_) => OpResult::NoChange,
292        }
293    }
294}
295
296/// Information about the current state of the analyzer.
297/// Used to track the value that should be returned from the function at the current
298/// analysis point
299#[derive(Clone, Debug, PartialEq, Eq)]
300pub struct ReturnInfo {
301    returned_vars: Vec<ValueInfo>,
302    location: LocationId,
303}
304
305/// A wrapper around `ReturnInfo` that makes it optional.
306///
307/// None indicates that the return info is unknown.
308/// If early_return_possible() returns true, the function can return early as the return value is
309/// already known.
310#[derive(Clone, Debug, PartialEq, Eq)]
311pub struct AnalyzerInfo {
312    opt_return_info: Option<ReturnInfo>,
313}
314
315impl AnalyzerInfo {
316    /// Creates a state of the analyzer where the return optimization is not applicable.
317    fn invalidated() -> Self {
318        AnalyzerInfo { opt_return_info: None }
319    }
320
321    /// Invalidates the state of the analyzer, identifying early return is no longer possible.
322    fn invalidate(&mut self) {
323        *self = Self::invalidated();
324    }
325
326    /// Applies the given function to the returned_vars
327    fn apply<F>(&mut self, f: &F)
328    where
329        F: Fn(&VarUsage) -> ValueInfo,
330    {
331        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
332            return;
333        };
334
335        for var_info in returned_vars.iter_mut() {
336            var_info.apply(f)
337        }
338    }
339
340    /// Replaces occurrences of `var_id` with `var_info`.
341    fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
342        self.apply(&|var_usage| {
343            if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
344        });
345    }
346
347    /// Updates the info to the state before the StructDeconstruct statement.
348    fn apply_deconstruct(
349        &mut self,
350        ctx: &ReturnOptimizerContext<'_>,
351        stmt: &StatementStructDestructure,
352    ) {
353        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
354
355        let mut input_consumed = false;
356        for var_info in returned_vars.iter_mut() {
357            match var_info.apply_deconstruct(ctx, stmt) {
358                OpResult::InputConsumed => {
359                    input_consumed = true;
360                }
361                OpResult::ValueInvalidated => {
362                    self.invalidate();
363                    return;
364                }
365                OpResult::NoChange => {}
366            };
367        }
368
369        if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
370            self.invalidate();
371        }
372    }
373
374    /// Updates the info to the state before match arm.
375    fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
376        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
377
378        let mut input_consumed = false;
379        for var_info in returned_vars.iter_mut() {
380            match var_info.apply_match_arm(input, arm) {
381                OpResult::InputConsumed => {
382                    input_consumed = true;
383                }
384                OpResult::ValueInvalidated => {
385                    self.invalidate();
386                    return;
387                }
388                OpResult::NoChange => {}
389            };
390        }
391
392        if !(input_consumed || is_droppable) {
393            self.invalidate();
394        }
395    }
396
397    /// Returns true if an early return is possible according to 'self'.
398    fn early_return_possible(&self) -> bool {
399        let Some(ReturnInfo { ref returned_vars, .. }) = self.opt_return_info else { return false };
400
401        returned_vars.iter().all(|var_info| match var_info {
402            ValueInfo::Var(_) => true,
403            ValueInfo::StructConstruct { .. } => false,
404            ValueInfo::EnumConstruct { .. } => false,
405            ValueInfo::Interchangeable(_) => false,
406        })
407    }
408}
409
410impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
411    type Info = AnalyzerInfo;
412
413    fn visit_stmt(
414        &mut self,
415        info: &mut Self::Info,
416        (block_idx, statement_idx): StatementLocation,
417        stmt: &'a Statement,
418    ) {
419        let opt_orig_info = if info.early_return_possible() { Some(info.clone()) } else { None };
420
421        match stmt {
422            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
423                // Note that the ValueInfo::StructConstruct can only be removed by
424                // a StructDeconstruct statement that produces its non-interchangeable inputs so
425                // allowing undroppable inputs is ok here.
426                info.replace(*output, ValueInfo::StructConstruct {
427                    ty: self.lowered.variables[*output].ty,
428                    var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
429                });
430            }
431
432            Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
433            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
434                info.replace(*output, ValueInfo::EnumConstruct {
435                    var_info: Box::new(self.get_var_info(input)),
436                    variant: variant.clone(),
437                });
438            }
439            _ => info.invalidate(),
440        }
441
442        if let Some(return_info) = opt_orig_info {
443            if !info.early_return_possible() {
444                self.fixes.push(FixInfo {
445                    location: (block_idx, statement_idx + 1),
446                    return_info: return_info.opt_return_info.unwrap(),
447                });
448            }
449        }
450    }
451
452    fn visit_goto(
453        &mut self,
454        info: &mut Self::Info,
455        _statement_location: StatementLocation,
456        _target_block_id: BlockId,
457        remapping: &VarRemapping,
458    ) {
459        info.apply(&|var_usage| {
460            if let Some(usage) = remapping.get(&var_usage.var_id) {
461                ValueInfo::Var(*usage)
462            } else {
463                ValueInfo::Var(*var_usage)
464            }
465        });
466    }
467
468    fn merge_match(
469        &mut self,
470        _statement_location: StatementLocation,
471        match_info: &'a MatchInfo,
472        infos: impl Iterator<Item = Self::Info>,
473    ) -> Self::Info {
474        let infos: Vec<_> = infos.collect();
475        let opt_return_info = self.try_merge_match(match_info, &infos);
476        if opt_return_info.is_none() {
477            // If the optimization is not applicable before the match, check if it is applicable
478            // in the arms.
479            for (arm, info) in match_info.arms().iter().zip(infos) {
480                if info.early_return_possible() {
481                    self.fixes.push(FixInfo {
482                        location: (arm.block_id, 0),
483                        return_info: info.opt_return_info.unwrap(),
484                    });
485                }
486            }
487        }
488        Self::Info { opt_return_info }
489    }
490
491    fn info_from_return(
492        &mut self,
493        (block_id, _statement_idx): StatementLocation,
494        vars: &'a [VarUsage],
495    ) -> Self::Info {
496        let location = match &self.lowered.blocks[block_id].end {
497            FlatBlockEnd::Return(_vars, location) => *location,
498            _ => unreachable!(),
499        };
500
501        // Note that `self.get_var_info` is not used here because ValueInfo::Interchangeable is
502        // supported only inside other ValueInfo variants.
503        AnalyzerInfo {
504            opt_return_info: Some(ReturnInfo {
505                returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
506                location,
507            }),
508        }
509    }
510}