cairo_lang_lowering/lower/
refs.rs

1use cairo_lang_defs::ids::MemberId;
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::usage::MemberPath;
5use cairo_lang_semantic::{self as semantic};
6use cairo_lang_utils::extract_matches;
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::chain;
10
11use crate::VariableId;
12use crate::db::LoweringGroup;
13
14//  Information about members captured by the closure and their types.
15#[derive(Clone, Debug)]
16pub struct ClosureInfo {
17    // TODO(TomerStarkware): unite copiable members and snapshots into a single map.
18    /// The members captured by the closure (not as snapshot).
19    pub members: OrderedHashMap<MemberPath, semantic::TypeId>,
20    /// The types of the captured snapshot variables.
21    pub snapshots: OrderedHashMap<MemberPath, semantic::TypeId>,
22}
23
24#[derive(Clone, Default, Debug)]
25pub struct SemanticLoweringMapping {
26    /// Maps member paths ([MemberPath]) to lowered variable ids or scattered variable ids.
27    scattered: OrderedHashMap<MemberPath, Value>,
28    /// Maps captured member paths to a closure that captured them.
29    pub captured: UnorderedHashMap<MemberPath, VariableId>,
30    /// Maps captured member paths which are copiable to a closure that captured them.
31    pub copiable_captured: UnorderedHashMap<MemberPath, VariableId>,
32    /// Maps the variable id of a closure to the closure info.
33    pub closures: UnorderedHashMap<VariableId, ClosureInfo>,
34}
35impl SemanticLoweringMapping {
36    /// Returns the topmost mapped member path containing the given member path, or None no such
37    /// member path exists in the mapping.
38    pub fn topmost_mapped_containing_member_path(
39        &mut self,
40        mut member_path: MemberPath,
41    ) -> Option<MemberPath> {
42        let mut res = None;
43        loop {
44            if self.scattered.contains_key(&member_path) {
45                res = Some(member_path.clone());
46            }
47            let MemberPath::Member { parent, .. } = member_path else {
48                return res;
49            };
50            member_path = *parent;
51        }
52    }
53
54    /// Returns the scattered members of the given member path, or None if the member path is not
55    /// scattered.
56    pub fn get_scattered_members(&mut self, member_path: &MemberPath) -> Option<Vec<MemberPath>> {
57        let Some(Value::Scattered(scattered)) = self.scattered.get(member_path) else {
58            return None;
59        };
60        Some(
61            scattered
62                .members
63                .iter()
64                .map(|(member_id, _)| MemberPath::Member {
65                    parent: member_path.clone().into(),
66                    member_id: *member_id,
67                    concrete_struct_id: scattered.concrete_struct_id,
68                })
69                .collect(),
70        )
71    }
72
73    pub fn destructure_closure<TContext: StructRecomposer>(
74        &mut self,
75        ctx: &mut TContext,
76        closure_var: VariableId,
77        closure_info: &ClosureInfo,
78    ) -> Vec<VariableId> {
79        ctx.deconstruct_by_types(
80            closure_var,
81            chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
82        )
83    }
84
85    pub fn invalidate_closure<TContext: StructRecomposer>(
86        &mut self,
87        ctx: &mut TContext,
88        closure_var: VariableId,
89    ) {
90        let opt_closure = self.closures.remove(&closure_var);
91        if let Some(closure_info) = opt_closure {
92            let new_vars = self.destructure_closure(ctx, closure_var, &closure_info);
93
94            // Note that members.keys() can be shorter than new_vars, as the members captured
95            // as snapshots don't need to be updated.
96            for (path, new_var) in closure_info.members.keys().zip(new_vars) {
97                if self.captured.remove(path).is_some() {
98                    self.update(ctx, path, new_var).unwrap();
99                } else {
100                    self.copiable_captured.remove(path);
101                }
102            }
103        }
104    }
105
106    pub fn get<TContext: StructRecomposer>(
107        &mut self,
108        mut ctx: TContext,
109        path: &MemberPath,
110    ) -> Option<VariableId> {
111        if let Some(closure_var) = self.captured.get(path) {
112            self.invalidate_closure(&mut ctx, *closure_var);
113        }
114        let value = self.break_into_value(&mut ctx, path)?;
115        Self::assemble_value(&mut ctx, value)
116    }
117
118    pub fn introduce(&mut self, path: MemberPath, var: VariableId) {
119        self.scattered.insert(path, Value::Var(var));
120    }
121
122    pub fn update<TContext: StructRecomposer>(
123        &mut self,
124        ctx: &mut TContext,
125        path: &MemberPath,
126        var: VariableId,
127    ) -> Option<()> {
128        // TODO(TomerStarkware): check if path is captured by a closure and invalidate the closure.
129        // Right now this can only happen if we take a snapshot of the variable (as the
130        // snapshot function returns a new var).
131        // we need the make sure the borrow checker invalidates the closure when mutable capture
132        // is supported.
133
134        let value = self.break_into_value(ctx, path)?;
135        *value = Value::Var(var);
136        Some(())
137    }
138
139    fn assemble_value<TContext: StructRecomposer>(
140        ctx: &mut TContext,
141        value: &mut Value,
142    ) -> Option<VariableId> {
143        Some(match value {
144            Value::Var(var) => *var,
145            Value::Scattered(scattered) => {
146                let members = scattered
147                    .members
148                    .iter_mut()
149                    .map(|(_, value)| Self::assemble_value(ctx, value))
150                    .collect::<Option<_>>()?;
151                let var = ctx.reconstruct(scattered.concrete_struct_id, members);
152                *value = Value::Var(var);
153                var
154            }
155        })
156    }
157
158    fn break_into_value<TContext: StructRecomposer>(
159        &mut self,
160        ctx: &mut TContext,
161        path: &MemberPath,
162    ) -> Option<&mut Value> {
163        if self.scattered.contains_key(path) {
164            return self.scattered.get_mut(path);
165        }
166
167        let MemberPath::Member { parent, member_id, concrete_struct_id, .. } = path else {
168            return None;
169        };
170
171        let parent_value = self.break_into_value(ctx, parent)?;
172        match parent_value {
173            Value::Var(var) => {
174                let members = ctx.deconstruct(*concrete_struct_id, *var);
175                let members = OrderedHashMap::from_iter(
176                    members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
177                );
178                let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
179                *parent_value = Value::Scattered(Box::new(scattered));
180
181                extract_matches!(parent_value, Value::Scattered).members.get_mut(member_id)
182            }
183            Value::Scattered(scattered) => scattered.members.get_mut(member_id),
184        }
185    }
186}
187
188/// A trait for deconstructing and constructing structs.
189pub trait StructRecomposer {
190    fn deconstruct(
191        &mut self,
192        concrete_struct_id: semantic::ConcreteStructId,
193        value: VariableId,
194    ) -> OrderedHashMap<MemberId, VariableId>;
195
196    fn deconstruct_by_types(
197        &mut self,
198        value: VariableId,
199        types: impl Iterator<Item = semantic::TypeId>,
200    ) -> Vec<VariableId>;
201
202    fn reconstruct(
203        &mut self,
204        concrete_struct_id: semantic::ConcreteStructId,
205        members: Vec<VariableId>,
206    ) -> VariableId;
207    fn var_ty(&self, var: VariableId) -> semantic::TypeId;
208    fn db(&self) -> &dyn LoweringGroup;
209}
210
211/// An intermediate value for a member path.
212#[derive(Clone, Debug, DebugWithDb)]
213#[debug_db(ExprFormatter<'a>)]
214enum Value {
215    /// The value of member path is stored in a lowered variable.
216    Var(VariableId),
217    /// The value of the member path is not stored. It should be reconstructed from the member
218    /// values.
219    Scattered(Box<Scattered>),
220}
221
222/// A value for a non-stored member path. Recursively holds the [Value] for the members.
223#[derive(Clone, Debug, DebugWithDb)]
224#[debug_db(ExprFormatter<'a>)]
225struct Scattered {
226    concrete_struct_id: semantic::ConcreteStructId,
227    members: OrderedHashMap<MemberId, Value>,
228}