1use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12 GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13 LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14 TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21 Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36 ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41 ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45 ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50 ImplTypeById, ImplTypeId,
51};
52use crate::{
53 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69 pub inference_id: InferenceId,
70 pub id: LocalTypeVarId,
71}
72
73#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77 pub inference_id: InferenceId,
78 pub id: LocalConstVarId,
79}
80
81#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85 LookupItemDeclaration(LookupItemId),
86 LookupItemGenerics(LookupItemId),
87 LookupItemDefinition(LookupItemId),
88 ImplDefTrait(ImplDefId),
89 ImplAliasImplDef(ImplAliasId),
90 GenericParam(GenericParamId),
91 GenericImplParamTrait(GenericParamId),
92 GlobalUseStar(GlobalUseId),
93 Canonical,
94 NoContext,
96}
97
98#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103 pub inference_id: InferenceId,
104 #[dont_rewrite]
105 pub id: LocalImplVarId,
106 pub concrete_trait_id: ConcreteTraitId,
107 #[dont_rewrite]
108 pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111 pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112 self.clone().intern(db)
113 }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126 pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127 self.lookup_intern(db).id
128 }
129 pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130 self.lookup_intern(db).concrete_trait_id
131 }
132 pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133 self.lookup_intern(db).lookup_context
134 }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140 Type(LocalTypeVarId),
141 Const(LocalConstVarId),
142 Impl(LocalImplVarId),
143}
144
145#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149 Reported(DiagnosticAdded),
151 Cycle(InferenceVar),
152 TypeKindMismatch {
153 ty0: TypeId,
154 ty1: TypeId,
155 },
156 ConstKindMismatch {
157 const0: ConstValueId,
158 const1: ConstValueId,
159 },
160 ImplKindMismatch {
161 impl0: ImplId,
162 impl1: ImplId,
163 },
164 GenericArgMismatch {
165 garg0: GenericArgumentId,
166 garg1: GenericArgumentId,
167 },
168 TraitMismatch {
169 trt0: TraitId,
170 trt1: TraitId,
171 },
172 ImplTypeMismatch {
173 impl_id: ImplId,
174 trait_type_id: TraitTypeId,
175 ty0: TypeId,
176 ty1: TypeId,
177 },
178 GenericFunctionMismatch {
179 func0: GenericFunctionId,
180 func1: GenericFunctionId,
181 },
182 ConstInferenceNotSupported,
183
184 NoImplsFound(ConcreteTraitId),
187 Ambiguity(Ambiguity),
188 TypeNotInferred(TypeId),
189}
190impl InferenceError {
191 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
192 match self {
193 InferenceError::Reported(_) => "Inference error occurred.".into(),
194 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
195 InferenceError::TypeKindMismatch { ty0, ty1 } => {
196 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
197 }
198 InferenceError::ConstKindMismatch { const0, const1 } => {
199 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
200 }
201 InferenceError::ImplKindMismatch { impl0, impl1 } => {
202 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
203 }
204 InferenceError::GenericArgMismatch { garg0, garg1 } => {
205 format!(
206 "Generic arg mismatch: `{:?}` and `{:?}`.",
207 garg0.debug(db),
208 garg1.debug(db)
209 )
210 }
211 InferenceError::TraitMismatch { trt0, trt1 } => {
212 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
213 }
214 InferenceError::ConstInferenceNotSupported => {
215 "Const generic inference not yet supported.".into()
216 }
217 InferenceError::NoImplsFound(concrete_trait_id) => {
218 let info = db.core_info();
219 let trait_id = concrete_trait_id.trait_id(db);
220 if trait_id == info.numeric_literal_trt {
221 let generic_type = extract_matches!(
222 concrete_trait_id.generic_args(db)[0],
223 GenericArgumentId::Type
224 );
225 return format!(
226 "Mismatched types. The type `{:?}` cannot be created from a numeric \
227 literal.",
228 generic_type.debug(db)
229 );
230 } else if trait_id == info.string_literal_trt {
231 let generic_type = extract_matches!(
232 concrete_trait_id.generic_args(db)[0],
233 GenericArgumentId::Type
234 );
235 return format!(
236 "Mismatched types. The type `{:?}` cannot be created from a string \
237 literal.",
238 generic_type.debug(db)
239 );
240 }
241 format!(
242 "Trait has no implementation in context: {:?}.",
243 concrete_trait_id.debug(db)
244 )
245 }
246 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
247 InferenceError::TypeNotInferred(ty) => {
248 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
249 }
250 InferenceError::GenericFunctionMismatch { func0, func1 } => {
251 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
252 }
253 InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
254 format!(
255 "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
256 impl_id.format(db.upcast()),
257 trait_type_id.name(db.upcast()),
258 ty0.debug(db),
259 ty1.debug(db)
260 )
261 }
262 }
263 }
264}
265
266impl InferenceError {
267 pub fn report(
268 &self,
269 diagnostics: &mut SemanticDiagnostics,
270 stable_ptr: SyntaxStablePtrId,
271 ) -> DiagnosticAdded {
272 match self {
273 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
274 _ => diagnostics
275 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
276 }
277 }
278}
279
280#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
285pub struct ErrorSet;
286
287pub type InferenceResult<T> = Result<T, ErrorSet>;
288
289#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
290pub enum InferenceErrorStatus {
291 Pending,
292 Consumed,
293}
294
295#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
297pub struct ImplVarTraitItemMappings {
298 types: OrderedHashMap<TraitTypeId, TypeId>,
300 constants: OrderedHashMap<TraitConstantId, ConstValueId>,
302 impls: OrderedHashMap<TraitImplId, ImplId>,
304}
305impl Hash for ImplVarTraitItemMappings {
306 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
307 self.types.iter().for_each(|(trait_type_id, type_id)| {
308 trait_type_id.hash(state);
309 type_id.hash(state);
310 });
311 self.constants.iter().for_each(|(trait_const_id, const_id)| {
312 trait_const_id.hash(state);
313 const_id.hash(state);
314 });
315 self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
316 trait_impl_id.hash(state);
317 impl_id.hash(state);
318 });
319 }
320}
321
322#[derive(Debug, DebugWithDb, PartialEq, Eq)]
324#[debug_db(dyn SemanticGroup + 'static)]
325pub struct InferenceData {
326 pub inference_id: InferenceId,
327 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
329 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
331 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
333 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
336 pub type_vars: Vec<TypeVar>,
338 pub const_vars: Vec<ConstVar>,
340 pub impl_vars: Vec<ImplVar>,
342 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
344 pending: VecDeque<LocalImplVarId>,
346 refuted: Vec<LocalImplVarId>,
348 solved: Vec<LocalImplVarId>,
350 ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
352 pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
354
355 pub error_status: Result<(), InferenceErrorStatus>,
358 error: Option<InferenceError>,
360 consumed_error: Option<DiagnosticAdded>,
362}
363impl InferenceData {
364 pub fn new(inference_id: InferenceId) -> Self {
365 Self {
366 inference_id,
367 type_assignment: OrderedHashMap::default(),
368 impl_assignment: OrderedHashMap::default(),
369 const_assignment: OrderedHashMap::default(),
370 impl_vars_trait_item_mappings: HashMap::new(),
371 type_vars: Vec::new(),
372 impl_vars: Vec::new(),
373 const_vars: Vec::new(),
374 stable_ptrs: HashMap::new(),
375 pending: VecDeque::new(),
376 refuted: Vec::new(),
377 solved: Vec::new(),
378 ambiguous: Vec::new(),
379 impl_type_bounds: Default::default(),
380 error_status: Ok(()),
381 error: None,
382 consumed_error: None,
383 }
384 }
385 pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
386 Inference::new(db, self)
387 }
388 pub fn clone_with_inference_id(
389 &self,
390 db: &dyn SemanticGroup,
391 inference_id: InferenceId,
392 ) -> InferenceData {
393 let mut inference_id_replacer =
394 InferenceIdReplacer::new(db, self.inference_id, inference_id);
395 Self {
396 inference_id,
397 type_assignment: self
398 .type_assignment
399 .iter()
400 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
401 .collect(),
402 const_assignment: self
403 .const_assignment
404 .iter()
405 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
406 .collect(),
407 impl_assignment: self
408 .impl_assignment
409 .iter()
410 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
411 .collect(),
412 impl_vars_trait_item_mappings: self
413 .impl_vars_trait_item_mappings
414 .iter()
415 .map(|(k, mappings)| {
416 (
417 *k,
418 ImplVarTraitItemMappings {
419 types: mappings
420 .types
421 .iter()
422 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
423 .collect(),
424 constants: mappings
425 .constants
426 .iter()
427 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
428 .collect(),
429 impls: mappings
430 .impls
431 .iter()
432 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
433 .collect(),
434 },
435 )
436 })
437 .collect(),
438 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
439 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
440 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
441 stable_ptrs: self.stable_ptrs.clone(),
442 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
443 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
444 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
445 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
446 impl_type_bounds: self.impl_type_bounds.clone(),
448
449 error_status: self.error_status,
450 error: self.error.clone(),
451 consumed_error: self.consumed_error,
452 }
453 }
454 pub fn temporary_clone(&self) -> InferenceData {
455 Self {
456 inference_id: self.inference_id,
457 type_assignment: self.type_assignment.clone(),
458 const_assignment: self.const_assignment.clone(),
459 impl_assignment: self.impl_assignment.clone(),
460 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
461 type_vars: self.type_vars.clone(),
462 const_vars: self.const_vars.clone(),
463 impl_vars: self.impl_vars.clone(),
464 stable_ptrs: self.stable_ptrs.clone(),
465 pending: self.pending.clone(),
466 refuted: self.refuted.clone(),
467 solved: self.solved.clone(),
468 ambiguous: self.ambiguous.clone(),
469 impl_type_bounds: self.impl_type_bounds.clone(),
470 error_status: self.error_status,
471 error: self.error.clone(),
472 consumed_error: self.consumed_error,
473 }
474 }
475}
476
477pub struct Inference<'db> {
479 db: &'db dyn SemanticGroup,
480 pub data: &'db mut InferenceData,
481}
482
483impl Deref for Inference<'_> {
484 type Target = InferenceData;
485
486 fn deref(&self) -> &Self::Target {
487 self.data
488 }
489}
490impl DerefMut for Inference<'_> {
491 fn deref_mut(&mut self) -> &mut Self::Target {
492 self.data
493 }
494}
495
496impl std::fmt::Debug for Inference<'_> {
497 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498 let x = self.data.debug(self.db.elongate());
499 write!(f, "{x:?}")
500 }
501}
502
503impl<'db> Inference<'db> {
504 fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
505 Self { db, data }
506 }
507
508 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
510 &self.impl_vars[var_id.0]
511 }
512
513 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
515 self.impl_assignment.get(&var_id).copied()
516 }
517
518 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
520 self.type_assignment.get(&var_id).copied()
521 }
522
523 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
526 let var = self.new_type_var_raw(stable_ptr);
527
528 TypeLongId::Var(var).intern(self.db)
529 }
530
531 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
534 let var =
535 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
536 if let Some(stable_ptr) = stable_ptr {
537 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
538 }
539 self.type_vars.push(var);
540 var
541 }
542
543 pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
546 let impl_type_bounds_finalized = impl_type_bounds
547 .iter()
548 .filter_map(|(impl_type, ty)| {
549 let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
550 if !matches!(rewritten_type, TypeLongId::Var(_)) {
551 return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
552 }
553 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
556 None
557 })
558 .collect();
559
560 self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
561 }
562
563 pub fn new_const_var(
566 &mut self,
567 stable_ptr: Option<SyntaxStablePtrId>,
568 ty: TypeId,
569 ) -> ConstValueId {
570 let var = self.new_const_var_raw(stable_ptr);
571 ConstValue::Var(var, ty).intern(self.db)
572 }
573
574 pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
577 let var = ConstVar {
578 inference_id: self.inference_id,
579 id: LocalConstVarId(self.const_vars.len()),
580 };
581 if let Some(stable_ptr) = stable_ptr {
582 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
583 }
584 self.const_vars.push(var);
585 var
586 }
587
588 pub fn new_impl_var(
591 &mut self,
592 concrete_trait_id: ConcreteTraitId,
593 stable_ptr: Option<SyntaxStablePtrId>,
594 lookup_context: ImplLookupContext,
595 ) -> ImplId {
596 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
597 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
598 }
599
600 fn new_impl_var_raw(
603 &mut self,
604 lookup_context: ImplLookupContext,
605 concrete_trait_id: ConcreteTraitId,
606 stable_ptr: Option<SyntaxStablePtrId>,
607 ) -> LocalImplVarId {
608 let mut lookup_context = lookup_context;
609 lookup_context
610 .insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db.upcast()).0);
611
612 let id = LocalImplVarId(self.impl_vars.len());
613 if let Some(stable_ptr) = stable_ptr {
614 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
615 }
616 let var =
617 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
618 self.impl_vars.push(var);
619 self.pending.push_back(id);
620 id
621 }
622
623 pub fn solve(&mut self) -> InferenceResult<()> {
628 self.solve_ex().map_err(|(err_set, _)| err_set)
629 }
630
631 fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
633 let mut ambiguous = std::mem::take(&mut self.ambiguous);
634 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
635 while let Some(var) = self.pending.pop_front() {
636 self.solve_single_pending(var).map_err(|err_set| {
638 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
639 })?;
640 }
641 Ok(())
642 }
643
644 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
645 if self.impl_assignment.contains_key(&var) {
646 return Ok(());
647 }
648 let solution = match self.impl_var_solution_set(var)? {
649 SolutionSet::None => {
650 self.refuted.push(var);
651 return Ok(());
652 }
653 SolutionSet::Ambiguous(ambiguity) => {
654 self.ambiguous.push((var, ambiguity));
655 return Ok(());
656 }
657 SolutionSet::Unique(solution) => solution,
658 };
659
660 self.assign_local_impl(var, solution)?;
662
663 self.solved.push(var);
665 let mut ambiguous = std::mem::take(&mut self.ambiguous);
666 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
667
668 Ok(())
669 }
670
671 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
674 self.solve()?;
675 if !self.refuted.is_empty() {
676 return Ok(SolutionSet::None);
677 }
678 if let Some((_, ambiguity)) = self.ambiguous.first() {
679 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
680 }
681 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
682 Ok(SolutionSet::Unique(()))
683 }
684
685 pub fn finalize_without_reporting(
688 &mut self,
689 ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
690 if self.error_status.is_err() {
691 return Err((ErrorSet, None));
693 }
694 let info = self.db.core_info();
695 let numeric_trait_id = info.numeric_literal_trt;
696 let felt_ty = info.felt252;
697
698 loop {
700 let mut changed = false;
701 self.solve_ex()?;
702 for (var, _) in self.ambiguous.clone() {
703 let impl_var = self.impl_var(var).clone();
704 if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
705 continue;
706 }
707 let ty = extract_matches!(
709 impl_var.concrete_trait_id.generic_args(self.db)[0],
710 GenericArgumentId::Type
711 );
712 if self.rewrite(ty).no_err() == felt_ty {
713 continue;
714 }
715 self.conform_ty(ty, felt_ty).map_err(|err_set| {
716 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
717 })?;
718 changed = true;
719 break;
720 }
721 if !changed {
722 break;
723 }
724 }
725 assert!(
726 self.pending.is_empty(),
727 "pending should all be solved by this point. Guaranteed by solve()."
728 );
729
730 let Some((var, err)) = self.first_undetermined_variable() else {
731 return Ok(());
732 };
733 Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
734 }
735
736 pub fn finalize(
740 &mut self,
741 diagnostics: &mut SemanticDiagnostics,
742 stable_ptr: SyntaxStablePtrId,
743 ) {
744 if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
745 let diag = self.report_on_pending_error(
746 err_set,
747 diagnostics,
748 err_stable_ptr.unwrap_or(stable_ptr),
749 );
750
751 let ty_missing = TypeId::missing(self.db, diag);
752 for var in &self.data.type_vars {
753 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
754 }
755 }
756 }
757
758 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
762 if let Some(var) = self.refuted.first().copied() {
763 let impl_var = self.impl_var(var).clone();
764 let concrete_trait_id = impl_var.concrete_trait_id;
765 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
766 return Some((
767 InferenceVar::Impl(var),
768 InferenceError::NoImplsFound(concrete_trait_id),
769 ));
770 }
771 let mut fallback_ret = None;
772 if let Some((var, ambiguity)) = self.ambiguous.first() {
773 let ret =
775 Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
776 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
777 return ret;
778 } else {
779 fallback_ret = ret;
780 }
781 }
782 for (id, var) in self.type_vars.iter().enumerate() {
783 if self.type_assignment(LocalTypeVarId(id)).is_none() {
784 let ty = TypeLongId::Var(*var).intern(self.db);
785 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
786 }
787 }
788 fallback_ret
789 }
790
791 fn assign_local_impl(
793 &mut self,
794 var: LocalImplVarId,
795 impl_id: ImplId,
796 ) -> InferenceResult<ImplId> {
797 let concrete_trait = impl_id
798 .concrete_trait(self.db)
799 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
800 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
801 if let Some(other_impl) = self.impl_assignment(var) {
802 return self.conform_impl(impl_id, other_impl);
803 }
804 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
805 {
806 return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
807 }
808 self.impl_assignment.insert(var, impl_id);
809 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
810 for (trait_type_id, ty) in mappings.types {
811 let impl_ty = self
812 .db
813 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
814 .map_err(|_| ErrorSet)?;
815 if let Err(err_set) = self.conform_ty(ty, impl_ty) {
816 let ty0 = self.rewrite(ty).no_err();
818 let ty1 = self.rewrite(impl_ty).no_err();
819
820 self.error =
821 Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
822 return Err(err_set);
823 }
824 }
825 for (trait_constant, constant_id) in mappings.constants {
826 self.conform_const(
827 constant_id,
828 self.db
829 .impl_constant_concrete_implized_value(ImplConstantId::new(
830 impl_id,
831 trait_constant,
832 self.db,
833 ))
834 .map_err(|_| ErrorSet)?,
835 )?;
836 }
837 for (trait_impl, inner_impl_id) in mappings.impls {
838 self.conform_impl(
839 inner_impl_id,
840 self.db
841 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
842 .map_err(|_| ErrorSet)?,
843 )?;
844 }
845 }
846 Ok(impl_id)
847 }
848
849 fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
851 let var = var_id.lookup_intern(self.db);
852 if var.inference_id != self.inference_id {
853 return Err(self.set_error(InferenceError::ImplKindMismatch {
854 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
855 impl1: impl_id,
856 }));
857 }
858 self.assign_local_impl(var.id, impl_id)
859 }
860
861 fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
864 if var.inference_id != self.inference_id {
865 return Err(self.set_error(InferenceError::TypeKindMismatch {
866 ty0: TypeLongId::Var(var).intern(self.db),
867 ty1: ty,
868 }));
869 }
870 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
871 let inference_var = InferenceVar::Type(var.id);
872 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
873 return Err(self.set_error(InferenceError::Cycle(inference_var)));
874 }
875 if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
877 if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
878 let var_ty = TypeLongId::Var(var).intern(self.db);
879 self.type_assignment.insert(other.id, var_ty);
880 return Ok(var_ty);
881 }
882 }
883 self.type_assignment.insert(var.id, ty);
884 Ok(ty)
885 }
886
887 fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
890 if var.inference_id != self.inference_id {
891 return Err(self.set_error(InferenceError::ConstKindMismatch {
892 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
893 .intern(self.db),
894 const1: id,
895 }));
896 }
897
898 self.const_assignment.insert(var.id, id);
899 Ok(id)
900 }
901
902 fn impl_var_solution_set(
904 &mut self,
905 var: LocalImplVarId,
906 ) -> InferenceResult<SolutionSet<ImplId>> {
907 let impl_var = self.impl_var(var).clone();
908 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
910 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
911 let impl_var_trait_item_mappings =
912 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
913 let solution_set = self.trait_solution_set(
914 concrete_trait_id,
915 impl_var_trait_item_mappings,
916 impl_var.lookup_context,
917 )?;
918 Ok(match solution_set {
919 SolutionSet::None => SolutionSet::None,
920 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
921 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
922 }
923 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
924 })
925 }
926
927 pub fn trait_solution_set(
929 &mut self,
930 concrete_trait_id: ConcreteTraitId,
931 impl_var_trait_item_mappings: ImplVarTraitItemMappings,
932 mut lookup_context: ImplLookupContext,
933 ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
934 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
935 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
937 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
938
939 let generic_args = concrete_trait_id.generic_args(self.db);
941 match generic_args.first() {
942 Some(GenericArgumentId::Type(ty)) => {
943 if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
944 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
946 }
947 }
948 Some(GenericArgumentId::Impl(imp)) => {
949 if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
951 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
952 }
953 }
954 Some(GenericArgumentId::Constant(const_value)) => {
955 if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
956 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
958 }
959 }
960 _ => {}
961 };
962 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
963 self.db,
964 self.inference_id,
965 concrete_trait_id,
966 impl_var_trait_item_mappings,
967 );
968 let solution_set = match self.db.canonic_trait_solutions(
971 canonical_trait,
972 lookup_context,
973 (*self.data.impl_type_bounds).clone(),
974 ) {
975 Ok(solution_set) => solution_set,
976 Err(err) => return Err(self.set_error(err)),
977 };
978 match solution_set {
979 SolutionSet::None => Ok(SolutionSet::None),
980 SolutionSet::Unique(canonical_impl) => {
981 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
982 }
983 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
984 }
985 }
986
987 fn validate_neg_impls(
991 &mut self,
992 lookup_context: &ImplLookupContext,
993 canonical_impl: CanonicalImpl,
994 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
995 fn validate_no_solution_set(
997 inference: &mut Inference<'_>,
998 canonical_impl: CanonicalImpl,
999 lookup_context: &ImplLookupContext,
1000 negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
1001 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
1002 for concrete_trait_id in negative_impls_concrete_traits {
1003 let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
1004 inference.set_error(InferenceError::Reported(diag_added))
1005 })?;
1006 for garg in concrete_trait_id.generic_args(inference.db) {
1007 let GenericArgumentId::Type(ty) = garg else {
1008 continue;
1009 };
1010 let ty = inference.rewrite(ty).no_err();
1011 if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1020 && !ty.is_fully_concrete(inference.db)
1021 {
1022 return Ok(SolutionSet::Ambiguous(
1025 Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1026 impl_id: canonical_impl.0,
1027 ty,
1028 },
1029 ));
1030 }
1031 }
1032
1033 if !matches!(
1034 inference.trait_solution_set(
1035 concrete_trait_id,
1036 ImplVarTraitItemMappings::default(),
1037 lookup_context.clone()
1038 )?,
1039 SolutionSet::None
1040 ) {
1041 return Ok(SolutionSet::None);
1043 }
1044 }
1045
1046 Ok(SolutionSet::Unique(canonical_impl))
1047 }
1048 match canonical_impl.0.lookup_intern(self.db) {
1049 ImplLongId::Concrete(concrete_impl) => {
1050 let substitution = concrete_impl
1051 .substitution(self.db)
1052 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1053 let generic_params = self
1054 .db
1055 .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1056 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1057 let concrete_traits = generic_params
1058 .iter()
1059 .filter_map(|generic_param| {
1060 try_extract_matches!(generic_param, GenericParam::NegImpl)
1061 })
1062 .map(|generic_param| {
1063 substitution
1064 .substitute(self.db, generic_param.clone())
1065 .and_then(|generic_param| generic_param.concrete_trait)
1066 });
1067 validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1068 }
1069 ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1070 self,
1071 canonical_impl,
1072 lookup_context,
1073 generated_impl
1074 .lookup_intern(self.db)
1075 .generic_params
1076 .iter()
1077 .filter_map(|generic_param| {
1078 try_extract_matches!(generic_param, GenericParam::NegImpl)
1079 })
1080 .map(|generic_param| generic_param.concrete_trait),
1081 ),
1082 ImplLongId::GenericParameter(_)
1083 | ImplLongId::ImplVar(_)
1084 | ImplLongId::ImplImpl(_)
1085 | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1086 }
1087 }
1088
1089 pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1096 if self.error_status.is_err() {
1097 return ErrorSet;
1098 }
1099 self.error_status = if let InferenceError::Reported(diag_added) = err {
1100 self.consumed_error = Some(diag_added);
1101 Err(InferenceErrorStatus::Consumed)
1102 } else {
1103 self.error = Some(err);
1104 Err(InferenceErrorStatus::Pending)
1105 };
1106 ErrorSet
1107 }
1108
1109 pub fn is_error_set(&self) -> InferenceResult<()> {
1111 if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1112 }
1113
1114 pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1120 self.consume_error_inner(err_set, skip_diagnostic())
1121 }
1122
1123 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1130 self.consume_error_inner(err_set, diag_added);
1131 }
1132
1133 fn consume_error_inner(
1140 &mut self,
1141 _err_set: ErrorSet,
1142 diag_added: DiagnosticAdded,
1143 ) -> Option<InferenceError> {
1144 if self.error_status != Err(InferenceErrorStatus::Pending) {
1145 return None;
1146 }
1148 self.error_status = Err(InferenceErrorStatus::Consumed);
1149 self.consumed_error = Some(diag_added);
1150 mem::take(&mut self.error)
1151 }
1152
1153 pub fn report_on_pending_error(
1159 &mut self,
1160 _err_set: ErrorSet,
1161 diagnostics: &mut SemanticDiagnostics,
1162 stable_ptr: SyntaxStablePtrId,
1163 ) -> DiagnosticAdded {
1164 let Err(state_error) = self.error_status else {
1165 panic!("report_on_pending_error should be called only on error");
1166 };
1167 match state_error {
1168 InferenceErrorStatus::Consumed => self
1169 .consumed_error
1170 .expect("consumed_error is not set although error_status is Err(Consumed)"),
1171 InferenceErrorStatus::Pending => {
1172 let diag_added = match mem::take(&mut self.error)
1173 .expect("error is not set although error_status is Err(Pending)")
1174 {
1175 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1176 skip_diagnostic()
1181 }
1182 diag => diag.report(diagnostics, stable_ptr),
1183 };
1184
1185 self.error_status = Err(InferenceErrorStatus::Consumed);
1186 self.consumed_error = Some(diag_added);
1187 diag_added
1188 }
1189 }
1190 }
1191
1192 pub fn report_modified_if_pending(
1195 &mut self,
1196 err_set: ErrorSet,
1197 report: impl FnOnce() -> DiagnosticAdded,
1198 ) {
1199 if self.error_status == Err(InferenceErrorStatus::Pending) {
1200 self.consume_reported_error(err_set, report());
1201 }
1202 }
1203}
1204
1205impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1206 fn get_db(&self) -> &'a dyn SemanticGroup {
1207 self.db
1208 }
1209}
1210add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1211add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1212add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1213impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1214 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1215 if value.is_var_free(self.db) {
1216 return Ok(RewriteResult::NoChange);
1217 }
1218 value.default_rewrite(self)
1219 }
1220}
1221impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1222 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1223 if value.is_var_free(self.db) {
1224 return Ok(RewriteResult::NoChange);
1225 }
1226 value.default_rewrite(self)
1227 }
1228}
1229impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1230 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1231 match value {
1232 TypeLongId::Var(var) => {
1233 if let Some(type_id) = self.type_assignment.get(&var.id) {
1234 let mut long_type_id = type_id.lookup_intern(self.db);
1235 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1236 *self.type_assignment.get_mut(&var.id).unwrap() =
1237 long_type_id.clone().intern(self.db);
1238 }
1239 *value = long_type_id;
1240 return Ok(RewriteResult::Modified);
1241 }
1242 }
1243 TypeLongId::ImplType(impl_type_id) => {
1244 if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1245 *value = type_id.lookup_intern(self.db);
1246 self.internal_rewrite(value)?;
1247 return Ok(RewriteResult::Modified);
1248 }
1249 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1250 let impl_id = impl_type_id.impl_id();
1251 let trait_ty = impl_type_id.ty();
1252 return Ok(match impl_id.lookup_intern(self.db) {
1253 ImplLongId::GenericParameter(_)
1254 | ImplLongId::SelfImpl(_)
1255 | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1256 ImplLongId::Concrete(_) => {
1257 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1258 impl_id, trait_ty, self.db,
1259 )) {
1260 *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1261 RewriteResult::Modified
1262 } else {
1263 impl_type_id_rewrite_result
1264 }
1265 }
1266 ImplLongId::ImplVar(var) => {
1267 *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1268 return Ok(RewriteResult::Modified);
1269 }
1270 ImplLongId::GeneratedImpl(generated) => {
1271 *value = self
1272 .rewrite(
1273 *generated
1274 .lookup_intern(self.db)
1275 .impl_items
1276 .0
1277 .get(&impl_type_id.ty())
1278 .unwrap(),
1279 )
1280 .no_err()
1281 .lookup_intern(self.db);
1282 RewriteResult::Modified
1283 }
1284 });
1285 }
1286 _ => {}
1287 }
1288 value.default_rewrite(self)
1289 }
1290}
1291impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1292 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1293 match value {
1294 ConstValue::Var(var, _) => {
1295 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1296 let mut const_value = const_value_id.lookup_intern(self.db);
1297 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1298 *self.const_assignment.get_mut(&var.id).unwrap() =
1299 const_value.clone().intern(self.db);
1300 }
1301 *value = const_value;
1302 RewriteResult::Modified
1303 } else {
1304 RewriteResult::NoChange
1305 });
1306 }
1307 ConstValue::ImplConstant(impl_constant_id) => {
1308 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1309 let impl_id = impl_constant_id.impl_id();
1310 let trait_constant = impl_constant_id.trait_constant_id();
1311 return Ok(match impl_id.lookup_intern(self.db) {
1312 ImplLongId::GenericParameter(_)
1313 | ImplLongId::SelfImpl(_)
1314 | ImplLongId::GeneratedImpl(_)
1315 | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1316 ImplLongId::Concrete(_) => {
1317 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1318 ImplConstantId::new(impl_id, trait_constant, self.db),
1319 ) {
1320 *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1321 RewriteResult::Modified
1322 } else {
1323 impl_constant_id_rewrite_result
1324 }
1325 }
1326 ImplLongId::ImplVar(var) => {
1327 *value = self
1328 .rewritten_impl_constant(var, trait_constant)
1329 .lookup_intern(self.db);
1330 return Ok(RewriteResult::Modified);
1331 }
1332 });
1333 }
1334 _ => {}
1335 }
1336 value.default_rewrite(self)
1337 }
1338}
1339impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1340 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1341 match value {
1342 ImplLongId::ImplVar(var) => {
1343 let long_id = var.lookup_intern(self.db);
1344 let impl_var_id = long_id.id;
1346 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1347 let mut long_impl_id = impl_id.lookup_intern(self.db);
1348 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1349 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1350 long_impl_id.clone().intern(self.db);
1351 }
1352 *value = long_impl_id;
1353 return Ok(RewriteResult::Modified);
1354 }
1355 }
1356 ImplLongId::ImplImpl(impl_impl_id) => {
1357 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1358 let impl_id = impl_impl_id.impl_id();
1359 return Ok(match impl_id.lookup_intern(self.db) {
1360 ImplLongId::GenericParameter(_)
1361 | ImplLongId::SelfImpl(_)
1362 | ImplLongId::GeneratedImpl(_)
1363 | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1364 ImplLongId::Concrete(_) => {
1365 if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1366 *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1367 RewriteResult::Modified
1368 } else {
1369 impl_impl_id_rewrite_result
1370 }
1371 }
1372 ImplLongId::ImplVar(var) => {
1373 if let Ok(concrete_trait_impl) =
1374 impl_impl_id.concrete_trait_impl_id(self.db)
1375 {
1376 *value = self
1377 .rewritten_impl_impl(var, concrete_trait_impl)
1378 .lookup_intern(self.db);
1379 return Ok(RewriteResult::Modified);
1380 } else {
1381 impl_impl_id_rewrite_result
1382 }
1383 }
1384 });
1385 }
1386
1387 _ => {}
1388 }
1389 if value.is_var_free(self.db) {
1390 return Ok(RewriteResult::NoChange);
1391 }
1392 value.default_rewrite(self)
1393 }
1394}
1395
1396struct InferenceIdReplacer<'a> {
1397 db: &'a dyn SemanticGroup,
1398 from_inference_id: InferenceId,
1399 to_inference_id: InferenceId,
1400}
1401impl<'a> InferenceIdReplacer<'a> {
1402 fn new(
1403 db: &'a dyn SemanticGroup,
1404 from_inference_id: InferenceId,
1405 to_inference_id: InferenceId,
1406 ) -> Self {
1407 Self { db, from_inference_id, to_inference_id }
1408 }
1409}
1410impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1411 fn get_db(&self) -> &'a dyn SemanticGroup {
1412 self.db
1413 }
1414}
1415add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1416add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1417add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1418impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1419 fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1420 if value == &self.from_inference_id {
1421 *value = self.to_inference_id;
1422 Ok(RewriteResult::Modified)
1423 } else {
1424 Ok(RewriteResult::NoChange)
1425 }
1426 }
1427}