1use cairo_lang_defs::ids::{
2 EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId, ImplAliasId, ImplDefId,
3 ImplFunctionId, ImplImplDefId, LocalVarId, MemberId, ParamId, StructId, TraitConstantId,
4 TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
5};
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9
10use super::{
11 ConstVar, ImplVar, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceId, InferenceVar,
12 LocalConstVarId, LocalImplVarId, LocalTypeVarId, TypeVar,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
16use crate::items::functions::{
17 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
18 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
19 ImplGenericFunctionWithBodyId,
20};
21use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
22use crate::items::imp::{
23 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
24 UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
25};
26use crate::items::trt::{ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId};
27use crate::substitution::{HasDb, RewriteResult, SemanticObject, SemanticRewriter};
28use crate::types::{
29 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
30 ImplTypeId,
31};
32use crate::{
33 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
34 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
35 ExprId, ExprVar, ExprVarMemberPath, FunctionId, FunctionLongId, GenericArgumentId,
36 GenericParam, MatchArmSelector, Parameter, Signature, TypeId, TypeLongId, ValueSelectorArm,
37 add_basic_rewrites,
38};
39
40#[derive(Clone, PartialEq, Hash, Eq, Debug, SemanticObject)]
42pub struct CanonicalTrait {
43 pub id: ConcreteTraitId,
44 pub mappings: ImplVarTraitItemMappings,
45}
46
47impl CanonicalTrait {
48 pub fn canonicalize(
50 db: &dyn SemanticGroup,
51 source_inference_id: InferenceId,
52 trait_id: ConcreteTraitId,
53 impl_var_mappings: ImplVarTraitItemMappings,
54 ) -> (Self, CanonicalMapping) {
55 Canonicalizer::canonicalize(db, source_inference_id, Self {
56 id: trait_id,
57 mappings: impl_var_mappings,
58 })
59 }
60 pub fn embed(&self, inference: &mut Inference<'_>) -> (CanonicalTrait, CanonicalMapping) {
62 Embedder::embed(inference, self.clone())
63 }
64}
65
66#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
68pub struct CanonicalImpl(pub ImplId);
69impl CanonicalImpl {
70 pub fn canonicalize(
73 db: &dyn SemanticGroup,
74 impl_id: ImplId,
75 mapping: &CanonicalMapping,
76 ) -> Result<Self, MapperError> {
77 Ok(Self(Mapper::map(db, impl_id, &mapping.to_canonic)?))
78 }
79 pub fn embed(&self, inference: &Inference<'_>, mapping: &CanonicalMapping) -> ImplId {
82 Mapper::map(inference.db, self.0, &mapping.from_canonic)
83 .expect("Tried to embed a non canonical impl")
84 }
85}
86
87#[derive(Debug)]
90pub struct CanonicalMapping {
91 to_canonic: VarMapping,
92 from_canonic: VarMapping,
93}
94impl CanonicalMapping {
95 fn from_to_canonic(to_canonic: VarMapping) -> CanonicalMapping {
96 let from_canonic = VarMapping {
97 type_var_mapping: to_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
98 const_var_mapping: to_canonic.const_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
99 impl_var_mapping: to_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
100 source_inference_id: to_canonic.target_inference_id,
101 target_inference_id: to_canonic.source_inference_id,
102 };
103 Self { to_canonic, from_canonic }
104 }
105 fn from_from_canonic(from_canonic: VarMapping) -> CanonicalMapping {
106 let to_canonic = VarMapping {
107 type_var_mapping: from_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
108 const_var_mapping: from_canonic
109 .const_var_mapping
110 .iter()
111 .map(|(k, v)| (*v, *k))
112 .collect(),
113 impl_var_mapping: from_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
114 source_inference_id: from_canonic.target_inference_id,
115 target_inference_id: from_canonic.source_inference_id,
116 };
117 Self { to_canonic, from_canonic }
118 }
119}
120
121#[derive(Debug)]
123pub struct VarMapping {
124 type_var_mapping: OrderedHashMap<LocalTypeVarId, LocalTypeVarId>,
125 const_var_mapping: OrderedHashMap<LocalConstVarId, LocalConstVarId>,
126 impl_var_mapping: OrderedHashMap<LocalImplVarId, LocalImplVarId>,
127 source_inference_id: InferenceId,
128 target_inference_id: InferenceId,
129}
130impl VarMapping {
131 fn new_to_canonic(source_inference_id: InferenceId) -> Self {
132 Self {
133 type_var_mapping: OrderedHashMap::default(),
134 const_var_mapping: OrderedHashMap::default(),
135 impl_var_mapping: OrderedHashMap::default(),
136 source_inference_id,
137 target_inference_id: InferenceId::Canonical,
138 }
139 }
140 fn new_from_canonic(target_inference_id: InferenceId) -> Self {
141 Self {
142 type_var_mapping: OrderedHashMap::default(),
143 const_var_mapping: OrderedHashMap::default(),
144 impl_var_mapping: OrderedHashMap::default(),
145 source_inference_id: InferenceId::Canonical,
146 target_inference_id,
147 }
148 }
149}
150
151#[derive(Debug)]
153pub enum NoError {}
154pub trait ResultNoErrEx<T> {
155 fn no_err(self) -> T;
156}
157impl<T> ResultNoErrEx<T> for Result<T, NoError> {
158 fn no_err(self) -> T {
159 match self {
160 Ok(v) => v,
161 #[allow(unreachable_patterns)]
162 Err(err) => match err {},
163 }
164 }
165}
166
167struct Canonicalizer<'db> {
170 db: &'db dyn SemanticGroup,
171 to_canonic: VarMapping,
172}
173impl<'db> Canonicalizer<'db> {
174 fn canonicalize<T>(
175 db: &'db dyn SemanticGroup,
176 source_inference_id: InferenceId,
177 value: T,
178 ) -> (T, CanonicalMapping)
179 where
180 Self: SemanticRewriter<T, NoError>,
181 {
182 let mut canonicalizer =
183 Self { db, to_canonic: VarMapping::new_to_canonic(source_inference_id) };
184 let value = canonicalizer.rewrite(value).no_err();
185 let mapping = CanonicalMapping::from_to_canonic(canonicalizer.to_canonic);
186 (value, mapping)
187 }
188}
189impl<'a> HasDb<&'a dyn SemanticGroup> for Canonicalizer<'a> {
190 fn get_db(&self) -> &'a dyn SemanticGroup {
191 self.db
192 }
193}
194
195add_basic_rewrites!(
196 <'a>,
197 Canonicalizer<'a>,
198 NoError,
199 @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
200);
201
202impl SemanticRewriter<TypeId, NoError> for Canonicalizer<'_> {
203 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
204 if value.is_var_free(self.db) {
205 return Ok(RewriteResult::NoChange);
206 }
207 value.default_rewrite(self)
208 }
209}
210impl SemanticRewriter<TypeLongId, NoError> for Canonicalizer<'_> {
211 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
212 let TypeLongId::Var(var) = value else {
213 return value.default_rewrite(self);
214 };
215 if var.inference_id != self.to_canonic.source_inference_id {
216 return value.default_rewrite(self);
217 }
218 let next_id = LocalTypeVarId(self.to_canonic.type_var_mapping.len());
219 *value = TypeLongId::Var(TypeVar {
220 id: *self.to_canonic.type_var_mapping.entry(var.id).or_insert(next_id),
221 inference_id: InferenceId::Canonical,
222 });
223 Ok(RewriteResult::Modified)
224 }
225}
226impl SemanticRewriter<ConstValue, NoError> for Canonicalizer<'_> {
227 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
228 let ConstValue::Var(var, mut ty) = value else {
229 return value.default_rewrite(self);
230 };
231 if var.inference_id != self.to_canonic.source_inference_id {
232 return value.default_rewrite(self);
233 }
234 let next_id = LocalConstVarId(self.to_canonic.const_var_mapping.len());
235 ty.default_rewrite(self)?;
236 *value = ConstValue::Var(
237 ConstVar {
238 id: *self.to_canonic.const_var_mapping.entry(var.id).or_insert(next_id),
239 inference_id: InferenceId::Canonical,
240 },
241 ty,
242 );
243 Ok(RewriteResult::Modified)
244 }
245}
246impl SemanticRewriter<ImplId, NoError> for Canonicalizer<'_> {
247 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
248 if value.is_var_free(self.db) {
249 return Ok(RewriteResult::NoChange);
250 }
251 value.default_rewrite(self)
252 }
253}
254impl SemanticRewriter<ImplLongId, NoError> for Canonicalizer<'_> {
255 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
256 let ImplLongId::ImplVar(var_id) = value else {
257 if value.is_var_free(self.db) {
258 return Ok(RewriteResult::NoChange);
259 }
260 return value.default_rewrite(self);
261 };
262 let var = var_id.lookup_intern(self.db);
263 if var.inference_id != self.to_canonic.source_inference_id {
264 return value.default_rewrite(self);
265 }
266 let next_id = LocalImplVarId(self.to_canonic.impl_var_mapping.len());
267
268 let mut var = ImplVar {
269 id: *self.to_canonic.impl_var_mapping.entry(var.id).or_insert(next_id),
270 inference_id: InferenceId::Canonical,
271 lookup_context: var.lookup_context,
272 concrete_trait_id: var.concrete_trait_id,
273 };
274 var.concrete_trait_id.default_rewrite(self)?;
275 *value = ImplLongId::ImplVar(var.intern(self.db));
276 Ok(RewriteResult::Modified)
277 }
278}
279
280struct Embedder<'a, 'db> {
282 inference: &'a mut Inference<'db>,
283 from_canonic: VarMapping,
284}
285impl<'a, 'db> Embedder<'a, 'db> {
286 fn embed<T>(inference: &'a mut Inference<'db>, value: T) -> (T, CanonicalMapping)
287 where
288 Self: SemanticRewriter<T, NoError>,
289 {
290 let from_canonic = VarMapping::new_from_canonic(inference.inference_id);
291 let mut embedder = Self { inference, from_canonic };
292 let value = embedder.rewrite(value).no_err();
293 let mapping = CanonicalMapping::from_from_canonic(embedder.from_canonic);
294 (value, mapping)
295 }
296}
297
298impl<'a> HasDb<&'a dyn SemanticGroup> for Embedder<'a, '_> {
299 fn get_db(&self) -> &'a dyn SemanticGroup {
300 self.inference.db
301 }
302}
303
304add_basic_rewrites!(
305 <'a,'b>,
306 Embedder<'a,'b>,
307 NoError,
308 @exclude TypeLongId TypeId ConstValue ImplLongId ImplId
309);
310
311impl SemanticRewriter<TypeId, NoError> for Embedder<'_, '_> {
312 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
313 if value.is_var_free(self.get_db()) {
314 return Ok(RewriteResult::NoChange);
315 }
316 value.default_rewrite(self)
317 }
318}
319impl SemanticRewriter<TypeLongId, NoError> for Embedder<'_, '_> {
320 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
321 let TypeLongId::Var(var) = value else {
322 return value.default_rewrite(self);
323 };
324 if var.inference_id != InferenceId::Canonical {
325 return value.default_rewrite(self);
326 }
327 let new_id = self
328 .from_canonic
329 .type_var_mapping
330 .entry(var.id)
331 .or_insert_with(|| self.inference.new_type_var_raw(None).id);
332 *value = TypeLongId::Var(self.inference.type_vars[new_id.0]);
333 Ok(RewriteResult::Modified)
334 }
335}
336impl SemanticRewriter<ConstValue, NoError> for Embedder<'_, '_> {
337 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
338 let ConstValue::Var(var, mut ty) = value else {
339 return value.default_rewrite(self);
340 };
341 if var.inference_id != InferenceId::Canonical {
342 return value.default_rewrite(self);
343 }
344 ty.default_rewrite(self)?;
345 let new_id = self
346 .from_canonic
347 .const_var_mapping
348 .entry(var.id)
349 .or_insert_with(|| self.inference.new_const_var_raw(None).id);
350 *value = ConstValue::Var(self.inference.const_vars[new_id.0], ty);
351 Ok(RewriteResult::Modified)
352 }
353}
354impl SemanticRewriter<ImplId, NoError> for Embedder<'_, '_> {
355 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
356 if value.is_var_free(self.get_db()) {
357 return Ok(RewriteResult::NoChange);
358 }
359 value.default_rewrite(self)
360 }
361}
362impl SemanticRewriter<ImplLongId, NoError> for Embedder<'_, '_> {
363 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
364 let ImplLongId::ImplVar(var_id) = value else {
365 if value.is_var_free(self.get_db()) {
366 return Ok(RewriteResult::NoChange);
367 }
368 return value.default_rewrite(self);
369 };
370 let var = var_id.lookup_intern(self.get_db());
371 if var.inference_id != InferenceId::Canonical {
372 return value.default_rewrite(self);
373 }
374 let concrete_trait_id = self.rewrite(var.concrete_trait_id)?;
375 let new_id = self.from_canonic.impl_var_mapping.entry(var.id).or_insert_with(|| {
376 self.inference.new_impl_var_raw(var.lookup_context.clone(), concrete_trait_id, None)
377 });
378 *value = ImplLongId::ImplVar(self.inference.impl_vars[new_id.0].intern(self.get_db()));
379 Ok(RewriteResult::Modified)
380 }
381}
382
383#[derive(Clone, Debug)]
385pub struct MapperError(pub InferenceVar);
386struct Mapper<'db> {
387 db: &'db dyn SemanticGroup,
388 mapping: &'db VarMapping,
389}
390impl<'db> Mapper<'db> {
391 fn map<T>(
392 db: &'db dyn SemanticGroup,
393 value: T,
394 mapping: &'db VarMapping,
395 ) -> Result<T, MapperError>
396 where
397 Self: SemanticRewriter<T, MapperError>,
398 {
399 let mut mapper = Self { db, mapping };
400 mapper.rewrite(value)
401 }
402}
403
404impl<'db> HasDb<&'db dyn SemanticGroup> for Mapper<'db> {
405 fn get_db(&self) -> &'db dyn SemanticGroup {
406 self.db
407 }
408}
409
410add_basic_rewrites!(
411 <'a>,
412 Mapper<'a>,
413 MapperError,
414 @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
415);
416
417impl SemanticRewriter<TypeId, MapperError> for Mapper<'_> {
418 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, MapperError> {
419 if value.is_var_free(self.db) {
420 return Ok(RewriteResult::NoChange);
421 }
422 value.default_rewrite(self)
423 }
424}
425impl SemanticRewriter<TypeLongId, MapperError> for Mapper<'_> {
426 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, MapperError> {
427 let TypeLongId::Var(var) = value else {
428 return value.default_rewrite(self);
429 };
430 let id = self
431 .mapping
432 .type_var_mapping
433 .get(&var.id)
434 .copied()
435 .ok_or(MapperError(InferenceVar::Type(var.id)))?;
436 *value = TypeLongId::Var(TypeVar { id, inference_id: self.mapping.target_inference_id });
437 Ok(RewriteResult::Modified)
438 }
439}
440impl SemanticRewriter<ConstValue, MapperError> for Mapper<'_> {
441 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, MapperError> {
442 let ConstValue::Var(var, mut ty) = value else {
443 return value.default_rewrite(self);
444 };
445 let id = self
446 .mapping
447 .const_var_mapping
448 .get(&var.id)
449 .copied()
450 .ok_or(MapperError(InferenceVar::Const(var.id)))?;
451 ty.default_rewrite(self)?;
452 *value =
453 ConstValue::Var(ConstVar { id, inference_id: self.mapping.target_inference_id }, ty);
454 Ok(RewriteResult::Modified)
455 }
456}
457impl SemanticRewriter<ImplId, MapperError> for Mapper<'_> {
458 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, MapperError> {
459 if value.is_var_free(self.db) {
460 return Ok(RewriteResult::NoChange);
461 }
462 value.default_rewrite(self)
463 }
464}
465impl SemanticRewriter<ImplLongId, MapperError> for Mapper<'_> {
466 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, MapperError> {
467 let ImplLongId::ImplVar(var_id) = value else {
468 return value.default_rewrite(self);
469 };
470 let var = var_id.lookup_intern(self.get_db());
471 let id = self
472 .mapping
473 .impl_var_mapping
474 .get(&var.id)
475 .copied()
476 .ok_or(MapperError(InferenceVar::Impl(var.id)))?;
477 let var = ImplVar { id, inference_id: self.mapping.target_inference_id, ..var };
478
479 *value = ImplLongId::ImplVar(var.intern(self.get_db()));
480 Ok(RewriteResult::Modified)
481 }
482}