cairo_lang_semantic/expr/inference/
canonic.rs

1use cairo_lang_defs::ids::{
2    EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId, ImplAliasId, ImplDefId,
3    ImplFunctionId, ImplImplDefId, LocalVarId, MemberId, ParamId, StructId, TraitConstantId,
4    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
5};
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9
10use super::{
11    ConstVar, ImplVar, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceId, InferenceVar,
12    LocalConstVarId, LocalImplVarId, LocalTypeVarId, TypeVar,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
16use crate::items::functions::{
17    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
18    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
19    ImplGenericFunctionWithBodyId,
20};
21use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
22use crate::items::imp::{
23    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
24    UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
25};
26use crate::items::trt::{
27    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
28    ConcreteTraitTypeLongId,
29};
30use crate::substitution::{HasDb, RewriteResult, SemanticObject, SemanticRewriter};
31use crate::types::{
32    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
33    ImplTypeId,
34};
35use crate::{
36    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
37    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
38    ExprId, ExprVar, ExprVarMemberPath, FunctionId, FunctionLongId, GenericArgumentId,
39    GenericParam, MatchArmSelector, Parameter, Signature, TypeId, TypeLongId, ValueSelectorArm,
40    add_basic_rewrites,
41};
42
43/// A canonical representation of a concrete trait that needs to be solved.
44#[derive(Clone, PartialEq, Hash, Eq, Debug, SemanticObject)]
45pub struct CanonicalTrait {
46    pub id: ConcreteTraitId,
47    pub mappings: ImplVarTraitItemMappings,
48}
49
50impl CanonicalTrait {
51    /// Canonicalizes a concrete trait that is part of an [Inference].
52    pub fn canonicalize(
53        db: &dyn SemanticGroup,
54        source_inference_id: InferenceId,
55        trait_id: ConcreteTraitId,
56        impl_var_mappings: ImplVarTraitItemMappings,
57    ) -> (Self, CanonicalMapping) {
58        Canonicalizer::canonicalize(
59            db,
60            source_inference_id,
61            Self { id: trait_id, mappings: impl_var_mappings },
62        )
63    }
64    /// Embeds a canonical trait into an [Inference].
65    pub fn embed(&self, inference: &mut Inference<'_>) -> (CanonicalTrait, CanonicalMapping) {
66        Embedder::embed(inference, self.clone())
67    }
68}
69
70/// A solution for a [CanonicalTrait].
71#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
72pub struct CanonicalImpl(pub ImplId);
73impl CanonicalImpl {
74    /// Canonicalizes a concrete impl that is part of an [Inference].
75    /// Uses the same canonicalization of the trait, to be consistent.
76    pub fn canonicalize(
77        db: &dyn SemanticGroup,
78        impl_id: ImplId,
79        mapping: &CanonicalMapping,
80    ) -> Result<Self, MapperError> {
81        Ok(Self(Mapper::map(db, impl_id, &mapping.to_canonic)?))
82    }
83    /// Embeds a canonical impl into an [Inference].
84    /// Uses the same embedding of the trait, to be consistent.
85    pub fn embed(&self, inference: &Inference<'_>, mapping: &CanonicalMapping) -> ImplId {
86        Mapper::map(inference.db, self.0, &mapping.from_canonic)
87            .expect("Tried to embed a non canonical impl")
88    }
89}
90
91/// Mapping between canonical space and inference space.
92/// Created by either canonicalizing or embedding a trait.
93#[derive(Debug)]
94pub struct CanonicalMapping {
95    to_canonic: VarMapping,
96    from_canonic: VarMapping,
97}
98impl CanonicalMapping {
99    fn from_to_canonic(to_canonic: VarMapping) -> CanonicalMapping {
100        let from_canonic = VarMapping {
101            type_var_mapping: to_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
102            const_var_mapping: to_canonic.const_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
103            impl_var_mapping: to_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
104            source_inference_id: to_canonic.target_inference_id,
105            target_inference_id: to_canonic.source_inference_id,
106        };
107        Self { to_canonic, from_canonic }
108    }
109    fn from_from_canonic(from_canonic: VarMapping) -> CanonicalMapping {
110        let to_canonic = VarMapping {
111            type_var_mapping: from_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
112            const_var_mapping: from_canonic
113                .const_var_mapping
114                .iter()
115                .map(|(k, v)| (*v, *k))
116                .collect(),
117            impl_var_mapping: from_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
118            source_inference_id: from_canonic.target_inference_id,
119            target_inference_id: from_canonic.source_inference_id,
120        };
121        Self { to_canonic, from_canonic }
122    }
123}
124
125// Mappings.
126#[derive(Debug)]
127pub struct VarMapping {
128    type_var_mapping: OrderedHashMap<LocalTypeVarId, LocalTypeVarId>,
129    const_var_mapping: OrderedHashMap<LocalConstVarId, LocalConstVarId>,
130    impl_var_mapping: OrderedHashMap<LocalImplVarId, LocalImplVarId>,
131    source_inference_id: InferenceId,
132    target_inference_id: InferenceId,
133}
134impl VarMapping {
135    fn new_to_canonic(source_inference_id: InferenceId) -> Self {
136        Self {
137            type_var_mapping: OrderedHashMap::default(),
138            const_var_mapping: OrderedHashMap::default(),
139            impl_var_mapping: OrderedHashMap::default(),
140            source_inference_id,
141            target_inference_id: InferenceId::Canonical,
142        }
143    }
144    fn new_from_canonic(target_inference_id: InferenceId) -> Self {
145        Self {
146            type_var_mapping: OrderedHashMap::default(),
147            const_var_mapping: OrderedHashMap::default(),
148            impl_var_mapping: OrderedHashMap::default(),
149            source_inference_id: InferenceId::Canonical,
150            target_inference_id,
151        }
152    }
153}
154
155/// A 'never' error.
156#[derive(Debug)]
157pub enum NoError {}
158pub trait ResultNoErrEx<T> {
159    fn no_err(self) -> T;
160}
161impl<T> ResultNoErrEx<T> for Result<T, NoError> {
162    fn no_err(self) -> T {
163        match self {
164            Ok(v) => v,
165            #[allow(unreachable_patterns)]
166            Err(err) => match err {},
167        }
168    }
169}
170
171/// Canonicalization rewriter. Each encountered variable is mapped to a new free variable,
172/// in pre-order.
173struct Canonicalizer<'db> {
174    db: &'db dyn SemanticGroup,
175    to_canonic: VarMapping,
176}
177impl<'db> Canonicalizer<'db> {
178    fn canonicalize<T>(
179        db: &'db dyn SemanticGroup,
180        source_inference_id: InferenceId,
181        value: T,
182    ) -> (T, CanonicalMapping)
183    where
184        Self: SemanticRewriter<T, NoError>,
185    {
186        let mut canonicalizer =
187            Self { db, to_canonic: VarMapping::new_to_canonic(source_inference_id) };
188        let value = canonicalizer.rewrite(value).no_err();
189        let mapping = CanonicalMapping::from_to_canonic(canonicalizer.to_canonic);
190        (value, mapping)
191    }
192}
193impl<'a> HasDb<&'a dyn SemanticGroup> for Canonicalizer<'a> {
194    fn get_db(&self) -> &'a dyn SemanticGroup {
195        self.db
196    }
197}
198
199add_basic_rewrites!(
200    <'a>,
201    Canonicalizer<'a>,
202    NoError,
203    @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
204);
205
206impl SemanticRewriter<TypeId, NoError> for Canonicalizer<'_> {
207    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
208        if value.is_var_free(self.db) {
209            return Ok(RewriteResult::NoChange);
210        }
211        value.default_rewrite(self)
212    }
213}
214impl SemanticRewriter<TypeLongId, NoError> for Canonicalizer<'_> {
215    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
216        let TypeLongId::Var(var) = value else {
217            return value.default_rewrite(self);
218        };
219        if var.inference_id != self.to_canonic.source_inference_id {
220            return value.default_rewrite(self);
221        }
222        let next_id = LocalTypeVarId(self.to_canonic.type_var_mapping.len());
223        *value = TypeLongId::Var(TypeVar {
224            id: *self.to_canonic.type_var_mapping.entry(var.id).or_insert(next_id),
225            inference_id: InferenceId::Canonical,
226        });
227        Ok(RewriteResult::Modified)
228    }
229}
230impl SemanticRewriter<ConstValue, NoError> for Canonicalizer<'_> {
231    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
232        let ConstValue::Var(var, mut ty) = value else {
233            return value.default_rewrite(self);
234        };
235        if var.inference_id != self.to_canonic.source_inference_id {
236            return value.default_rewrite(self);
237        }
238        let next_id = LocalConstVarId(self.to_canonic.const_var_mapping.len());
239        ty.default_rewrite(self)?;
240        *value = ConstValue::Var(
241            ConstVar {
242                id: *self.to_canonic.const_var_mapping.entry(var.id).or_insert(next_id),
243                inference_id: InferenceId::Canonical,
244            },
245            ty,
246        );
247        Ok(RewriteResult::Modified)
248    }
249}
250impl SemanticRewriter<ImplId, NoError> for Canonicalizer<'_> {
251    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
252        if value.is_var_free(self.db) {
253            return Ok(RewriteResult::NoChange);
254        }
255        value.default_rewrite(self)
256    }
257}
258impl SemanticRewriter<ImplLongId, NoError> for Canonicalizer<'_> {
259    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
260        let ImplLongId::ImplVar(var_id) = value else {
261            if value.is_var_free(self.db) {
262                return Ok(RewriteResult::NoChange);
263            }
264            return value.default_rewrite(self);
265        };
266        let var = var_id.lookup_intern(self.db);
267        if var.inference_id != self.to_canonic.source_inference_id {
268            return value.default_rewrite(self);
269        }
270        let next_id = LocalImplVarId(self.to_canonic.impl_var_mapping.len());
271
272        let mut var = ImplVar {
273            id: *self.to_canonic.impl_var_mapping.entry(var.id).or_insert(next_id),
274            inference_id: InferenceId::Canonical,
275            lookup_context: var.lookup_context,
276            concrete_trait_id: var.concrete_trait_id,
277        };
278        var.concrete_trait_id.default_rewrite(self)?;
279        *value = ImplLongId::ImplVar(var.intern(self.db));
280        Ok(RewriteResult::Modified)
281    }
282}
283
284/// Embedder rewriter. Each canonical variable is mapped to a new inference variable.
285struct Embedder<'a, 'db> {
286    inference: &'a mut Inference<'db>,
287    from_canonic: VarMapping,
288}
289impl<'a, 'db> Embedder<'a, 'db> {
290    fn embed<T>(inference: &'a mut Inference<'db>, value: T) -> (T, CanonicalMapping)
291    where
292        Self: SemanticRewriter<T, NoError>,
293    {
294        let from_canonic = VarMapping::new_from_canonic(inference.inference_id);
295        let mut embedder = Self { inference, from_canonic };
296        let value = embedder.rewrite(value).no_err();
297        let mapping = CanonicalMapping::from_from_canonic(embedder.from_canonic);
298        (value, mapping)
299    }
300}
301
302impl<'a> HasDb<&'a dyn SemanticGroup> for Embedder<'a, '_> {
303    fn get_db(&self) -> &'a dyn SemanticGroup {
304        self.inference.db
305    }
306}
307
308add_basic_rewrites!(
309    <'a,'b>,
310    Embedder<'a,'b>,
311    NoError,
312    @exclude TypeLongId TypeId ConstValue ImplLongId ImplId
313);
314
315impl SemanticRewriter<TypeId, NoError> for Embedder<'_, '_> {
316    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
317        if value.is_var_free(self.get_db()) {
318            return Ok(RewriteResult::NoChange);
319        }
320        value.default_rewrite(self)
321    }
322}
323impl SemanticRewriter<TypeLongId, NoError> for Embedder<'_, '_> {
324    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
325        let TypeLongId::Var(var) = value else {
326            return value.default_rewrite(self);
327        };
328        if var.inference_id != InferenceId::Canonical {
329            return value.default_rewrite(self);
330        }
331        let new_id = self
332            .from_canonic
333            .type_var_mapping
334            .entry(var.id)
335            .or_insert_with(|| self.inference.new_type_var_raw(None).id);
336        *value = TypeLongId::Var(self.inference.type_vars[new_id.0]);
337        Ok(RewriteResult::Modified)
338    }
339}
340impl SemanticRewriter<ConstValue, NoError> for Embedder<'_, '_> {
341    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
342        let ConstValue::Var(var, mut ty) = value else {
343            return value.default_rewrite(self);
344        };
345        if var.inference_id != InferenceId::Canonical {
346            return value.default_rewrite(self);
347        }
348        ty.default_rewrite(self)?;
349        let new_id = self
350            .from_canonic
351            .const_var_mapping
352            .entry(var.id)
353            .or_insert_with(|| self.inference.new_const_var_raw(None).id);
354        *value = ConstValue::Var(self.inference.const_vars[new_id.0], ty);
355        Ok(RewriteResult::Modified)
356    }
357}
358impl SemanticRewriter<ImplId, NoError> for Embedder<'_, '_> {
359    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
360        if value.is_var_free(self.get_db()) {
361            return Ok(RewriteResult::NoChange);
362        }
363        value.default_rewrite(self)
364    }
365}
366impl SemanticRewriter<ImplLongId, NoError> for Embedder<'_, '_> {
367    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
368        let ImplLongId::ImplVar(var_id) = value else {
369            if value.is_var_free(self.get_db()) {
370                return Ok(RewriteResult::NoChange);
371            }
372            return value.default_rewrite(self);
373        };
374        let var = var_id.lookup_intern(self.get_db());
375        if var.inference_id != InferenceId::Canonical {
376            return value.default_rewrite(self);
377        }
378        let concrete_trait_id = self.rewrite(var.concrete_trait_id)?;
379        let new_id = self.from_canonic.impl_var_mapping.entry(var.id).or_insert_with(|| {
380            self.inference.new_impl_var_raw(var.lookup_context.clone(), concrete_trait_id, None)
381        });
382        *value = ImplLongId::ImplVar(self.inference.impl_vars[new_id.0].intern(self.get_db()));
383        Ok(RewriteResult::Modified)
384    }
385}
386
387/// Mapper rewriter. Maps variables according to a given [VarMapping].
388#[derive(Clone, Debug)]
389pub struct MapperError(pub InferenceVar);
390struct Mapper<'db> {
391    db: &'db dyn SemanticGroup,
392    mapping: &'db VarMapping,
393}
394impl<'db> Mapper<'db> {
395    fn map<T>(
396        db: &'db dyn SemanticGroup,
397        value: T,
398        mapping: &'db VarMapping,
399    ) -> Result<T, MapperError>
400    where
401        Self: SemanticRewriter<T, MapperError>,
402    {
403        let mut mapper = Self { db, mapping };
404        mapper.rewrite(value)
405    }
406}
407
408impl<'db> HasDb<&'db dyn SemanticGroup> for Mapper<'db> {
409    fn get_db(&self) -> &'db dyn SemanticGroup {
410        self.db
411    }
412}
413
414add_basic_rewrites!(
415    <'a>,
416    Mapper<'a>,
417    MapperError,
418    @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
419);
420
421impl SemanticRewriter<TypeId, MapperError> for Mapper<'_> {
422    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, MapperError> {
423        if value.is_var_free(self.db) {
424            return Ok(RewriteResult::NoChange);
425        }
426        value.default_rewrite(self)
427    }
428}
429impl SemanticRewriter<TypeLongId, MapperError> for Mapper<'_> {
430    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, MapperError> {
431        let TypeLongId::Var(var) = value else {
432            return value.default_rewrite(self);
433        };
434        let id = self
435            .mapping
436            .type_var_mapping
437            .get(&var.id)
438            .copied()
439            .ok_or(MapperError(InferenceVar::Type(var.id)))?;
440        *value = TypeLongId::Var(TypeVar { id, inference_id: self.mapping.target_inference_id });
441        Ok(RewriteResult::Modified)
442    }
443}
444impl SemanticRewriter<ConstValue, MapperError> for Mapper<'_> {
445    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, MapperError> {
446        let ConstValue::Var(var, mut ty) = value else {
447            return value.default_rewrite(self);
448        };
449        let id = self
450            .mapping
451            .const_var_mapping
452            .get(&var.id)
453            .copied()
454            .ok_or(MapperError(InferenceVar::Const(var.id)))?;
455        ty.default_rewrite(self)?;
456        *value =
457            ConstValue::Var(ConstVar { id, inference_id: self.mapping.target_inference_id }, ty);
458        Ok(RewriteResult::Modified)
459    }
460}
461impl SemanticRewriter<ImplId, MapperError> for Mapper<'_> {
462    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, MapperError> {
463        if value.is_var_free(self.db) {
464            return Ok(RewriteResult::NoChange);
465        }
466        value.default_rewrite(self)
467    }
468}
469impl SemanticRewriter<ImplLongId, MapperError> for Mapper<'_> {
470    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, MapperError> {
471        let ImplLongId::ImplVar(var_id) = value else {
472            return value.default_rewrite(self);
473        };
474        let var = var_id.lookup_intern(self.get_db());
475        let id = self
476            .mapping
477            .impl_var_mapping
478            .get(&var.id)
479            .copied()
480            .ok_or(MapperError(InferenceVar::Impl(var.id)))?;
481        let var = ImplVar { id, inference_id: self.mapping.target_inference_id, ..var };
482
483        *value = ImplLongId::ImplVar(var.intern(self.get_db()));
484        Ok(RewriteResult::Modified)
485    }
486}