1use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_semantic as semantic;
9use cairo_lang_semantic::ConcreteFunction;
10use cairo_lang_semantic::corelib::{core_array_felt252_ty, core_module, get_ty_by_name, unit_ty};
11use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
12use cairo_lang_semantic::items::imp::ImplId;
13use cairo_lang_utils::{Intern, LookupIntern};
14use itertools::{Itertools, chain, zip_eq};
15use semantic::{TypeId, TypeLongId};
16
17use crate::borrow_check::Demand;
18use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
19use crate::borrow_check::demand::{AuxCombine, DemandReporter};
20use crate::db::LoweringGroup;
21use crate::ids::{ConcreteFunctionWithBodyId, SemanticFunctionIdEx};
22use crate::lower::context::{VarRequest, VariableAllocator};
23use crate::{
24 BlockId, FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall,
25 StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
26};
27
28pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
29
30#[derive(PartialEq, Eq, PartialOrd, Ord)]
32enum AddDestructFlowType {
33 Plain,
35 PanicVar,
37 PanicPostMatch,
39}
40
41pub struct DestructAdder<'a> {
43 db: &'a dyn LoweringGroup,
44 lowered: &'a FlatLowered,
45 destructions: Vec<DestructionEntry>,
46 panic_ty: TypeId,
47 never_fn_actual_return_ty: TypeId,
49 is_panic_destruct_fn: bool,
50}
51
52enum DestructionEntry {
54 Plain(PlainDestructionEntry),
56 Panic(PanicDeconstructionEntry),
58}
59
60struct PlainDestructionEntry {
61 position: StatementLocation,
62 var_id: VariableId,
63 impl_id: ImplId,
64}
65struct PanicDeconstructionEntry {
66 panic_location: PanicLocation,
67 var_id: VariableId,
68 impl_id: ImplId,
69}
70
71impl DestructAdder<'_> {
72 fn set_post_stmt_destruct(
74 &mut self,
75 introductions: &[VariableId],
76 info: &mut DestructAdderDemand,
77 block_id: BlockId,
78 statement_index: usize,
79 ) {
80 if let [panic_var] = introductions[..] {
81 let var = &self.lowered.variables[panic_var];
82 if [self.panic_ty, self.never_fn_actual_return_ty].contains(&var.ty) {
83 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
84 statement_location: (block_id, statement_index),
85 }]);
86 }
87 }
88 }
89
90 fn set_post_match_state(
93 &mut self,
94 introduced_vars: &[VariableId],
95 info: &mut DestructAdderDemand,
96 match_block_id: BlockId,
97 target_block_id: BlockId,
98 arm_idx: usize,
99 ) {
100 if arm_idx != 1 {
101 return;
103 }
104 if let [err_var] = introduced_vars[..] {
105 let var = &self.lowered.variables[err_var];
106
107 let long_ty = var.ty.lookup_intern(self.db);
108 let TypeLongId::Tuple(tys) = long_ty else {
109 return;
110 };
111 if tys.first() == Some(&self.panic_ty) {
112 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
113 match_block_id,
114 target_block_id,
115 }]);
116 }
117 }
118 }
119}
120
121impl DemandReporter<VariableId, PanicState> for DestructAdder<'_> {
122 type IntroducePosition = StatementLocation;
123 type UsePosition = ();
124
125 fn drop_aux(
126 &mut self,
127 position: StatementLocation,
128 var_id: VariableId,
129 panic_state: PanicState,
130 ) {
131 let var = &self.lowered.variables[var_id];
132 if var.droppable.is_ok() {
136 return;
137 };
138 if let Ok(impl_id) = var.destruct_impl.clone() {
140 self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
141 position,
142 var_id,
143 impl_id,
144 }));
145 return;
146 }
147 if let Ok(impl_id) = var.panic_destruct_impl.clone() {
149 if let PanicState::EndsWithPanic(panic_locations) = panic_state {
150 for panic_location in panic_locations {
151 self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
152 panic_location,
153 var_id,
154 impl_id,
155 }));
156 }
157 return;
158 }
159 }
160
161 panic!("Borrow checker should have caught this.")
162 }
163}
164
165#[derive(Clone, Default)]
168pub enum PanicState {
169 EndsWithPanic(Vec<PanicLocation>),
173 #[default]
174 Otherwise,
175}
176impl AuxCombine for PanicState {
178 fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
179 where
180 Self: 'a,
181 {
182 let mut panic_locations = vec![];
183 for state in iter {
184 if let Self::EndsWithPanic(locations) = state {
185 panic_locations.extend_from_slice(locations);
186 } else {
187 return Self::Otherwise;
188 }
189 }
190
191 Self::EndsWithPanic(panic_locations)
192 }
193}
194
195#[derive(Clone)]
197pub enum PanicLocation {
198 PanicVar { statement_location: StatementLocation },
200 PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
202}
203
204impl Analyzer<'_> for DestructAdder<'_> {
205 type Info = DestructAdderDemand;
206
207 fn visit_stmt(
208 &mut self,
209 info: &mut Self::Info,
210 (block_id, statement_index): StatementLocation,
211 stmt: &Statement,
212 ) {
213 self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
214 info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
216 info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
217 }
218
219 fn visit_goto(
220 &mut self,
221 info: &mut Self::Info,
222 _statement_location: StatementLocation,
223 _target_block_id: BlockId,
224 remapping: &VarRemapping,
225 ) {
226 info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
227 }
228
229 fn merge_match(
230 &mut self,
231 (block_id, _statement_index): StatementLocation,
232 match_info: &MatchInfo,
233 infos: impl Iterator<Item = Self::Info>,
234 ) -> Self::Info {
235 let arm_demands = zip_eq(match_info.arms(), infos)
236 .enumerate()
237 .map(|(arm_idx, (arm, mut demand))| {
238 let use_position = (arm.block_id, 0);
239 self.set_post_match_state(
240 &arm.var_ids,
241 &mut demand,
242 block_id,
243 arm.block_id,
244 arm_idx,
245 );
246 demand.variables_introduced(self, &arm.var_ids, use_position);
247 (demand, use_position)
248 })
249 .collect_vec();
250 let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
251 demand.variables_used(
252 self,
253 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
254 );
255 demand
256 }
257
258 fn info_from_return(
259 &mut self,
260 statement_location: StatementLocation,
261 vars: &[VarUsage],
262 ) -> Self::Info {
263 let mut info = DestructAdderDemand::default();
264 if self.is_panic_destruct_fn {
266 info.aux =
267 PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
268 }
269
270 info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
271 info
272 }
273}
274
275fn panic_ty(db: &dyn LoweringGroup) -> semantic::TypeId {
276 get_ty_by_name(db.upcast(), core_module(db.upcast()), "Panic".into(), vec![])
277}
278
279pub fn add_destructs(
281 db: &dyn LoweringGroup,
282 function_id: ConcreteFunctionWithBodyId,
283 lowered: &mut FlatLowered,
284) {
285 if lowered.blocks.is_empty() {
286 return;
287 }
288
289 let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
290 return;
291 };
292
293 let panic_ty = panic_ty(db.upcast());
294 let felt_arr_ty = core_array_felt252_ty(db.upcast());
295 let never_fn_actual_return_ty = TypeLongId::Tuple(vec![panic_ty, felt_arr_ty]).intern(db);
296 let checker = DestructAdder {
297 db,
298 lowered,
299 destructions: vec![],
300 panic_ty,
301 never_fn_actual_return_ty,
302 is_panic_destruct_fn,
303 };
304 let mut analysis = BackAnalysis::new(lowered, checker);
305 let mut root_demand = analysis.get_root_info();
306 root_demand.variables_introduced(
307 &mut analysis.analyzer,
308 &lowered.parameters,
309 (BlockId::root(), 0),
310 );
311 assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
312
313 let mut variables = VariableAllocator::new(
314 db,
315 function_id.function_with_body_id(db).base_semantic_function(db),
316 lowered.variables.clone(),
317 )
318 .unwrap();
319
320 let info = db.core_info();
321 let plain_trait_function = info.destruct_fn;
322 let panic_trait_function = info.panic_destruct_fn;
323
324 let stable_ptr = function_id
326 .function_with_body_id(db.upcast())
327 .base_semantic_function(db)
328 .untyped_stable_ptr(db.upcast());
329
330 let location = variables.get_location(stable_ptr);
331
332 let destructions = analysis.analyzer.destructions;
333
334 let as_tuple = |entry: &DestructionEntry| match entry {
341 DestructionEntry::Plain(plain_destruct) => {
342 (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
343 }
344 DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
345 PanicLocation::PanicMatch { target_block_id, match_block_id } => {
346 (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
347 }
348 PanicLocation::PanicVar { statement_location } => {
349 (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
350 }
351 },
352 };
353
354 for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
355 destructions.into_iter().sorted_by_key(as_tuple).rev().chunk_by(as_tuple).into_iter()
356 {
357 let mut stmts = vec![];
358
359 let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
360 let mut last_panic_var = first_panic_var;
361
362 for destruction in destructions {
363 let output_var = variables.new_var(VarRequest { ty: unit_ty(db.upcast()), location });
364
365 match destruction {
366 DestructionEntry::Plain(plain_destruct) => {
367 let semantic_function = semantic::FunctionLongId {
368 function: ConcreteFunction {
369 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
370 impl_id: plain_destruct.impl_id,
371 function: plain_trait_function,
372 }),
373 generic_args: vec![],
374 },
375 }
376 .intern(db);
377
378 stmts.push(StatementCall {
379 function: semantic_function.lowered(db),
380 inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
381 with_coupon: false,
382 outputs: vec![output_var],
383 location: lowered.variables[plain_destruct.var_id].location,
384 })
385 }
386
387 DestructionEntry::Panic(panic_destruct) => {
388 let semantic_function = semantic::FunctionLongId {
389 function: ConcreteFunction {
390 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
391 impl_id: panic_destruct.impl_id,
392 function: panic_trait_function,
393 }),
394 generic_args: vec![],
395 },
396 }
397 .intern(db);
398
399 let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
400
401 stmts.push(StatementCall {
402 function: semantic_function.lowered(db),
403 inputs: vec![
404 VarUsage { var_id: panic_destruct.var_id, location },
405 VarUsage { var_id: last_panic_var, location },
406 ],
407 with_coupon: false,
408 outputs: vec![new_panic_var, output_var],
409 location,
410 });
411 last_panic_var = new_panic_var;
412 }
413 }
414 }
415
416 match destruct_type {
417 AddDestructFlowType::Plain => {
418 let block = &mut lowered.blocks[BlockId(block_id)];
419 block
420 .statements
421 .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
422 }
423 AddDestructFlowType::PanicPostMatch => {
424 let block = &mut lowered.blocks[BlockId(match_block_id)];
425 let FlatBlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
426 unreachable!();
427 };
428
429 let arm = &mut info.arms[1];
430 let tuple_var = &mut arm.var_ids[0];
431 let tuple_ty = lowered.variables[*tuple_var].ty;
432 let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
433 let orig_tuple_var = *tuple_var;
434 *tuple_var = new_tuple_var;
435 let long_ty = tuple_ty.lookup_intern(db);
436 let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
437
438 let vars = tys
439 .iter()
440 .copied()
441 .map(|ty| variables.new_var(VarRequest { ty, location }))
442 .collect::<Vec<_>>();
443
444 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
445
446 let target_block_id = arm.block_id;
447
448 let block = &mut lowered.blocks[target_block_id];
449
450 block.statements.splice(
451 0..0,
452 chain!(
453 [Statement::StructDestructure(StatementStructDestructure {
454 input: VarUsage { var_id: new_tuple_var, location },
455 outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
456 .collect(),
457 })],
458 stmts.into_iter().map(Statement::Call),
459 [Statement::StructConstruct(StatementStructConstruct {
460 inputs: vars
461 .into_iter()
462 .map(|var_id| VarUsage { var_id, location })
463 .collect(),
464 output: orig_tuple_var,
465 })]
466 ),
467 );
468 }
469 AddDestructFlowType::PanicVar => {
470 let block = &mut lowered.blocks[BlockId(block_id)];
471
472 let idx = match block.statements.get_mut(statement_idx) {
473 Some(stmt) => {
474 match stmt {
475 Statement::StructConstruct(stmt) => {
476 let panic_var = &mut stmt.output;
477 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
478 *panic_var = first_panic_var;
479 }
480 Statement::Call(stmt) => {
481 let tuple_var = &mut stmt.outputs[0];
482 let new_tuple_var = variables.new_var(VarRequest {
483 ty: never_fn_actual_return_ty,
484 location,
485 });
486 let orig_tuple_var = *tuple_var;
487 *tuple_var = new_tuple_var;
488 let new_panic_var =
489 variables.new_var(VarRequest { ty: panic_ty, location });
490 let new_arr_var =
491 variables.new_var(VarRequest { ty: felt_arr_ty, location });
492 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() =
493 new_panic_var;
494 let idx = statement_idx + 1;
495 block.statements.splice(
496 idx..idx,
497 chain!(
498 [Statement::StructDestructure(
499 StatementStructDestructure {
500 input: VarUsage { var_id: new_tuple_var, location },
501 outputs: vec![first_panic_var, new_arr_var],
502 }
503 )],
504 stmts.into_iter().map(Statement::Call),
505 [Statement::StructConstruct(StatementStructConstruct {
506 inputs: [new_panic_var, new_arr_var]
507 .into_iter()
508 .map(|var_id| VarUsage { var_id, location })
509 .collect(),
510 output: orig_tuple_var,
511 })]
512 ),
513 );
514 stmts = vec![];
515 }
516 _ => unreachable!("Expected a struct construct or a call statement."),
517 }
518 statement_idx + 1
519 }
520 None => {
521 assert_eq!(statement_idx, block.statements.len());
522 let panic_var = match &mut block.end {
523 FlatBlockEnd::Return(vars, _) => &mut vars[0].var_id,
524 _ => unreachable!("Expected a return statement."),
525 };
526
527 stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
528 *panic_var = last_panic_var;
529 statement_idx
530 }
531 };
532
533 block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
534 }
535 };
536 }
537
538 lowered.variables = variables.variables;
539}