cairo_lang_semantic/expr/inference/
canonic.rs

1use cairo_lang_defs::ids::{
2    EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId, ImplAliasId, ImplDefId,
3    ImplFunctionId, ImplImplDefId, LocalVarId, MemberId, ParamId, StructId, TraitConstantId,
4    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
5};
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9
10use super::{
11    ConstVar, ImplVar, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceId, InferenceVar,
12    LocalConstVarId, LocalImplVarId, LocalTypeVarId, TypeVar,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
16use crate::items::functions::{
17    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
18    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
19    ImplGenericFunctionWithBodyId,
20};
21use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
22use crate::items::imp::{
23    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
24    UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
25};
26use crate::items::trt::{ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId};
27use crate::substitution::{HasDb, RewriteResult, SemanticObject, SemanticRewriter};
28use crate::types::{
29    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
30    ImplTypeId,
31};
32use crate::{
33    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
34    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
35    ExprId, ExprVar, ExprVarMemberPath, FunctionId, FunctionLongId, GenericArgumentId,
36    GenericParam, MatchArmSelector, Parameter, Signature, TypeId, TypeLongId, ValueSelectorArm,
37    add_basic_rewrites,
38};
39
40/// A canonical representation of a concrete trait that needs to be solved.
41#[derive(Clone, PartialEq, Hash, Eq, Debug, SemanticObject)]
42pub struct CanonicalTrait {
43    pub id: ConcreteTraitId,
44    pub mappings: ImplVarTraitItemMappings,
45}
46
47impl CanonicalTrait {
48    /// Canonicalizes a concrete trait that is part of an [Inference].
49    pub fn canonicalize(
50        db: &dyn SemanticGroup,
51        source_inference_id: InferenceId,
52        trait_id: ConcreteTraitId,
53        impl_var_mappings: ImplVarTraitItemMappings,
54    ) -> (Self, CanonicalMapping) {
55        Canonicalizer::canonicalize(db, source_inference_id, Self {
56            id: trait_id,
57            mappings: impl_var_mappings,
58        })
59    }
60    /// Embeds a canonical trait into an [Inference].
61    pub fn embed(&self, inference: &mut Inference<'_>) -> (CanonicalTrait, CanonicalMapping) {
62        Embedder::embed(inference, self.clone())
63    }
64}
65
66/// A solution for a [CanonicalTrait].
67#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
68pub struct CanonicalImpl(pub ImplId);
69impl CanonicalImpl {
70    /// Canonicalizes a concrete impl that is part of an [Inference].
71    /// Uses the same canonicalization of the trait, to be consistent.
72    pub fn canonicalize(
73        db: &dyn SemanticGroup,
74        impl_id: ImplId,
75        mapping: &CanonicalMapping,
76    ) -> Result<Self, MapperError> {
77        Ok(Self(Mapper::map(db, impl_id, &mapping.to_canonic)?))
78    }
79    /// Embeds a canonical impl into an [Inference].
80    /// Uses the same embedding of the trait, to be consistent.
81    pub fn embed(&self, inference: &Inference<'_>, mapping: &CanonicalMapping) -> ImplId {
82        Mapper::map(inference.db, self.0, &mapping.from_canonic)
83            .expect("Tried to embed a non canonical impl")
84    }
85}
86
87/// Mapping between canonical space and inference space.
88/// Created by a either canonicalizing or embedding a trait.
89#[derive(Debug)]
90pub struct CanonicalMapping {
91    to_canonic: VarMapping,
92    from_canonic: VarMapping,
93}
94impl CanonicalMapping {
95    fn from_to_canonic(to_canonic: VarMapping) -> CanonicalMapping {
96        let from_canonic = VarMapping {
97            type_var_mapping: to_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
98            const_var_mapping: to_canonic.const_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
99            impl_var_mapping: to_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
100            source_inference_id: to_canonic.target_inference_id,
101            target_inference_id: to_canonic.source_inference_id,
102        };
103        Self { to_canonic, from_canonic }
104    }
105    fn from_from_canonic(from_canonic: VarMapping) -> CanonicalMapping {
106        let to_canonic = VarMapping {
107            type_var_mapping: from_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
108            const_var_mapping: from_canonic
109                .const_var_mapping
110                .iter()
111                .map(|(k, v)| (*v, *k))
112                .collect(),
113            impl_var_mapping: from_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
114            source_inference_id: from_canonic.target_inference_id,
115            target_inference_id: from_canonic.source_inference_id,
116        };
117        Self { to_canonic, from_canonic }
118    }
119}
120
121// Mappings.
122#[derive(Debug)]
123pub struct VarMapping {
124    type_var_mapping: OrderedHashMap<LocalTypeVarId, LocalTypeVarId>,
125    const_var_mapping: OrderedHashMap<LocalConstVarId, LocalConstVarId>,
126    impl_var_mapping: OrderedHashMap<LocalImplVarId, LocalImplVarId>,
127    source_inference_id: InferenceId,
128    target_inference_id: InferenceId,
129}
130impl VarMapping {
131    fn new_to_canonic(source_inference_id: InferenceId) -> Self {
132        Self {
133            type_var_mapping: OrderedHashMap::default(),
134            const_var_mapping: OrderedHashMap::default(),
135            impl_var_mapping: OrderedHashMap::default(),
136            source_inference_id,
137            target_inference_id: InferenceId::Canonical,
138        }
139    }
140    fn new_from_canonic(target_inference_id: InferenceId) -> Self {
141        Self {
142            type_var_mapping: OrderedHashMap::default(),
143            const_var_mapping: OrderedHashMap::default(),
144            impl_var_mapping: OrderedHashMap::default(),
145            source_inference_id: InferenceId::Canonical,
146            target_inference_id,
147        }
148    }
149}
150
151/// A 'never' error.
152#[derive(Debug)]
153pub enum NoError {}
154pub trait ResultNoErrEx<T> {
155    fn no_err(self) -> T;
156}
157impl<T> ResultNoErrEx<T> for Result<T, NoError> {
158    fn no_err(self) -> T {
159        match self {
160            Ok(v) => v,
161            #[allow(unreachable_patterns)]
162            Err(err) => match err {},
163        }
164    }
165}
166
167/// Canonicalization rewriter. Each encountered variable is mapped to a new free variable,
168/// in pre-order.
169struct Canonicalizer<'db> {
170    db: &'db dyn SemanticGroup,
171    to_canonic: VarMapping,
172}
173impl<'db> Canonicalizer<'db> {
174    fn canonicalize<T>(
175        db: &'db dyn SemanticGroup,
176        source_inference_id: InferenceId,
177        value: T,
178    ) -> (T, CanonicalMapping)
179    where
180        Self: SemanticRewriter<T, NoError>,
181    {
182        let mut canonicalizer =
183            Self { db, to_canonic: VarMapping::new_to_canonic(source_inference_id) };
184        let value = canonicalizer.rewrite(value).no_err();
185        let mapping = CanonicalMapping::from_to_canonic(canonicalizer.to_canonic);
186        (value, mapping)
187    }
188}
189impl<'a> HasDb<&'a dyn SemanticGroup> for Canonicalizer<'a> {
190    fn get_db(&self) -> &'a dyn SemanticGroup {
191        self.db
192    }
193}
194
195add_basic_rewrites!(
196    <'a>,
197    Canonicalizer<'a>,
198    NoError,
199    @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
200);
201
202impl SemanticRewriter<TypeId, NoError> for Canonicalizer<'_> {
203    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
204        if value.is_var_free(self.db) {
205            return Ok(RewriteResult::NoChange);
206        }
207        value.default_rewrite(self)
208    }
209}
210impl SemanticRewriter<TypeLongId, NoError> for Canonicalizer<'_> {
211    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
212        let TypeLongId::Var(var) = value else {
213            return value.default_rewrite(self);
214        };
215        if var.inference_id != self.to_canonic.source_inference_id {
216            return value.default_rewrite(self);
217        }
218        let next_id = LocalTypeVarId(self.to_canonic.type_var_mapping.len());
219        *value = TypeLongId::Var(TypeVar {
220            id: *self.to_canonic.type_var_mapping.entry(var.id).or_insert(next_id),
221            inference_id: InferenceId::Canonical,
222        });
223        Ok(RewriteResult::Modified)
224    }
225}
226impl SemanticRewriter<ConstValue, NoError> for Canonicalizer<'_> {
227    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
228        let ConstValue::Var(var, mut ty) = value else {
229            return value.default_rewrite(self);
230        };
231        if var.inference_id != self.to_canonic.source_inference_id {
232            return value.default_rewrite(self);
233        }
234        let next_id = LocalConstVarId(self.to_canonic.const_var_mapping.len());
235        ty.default_rewrite(self)?;
236        *value = ConstValue::Var(
237            ConstVar {
238                id: *self.to_canonic.const_var_mapping.entry(var.id).or_insert(next_id),
239                inference_id: InferenceId::Canonical,
240            },
241            ty,
242        );
243        Ok(RewriteResult::Modified)
244    }
245}
246impl SemanticRewriter<ImplId, NoError> for Canonicalizer<'_> {
247    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
248        if value.is_var_free(self.db) {
249            return Ok(RewriteResult::NoChange);
250        }
251        value.default_rewrite(self)
252    }
253}
254impl SemanticRewriter<ImplLongId, NoError> for Canonicalizer<'_> {
255    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
256        let ImplLongId::ImplVar(var_id) = value else {
257            if value.is_var_free(self.db) {
258                return Ok(RewriteResult::NoChange);
259            }
260            return value.default_rewrite(self);
261        };
262        let var = var_id.lookup_intern(self.db);
263        if var.inference_id != self.to_canonic.source_inference_id {
264            return value.default_rewrite(self);
265        }
266        let next_id = LocalImplVarId(self.to_canonic.impl_var_mapping.len());
267
268        let mut var = ImplVar {
269            id: *self.to_canonic.impl_var_mapping.entry(var.id).or_insert(next_id),
270            inference_id: InferenceId::Canonical,
271            lookup_context: var.lookup_context,
272            concrete_trait_id: var.concrete_trait_id,
273        };
274        var.concrete_trait_id.default_rewrite(self)?;
275        *value = ImplLongId::ImplVar(var.intern(self.db));
276        Ok(RewriteResult::Modified)
277    }
278}
279
280/// Embedder rewriter. Each canonical variable is mapped to a new inference variable.
281struct Embedder<'a, 'db> {
282    inference: &'a mut Inference<'db>,
283    from_canonic: VarMapping,
284}
285impl<'a, 'db> Embedder<'a, 'db> {
286    fn embed<T>(inference: &'a mut Inference<'db>, value: T) -> (T, CanonicalMapping)
287    where
288        Self: SemanticRewriter<T, NoError>,
289    {
290        let from_canonic = VarMapping::new_from_canonic(inference.inference_id);
291        let mut embedder = Self { inference, from_canonic };
292        let value = embedder.rewrite(value).no_err();
293        let mapping = CanonicalMapping::from_from_canonic(embedder.from_canonic);
294        (value, mapping)
295    }
296}
297
298impl<'a> HasDb<&'a dyn SemanticGroup> for Embedder<'a, '_> {
299    fn get_db(&self) -> &'a dyn SemanticGroup {
300        self.inference.db
301    }
302}
303
304add_basic_rewrites!(
305    <'a,'b>,
306    Embedder<'a,'b>,
307    NoError,
308    @exclude TypeLongId TypeId ConstValue ImplLongId ImplId
309);
310
311impl SemanticRewriter<TypeId, NoError> for Embedder<'_, '_> {
312    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
313        if value.is_var_free(self.get_db()) {
314            return Ok(RewriteResult::NoChange);
315        }
316        value.default_rewrite(self)
317    }
318}
319impl SemanticRewriter<TypeLongId, NoError> for Embedder<'_, '_> {
320    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
321        let TypeLongId::Var(var) = value else {
322            return value.default_rewrite(self);
323        };
324        if var.inference_id != InferenceId::Canonical {
325            return value.default_rewrite(self);
326        }
327        let new_id = self
328            .from_canonic
329            .type_var_mapping
330            .entry(var.id)
331            .or_insert_with(|| self.inference.new_type_var_raw(None).id);
332        *value = TypeLongId::Var(self.inference.type_vars[new_id.0]);
333        Ok(RewriteResult::Modified)
334    }
335}
336impl SemanticRewriter<ConstValue, NoError> for Embedder<'_, '_> {
337    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
338        let ConstValue::Var(var, mut ty) = value else {
339            return value.default_rewrite(self);
340        };
341        if var.inference_id != InferenceId::Canonical {
342            return value.default_rewrite(self);
343        }
344        ty.default_rewrite(self)?;
345        let new_id = self
346            .from_canonic
347            .const_var_mapping
348            .entry(var.id)
349            .or_insert_with(|| self.inference.new_const_var_raw(None).id);
350        *value = ConstValue::Var(self.inference.const_vars[new_id.0], ty);
351        Ok(RewriteResult::Modified)
352    }
353}
354impl SemanticRewriter<ImplId, NoError> for Embedder<'_, '_> {
355    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
356        if value.is_var_free(self.get_db()) {
357            return Ok(RewriteResult::NoChange);
358        }
359        value.default_rewrite(self)
360    }
361}
362impl SemanticRewriter<ImplLongId, NoError> for Embedder<'_, '_> {
363    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
364        let ImplLongId::ImplVar(var_id) = value else {
365            if value.is_var_free(self.get_db()) {
366                return Ok(RewriteResult::NoChange);
367            }
368            return value.default_rewrite(self);
369        };
370        let var = var_id.lookup_intern(self.get_db());
371        if var.inference_id != InferenceId::Canonical {
372            return value.default_rewrite(self);
373        }
374        let concrete_trait_id = self.rewrite(var.concrete_trait_id)?;
375        let new_id = self.from_canonic.impl_var_mapping.entry(var.id).or_insert_with(|| {
376            self.inference.new_impl_var_raw(var.lookup_context.clone(), concrete_trait_id, None)
377        });
378        *value = ImplLongId::ImplVar(self.inference.impl_vars[new_id.0].intern(self.get_db()));
379        Ok(RewriteResult::Modified)
380    }
381}
382
383/// Mapper rewriter. Maps variables according to a given [VarMapping].
384#[derive(Clone, Debug)]
385pub struct MapperError(pub InferenceVar);
386struct Mapper<'db> {
387    db: &'db dyn SemanticGroup,
388    mapping: &'db VarMapping,
389}
390impl<'db> Mapper<'db> {
391    fn map<T>(
392        db: &'db dyn SemanticGroup,
393        value: T,
394        mapping: &'db VarMapping,
395    ) -> Result<T, MapperError>
396    where
397        Self: SemanticRewriter<T, MapperError>,
398    {
399        let mut mapper = Self { db, mapping };
400        mapper.rewrite(value)
401    }
402}
403
404impl<'db> HasDb<&'db dyn SemanticGroup> for Mapper<'db> {
405    fn get_db(&self) -> &'db dyn SemanticGroup {
406        self.db
407    }
408}
409
410add_basic_rewrites!(
411    <'a>,
412    Mapper<'a>,
413    MapperError,
414    @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
415);
416
417impl SemanticRewriter<TypeId, MapperError> for Mapper<'_> {
418    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, MapperError> {
419        if value.is_var_free(self.db) {
420            return Ok(RewriteResult::NoChange);
421        }
422        value.default_rewrite(self)
423    }
424}
425impl SemanticRewriter<TypeLongId, MapperError> for Mapper<'_> {
426    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, MapperError> {
427        let TypeLongId::Var(var) = value else {
428            return value.default_rewrite(self);
429        };
430        let id = self
431            .mapping
432            .type_var_mapping
433            .get(&var.id)
434            .copied()
435            .ok_or(MapperError(InferenceVar::Type(var.id)))?;
436        *value = TypeLongId::Var(TypeVar { id, inference_id: self.mapping.target_inference_id });
437        Ok(RewriteResult::Modified)
438    }
439}
440impl SemanticRewriter<ConstValue, MapperError> for Mapper<'_> {
441    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, MapperError> {
442        let ConstValue::Var(var, mut ty) = value else {
443            return value.default_rewrite(self);
444        };
445        let id = self
446            .mapping
447            .const_var_mapping
448            .get(&var.id)
449            .copied()
450            .ok_or(MapperError(InferenceVar::Const(var.id)))?;
451        ty.default_rewrite(self)?;
452        *value =
453            ConstValue::Var(ConstVar { id, inference_id: self.mapping.target_inference_id }, ty);
454        Ok(RewriteResult::Modified)
455    }
456}
457impl SemanticRewriter<ImplId, MapperError> for Mapper<'_> {
458    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, MapperError> {
459        if value.is_var_free(self.db) {
460            return Ok(RewriteResult::NoChange);
461        }
462        value.default_rewrite(self)
463    }
464}
465impl SemanticRewriter<ImplLongId, MapperError> for Mapper<'_> {
466    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, MapperError> {
467        let ImplLongId::ImplVar(var_id) = value else {
468            return value.default_rewrite(self);
469        };
470        let var = var_id.lookup_intern(self.get_db());
471        let id = self
472            .mapping
473            .impl_var_mapping
474            .get(&var.id)
475            .copied()
476            .ok_or(MapperError(InferenceVar::Impl(var.id)))?;
477        let var = ImplVar { id, inference_id: self.mapping.target_inference_id, ..var };
478
479        *value = ImplLongId::ImplVar(var.intern(self.get_db()));
480        Ok(RewriteResult::Modified)
481    }
482}