1use cairo_lang_defs::ids::{
2 EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId, ImplAliasId, ImplDefId,
3 ImplFunctionId, ImplImplDefId, LocalVarId, MemberId, ParamId, StructId, TraitConstantId,
4 TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
5};
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9
10use super::{
11 ConstVar, ImplVar, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceId, InferenceVar,
12 LocalConstVarId, LocalImplVarId, LocalTypeVarId, TypeVar,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
16use crate::items::functions::{
17 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
18 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
19 ImplGenericFunctionWithBodyId,
20};
21use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
22use crate::items::imp::{
23 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
24 UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
25};
26use crate::items::trt::{
27 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
28 ConcreteTraitTypeLongId,
29};
30use crate::substitution::{HasDb, RewriteResult, SemanticObject, SemanticRewriter};
31use crate::types::{
32 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
33 ImplTypeId,
34};
35use crate::{
36 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
37 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
38 ExprId, ExprVar, ExprVarMemberPath, FunctionId, FunctionLongId, GenericArgumentId,
39 GenericParam, MatchArmSelector, Parameter, Signature, TypeId, TypeLongId, ValueSelectorArm,
40 add_basic_rewrites,
41};
42
43#[derive(Clone, PartialEq, Hash, Eq, Debug, SemanticObject)]
45pub struct CanonicalTrait {
46 pub id: ConcreteTraitId,
47 pub mappings: ImplVarTraitItemMappings,
48}
49
50impl CanonicalTrait {
51 pub fn canonicalize(
53 db: &dyn SemanticGroup,
54 source_inference_id: InferenceId,
55 trait_id: ConcreteTraitId,
56 impl_var_mappings: ImplVarTraitItemMappings,
57 ) -> (Self, CanonicalMapping) {
58 Canonicalizer::canonicalize(
59 db,
60 source_inference_id,
61 Self { id: trait_id, mappings: impl_var_mappings },
62 )
63 }
64 pub fn embed(&self, inference: &mut Inference<'_>) -> (CanonicalTrait, CanonicalMapping) {
66 Embedder::embed(inference, self.clone())
67 }
68}
69
70#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
72pub struct CanonicalImpl(pub ImplId);
73impl CanonicalImpl {
74 pub fn canonicalize(
77 db: &dyn SemanticGroup,
78 impl_id: ImplId,
79 mapping: &CanonicalMapping,
80 ) -> Result<Self, MapperError> {
81 Ok(Self(Mapper::map(db, impl_id, &mapping.to_canonic)?))
82 }
83 pub fn embed(&self, inference: &Inference<'_>, mapping: &CanonicalMapping) -> ImplId {
86 Mapper::map(inference.db, self.0, &mapping.from_canonic)
87 .expect("Tried to embed a non canonical impl")
88 }
89}
90
91#[derive(Debug)]
94pub struct CanonicalMapping {
95 to_canonic: VarMapping,
96 from_canonic: VarMapping,
97}
98impl CanonicalMapping {
99 fn from_to_canonic(to_canonic: VarMapping) -> CanonicalMapping {
100 let from_canonic = VarMapping {
101 type_var_mapping: to_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
102 const_var_mapping: to_canonic.const_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
103 impl_var_mapping: to_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
104 source_inference_id: to_canonic.target_inference_id,
105 target_inference_id: to_canonic.source_inference_id,
106 };
107 Self { to_canonic, from_canonic }
108 }
109 fn from_from_canonic(from_canonic: VarMapping) -> CanonicalMapping {
110 let to_canonic = VarMapping {
111 type_var_mapping: from_canonic.type_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
112 const_var_mapping: from_canonic
113 .const_var_mapping
114 .iter()
115 .map(|(k, v)| (*v, *k))
116 .collect(),
117 impl_var_mapping: from_canonic.impl_var_mapping.iter().map(|(k, v)| (*v, *k)).collect(),
118 source_inference_id: from_canonic.target_inference_id,
119 target_inference_id: from_canonic.source_inference_id,
120 };
121 Self { to_canonic, from_canonic }
122 }
123}
124
125#[derive(Debug)]
127pub struct VarMapping {
128 type_var_mapping: OrderedHashMap<LocalTypeVarId, LocalTypeVarId>,
129 const_var_mapping: OrderedHashMap<LocalConstVarId, LocalConstVarId>,
130 impl_var_mapping: OrderedHashMap<LocalImplVarId, LocalImplVarId>,
131 source_inference_id: InferenceId,
132 target_inference_id: InferenceId,
133}
134impl VarMapping {
135 fn new_to_canonic(source_inference_id: InferenceId) -> Self {
136 Self {
137 type_var_mapping: OrderedHashMap::default(),
138 const_var_mapping: OrderedHashMap::default(),
139 impl_var_mapping: OrderedHashMap::default(),
140 source_inference_id,
141 target_inference_id: InferenceId::Canonical,
142 }
143 }
144 fn new_from_canonic(target_inference_id: InferenceId) -> Self {
145 Self {
146 type_var_mapping: OrderedHashMap::default(),
147 const_var_mapping: OrderedHashMap::default(),
148 impl_var_mapping: OrderedHashMap::default(),
149 source_inference_id: InferenceId::Canonical,
150 target_inference_id,
151 }
152 }
153}
154
155#[derive(Debug)]
157pub enum NoError {}
158pub trait ResultNoErrEx<T> {
159 fn no_err(self) -> T;
160}
161impl<T> ResultNoErrEx<T> for Result<T, NoError> {
162 fn no_err(self) -> T {
163 match self {
164 Ok(v) => v,
165 #[allow(unreachable_patterns)]
166 Err(err) => match err {},
167 }
168 }
169}
170
171struct Canonicalizer<'db> {
174 db: &'db dyn SemanticGroup,
175 to_canonic: VarMapping,
176}
177impl<'db> Canonicalizer<'db> {
178 fn canonicalize<T>(
179 db: &'db dyn SemanticGroup,
180 source_inference_id: InferenceId,
181 value: T,
182 ) -> (T, CanonicalMapping)
183 where
184 Self: SemanticRewriter<T, NoError>,
185 {
186 let mut canonicalizer =
187 Self { db, to_canonic: VarMapping::new_to_canonic(source_inference_id) };
188 let value = canonicalizer.rewrite(value).no_err();
189 let mapping = CanonicalMapping::from_to_canonic(canonicalizer.to_canonic);
190 (value, mapping)
191 }
192}
193impl<'a> HasDb<&'a dyn SemanticGroup> for Canonicalizer<'a> {
194 fn get_db(&self) -> &'a dyn SemanticGroup {
195 self.db
196 }
197}
198
199add_basic_rewrites!(
200 <'a>,
201 Canonicalizer<'a>,
202 NoError,
203 @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
204);
205
206impl SemanticRewriter<TypeId, NoError> for Canonicalizer<'_> {
207 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
208 if value.is_var_free(self.db) {
209 return Ok(RewriteResult::NoChange);
210 }
211 value.default_rewrite(self)
212 }
213}
214impl SemanticRewriter<TypeLongId, NoError> for Canonicalizer<'_> {
215 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
216 let TypeLongId::Var(var) = value else {
217 return value.default_rewrite(self);
218 };
219 if var.inference_id != self.to_canonic.source_inference_id {
220 return value.default_rewrite(self);
221 }
222 let next_id = LocalTypeVarId(self.to_canonic.type_var_mapping.len());
223 *value = TypeLongId::Var(TypeVar {
224 id: *self.to_canonic.type_var_mapping.entry(var.id).or_insert(next_id),
225 inference_id: InferenceId::Canonical,
226 });
227 Ok(RewriteResult::Modified)
228 }
229}
230impl SemanticRewriter<ConstValue, NoError> for Canonicalizer<'_> {
231 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
232 let ConstValue::Var(var, mut ty) = value else {
233 return value.default_rewrite(self);
234 };
235 if var.inference_id != self.to_canonic.source_inference_id {
236 return value.default_rewrite(self);
237 }
238 let next_id = LocalConstVarId(self.to_canonic.const_var_mapping.len());
239 ty.default_rewrite(self)?;
240 *value = ConstValue::Var(
241 ConstVar {
242 id: *self.to_canonic.const_var_mapping.entry(var.id).or_insert(next_id),
243 inference_id: InferenceId::Canonical,
244 },
245 ty,
246 );
247 Ok(RewriteResult::Modified)
248 }
249}
250impl SemanticRewriter<ImplId, NoError> for Canonicalizer<'_> {
251 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
252 if value.is_var_free(self.db) {
253 return Ok(RewriteResult::NoChange);
254 }
255 value.default_rewrite(self)
256 }
257}
258impl SemanticRewriter<ImplLongId, NoError> for Canonicalizer<'_> {
259 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
260 let ImplLongId::ImplVar(var_id) = value else {
261 if value.is_var_free(self.db) {
262 return Ok(RewriteResult::NoChange);
263 }
264 return value.default_rewrite(self);
265 };
266 let var = var_id.lookup_intern(self.db);
267 if var.inference_id != self.to_canonic.source_inference_id {
268 return value.default_rewrite(self);
269 }
270 let next_id = LocalImplVarId(self.to_canonic.impl_var_mapping.len());
271
272 let mut var = ImplVar {
273 id: *self.to_canonic.impl_var_mapping.entry(var.id).or_insert(next_id),
274 inference_id: InferenceId::Canonical,
275 lookup_context: var.lookup_context,
276 concrete_trait_id: var.concrete_trait_id,
277 };
278 var.concrete_trait_id.default_rewrite(self)?;
279 *value = ImplLongId::ImplVar(var.intern(self.db));
280 Ok(RewriteResult::Modified)
281 }
282}
283
284struct Embedder<'a, 'db> {
286 inference: &'a mut Inference<'db>,
287 from_canonic: VarMapping,
288}
289impl<'a, 'db> Embedder<'a, 'db> {
290 fn embed<T>(inference: &'a mut Inference<'db>, value: T) -> (T, CanonicalMapping)
291 where
292 Self: SemanticRewriter<T, NoError>,
293 {
294 let from_canonic = VarMapping::new_from_canonic(inference.inference_id);
295 let mut embedder = Self { inference, from_canonic };
296 let value = embedder.rewrite(value).no_err();
297 let mapping = CanonicalMapping::from_from_canonic(embedder.from_canonic);
298 (value, mapping)
299 }
300}
301
302impl<'a> HasDb<&'a dyn SemanticGroup> for Embedder<'a, '_> {
303 fn get_db(&self) -> &'a dyn SemanticGroup {
304 self.inference.db
305 }
306}
307
308add_basic_rewrites!(
309 <'a,'b>,
310 Embedder<'a,'b>,
311 NoError,
312 @exclude TypeLongId TypeId ConstValue ImplLongId ImplId
313);
314
315impl SemanticRewriter<TypeId, NoError> for Embedder<'_, '_> {
316 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
317 if value.is_var_free(self.get_db()) {
318 return Ok(RewriteResult::NoChange);
319 }
320 value.default_rewrite(self)
321 }
322}
323impl SemanticRewriter<TypeLongId, NoError> for Embedder<'_, '_> {
324 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
325 let TypeLongId::Var(var) = value else {
326 return value.default_rewrite(self);
327 };
328 if var.inference_id != InferenceId::Canonical {
329 return value.default_rewrite(self);
330 }
331 let new_id = self
332 .from_canonic
333 .type_var_mapping
334 .entry(var.id)
335 .or_insert_with(|| self.inference.new_type_var_raw(None).id);
336 *value = TypeLongId::Var(self.inference.type_vars[new_id.0]);
337 Ok(RewriteResult::Modified)
338 }
339}
340impl SemanticRewriter<ConstValue, NoError> for Embedder<'_, '_> {
341 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
342 let ConstValue::Var(var, mut ty) = value else {
343 return value.default_rewrite(self);
344 };
345 if var.inference_id != InferenceId::Canonical {
346 return value.default_rewrite(self);
347 }
348 ty.default_rewrite(self)?;
349 let new_id = self
350 .from_canonic
351 .const_var_mapping
352 .entry(var.id)
353 .or_insert_with(|| self.inference.new_const_var_raw(None).id);
354 *value = ConstValue::Var(self.inference.const_vars[new_id.0], ty);
355 Ok(RewriteResult::Modified)
356 }
357}
358impl SemanticRewriter<ImplId, NoError> for Embedder<'_, '_> {
359 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
360 if value.is_var_free(self.get_db()) {
361 return Ok(RewriteResult::NoChange);
362 }
363 value.default_rewrite(self)
364 }
365}
366impl SemanticRewriter<ImplLongId, NoError> for Embedder<'_, '_> {
367 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
368 let ImplLongId::ImplVar(var_id) = value else {
369 if value.is_var_free(self.get_db()) {
370 return Ok(RewriteResult::NoChange);
371 }
372 return value.default_rewrite(self);
373 };
374 let var = var_id.lookup_intern(self.get_db());
375 if var.inference_id != InferenceId::Canonical {
376 return value.default_rewrite(self);
377 }
378 let concrete_trait_id = self.rewrite(var.concrete_trait_id)?;
379 let new_id = self.from_canonic.impl_var_mapping.entry(var.id).or_insert_with(|| {
380 self.inference.new_impl_var_raw(var.lookup_context.clone(), concrete_trait_id, None)
381 });
382 *value = ImplLongId::ImplVar(self.inference.impl_vars[new_id.0].intern(self.get_db()));
383 Ok(RewriteResult::Modified)
384 }
385}
386
387#[derive(Clone, Debug)]
389pub struct MapperError(pub InferenceVar);
390struct Mapper<'db> {
391 db: &'db dyn SemanticGroup,
392 mapping: &'db VarMapping,
393}
394impl<'db> Mapper<'db> {
395 fn map<T>(
396 db: &'db dyn SemanticGroup,
397 value: T,
398 mapping: &'db VarMapping,
399 ) -> Result<T, MapperError>
400 where
401 Self: SemanticRewriter<T, MapperError>,
402 {
403 let mut mapper = Self { db, mapping };
404 mapper.rewrite(value)
405 }
406}
407
408impl<'db> HasDb<&'db dyn SemanticGroup> for Mapper<'db> {
409 fn get_db(&self) -> &'db dyn SemanticGroup {
410 self.db
411 }
412}
413
414add_basic_rewrites!(
415 <'a>,
416 Mapper<'a>,
417 MapperError,
418 @exclude TypeLongId TypeId ImplLongId ImplId ConstValue
419);
420
421impl SemanticRewriter<TypeId, MapperError> for Mapper<'_> {
422 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, MapperError> {
423 if value.is_var_free(self.db) {
424 return Ok(RewriteResult::NoChange);
425 }
426 value.default_rewrite(self)
427 }
428}
429impl SemanticRewriter<TypeLongId, MapperError> for Mapper<'_> {
430 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, MapperError> {
431 let TypeLongId::Var(var) = value else {
432 return value.default_rewrite(self);
433 };
434 let id = self
435 .mapping
436 .type_var_mapping
437 .get(&var.id)
438 .copied()
439 .ok_or(MapperError(InferenceVar::Type(var.id)))?;
440 *value = TypeLongId::Var(TypeVar { id, inference_id: self.mapping.target_inference_id });
441 Ok(RewriteResult::Modified)
442 }
443}
444impl SemanticRewriter<ConstValue, MapperError> for Mapper<'_> {
445 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, MapperError> {
446 let ConstValue::Var(var, mut ty) = value else {
447 return value.default_rewrite(self);
448 };
449 let id = self
450 .mapping
451 .const_var_mapping
452 .get(&var.id)
453 .copied()
454 .ok_or(MapperError(InferenceVar::Const(var.id)))?;
455 ty.default_rewrite(self)?;
456 *value =
457 ConstValue::Var(ConstVar { id, inference_id: self.mapping.target_inference_id }, ty);
458 Ok(RewriteResult::Modified)
459 }
460}
461impl SemanticRewriter<ImplId, MapperError> for Mapper<'_> {
462 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, MapperError> {
463 if value.is_var_free(self.db) {
464 return Ok(RewriteResult::NoChange);
465 }
466 value.default_rewrite(self)
467 }
468}
469impl SemanticRewriter<ImplLongId, MapperError> for Mapper<'_> {
470 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, MapperError> {
471 let ImplLongId::ImplVar(var_id) = value else {
472 return value.default_rewrite(self);
473 };
474 let var = var_id.lookup_intern(self.get_db());
475 let id = self
476 .mapping
477 .impl_var_mapping
478 .get(&var.id)
479 .copied()
480 .ok_or(MapperError(InferenceVar::Impl(var.id)))?;
481 let var = ImplVar { id, inference_id: self.mapping.target_inference_id, ..var };
482
483 *value = ImplLongId::ImplVar(var.intern(self.get_db()));
484 Ok(RewriteResult::Modified)
485 }
486}