1use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
2
3use crate::ids::LocationId;
4use crate::{
5 BlockId, FlatBlock, FlatBlockEnd, MatchArm, MatchEnumInfo, MatchEnumValue, MatchExternInfo,
6 MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
7 StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
8 VarUsage, VariableId,
9};
10
11#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
13pub enum InliningStrategy {
14 #[default]
16 Default,
17 Avoid,
19}
20
21pub trait Rebuilder {
23 fn map_var_id(&mut self, var: VariableId) -> VariableId;
24 fn map_var_usage(&mut self, var_usage: VarUsage) -> VarUsage {
25 VarUsage {
26 var_id: self.map_var_id(var_usage.var_id),
27 location: self.map_location(var_usage.location),
28 }
29 }
30 fn map_location(&mut self, location: LocationId) -> LocationId {
31 location
32 }
33 fn map_block_id(&mut self, block: BlockId) -> BlockId {
34 block
35 }
36 fn transform_statement(&mut self, _statement: &mut Statement) {}
37 fn transform_remapping(&mut self, _remapping: &mut VarRemapping) {}
38 fn transform_end(&mut self, _end: &mut FlatBlockEnd) {}
39 fn transform_block(&mut self, _block: &mut FlatBlock) {}
40}
41
42pub trait RebuilderEx: Rebuilder {
43 fn rebuild_statement(&mut self, statement: &Statement) -> Statement {
45 let mut statement = match statement {
46 Statement::Const(stmt) => Statement::Const(StatementConst {
47 value: stmt.value.clone(),
48 output: self.map_var_id(stmt.output),
49 }),
50 Statement::Call(stmt) => Statement::Call(StatementCall {
51 function: stmt.function,
52 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
53 with_coupon: stmt.with_coupon,
54 outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
55 location: self.map_location(stmt.location),
56 }),
57 Statement::StructConstruct(stmt) => {
58 Statement::StructConstruct(StatementStructConstruct {
59 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
60 output: self.map_var_id(stmt.output),
61 })
62 }
63 Statement::StructDestructure(stmt) => {
64 Statement::StructDestructure(StatementStructDestructure {
65 input: self.map_var_usage(stmt.input),
66 outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
67 })
68 }
69 Statement::EnumConstruct(stmt) => Statement::EnumConstruct(StatementEnumConstruct {
70 variant: stmt.variant.clone(),
71 input: self.map_var_usage(stmt.input),
72 output: self.map_var_id(stmt.output),
73 }),
74 Statement::Snapshot(stmt) => Statement::Snapshot(StatementSnapshot::new(
75 self.map_var_usage(stmt.input),
76 self.map_var_id(stmt.original()),
77 self.map_var_id(stmt.snapshot()),
78 )),
79 Statement::Desnap(stmt) => Statement::Desnap(StatementDesnap {
80 input: self.map_var_usage(stmt.input),
81 output: self.map_var_id(stmt.output),
82 }),
83 };
84 self.transform_statement(&mut statement);
85 statement
86 }
87
88 fn rebuild_remapping(&mut self, remapping: &VarRemapping) -> VarRemapping {
90 let mut remapping = VarRemapping {
91 remapping: OrderedHashMap::from_iter(remapping.iter().map(|(dst, src_var_usage)| {
92 (self.map_var_id(*dst), self.map_var_usage(*src_var_usage))
93 })),
94 };
95 self.transform_remapping(&mut remapping);
96 remapping
97 }
98
99 fn rebuild_end(&mut self, end: &FlatBlockEnd) -> FlatBlockEnd {
101 let mut end = match end {
102 FlatBlockEnd::Return(returns, location) => FlatBlockEnd::Return(
103 returns.iter().map(|var_usage| self.map_var_usage(*var_usage)).collect(),
104 self.map_location(*location),
105 ),
106 FlatBlockEnd::Panic(data) => FlatBlockEnd::Panic(self.map_var_usage(*data)),
107 FlatBlockEnd::Goto(block_id, remapping) => {
108 FlatBlockEnd::Goto(self.map_block_id(*block_id), self.rebuild_remapping(remapping))
109 }
110 FlatBlockEnd::NotSet => unreachable!(),
111 FlatBlockEnd::Match { info } => FlatBlockEnd::Match {
112 info: match info {
113 MatchInfo::Extern(stmt) => MatchInfo::Extern(MatchExternInfo {
114 function: stmt.function,
115 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
116 arms: stmt
117 .arms
118 .iter()
119 .map(|arm| MatchArm {
120 arm_selector: arm.arm_selector.clone(),
121 block_id: self.map_block_id(arm.block_id),
122 var_ids: arm
123 .var_ids
124 .iter()
125 .map(|var_id| self.map_var_id(*var_id))
126 .collect(),
127 })
128 .collect(),
129 location: self.map_location(stmt.location),
130 }),
131 MatchInfo::Enum(stmt) => MatchInfo::Enum(MatchEnumInfo {
132 concrete_enum_id: stmt.concrete_enum_id,
133 input: self.map_var_usage(stmt.input),
134 arms: stmt
135 .arms
136 .iter()
137 .map(|arm| MatchArm {
138 arm_selector: arm.arm_selector.clone(),
139 block_id: self.map_block_id(arm.block_id),
140 var_ids: arm
141 .var_ids
142 .iter()
143 .map(|var_id| self.map_var_id(*var_id))
144 .collect(),
145 })
146 .collect(),
147 location: self.map_location(stmt.location),
148 }),
149 MatchInfo::Value(stmt) => MatchInfo::Value(MatchEnumValue {
150 num_of_arms: stmt.num_of_arms,
151 input: self.map_var_usage(stmt.input),
152 arms: stmt
153 .arms
154 .iter()
155 .map(|arm| MatchArm {
156 arm_selector: arm.arm_selector.clone(),
157 block_id: self.map_block_id(arm.block_id),
158 var_ids: arm
159 .var_ids
160 .iter()
161 .map(|var_id| self.map_var_id(*var_id))
162 .collect(),
163 })
164 .collect(),
165 location: self.map_location(stmt.location),
166 }),
167 },
168 },
169 };
170 self.transform_end(&mut end);
171 end
172 }
173
174 fn rebuild_block(&mut self, block: &FlatBlock) -> FlatBlock {
176 let mut statements = vec![];
177 for stmt in &block.statements {
178 statements.push(self.rebuild_statement(stmt));
179 }
180 let end = self.rebuild_end(&block.end);
181 let mut block = FlatBlock { statements, end };
182 self.transform_block(&mut block);
183 block
184 }
185}
186
187impl<T: Rebuilder> RebuilderEx for T {}