cairo_lang_lowering/optimizations/
return_optimization.rs1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic as semantic;
6use cairo_lang_utils::{extract_matches, require};
7use itertools::Itertools;
8use semantic::MatchArmSelector;
9
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::db::LoweringGroup;
12use crate::ids::LocationId;
13use crate::{
14 BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
15 StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
16 VarUsage, VariableId,
17};
18
19pub fn return_optimization(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
25 if lowered.blocks.is_empty() {
26 return;
27 }
28 let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
29 let mut analysis = BackAnalysis::new(lowered, ctx);
30 let info = analysis.get_root_info();
31 let mut ctx = analysis.analyzer;
32
33 if info.early_return_possible() {
34 ctx.fixes.push(FixInfo {
35 location: (BlockId::root(), 0),
36 return_info: info.opt_return_info.clone().unwrap(),
37 });
38 }
39
40 for FixInfo { location: (block_id, statement_idx), return_info } in ctx.fixes.into_iter() {
41 let block = &mut lowered.blocks[block_id];
42 block.statements.truncate(statement_idx);
43 block.end = FlatBlockEnd::Return(
44 return_info
45 .returned_vars
46 .iter()
47 .map(|var_info| *extract_matches!(var_info, ValueInfo::Var))
48 .collect_vec(),
49 return_info.location,
50 )
51 }
52}
53
54pub struct ReturnOptimizerContext<'a> {
55 db: &'a dyn LoweringGroup,
56 lowered: &'a FlatLowered,
57
58 fixes: Vec<FixInfo>,
60}
61impl ReturnOptimizerContext<'_> {
62 fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
64 let var_ty = &self.lowered.variables[var_usage.var_id].ty;
65 if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
66 ValueInfo::Interchangeable(*var_ty)
67 } else {
68 ValueInfo::Var(*var_usage)
69 }
70 }
71
72 fn is_droppable(&self, var_id: VariableId) -> bool {
74 self.lowered.variables[var_id].droppable.is_ok()
75 }
76
77 fn try_merge_match(
80 &mut self,
81 match_info: &MatchInfo,
82 infos: &[AnalyzerInfo],
83 ) -> Option<ReturnInfo> {
84 let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
85 return None;
86 };
87 require(!arms.is_empty())?;
88
89 let input_info = self.get_var_info(input);
90 let mut opt_last_info = None;
91 for (arm, info) in arms.iter().zip(infos) {
92 let mut curr_info = info.clone();
93 curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
94
95 require(curr_info.early_return_possible())?;
96
97 match curr_info.opt_return_info {
98 Some(return_info)
99 if opt_last_info
100 .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
101 .unwrap_or(true) =>
102 {
103 opt_last_info = Some(return_info)
106 }
107 _ => return None,
108 }
109 }
110
111 Some(opt_last_info.unwrap())
112 }
113}
114
115pub struct FixInfo {
117 location: StatementLocation,
119 return_info: ReturnInfo,
121}
122
123#[derive(Clone, Debug, PartialEq, Eq)]
125pub enum ValueInfo {
126 Var(VarUsage),
128 Interchangeable(semantic::TypeId),
130 StructConstruct {
132 ty: semantic::TypeId,
134 var_infos: Vec<ValueInfo>,
136 },
137 EnumConstruct {
139 var_info: Box<ValueInfo>,
141 variant: semantic::ConcreteVariant,
143 },
144}
145
146enum OpResult {
148 InputConsumed,
150 ValueInvalidated,
152 NoChange,
154}
155
156impl ValueInfo {
157 fn apply<F>(&mut self, f: &F)
159 where
160 F: Fn(&VarUsage) -> ValueInfo,
161 {
162 match self {
163 ValueInfo::Var(var_usage) => *self = f(var_usage),
164 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
165 for var_info in var_infos.iter_mut() {
166 var_info.apply(f);
167 }
168 }
169 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
170 var_info.apply(f);
171 }
172 ValueInfo::Interchangeable(_) => {}
173 }
174 }
175
176 fn apply_deconstruct(
179 &mut self,
180 ctx: &ReturnOptimizerContext<'_>,
181 stmt: &StatementStructDestructure,
182 ) -> OpResult {
183 match self {
184 ValueInfo::Var(var_usage) => {
185 if stmt.outputs.contains(&var_usage.var_id) {
186 OpResult::ValueInvalidated
187 } else {
188 OpResult::NoChange
189 }
190 }
191 ValueInfo::StructConstruct { ty, var_infos } => {
192 let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
193 && var_infos.len() == stmt.outputs.len();
194 for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
195 if !cancels_out {
196 break;
197 }
198
199 match var_info {
200 ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
201 ValueInfo::Interchangeable(ty)
202 if &ctx.lowered.variables[*output].ty == ty => {}
203 _ => cancels_out = false,
204 }
205 }
206
207 if cancels_out {
208 *self = ValueInfo::Var(stmt.input);
211 return OpResult::InputConsumed;
212 }
213
214 let mut input_consumed = false;
215 for var_info in var_infos.iter_mut() {
216 match var_info.apply_deconstruct(ctx, stmt) {
217 OpResult::InputConsumed => {
218 input_consumed = true;
219 }
220 OpResult::ValueInvalidated => {
221 return OpResult::ValueInvalidated;
224 }
225 OpResult::NoChange => {}
226 }
227 }
228
229 match input_consumed {
230 true => OpResult::InputConsumed,
231 false => OpResult::NoChange,
232 }
233 }
234 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
235 var_info.apply_deconstruct(ctx, stmt)
236 }
237 ValueInfo::Interchangeable(_) => OpResult::NoChange,
238 }
239 }
240
241 fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
244 match self {
245 ValueInfo::Var(var_usage) => {
246 if arm.var_ids == [var_usage.var_id] {
247 OpResult::ValueInvalidated
248 } else {
249 OpResult::NoChange
250 }
251 }
252 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
253 let mut input_consumed = false;
254 for var_info in var_infos.iter_mut() {
255 match var_info.apply_match_arm(input, arm) {
256 OpResult::InputConsumed => {
257 input_consumed = true;
258 }
259 OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
260 OpResult::NoChange => {}
261 }
262 }
263
264 if input_consumed {
265 return OpResult::InputConsumed;
266 }
267 OpResult::NoChange
268 }
269 ValueInfo::EnumConstruct { ref mut var_info, variant } => {
270 let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
271 panic!("Enum construct should not appear in value match");
272 };
273
274 if *variant == *arm_variant {
275 let cancels_out = match **var_info {
276 ValueInfo::Interchangeable(_) => true,
277 ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
278 _ => false,
279 };
280
281 if cancels_out {
282 *self = input.clone();
285 return OpResult::InputConsumed;
286 }
287 }
288
289 var_info.apply_match_arm(input, arm)
290 }
291 ValueInfo::Interchangeable(_) => OpResult::NoChange,
292 }
293 }
294}
295
296#[derive(Clone, Debug, PartialEq, Eq)]
300pub struct ReturnInfo {
301 returned_vars: Vec<ValueInfo>,
302 location: LocationId,
303}
304
305#[derive(Clone, Debug, PartialEq, Eq)]
311pub struct AnalyzerInfo {
312 opt_return_info: Option<ReturnInfo>,
313}
314
315impl AnalyzerInfo {
316 fn invalidated() -> Self {
318 AnalyzerInfo { opt_return_info: None }
319 }
320
321 fn invalidate(&mut self) {
323 *self = Self::invalidated();
324 }
325
326 fn apply<F>(&mut self, f: &F)
328 where
329 F: Fn(&VarUsage) -> ValueInfo,
330 {
331 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
332 return;
333 };
334
335 for var_info in returned_vars.iter_mut() {
336 var_info.apply(f)
337 }
338 }
339
340 fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
342 self.apply(&|var_usage| {
343 if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
344 });
345 }
346
347 fn apply_deconstruct(
349 &mut self,
350 ctx: &ReturnOptimizerContext<'_>,
351 stmt: &StatementStructDestructure,
352 ) {
353 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
354
355 let mut input_consumed = false;
356 for var_info in returned_vars.iter_mut() {
357 match var_info.apply_deconstruct(ctx, stmt) {
358 OpResult::InputConsumed => {
359 input_consumed = true;
360 }
361 OpResult::ValueInvalidated => {
362 self.invalidate();
363 return;
364 }
365 OpResult::NoChange => {}
366 };
367 }
368
369 if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
370 self.invalidate();
371 }
372 }
373
374 fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
376 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
377
378 let mut input_consumed = false;
379 for var_info in returned_vars.iter_mut() {
380 match var_info.apply_match_arm(input, arm) {
381 OpResult::InputConsumed => {
382 input_consumed = true;
383 }
384 OpResult::ValueInvalidated => {
385 self.invalidate();
386 return;
387 }
388 OpResult::NoChange => {}
389 };
390 }
391
392 if !(input_consumed || is_droppable) {
393 self.invalidate();
394 }
395 }
396
397 fn early_return_possible(&self) -> bool {
399 let Some(ReturnInfo { ref returned_vars, .. }) = self.opt_return_info else { return false };
400
401 returned_vars.iter().all(|var_info| match var_info {
402 ValueInfo::Var(_) => true,
403 ValueInfo::StructConstruct { .. } => false,
404 ValueInfo::EnumConstruct { .. } => false,
405 ValueInfo::Interchangeable(_) => false,
406 })
407 }
408}
409
410impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
411 type Info = AnalyzerInfo;
412
413 fn visit_stmt(
414 &mut self,
415 info: &mut Self::Info,
416 (block_idx, statement_idx): StatementLocation,
417 stmt: &'a Statement,
418 ) {
419 let opt_orig_info = if info.early_return_possible() { Some(info.clone()) } else { None };
420
421 match stmt {
422 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
423 info.replace(*output, ValueInfo::StructConstruct {
427 ty: self.lowered.variables[*output].ty,
428 var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
429 });
430 }
431
432 Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
433 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
434 info.replace(*output, ValueInfo::EnumConstruct {
435 var_info: Box::new(self.get_var_info(input)),
436 variant: variant.clone(),
437 });
438 }
439 _ => info.invalidate(),
440 }
441
442 if let Some(return_info) = opt_orig_info {
443 if !info.early_return_possible() {
444 self.fixes.push(FixInfo {
445 location: (block_idx, statement_idx + 1),
446 return_info: return_info.opt_return_info.unwrap(),
447 });
448 }
449 }
450 }
451
452 fn visit_goto(
453 &mut self,
454 info: &mut Self::Info,
455 _statement_location: StatementLocation,
456 _target_block_id: BlockId,
457 remapping: &VarRemapping,
458 ) {
459 info.apply(&|var_usage| {
460 if let Some(usage) = remapping.get(&var_usage.var_id) {
461 ValueInfo::Var(*usage)
462 } else {
463 ValueInfo::Var(*var_usage)
464 }
465 });
466 }
467
468 fn merge_match(
469 &mut self,
470 _statement_location: StatementLocation,
471 match_info: &'a MatchInfo,
472 infos: impl Iterator<Item = Self::Info>,
473 ) -> Self::Info {
474 let infos: Vec<_> = infos.collect();
475 let opt_return_info = self.try_merge_match(match_info, &infos);
476 if opt_return_info.is_none() {
477 for (arm, info) in match_info.arms().iter().zip(infos) {
480 if info.early_return_possible() {
481 self.fixes.push(FixInfo {
482 location: (arm.block_id, 0),
483 return_info: info.opt_return_info.unwrap(),
484 });
485 }
486 }
487 }
488 Self::Info { opt_return_info }
489 }
490
491 fn info_from_return(
492 &mut self,
493 (block_id, _statement_idx): StatementLocation,
494 vars: &'a [VarUsage],
495 ) -> Self::Info {
496 let location = match &self.lowered.blocks[block_id].end {
497 FlatBlockEnd::Return(_vars, location) => *location,
498 _ => unreachable!(),
499 };
500
501 AnalyzerInfo {
504 opt_return_info: Some(ReturnInfo {
505 returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
506 location,
507 }),
508 }
509 }
510}