cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13    LocalVarId, LookupItemId, MemberId, ParamId, StructId, TraitConstantId, TraitFunctionId,
14    TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21    Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::corelib::{CoreTraitContext, get_core_trait, numeric_literal_trait};
27use crate::db::SemanticGroup;
28use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
29use crate::expr::inference::canonic::ResultNoErrEx;
30use crate::expr::inference::conform::InferenceConform;
31use crate::expr::objects::*;
32use crate::expr::pattern::*;
33use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
34use crate::items::functions::{
35    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
36    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
37    ImplGenericFunctionWithBodyId,
38};
39use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
40use crate::items::imp::{
41    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
42    ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
43};
44use crate::items::trt::{ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId};
45use crate::substitution::{HasDb, RewriteResult, SemanticRewriter, SubstitutionRewriter};
46use crate::types::{
47    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
48    ImplTypeById, ImplTypeId,
49};
50use crate::{
51    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
52    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
53    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
54    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
55    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
56};
57
58pub mod canonic;
59pub mod conform;
60pub mod infers;
61pub mod solver;
62
63/// A type variable, created when a generic type argument is not passed, and thus is not known
64/// yet and needs to be inferred.
65#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
66pub struct TypeVar {
67    pub inference_id: InferenceId,
68    pub id: LocalTypeVarId,
69}
70
71/// A const variable, created when a generic const argument is not passed, and thus is not known
72/// yet and needs to be inferred.
73#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
74pub struct ConstVar {
75    pub inference_id: InferenceId,
76    pub id: LocalConstVarId,
77}
78
79/// An id for an inference context. Each inference variable is associated with an inference id.
80#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
81#[debug_db(dyn SemanticGroup + 'static)]
82pub enum InferenceId {
83    LookupItemDeclaration(LookupItemId),
84    LookupItemGenerics(LookupItemId),
85    LookupItemDefinition(LookupItemId),
86    ImplDefTrait(ImplDefId),
87    ImplAliasImplDef(ImplAliasId),
88    GenericParam(GenericParamId),
89    GenericImplParamTrait(GenericParamId),
90    GlobalUseStar(GlobalUseId),
91    Canonical,
92    /// For resolving that will not be used anywhere in the semantic model.
93    NoContext,
94}
95
96/// An impl variable, created when a generic type argument is not passed, and thus is not known
97/// yet and needs to be inferred.
98#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
99#[debug_db(dyn SemanticGroup + 'static)]
100pub struct ImplVar {
101    pub inference_id: InferenceId,
102    #[dont_rewrite]
103    pub id: LocalImplVarId,
104    pub concrete_trait_id: ConcreteTraitId,
105    #[dont_rewrite]
106    pub lookup_context: ImplLookupContext,
107}
108impl ImplVar {
109    pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
110        self.clone().intern(db)
111    }
112}
113
114#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
115pub struct LocalTypeVarId(pub usize);
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalImplVarId(pub usize);
118
119#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
120pub struct LocalConstVarId(pub usize);
121
122define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
123impl ImplVarId {
124    pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
125        self.lookup_intern(db).id
126    }
127    pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
128        self.lookup_intern(db).concrete_trait_id
129    }
130    pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
131        self.lookup_intern(db).lookup_context
132    }
133}
134semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
135
136#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
137pub enum InferenceVar {
138    Type(LocalTypeVarId),
139    Const(LocalConstVarId),
140    Impl(LocalImplVarId),
141}
142
143// TODO(spapini): Add to diagnostics.
144#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
145#[debug_db(dyn SemanticGroup + 'static)]
146pub enum InferenceError {
147    /// An inference error wrapping a previously reported error.
148    Reported(DiagnosticAdded),
149    Cycle(InferenceVar),
150    TypeKindMismatch {
151        ty0: TypeId,
152        ty1: TypeId,
153    },
154    ConstKindMismatch {
155        const0: ConstValueId,
156        const1: ConstValueId,
157    },
158    ImplKindMismatch {
159        impl0: ImplId,
160        impl1: ImplId,
161    },
162    GenericArgMismatch {
163        garg0: GenericArgumentId,
164        garg1: GenericArgumentId,
165    },
166    TraitMismatch {
167        trt0: TraitId,
168        trt1: TraitId,
169    },
170    GenericFunctionMismatch {
171        func0: GenericFunctionId,
172        func1: GenericFunctionId,
173    },
174    ConstInferenceNotSupported,
175
176    // TODO(spapini): These are only used for external interface. Separate them along with the
177    // finalize() function to a wrapper.
178    NoImplsFound(ConcreteTraitId),
179    Ambiguity(Ambiguity),
180    TypeNotInferred(TypeId),
181}
182impl InferenceError {
183    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
184        match self {
185            InferenceError::Reported(_) => "Inference error occurred.".into(),
186            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
187            InferenceError::TypeKindMismatch { ty0, ty1 } => {
188                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
189            }
190            InferenceError::ConstKindMismatch { const0, const1 } => {
191                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
192            }
193            InferenceError::ImplKindMismatch { impl0, impl1 } => {
194                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
195            }
196            InferenceError::GenericArgMismatch { garg0, garg1 } => {
197                format!(
198                    "Generic arg mismatch: `{:?}` and `{:?}`.",
199                    garg0.debug(db),
200                    garg1.debug(db)
201                )
202            }
203            InferenceError::TraitMismatch { trt0, trt1 } => {
204                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
205            }
206            InferenceError::ConstInferenceNotSupported => {
207                "Const generic inference not yet supported.".into()
208            }
209            InferenceError::NoImplsFound(concrete_trait_id) => {
210                let trait_id = concrete_trait_id.trait_id(db);
211                if trait_id == numeric_literal_trait(db) {
212                    let generic_type = extract_matches!(
213                        concrete_trait_id.generic_args(db)[0],
214                        GenericArgumentId::Type
215                    );
216                    return format!(
217                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
218                         literal.",
219                        generic_type.debug(db)
220                    );
221                } else if trait_id
222                    == get_core_trait(db, CoreTraitContext::TopLevel, "StringLiteral".into())
223                {
224                    let generic_type = extract_matches!(
225                        concrete_trait_id.generic_args(db)[0],
226                        GenericArgumentId::Type
227                    );
228                    return format!(
229                        "Mismatched types. The type `{:?}` cannot be created from a string \
230                         literal.",
231                        generic_type.debug(db)
232                    );
233                }
234                format!(
235                    "Trait has no implementation in context: {:?}.",
236                    concrete_trait_id.debug(db)
237                )
238            }
239            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
240            InferenceError::TypeNotInferred(ty) => {
241                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
242            }
243            InferenceError::GenericFunctionMismatch { func0, func1 } => {
244                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
245            }
246        }
247    }
248}
249
250impl InferenceError {
251    pub fn report(
252        &self,
253        diagnostics: &mut SemanticDiagnostics,
254        stable_ptr: SyntaxStablePtrId,
255    ) -> DiagnosticAdded {
256        match self {
257            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
258            _ => diagnostics
259                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
260        }
261    }
262}
263
264/// This struct is used to ensure that when an inference error occurs, it is properly set in the
265/// `Inference` object, and then properly consumed.
266///
267/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
268#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
269pub struct ErrorSet;
270
271pub type InferenceResult<T> = Result<T, ErrorSet>;
272
273#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
274pub enum InferenceErrorStatus {
275    Pending,
276    Consumed,
277}
278
279/// A mapping of an impl var's trait items to concrete items
280#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
281pub struct ImplVarTraitItemMappings {
282    /// The trait types of the impl var.
283    types: OrderedHashMap<TraitTypeId, TypeId>,
284    /// The trait constants of the impl var.
285    constants: OrderedHashMap<TraitConstantId, ConstValueId>,
286    /// The trait impls of the impl var.
287    impls: OrderedHashMap<TraitImplId, ImplId>,
288}
289impl Hash for ImplVarTraitItemMappings {
290    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
291        self.types.iter().for_each(|(trait_type_id, type_id)| {
292            trait_type_id.hash(state);
293            type_id.hash(state);
294        });
295        self.constants.iter().for_each(|(trait_const_id, const_id)| {
296            trait_const_id.hash(state);
297            const_id.hash(state);
298        });
299        self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
300            trait_impl_id.hash(state);
301            impl_id.hash(state);
302        });
303    }
304}
305
306/// State of inference.
307#[derive(Debug, DebugWithDb, PartialEq, Eq)]
308#[debug_db(dyn SemanticGroup + 'static)]
309pub struct InferenceData {
310    pub inference_id: InferenceId,
311    /// Current inferred assignment for type variables.
312    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
313    /// Current inferred assignment for const variables.
314    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
315    /// Current inferred assignment for impl variables.
316    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
317    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
318    /// Upon solution of the trait conforms the fully known item to the variable.
319    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
320    /// Type variables.
321    pub type_vars: Vec<TypeVar>,
322    /// Const variables.
323    pub const_vars: Vec<ConstVar>,
324    /// Impl variables.
325    pub impl_vars: Vec<ImplVar>,
326    /// Mapping from variables to stable pointers, if exist.
327    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
328    /// Inference variables that are pending to be solved.
329    pending: VecDeque<LocalImplVarId>,
330    /// Inference variables that have been refuted - no solutions exist.
331    refuted: Vec<LocalImplVarId>,
332    /// Inference variables that have been solved.
333    solved: Vec<LocalImplVarId>,
334    /// Inference variables that are currently ambiguous. May be solved later.
335    ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
336    /// Mapping from impl types to type variables.
337    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
338
339    // Error handling members.
340    /// The current error status.
341    pub error_status: Result<(), InferenceErrorStatus>,
342    /// `Some` only when error_state is Err(Pending).
343    error: Option<InferenceError>,
344    /// `Some` only when error_state is Err(Consumed).
345    consumed_error: Option<DiagnosticAdded>,
346}
347impl InferenceData {
348    pub fn new(inference_id: InferenceId) -> Self {
349        Self {
350            inference_id,
351            type_assignment: OrderedHashMap::default(),
352            impl_assignment: OrderedHashMap::default(),
353            const_assignment: OrderedHashMap::default(),
354            impl_vars_trait_item_mappings: HashMap::new(),
355            type_vars: Vec::new(),
356            impl_vars: Vec::new(),
357            const_vars: Vec::new(),
358            stable_ptrs: HashMap::new(),
359            pending: VecDeque::new(),
360            refuted: Vec::new(),
361            solved: Vec::new(),
362            ambiguous: Vec::new(),
363            impl_type_bounds: Default::default(),
364            error_status: Ok(()),
365            error: None,
366            consumed_error: None,
367        }
368    }
369    pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
370        Inference::new(db, self)
371    }
372    pub fn clone_with_inference_id(
373        &self,
374        db: &dyn SemanticGroup,
375        inference_id: InferenceId,
376    ) -> InferenceData {
377        let mut inference_id_replacer =
378            InferenceIdReplacer::new(db, self.inference_id, inference_id);
379        Self {
380            inference_id,
381            type_assignment: self
382                .type_assignment
383                .iter()
384                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
385                .collect(),
386            const_assignment: self
387                .const_assignment
388                .iter()
389                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
390                .collect(),
391            impl_assignment: self
392                .impl_assignment
393                .iter()
394                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
395                .collect(),
396            impl_vars_trait_item_mappings: self
397                .impl_vars_trait_item_mappings
398                .iter()
399                .map(|(k, mappings)| {
400                    (*k, ImplVarTraitItemMappings {
401                        types: mappings
402                            .types
403                            .iter()
404                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
405                            .collect(),
406                        constants: mappings
407                            .constants
408                            .iter()
409                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
410                            .collect(),
411                        impls: mappings
412                            .impls
413                            .iter()
414                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
415                            .collect(),
416                    })
417                })
418                .collect(),
419            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
420            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
421            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
422            stable_ptrs: self.stable_ptrs.clone(),
423            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
424            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
425            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
426            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
427            // we do not need to rewrite the impl type bounds, as they all should be var free.
428            impl_type_bounds: self.impl_type_bounds.clone(),
429
430            error_status: self.error_status,
431            error: self.error.clone(),
432            consumed_error: self.consumed_error,
433        }
434    }
435    pub fn temporary_clone(&self) -> InferenceData {
436        Self {
437            inference_id: self.inference_id,
438            type_assignment: self.type_assignment.clone(),
439            const_assignment: self.const_assignment.clone(),
440            impl_assignment: self.impl_assignment.clone(),
441            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
442            type_vars: self.type_vars.clone(),
443            const_vars: self.const_vars.clone(),
444            impl_vars: self.impl_vars.clone(),
445            stable_ptrs: self.stable_ptrs.clone(),
446            pending: self.pending.clone(),
447            refuted: self.refuted.clone(),
448            solved: self.solved.clone(),
449            ambiguous: self.ambiguous.clone(),
450            impl_type_bounds: self.impl_type_bounds.clone(),
451            error_status: self.error_status,
452            error: self.error.clone(),
453            consumed_error: self.consumed_error,
454        }
455    }
456}
457
458/// State of inference. A system of inference constraints.
459pub struct Inference<'db> {
460    db: &'db dyn SemanticGroup,
461    pub data: &'db mut InferenceData,
462}
463
464impl Deref for Inference<'_> {
465    type Target = InferenceData;
466
467    fn deref(&self) -> &Self::Target {
468        self.data
469    }
470}
471impl DerefMut for Inference<'_> {
472    fn deref_mut(&mut self) -> &mut Self::Target {
473        self.data
474    }
475}
476
477impl std::fmt::Debug for Inference<'_> {
478    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479        let x = self.data.debug(self.db.elongate());
480        write!(f, "{x:?}")
481    }
482}
483
484impl<'db> Inference<'db> {
485    fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
486        Self { db, data }
487    }
488
489    /// Getter for an [ImplVar].
490    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
491        &self.impl_vars[var_id.0]
492    }
493
494    /// Getter for an impl var assignment.
495    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
496        self.impl_assignment.get(&var_id).copied()
497    }
498
499    /// Getter for a type var assignment.
500    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
501        self.type_assignment.get(&var_id).copied()
502    }
503
504    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
505    /// Returns a wrapping TypeId.
506    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
507        let var = self.new_type_var_raw(stable_ptr);
508
509        TypeLongId::Var(var).intern(self.db)
510    }
511
512    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
513    /// Returns the variable id.
514    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
515        let var =
516            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
517        if let Some(stable_ptr) = stable_ptr {
518            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
519        }
520        self.type_vars.push(var);
521        var
522    }
523
524    /// Sets the infrence's impl type bounds to the given map, and rewrittes the types so all the
525    /// types are var free.
526    pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
527        let impl_type_bounds_finalized = impl_type_bounds
528            .iter()
529            .filter_map(|(impl_type, ty)| {
530                let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
531                if !matches!(rewritten_type, TypeLongId::Var(_)) {
532                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
533                }
534                // conformed the var type to the original impl type to remove it from the pending
535                // list.
536                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
537                None
538            })
539            .collect();
540
541        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
542    }
543
544    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
545    /// Returns a wrapping [ConstValueId].
546    pub fn new_const_var(
547        &mut self,
548        stable_ptr: Option<SyntaxStablePtrId>,
549        ty: TypeId,
550    ) -> ConstValueId {
551        let var = self.new_const_var_raw(stable_ptr);
552        ConstValue::Var(var, ty).intern(self.db)
553    }
554
555    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
556    /// Returns the variable id.
557    pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
558        let var = ConstVar {
559            inference_id: self.inference_id,
560            id: LocalConstVarId(self.const_vars.len()),
561        };
562        if let Some(stable_ptr) = stable_ptr {
563            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
564        }
565        self.const_vars.push(var);
566        var
567    }
568
569    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
570    /// Returns a wrapping ImplId.
571    pub fn new_impl_var(
572        &mut self,
573        concrete_trait_id: ConcreteTraitId,
574        stable_ptr: Option<SyntaxStablePtrId>,
575        lookup_context: ImplLookupContext,
576    ) -> ImplId {
577        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
578        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
579    }
580
581    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
582    /// Returns the variable id.
583    fn new_impl_var_raw(
584        &mut self,
585        lookup_context: ImplLookupContext,
586        concrete_trait_id: ConcreteTraitId,
587        stable_ptr: Option<SyntaxStablePtrId>,
588    ) -> LocalImplVarId {
589        let mut lookup_context = lookup_context;
590        lookup_context
591            .insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db.upcast()).0);
592
593        let id = LocalImplVarId(self.impl_vars.len());
594        if let Some(stable_ptr) = stable_ptr {
595            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
596        }
597        let var =
598            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
599        self.impl_vars.push(var);
600        self.pending.push_back(id);
601        id
602    }
603
604    /// Solves the inference system. After a successful solve, there are no more pending impl
605    /// inferences.
606    /// Returns whether the inference was successful. If not, the error may be found by
607    /// `.error_state()`.
608    pub fn solve(&mut self) -> InferenceResult<()> {
609        self.solve_ex().map_err(|(err_set, _)| err_set)
610    }
611
612    /// Same as `solve`, but returns the error stable pointer if an error occurred.
613    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
614        let mut ambiguous = std::mem::take(&mut self.ambiguous);
615        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
616        while let Some(var) = self.pending.pop_front() {
617            // First inference error stops inference.
618            self.solve_single_pending(var).map_err(|err_set| {
619                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
620            })?;
621        }
622        Ok(())
623    }
624
625    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
626        if self.impl_assignment.contains_key(&var) {
627            return Ok(());
628        }
629        let solution = match self.impl_var_solution_set(var)? {
630            SolutionSet::None => {
631                self.refuted.push(var);
632                return Ok(());
633            }
634            SolutionSet::Ambiguous(ambiguity) => {
635                self.ambiguous.push((var, ambiguity));
636                return Ok(());
637            }
638            SolutionSet::Unique(solution) => solution,
639        };
640
641        // Solution found. Assign it.
642        self.assign_local_impl(var, solution)?;
643
644        // Something changed.
645        self.solved.push(var);
646        let mut ambiguous = std::mem::take(&mut self.ambiguous);
647        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
648
649        Ok(())
650    }
651
652    /// Returns the solution set status for the inference:
653    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
654    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
655        self.solve()?;
656        if !self.refuted.is_empty() {
657            return Ok(SolutionSet::None);
658        }
659        if let Some((_, ambiguity)) = self.ambiguous.first() {
660            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
661        }
662        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
663        Ok(SolutionSet::Unique(()))
664    }
665
666    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
667    /// Returns an error and does not report it.
668    pub fn finalize_without_reporting(
669        &mut self,
670    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
671        if self.error_status.is_err() {
672            // TODO(yuval): consider adding error location to the set error.
673            return Err((ErrorSet, None));
674        }
675
676        let numeric_trait_id = numeric_literal_trait(self.db);
677        let felt_ty = self.db.core_felt252_ty();
678
679        // Conform all uninferred numeric literals to felt252.
680        loop {
681            let mut changed = false;
682            self.solve_ex()?;
683            for (var, _) in self.ambiguous.clone() {
684                let impl_var = self.impl_var(var).clone();
685                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
686                    continue;
687                }
688                // Uninferred numeric trait. Resolve as felt252.
689                let ty = extract_matches!(
690                    impl_var.concrete_trait_id.generic_args(self.db)[0],
691                    GenericArgumentId::Type
692                );
693                if self.rewrite(ty).no_err() == felt_ty {
694                    continue;
695                }
696                self.conform_ty(ty, felt_ty).map_err(|err_set| {
697                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
698                })?;
699                changed = true;
700                break;
701            }
702            if !changed {
703                break;
704            }
705        }
706        assert!(
707            self.pending.is_empty(),
708            "pending should all be solved by this point. Guaranteed by solve()."
709        );
710
711        let Some((var, err)) = self.first_undetermined_variable() else {
712            return Ok(());
713        };
714        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
715    }
716
717    /// Finalizes the inference and report diagnostics if there are any errors.
718    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
719    /// diagnostics.
720    pub fn finalize(
721        &mut self,
722        diagnostics: &mut SemanticDiagnostics,
723        stable_ptr: SyntaxStablePtrId,
724    ) {
725        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
726            let diag = self.report_on_pending_error(
727                err_set,
728                diagnostics,
729                err_stable_ptr.unwrap_or(stable_ptr),
730            );
731
732            let ty_missing = TypeId::missing(self.db, diag);
733            for var in &self.data.type_vars {
734                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
735            }
736        }
737    }
738
739    /// Retrieves the first variable that is still not inferred, or None, if everything is
740    /// inferred.
741    /// Does not set the error but return it, which is ok as this is a private helper function.
742    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
743        if let Some(var) = self.refuted.first().copied() {
744            let impl_var = self.impl_var(var).clone();
745            let concrete_trait_id = impl_var.concrete_trait_id;
746            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
747            return Some((
748                InferenceVar::Impl(var),
749                InferenceError::NoImplsFound(concrete_trait_id),
750            ));
751        }
752        let mut fallback_ret = None;
753        if let Some((var, ambiguity)) = self.ambiguous.first() {
754            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
755            let ret =
756                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
757            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
758                return ret;
759            } else {
760                fallback_ret = ret;
761            }
762        }
763        for (id, var) in self.type_vars.iter().enumerate() {
764            if self.type_assignment(LocalTypeVarId(id)).is_none() {
765                let ty = TypeLongId::Var(*var).intern(self.db);
766                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
767            }
768        }
769        fallback_ret
770    }
771
772    /// Assigns a value to a local impl variable id. See assign_impl().
773    fn assign_local_impl(
774        &mut self,
775        var: LocalImplVarId,
776        impl_id: ImplId,
777    ) -> InferenceResult<ImplId> {
778        let concrete_trait = impl_id
779            .concrete_trait(self.db)
780            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
781        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
782        if let Some(other_impl) = self.impl_assignment(var) {
783            return self.conform_impl(impl_id, other_impl);
784        }
785        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
786        {
787            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
788        }
789        self.impl_assignment.insert(var, impl_id);
790        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
791            for (trait_ty, ty) in mappings.types {
792                self.conform_ty(
793                    ty,
794                    self.db
795                        .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
796                        .map_err(|_| ErrorSet)?,
797                )?;
798            }
799            for (trait_constant, constant_id) in mappings.constants {
800                self.conform_const(
801                    constant_id,
802                    self.db
803                        .impl_constant_concrete_implized_value(ImplConstantId::new(
804                            impl_id,
805                            trait_constant,
806                            self.db,
807                        ))
808                        .map_err(|_| ErrorSet)?,
809                )?;
810            }
811            for (trait_impl, inner_impl_id) in mappings.impls {
812                self.conform_impl(
813                    inner_impl_id,
814                    self.db
815                        .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
816                        .map_err(|_| ErrorSet)?,
817                )?;
818            }
819        }
820        Ok(impl_id)
821    }
822
823    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
824    fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
825        let var = var_id.lookup_intern(self.db);
826        if var.inference_id != self.inference_id {
827            return Err(self.set_error(InferenceError::ImplKindMismatch {
828                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
829                impl1: impl_id,
830            }));
831        }
832        self.assign_local_impl(var.id, impl_id)
833    }
834
835    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
836    /// Assumes the variable is not already assigned.
837    fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
838        if var.inference_id != self.inference_id {
839            return Err(self.set_error(InferenceError::TypeKindMismatch {
840                ty0: TypeLongId::Var(var).intern(self.db),
841                ty1: ty,
842            }));
843        }
844        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
845        let inference_var = InferenceVar::Type(var.id);
846        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
847            return Err(self.set_error(InferenceError::Cycle(inference_var)));
848        }
849        // If assigning var to var - making sure assigning to the lower id for proper canonization.
850        if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
851            if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
852                let var_ty = TypeLongId::Var(var).intern(self.db);
853                self.type_assignment.insert(other.id, var_ty);
854                return Ok(var_ty);
855            }
856        }
857        self.type_assignment.insert(var.id, ty);
858        Ok(ty)
859    }
860
861    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
862    /// Assumes the variable is not already assigned.
863    fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
864        if var.inference_id != self.inference_id {
865            return Err(self.set_error(InferenceError::ConstKindMismatch {
866                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
867                    .intern(self.db),
868                const1: id,
869            }));
870        }
871
872        self.const_assignment.insert(var.id, id);
873        Ok(id)
874    }
875
876    /// Computes the solution set for an impl variable with a recursive query.
877    fn impl_var_solution_set(
878        &mut self,
879        var: LocalImplVarId,
880    ) -> InferenceResult<SolutionSet<ImplId>> {
881        let impl_var = self.impl_var(var).clone();
882        // Update the concrete trait of the impl var.
883        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
884        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
885        let impl_var_trait_item_mappings =
886            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
887        let solution_set = self.trait_solution_set(
888            concrete_trait_id,
889            impl_var_trait_item_mappings,
890            impl_var.lookup_context,
891        )?;
892        Ok(match solution_set {
893            SolutionSet::None => SolutionSet::None,
894            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
895                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
896            }
897            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
898        })
899    }
900
901    /// Computes the solution set for a trait with a recursive query.
902    pub fn trait_solution_set(
903        &mut self,
904        concrete_trait_id: ConcreteTraitId,
905        impl_var_trait_item_mappings: ImplVarTraitItemMappings,
906        mut lookup_context: ImplLookupContext,
907    ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
908        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
909        // TODO(spapini): This is done twice. Consider doing it only here.
910        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
911        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
912
913        // Don't try to resolve impls if the first generic param is a variable.
914        let generic_args = concrete_trait_id.generic_args(self.db);
915        match generic_args.first() {
916            Some(GenericArgumentId::Type(ty)) => {
917                if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
918                    // Don't try to infer such impls.
919                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
920                }
921            }
922            Some(GenericArgumentId::Impl(imp)) => {
923                // Don't try to infer such impls.
924                if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
925                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
926                }
927            }
928            Some(GenericArgumentId::Constant(const_value)) => {
929                if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
930                    // Don't try to infer such impls.
931                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
932                }
933            }
934            _ => {}
935        };
936
937        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
938            self.db,
939            self.inference_id,
940            concrete_trait_id,
941            impl_var_trait_item_mappings,
942        );
943
944        // impl_type_bounds order is deterimend by the generic params of the function and therefore
945        // is consistent.
946        let solution_set = match self.db.canonic_trait_solutions(
947            canonical_trait,
948            lookup_context,
949            (*self.data.impl_type_bounds).clone(),
950        ) {
951            Ok(solution_set) => solution_set,
952            Err(err) => return Err(self.set_error(err)),
953        };
954        match solution_set {
955            SolutionSet::None => Ok(SolutionSet::None),
956            SolutionSet::Unique(canonical_impl) => {
957                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
958            }
959            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
960        }
961    }
962
963    /// Validate that the given impl is valid based on its negative impls arguments.
964    /// Returns `SolutionSet::Unique(canonical_impl)` if the impl is valid and
965    /// SolutionSet::Ambiguous(...) otherwise.
966    fn validate_neg_impls(
967        &mut self,
968        lookup_context: &ImplLookupContext,
969        canonical_impl: CanonicalImpl,
970    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
971        /// Validates that no solution set is found for the negative impls.
972        fn validate_no_solution_set(
973            inference: &mut Inference<'_>,
974            canonical_impl: CanonicalImpl,
975            lookup_context: &ImplLookupContext,
976            negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
977        ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
978            for concrete_trait_id in negative_impls_concrete_traits {
979                let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
980                    inference.set_error(InferenceError::Reported(diag_added))
981                })?;
982                for garg in concrete_trait_id.generic_args(inference.db) {
983                    let GenericArgumentId::Type(ty) = garg else {
984                        continue;
985                    };
986                    let ty = inference.rewrite(ty).no_err();
987                    // If the negative impl has a generic argument that is not fully
988                    // concrete we can't tell if we should rule out the candidate impl.
989                    // For example if we have -TypeEqual<S, T> we can't tell if S and
990                    // T are going to be assigned the same concrete type.
991                    // We return `SolutionSet::Ambiguous` here to indicate that more
992                    // information is needed.
993                    // Closure can only have one type, even if it's not fully concrete, so can use
994                    // it and not get ambiguity.
995                    if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
996                        && !ty.is_fully_concrete(inference.db)
997                    {
998                        // TODO(ilya): Try to detect the ambiguity earlier in the
999                        // inference process.
1000                        return Ok(SolutionSet::Ambiguous(
1001                            Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1002                                impl_id: canonical_impl.0,
1003                                ty,
1004                            },
1005                        ));
1006                    }
1007                }
1008
1009                if !matches!(
1010                    inference.trait_solution_set(
1011                        concrete_trait_id,
1012                        ImplVarTraitItemMappings::default(),
1013                        lookup_context.clone()
1014                    )?,
1015                    SolutionSet::None
1016                ) {
1017                    // If a negative impl has an impl, then we should skip it.
1018                    return Ok(SolutionSet::None);
1019                }
1020            }
1021
1022            Ok(SolutionSet::Unique(canonical_impl))
1023        }
1024        match canonical_impl.0.lookup_intern(self.db) {
1025            ImplLongId::Concrete(concrete_impl) => {
1026                let mut rewriter = SubstitutionRewriter {
1027                    db: self.db,
1028                    substitution: &concrete_impl.substitution(self.db).map_err(|diag_added| {
1029                        self.set_error(InferenceError::Reported(diag_added))
1030                    })?,
1031                };
1032                let generic_params = self
1033                    .db
1034                    .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1035                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1036                let concrete_traits = generic_params
1037                    .iter()
1038                    .filter_map(|generic_param| {
1039                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1040                    })
1041                    .map(|generic_param| {
1042                        rewriter
1043                            .rewrite(generic_param.clone())
1044                            .and_then(|generic_param| generic_param.concrete_trait)
1045                    });
1046                validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1047            }
1048            ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1049                self,
1050                canonical_impl,
1051                lookup_context,
1052                generated_impl
1053                    .lookup_intern(self.db)
1054                    .generic_params
1055                    .iter()
1056                    .filter_map(|generic_param| {
1057                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1058                    })
1059                    .map(|generic_param| generic_param.concrete_trait),
1060            ),
1061            ImplLongId::GenericParameter(_)
1062            | ImplLongId::ImplVar(_)
1063            | ImplLongId::ImplImpl(_)
1064            | ImplLongId::TraitImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1065        }
1066    }
1067
1068    // Error handling methods
1069    // ======================
1070
1071    /// Sets an error in the inference state.
1072    /// Does nothing if an error is already set.
1073    /// Returns an `ErrorSet` that can be used in reporting the error.
1074    pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1075        if self.error_status.is_err() {
1076            return ErrorSet;
1077        }
1078        self.error_status = if let InferenceError::Reported(diag_added) = err {
1079            self.consumed_error = Some(diag_added);
1080            Err(InferenceErrorStatus::Consumed)
1081        } else {
1082            self.error = Some(err);
1083            Err(InferenceErrorStatus::Pending)
1084        };
1085        ErrorSet
1086    }
1087
1088    /// Returns whether an error is set (either pending or consumed).
1089    pub fn is_error_set(&self) -> InferenceResult<()> {
1090        if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1091    }
1092
1093    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1094    /// returns None. This should be used with caution. Always prefer to use
1095    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1096    ///
1097    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1098    pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1099        self.consume_error_inner(err_set, skip_diagnostic())
1100    }
1101
1102    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1103    /// does nothing. This should be used with caution. Always prefer to use
1104    /// `report_on_pending_error` if possible.
1105    ///
1106    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1107    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1108    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1109        self.consume_error_inner(err_set, diag_added);
1110    }
1111
1112    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1113    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1114    /// `report_on_pending_error` if possible.
1115    ///
1116    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1117    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1118    fn consume_error_inner(
1119        &mut self,
1120        _err_set: ErrorSet,
1121        diag_added: DiagnosticAdded,
1122    ) -> Option<InferenceError> {
1123        if self.error_status != Err(InferenceErrorStatus::Pending) {
1124            return None;
1125            // panic!("consume_error when there is no pending error");
1126        }
1127        self.error_status = Err(InferenceErrorStatus::Consumed);
1128        self.consumed_error = Some(diag_added);
1129        mem::take(&mut self.error)
1130    }
1131
1132    /// Consumes the pending error, if any, and reports it.
1133    /// Should only be called when an error is set, otherwise it panics.
1134    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1135    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1136    /// stored `DiagnosticAdded`.
1137    pub fn report_on_pending_error(
1138        &mut self,
1139        _err_set: ErrorSet,
1140        diagnostics: &mut SemanticDiagnostics,
1141        stable_ptr: SyntaxStablePtrId,
1142    ) -> DiagnosticAdded {
1143        let Err(state_error) = self.error_status else {
1144            panic!("report_on_pending_error should be called only on error");
1145        };
1146        match state_error {
1147            InferenceErrorStatus::Consumed => self
1148                .consumed_error
1149                .expect("consumed_error is not set although error_status is Err(Consumed)"),
1150            InferenceErrorStatus::Pending => {
1151                let diag_added = match mem::take(&mut self.error)
1152                    .expect("error is not set although error_status is Err(Pending)")
1153                {
1154                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1155                        // If we have other diagnostics, there is no need to TypeNotInferred.
1156
1157                        // Note that `diagnostics` is not empty, so it is safe to return
1158                        // 'DiagnosticAdded' here.
1159                        skip_diagnostic()
1160                    }
1161                    diag => diag.report(diagnostics, stable_ptr),
1162                };
1163
1164                self.error_status = Err(InferenceErrorStatus::Consumed);
1165                self.consumed_error = Some(diag_added);
1166                diag_added
1167            }
1168        }
1169    }
1170
1171    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1172    /// `report`, and consumes the error. Otherwise, does nothing.
1173    pub fn report_modified_if_pending(
1174        &mut self,
1175        err_set: ErrorSet,
1176        report: impl FnOnce() -> DiagnosticAdded,
1177    ) {
1178        if self.error_status == Err(InferenceErrorStatus::Pending) {
1179            self.consume_reported_error(err_set, report());
1180        }
1181    }
1182}
1183
1184impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1185    fn get_db(&self) -> &'a dyn SemanticGroup {
1186        self.db
1187    }
1188}
1189add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1190add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1191add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1192impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1193    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1194        if value.is_var_free(self.db) {
1195            return Ok(RewriteResult::NoChange);
1196        }
1197        value.default_rewrite(self)
1198    }
1199}
1200impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1201    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1202        if value.is_var_free(self.db) {
1203            return Ok(RewriteResult::NoChange);
1204        }
1205        value.default_rewrite(self)
1206    }
1207}
1208impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1209    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1210        match value {
1211            TypeLongId::Var(var) => {
1212                if let Some(type_id) = self.type_assignment.get(&var.id) {
1213                    let mut long_type_id = type_id.lookup_intern(self.db);
1214                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1215                        *self.type_assignment.get_mut(&var.id).unwrap() =
1216                            long_type_id.clone().intern(self.db);
1217                    }
1218                    *value = long_type_id;
1219                    return Ok(RewriteResult::Modified);
1220                }
1221            }
1222            TypeLongId::ImplType(impl_type_id) => {
1223                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1224                    *value = type_id.lookup_intern(self.db);
1225                    self.internal_rewrite(value)?;
1226                    return Ok(RewriteResult::Modified);
1227                }
1228                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1229                let impl_id = impl_type_id.impl_id();
1230                let trait_ty = impl_type_id.ty();
1231                return Ok(match impl_id.lookup_intern(self.db) {
1232                    ImplLongId::GenericParameter(_) | ImplLongId::TraitImpl(_) => {
1233                        impl_type_id_rewrite_result
1234                    }
1235                    ImplLongId::ImplImpl(impl_impl) => {
1236                        // The grand parent impl must be var free since we are rewriting the parent,
1237                        // and the parent is not var.
1238                        assert!(impl_impl.impl_id().is_var_free(self.db));
1239                        impl_type_id_rewrite_result
1240                    }
1241                    ImplLongId::Concrete(_) => {
1242                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1243                            impl_id, trait_ty, self.db,
1244                        )) {
1245                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1246                            RewriteResult::Modified
1247                        } else {
1248                            impl_type_id_rewrite_result
1249                        }
1250                    }
1251                    ImplLongId::ImplVar(var) => {
1252                        *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1253                        return Ok(RewriteResult::Modified);
1254                    }
1255                    ImplLongId::GeneratedImpl(generated) => {
1256                        *value = self
1257                            .rewrite(
1258                                *generated
1259                                    .lookup_intern(self.db)
1260                                    .impl_items
1261                                    .0
1262                                    .get(&impl_type_id.ty())
1263                                    .unwrap(),
1264                            )
1265                            .no_err()
1266                            .lookup_intern(self.db);
1267                        RewriteResult::Modified
1268                    }
1269                });
1270            }
1271            _ => {}
1272        }
1273        value.default_rewrite(self)
1274    }
1275}
1276impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1277    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1278        match value {
1279            ConstValue::Var(var, _) => {
1280                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1281                    let mut const_value = const_value_id.lookup_intern(self.db);
1282                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1283                        *self.const_assignment.get_mut(&var.id).unwrap() =
1284                            const_value.clone().intern(self.db);
1285                    }
1286                    *value = const_value;
1287                    RewriteResult::Modified
1288                } else {
1289                    RewriteResult::NoChange
1290                });
1291            }
1292            ConstValue::ImplConstant(impl_constant_id) => {
1293                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1294                let impl_id = impl_constant_id.impl_id();
1295                let trait_constant = impl_constant_id.trait_constant_id();
1296                return Ok(match impl_id.lookup_intern(self.db) {
1297                    ImplLongId::GenericParameter(_)
1298                    | ImplLongId::TraitImpl(_)
1299                    | ImplLongId::GeneratedImpl(_) => impl_constant_id_rewrite_result,
1300                    ImplLongId::ImplImpl(impl_impl) => {
1301                        // The grand parent impl must be var free since we are rewriting the parent,
1302                        // and the parent is not var.
1303                        assert!(impl_impl.impl_id().is_var_free(self.db));
1304                        impl_constant_id_rewrite_result
1305                    }
1306                    ImplLongId::Concrete(_) => {
1307                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1308                            ImplConstantId::new(impl_id, trait_constant, self.db),
1309                        ) {
1310                            *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1311                            RewriteResult::Modified
1312                        } else {
1313                            impl_constant_id_rewrite_result
1314                        }
1315                    }
1316                    ImplLongId::ImplVar(var) => {
1317                        *value = self
1318                            .rewritten_impl_constant(var, trait_constant)
1319                            .lookup_intern(self.db);
1320                        return Ok(RewriteResult::Modified);
1321                    }
1322                });
1323            }
1324            _ => {}
1325        }
1326        value.default_rewrite(self)
1327    }
1328}
1329impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1330    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1331        match value {
1332            ImplLongId::ImplVar(var) => {
1333                let long_id = var.lookup_intern(self.db);
1334                // Relax the candidates.
1335                let impl_var_id = long_id.id;
1336                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1337                    let mut long_impl_id = impl_id.lookup_intern(self.db);
1338                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1339                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1340                            long_impl_id.clone().intern(self.db);
1341                    }
1342                    *value = long_impl_id;
1343                    return Ok(RewriteResult::Modified);
1344                }
1345            }
1346            ImplLongId::ImplImpl(impl_impl_id) => {
1347                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1348                let impl_id = impl_impl_id.impl_id();
1349                return Ok(match impl_id.lookup_intern(self.db) {
1350                    ImplLongId::GenericParameter(_)
1351                    | ImplLongId::TraitImpl(_)
1352                    | ImplLongId::GeneratedImpl(_) => impl_impl_id_rewrite_result,
1353                    ImplLongId::ImplImpl(impl_impl) => {
1354                        // The grand parent impl must be var free since we are rewriting the parent,
1355                        // and the parent is not var.
1356                        assert!(impl_impl.impl_id().is_var_free(self.db));
1357                        impl_impl_id_rewrite_result
1358                    }
1359                    ImplLongId::Concrete(_) => {
1360                        if let Ok(ty) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1361                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1362                            RewriteResult::Modified
1363                        } else {
1364                            impl_impl_id_rewrite_result
1365                        }
1366                    }
1367                    ImplLongId::ImplVar(var) => {
1368                        if let Ok(concrete_trait_impl) =
1369                            impl_impl_id.concrete_trait_impl_id(self.db)
1370                        {
1371                            *value = self
1372                                .rewritten_impl_impl(var, concrete_trait_impl)
1373                                .lookup_intern(self.db);
1374                            return Ok(RewriteResult::Modified);
1375                        } else {
1376                            impl_impl_id_rewrite_result
1377                        }
1378                    }
1379                });
1380            }
1381
1382            _ => {}
1383        }
1384        if value.is_var_free(self.db) {
1385            return Ok(RewriteResult::NoChange);
1386        }
1387        value.default_rewrite(self)
1388    }
1389}
1390
1391struct InferenceIdReplacer<'a> {
1392    db: &'a dyn SemanticGroup,
1393    from_inference_id: InferenceId,
1394    to_inference_id: InferenceId,
1395}
1396impl<'a> InferenceIdReplacer<'a> {
1397    fn new(
1398        db: &'a dyn SemanticGroup,
1399        from_inference_id: InferenceId,
1400        to_inference_id: InferenceId,
1401    ) -> Self {
1402        Self { db, from_inference_id, to_inference_id }
1403    }
1404}
1405impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1406    fn get_db(&self) -> &'a dyn SemanticGroup {
1407        self.db
1408    }
1409}
1410add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1411add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1412add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1413impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1414    fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1415        if value == &self.from_inference_id {
1416            *value = self.to_inference_id;
1417            Ok(RewriteResult::Modified)
1418        } else {
1419            Ok(RewriteResult::NoChange)
1420        }
1421    }
1422}