cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13    LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21    Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36    ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41    ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45    ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50    ImplTypeById, ImplTypeId,
51};
52use crate::{
53    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65/// A type variable, created when a generic type argument is not passed, and thus is not known
66/// yet and needs to be inferred.
67#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69    pub inference_id: InferenceId,
70    pub id: LocalTypeVarId,
71}
72
73/// A const variable, created when a generic const argument is not passed, and thus is not known
74/// yet and needs to be inferred.
75#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77    pub inference_id: InferenceId,
78    pub id: LocalConstVarId,
79}
80
81/// An id for an inference context. Each inference variable is associated with an inference id.
82#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85    LookupItemDeclaration(LookupItemId),
86    LookupItemGenerics(LookupItemId),
87    LookupItemDefinition(LookupItemId),
88    ImplDefTrait(ImplDefId),
89    ImplAliasImplDef(ImplAliasId),
90    GenericParam(GenericParamId),
91    GenericImplParamTrait(GenericParamId),
92    GlobalUseStar(GlobalUseId),
93    Canonical,
94    /// For resolving that will not be used anywhere in the semantic model.
95    NoContext,
96}
97
98/// An impl variable, created when a generic type argument is not passed, and thus is not known
99/// yet and needs to be inferred.
100#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103    pub inference_id: InferenceId,
104    #[dont_rewrite]
105    pub id: LocalImplVarId,
106    pub concrete_trait_id: ConcreteTraitId,
107    #[dont_rewrite]
108    pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111    pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112        self.clone().intern(db)
113    }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126    pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127        self.lookup_intern(db).id
128    }
129    pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130        self.lookup_intern(db).concrete_trait_id
131    }
132    pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133        self.lookup_intern(db).lookup_context
134    }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140    Type(LocalTypeVarId),
141    Const(LocalConstVarId),
142    Impl(LocalImplVarId),
143}
144
145// TODO(spapini): Add to diagnostics.
146#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149    /// An inference error wrapping a previously reported error.
150    Reported(DiagnosticAdded),
151    Cycle(InferenceVar),
152    TypeKindMismatch {
153        ty0: TypeId,
154        ty1: TypeId,
155    },
156    ConstKindMismatch {
157        const0: ConstValueId,
158        const1: ConstValueId,
159    },
160    ImplKindMismatch {
161        impl0: ImplId,
162        impl1: ImplId,
163    },
164    GenericArgMismatch {
165        garg0: GenericArgumentId,
166        garg1: GenericArgumentId,
167    },
168    TraitMismatch {
169        trt0: TraitId,
170        trt1: TraitId,
171    },
172    ImplTypeMismatch {
173        impl_id: ImplId,
174        trait_type_id: TraitTypeId,
175        ty0: TypeId,
176        ty1: TypeId,
177    },
178    GenericFunctionMismatch {
179        func0: GenericFunctionId,
180        func1: GenericFunctionId,
181    },
182    ConstInferenceNotSupported,
183
184    // TODO(spapini): These are only used for external interface. Separate them along with the
185    // finalize() function to a wrapper.
186    NoImplsFound(ConcreteTraitId),
187    Ambiguity(Ambiguity),
188    TypeNotInferred(TypeId),
189}
190impl InferenceError {
191    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
192        match self {
193            InferenceError::Reported(_) => "Inference error occurred.".into(),
194            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
195            InferenceError::TypeKindMismatch { ty0, ty1 } => {
196                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
197            }
198            InferenceError::ConstKindMismatch { const0, const1 } => {
199                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
200            }
201            InferenceError::ImplKindMismatch { impl0, impl1 } => {
202                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
203            }
204            InferenceError::GenericArgMismatch { garg0, garg1 } => {
205                format!(
206                    "Generic arg mismatch: `{:?}` and `{:?}`.",
207                    garg0.debug(db),
208                    garg1.debug(db)
209                )
210            }
211            InferenceError::TraitMismatch { trt0, trt1 } => {
212                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
213            }
214            InferenceError::ConstInferenceNotSupported => {
215                "Const generic inference not yet supported.".into()
216            }
217            InferenceError::NoImplsFound(concrete_trait_id) => {
218                let info = db.core_info();
219                let trait_id = concrete_trait_id.trait_id(db);
220                if trait_id == info.numeric_literal_trt {
221                    let generic_type = extract_matches!(
222                        concrete_trait_id.generic_args(db)[0],
223                        GenericArgumentId::Type
224                    );
225                    return format!(
226                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
227                         literal.",
228                        generic_type.debug(db)
229                    );
230                } else if trait_id == info.string_literal_trt {
231                    let generic_type = extract_matches!(
232                        concrete_trait_id.generic_args(db)[0],
233                        GenericArgumentId::Type
234                    );
235                    return format!(
236                        "Mismatched types. The type `{:?}` cannot be created from a string \
237                         literal.",
238                        generic_type.debug(db)
239                    );
240                }
241                format!(
242                    "Trait has no implementation in context: {:?}.",
243                    concrete_trait_id.debug(db)
244                )
245            }
246            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
247            InferenceError::TypeNotInferred(ty) => {
248                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
249            }
250            InferenceError::GenericFunctionMismatch { func0, func1 } => {
251                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
252            }
253            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
254                format!(
255                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
256                    impl_id.format(db.upcast()),
257                    trait_type_id.name(db.upcast()),
258                    ty0.debug(db),
259                    ty1.debug(db)
260                )
261            }
262        }
263    }
264}
265
266impl InferenceError {
267    pub fn report(
268        &self,
269        diagnostics: &mut SemanticDiagnostics,
270        stable_ptr: SyntaxStablePtrId,
271    ) -> DiagnosticAdded {
272        match self {
273            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
274            _ => diagnostics
275                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
276        }
277    }
278}
279
280/// This struct is used to ensure that when an inference error occurs, it is properly set in the
281/// `Inference` object, and then properly consumed.
282///
283/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
284#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
285pub struct ErrorSet;
286
287pub type InferenceResult<T> = Result<T, ErrorSet>;
288
289#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
290pub enum InferenceErrorStatus {
291    Pending,
292    Consumed,
293}
294
295/// A mapping of an impl var's trait items to concrete items
296#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
297pub struct ImplVarTraitItemMappings {
298    /// The trait types of the impl var.
299    types: OrderedHashMap<TraitTypeId, TypeId>,
300    /// The trait constants of the impl var.
301    constants: OrderedHashMap<TraitConstantId, ConstValueId>,
302    /// The trait impls of the impl var.
303    impls: OrderedHashMap<TraitImplId, ImplId>,
304}
305impl Hash for ImplVarTraitItemMappings {
306    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
307        self.types.iter().for_each(|(trait_type_id, type_id)| {
308            trait_type_id.hash(state);
309            type_id.hash(state);
310        });
311        self.constants.iter().for_each(|(trait_const_id, const_id)| {
312            trait_const_id.hash(state);
313            const_id.hash(state);
314        });
315        self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
316            trait_impl_id.hash(state);
317            impl_id.hash(state);
318        });
319    }
320}
321
322/// State of inference.
323#[derive(Debug, DebugWithDb, PartialEq, Eq)]
324#[debug_db(dyn SemanticGroup + 'static)]
325pub struct InferenceData {
326    pub inference_id: InferenceId,
327    /// Current inferred assignment for type variables.
328    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
329    /// Current inferred assignment for const variables.
330    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
331    /// Current inferred assignment for impl variables.
332    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
333    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
334    /// Upon solution of the trait conforms the fully known item to the variable.
335    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
336    /// Type variables.
337    pub type_vars: Vec<TypeVar>,
338    /// Const variables.
339    pub const_vars: Vec<ConstVar>,
340    /// Impl variables.
341    pub impl_vars: Vec<ImplVar>,
342    /// Mapping from variables to stable pointers, if exist.
343    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
344    /// Inference variables that are pending to be solved.
345    pending: VecDeque<LocalImplVarId>,
346    /// Inference variables that have been refuted - no solutions exist.
347    refuted: Vec<LocalImplVarId>,
348    /// Inference variables that have been solved.
349    solved: Vec<LocalImplVarId>,
350    /// Inference variables that are currently ambiguous. May be solved later.
351    ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
352    /// Mapping from impl types to type variables.
353    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
354
355    // Error handling members.
356    /// The current error status.
357    pub error_status: Result<(), InferenceErrorStatus>,
358    /// `Some` only when error_state is Err(Pending).
359    error: Option<InferenceError>,
360    /// `Some` only when error_state is Err(Consumed).
361    consumed_error: Option<DiagnosticAdded>,
362}
363impl InferenceData {
364    pub fn new(inference_id: InferenceId) -> Self {
365        Self {
366            inference_id,
367            type_assignment: OrderedHashMap::default(),
368            impl_assignment: OrderedHashMap::default(),
369            const_assignment: OrderedHashMap::default(),
370            impl_vars_trait_item_mappings: HashMap::new(),
371            type_vars: Vec::new(),
372            impl_vars: Vec::new(),
373            const_vars: Vec::new(),
374            stable_ptrs: HashMap::new(),
375            pending: VecDeque::new(),
376            refuted: Vec::new(),
377            solved: Vec::new(),
378            ambiguous: Vec::new(),
379            impl_type_bounds: Default::default(),
380            error_status: Ok(()),
381            error: None,
382            consumed_error: None,
383        }
384    }
385    pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
386        Inference::new(db, self)
387    }
388    pub fn clone_with_inference_id(
389        &self,
390        db: &dyn SemanticGroup,
391        inference_id: InferenceId,
392    ) -> InferenceData {
393        let mut inference_id_replacer =
394            InferenceIdReplacer::new(db, self.inference_id, inference_id);
395        Self {
396            inference_id,
397            type_assignment: self
398                .type_assignment
399                .iter()
400                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
401                .collect(),
402            const_assignment: self
403                .const_assignment
404                .iter()
405                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
406                .collect(),
407            impl_assignment: self
408                .impl_assignment
409                .iter()
410                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
411                .collect(),
412            impl_vars_trait_item_mappings: self
413                .impl_vars_trait_item_mappings
414                .iter()
415                .map(|(k, mappings)| {
416                    (
417                        *k,
418                        ImplVarTraitItemMappings {
419                            types: mappings
420                                .types
421                                .iter()
422                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
423                                .collect(),
424                            constants: mappings
425                                .constants
426                                .iter()
427                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
428                                .collect(),
429                            impls: mappings
430                                .impls
431                                .iter()
432                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
433                                .collect(),
434                        },
435                    )
436                })
437                .collect(),
438            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
439            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
440            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
441            stable_ptrs: self.stable_ptrs.clone(),
442            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
443            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
444            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
445            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
446            // we do not need to rewrite the impl type bounds, as they all should be var free.
447            impl_type_bounds: self.impl_type_bounds.clone(),
448
449            error_status: self.error_status,
450            error: self.error.clone(),
451            consumed_error: self.consumed_error,
452        }
453    }
454    pub fn temporary_clone(&self) -> InferenceData {
455        Self {
456            inference_id: self.inference_id,
457            type_assignment: self.type_assignment.clone(),
458            const_assignment: self.const_assignment.clone(),
459            impl_assignment: self.impl_assignment.clone(),
460            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
461            type_vars: self.type_vars.clone(),
462            const_vars: self.const_vars.clone(),
463            impl_vars: self.impl_vars.clone(),
464            stable_ptrs: self.stable_ptrs.clone(),
465            pending: self.pending.clone(),
466            refuted: self.refuted.clone(),
467            solved: self.solved.clone(),
468            ambiguous: self.ambiguous.clone(),
469            impl_type_bounds: self.impl_type_bounds.clone(),
470            error_status: self.error_status,
471            error: self.error.clone(),
472            consumed_error: self.consumed_error,
473        }
474    }
475}
476
477/// State of inference. A system of inference constraints.
478pub struct Inference<'db> {
479    db: &'db dyn SemanticGroup,
480    pub data: &'db mut InferenceData,
481}
482
483impl Deref for Inference<'_> {
484    type Target = InferenceData;
485
486    fn deref(&self) -> &Self::Target {
487        self.data
488    }
489}
490impl DerefMut for Inference<'_> {
491    fn deref_mut(&mut self) -> &mut Self::Target {
492        self.data
493    }
494}
495
496impl std::fmt::Debug for Inference<'_> {
497    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498        let x = self.data.debug(self.db.elongate());
499        write!(f, "{x:?}")
500    }
501}
502
503impl<'db> Inference<'db> {
504    fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
505        Self { db, data }
506    }
507
508    /// Getter for an [ImplVar].
509    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
510        &self.impl_vars[var_id.0]
511    }
512
513    /// Getter for an impl var assignment.
514    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
515        self.impl_assignment.get(&var_id).copied()
516    }
517
518    /// Getter for a type var assignment.
519    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
520        self.type_assignment.get(&var_id).copied()
521    }
522
523    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
524    /// Returns a wrapping TypeId.
525    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
526        let var = self.new_type_var_raw(stable_ptr);
527
528        TypeLongId::Var(var).intern(self.db)
529    }
530
531    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
532    /// Returns the variable id.
533    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
534        let var =
535            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
536        if let Some(stable_ptr) = stable_ptr {
537            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
538        }
539        self.type_vars.push(var);
540        var
541    }
542
543    /// Sets the infrence's impl type bounds to the given map, and rewrittes the types so all the
544    /// types are var free.
545    pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
546        let impl_type_bounds_finalized = impl_type_bounds
547            .iter()
548            .filter_map(|(impl_type, ty)| {
549                let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
550                if !matches!(rewritten_type, TypeLongId::Var(_)) {
551                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
552                }
553                // conformed the var type to the original impl type to remove it from the pending
554                // list.
555                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
556                None
557            })
558            .collect();
559
560        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
561    }
562
563    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
564    /// Returns a wrapping [ConstValueId].
565    pub fn new_const_var(
566        &mut self,
567        stable_ptr: Option<SyntaxStablePtrId>,
568        ty: TypeId,
569    ) -> ConstValueId {
570        let var = self.new_const_var_raw(stable_ptr);
571        ConstValue::Var(var, ty).intern(self.db)
572    }
573
574    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
575    /// Returns the variable id.
576    pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
577        let var = ConstVar {
578            inference_id: self.inference_id,
579            id: LocalConstVarId(self.const_vars.len()),
580        };
581        if let Some(stable_ptr) = stable_ptr {
582            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
583        }
584        self.const_vars.push(var);
585        var
586    }
587
588    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
589    /// Returns a wrapping ImplId.
590    pub fn new_impl_var(
591        &mut self,
592        concrete_trait_id: ConcreteTraitId,
593        stable_ptr: Option<SyntaxStablePtrId>,
594        lookup_context: ImplLookupContext,
595    ) -> ImplId {
596        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
597        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
598    }
599
600    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
601    /// Returns the variable id.
602    fn new_impl_var_raw(
603        &mut self,
604        lookup_context: ImplLookupContext,
605        concrete_trait_id: ConcreteTraitId,
606        stable_ptr: Option<SyntaxStablePtrId>,
607    ) -> LocalImplVarId {
608        let mut lookup_context = lookup_context;
609        lookup_context
610            .insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db.upcast()).0);
611
612        let id = LocalImplVarId(self.impl_vars.len());
613        if let Some(stable_ptr) = stable_ptr {
614            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
615        }
616        let var =
617            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
618        self.impl_vars.push(var);
619        self.pending.push_back(id);
620        id
621    }
622
623    /// Solves the inference system. After a successful solve, there are no more pending impl
624    /// inferences.
625    /// Returns whether the inference was successful. If not, the error may be found by
626    /// `.error_state()`.
627    pub fn solve(&mut self) -> InferenceResult<()> {
628        self.solve_ex().map_err(|(err_set, _)| err_set)
629    }
630
631    /// Same as `solve`, but returns the error stable pointer if an error occurred.
632    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
633        let mut ambiguous = std::mem::take(&mut self.ambiguous);
634        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
635        while let Some(var) = self.pending.pop_front() {
636            // First inference error stops inference.
637            self.solve_single_pending(var).map_err(|err_set| {
638                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
639            })?;
640        }
641        Ok(())
642    }
643
644    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
645        if self.impl_assignment.contains_key(&var) {
646            return Ok(());
647        }
648        let solution = match self.impl_var_solution_set(var)? {
649            SolutionSet::None => {
650                self.refuted.push(var);
651                return Ok(());
652            }
653            SolutionSet::Ambiguous(ambiguity) => {
654                self.ambiguous.push((var, ambiguity));
655                return Ok(());
656            }
657            SolutionSet::Unique(solution) => solution,
658        };
659
660        // Solution found. Assign it.
661        self.assign_local_impl(var, solution)?;
662
663        // Something changed.
664        self.solved.push(var);
665        let mut ambiguous = std::mem::take(&mut self.ambiguous);
666        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
667
668        Ok(())
669    }
670
671    /// Returns the solution set status for the inference:
672    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
673    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
674        self.solve()?;
675        if !self.refuted.is_empty() {
676            return Ok(SolutionSet::None);
677        }
678        if let Some((_, ambiguity)) = self.ambiguous.first() {
679            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
680        }
681        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
682        Ok(SolutionSet::Unique(()))
683    }
684
685    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
686    /// Returns an error and does not report it.
687    pub fn finalize_without_reporting(
688        &mut self,
689    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
690        if self.error_status.is_err() {
691            // TODO(yuval): consider adding error location to the set error.
692            return Err((ErrorSet, None));
693        }
694        let info = self.db.core_info();
695        let numeric_trait_id = info.numeric_literal_trt;
696        let felt_ty = info.felt252;
697
698        // Conform all uninferred numeric literals to felt252.
699        loop {
700            let mut changed = false;
701            self.solve_ex()?;
702            for (var, _) in self.ambiguous.clone() {
703                let impl_var = self.impl_var(var).clone();
704                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
705                    continue;
706                }
707                // Uninferred numeric trait. Resolve as felt252.
708                let ty = extract_matches!(
709                    impl_var.concrete_trait_id.generic_args(self.db)[0],
710                    GenericArgumentId::Type
711                );
712                if self.rewrite(ty).no_err() == felt_ty {
713                    continue;
714                }
715                self.conform_ty(ty, felt_ty).map_err(|err_set| {
716                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
717                })?;
718                changed = true;
719                break;
720            }
721            if !changed {
722                break;
723            }
724        }
725        assert!(
726            self.pending.is_empty(),
727            "pending should all be solved by this point. Guaranteed by solve()."
728        );
729
730        let Some((var, err)) = self.first_undetermined_variable() else {
731            return Ok(());
732        };
733        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
734    }
735
736    /// Finalizes the inference and report diagnostics if there are any errors.
737    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
738    /// diagnostics.
739    pub fn finalize(
740        &mut self,
741        diagnostics: &mut SemanticDiagnostics,
742        stable_ptr: SyntaxStablePtrId,
743    ) {
744        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
745            let diag = self.report_on_pending_error(
746                err_set,
747                diagnostics,
748                err_stable_ptr.unwrap_or(stable_ptr),
749            );
750
751            let ty_missing = TypeId::missing(self.db, diag);
752            for var in &self.data.type_vars {
753                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
754            }
755        }
756    }
757
758    /// Retrieves the first variable that is still not inferred, or None, if everything is
759    /// inferred.
760    /// Does not set the error but return it, which is ok as this is a private helper function.
761    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
762        if let Some(var) = self.refuted.first().copied() {
763            let impl_var = self.impl_var(var).clone();
764            let concrete_trait_id = impl_var.concrete_trait_id;
765            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
766            return Some((
767                InferenceVar::Impl(var),
768                InferenceError::NoImplsFound(concrete_trait_id),
769            ));
770        }
771        let mut fallback_ret = None;
772        if let Some((var, ambiguity)) = self.ambiguous.first() {
773            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
774            let ret =
775                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
776            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
777                return ret;
778            } else {
779                fallback_ret = ret;
780            }
781        }
782        for (id, var) in self.type_vars.iter().enumerate() {
783            if self.type_assignment(LocalTypeVarId(id)).is_none() {
784                let ty = TypeLongId::Var(*var).intern(self.db);
785                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
786            }
787        }
788        fallback_ret
789    }
790
791    /// Assigns a value to a local impl variable id. See assign_impl().
792    fn assign_local_impl(
793        &mut self,
794        var: LocalImplVarId,
795        impl_id: ImplId,
796    ) -> InferenceResult<ImplId> {
797        let concrete_trait = impl_id
798            .concrete_trait(self.db)
799            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
800        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
801        if let Some(other_impl) = self.impl_assignment(var) {
802            return self.conform_impl(impl_id, other_impl);
803        }
804        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
805        {
806            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
807        }
808        self.impl_assignment.insert(var, impl_id);
809        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
810            for (trait_type_id, ty) in mappings.types {
811                let impl_ty = self
812                    .db
813                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
814                    .map_err(|_| ErrorSet)?;
815                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
816                    // Override the error with ImplTypeMismatch.
817                    let ty0 = self.rewrite(ty).no_err();
818                    let ty1 = self.rewrite(impl_ty).no_err();
819
820                    self.error =
821                        Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
822                    return Err(err_set);
823                }
824            }
825            for (trait_constant, constant_id) in mappings.constants {
826                self.conform_const(
827                    constant_id,
828                    self.db
829                        .impl_constant_concrete_implized_value(ImplConstantId::new(
830                            impl_id,
831                            trait_constant,
832                            self.db,
833                        ))
834                        .map_err(|_| ErrorSet)?,
835                )?;
836            }
837            for (trait_impl, inner_impl_id) in mappings.impls {
838                self.conform_impl(
839                    inner_impl_id,
840                    self.db
841                        .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
842                        .map_err(|_| ErrorSet)?,
843                )?;
844            }
845        }
846        Ok(impl_id)
847    }
848
849    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
850    fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
851        let var = var_id.lookup_intern(self.db);
852        if var.inference_id != self.inference_id {
853            return Err(self.set_error(InferenceError::ImplKindMismatch {
854                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
855                impl1: impl_id,
856            }));
857        }
858        self.assign_local_impl(var.id, impl_id)
859    }
860
861    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
862    /// Assumes the variable is not already assigned.
863    fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
864        if var.inference_id != self.inference_id {
865            return Err(self.set_error(InferenceError::TypeKindMismatch {
866                ty0: TypeLongId::Var(var).intern(self.db),
867                ty1: ty,
868            }));
869        }
870        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
871        let inference_var = InferenceVar::Type(var.id);
872        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
873            return Err(self.set_error(InferenceError::Cycle(inference_var)));
874        }
875        // If assigning var to var - making sure assigning to the lower id for proper canonization.
876        if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
877            if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
878                let var_ty = TypeLongId::Var(var).intern(self.db);
879                self.type_assignment.insert(other.id, var_ty);
880                return Ok(var_ty);
881            }
882        }
883        self.type_assignment.insert(var.id, ty);
884        Ok(ty)
885    }
886
887    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
888    /// Assumes the variable is not already assigned.
889    fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
890        if var.inference_id != self.inference_id {
891            return Err(self.set_error(InferenceError::ConstKindMismatch {
892                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
893                    .intern(self.db),
894                const1: id,
895            }));
896        }
897
898        self.const_assignment.insert(var.id, id);
899        Ok(id)
900    }
901
902    /// Computes the solution set for an impl variable with a recursive query.
903    fn impl_var_solution_set(
904        &mut self,
905        var: LocalImplVarId,
906    ) -> InferenceResult<SolutionSet<ImplId>> {
907        let impl_var = self.impl_var(var).clone();
908        // Update the concrete trait of the impl var.
909        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
910        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
911        let impl_var_trait_item_mappings =
912            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
913        let solution_set = self.trait_solution_set(
914            concrete_trait_id,
915            impl_var_trait_item_mappings,
916            impl_var.lookup_context,
917        )?;
918        Ok(match solution_set {
919            SolutionSet::None => SolutionSet::None,
920            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
921                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
922            }
923            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
924        })
925    }
926
927    /// Computes the solution set for a trait with a recursive query.
928    pub fn trait_solution_set(
929        &mut self,
930        concrete_trait_id: ConcreteTraitId,
931        impl_var_trait_item_mappings: ImplVarTraitItemMappings,
932        mut lookup_context: ImplLookupContext,
933    ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
934        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
935        // TODO(spapini): This is done twice. Consider doing it only here.
936        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
937        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
938
939        // Don't try to resolve impls if the first generic param is a variable.
940        let generic_args = concrete_trait_id.generic_args(self.db);
941        match generic_args.first() {
942            Some(GenericArgumentId::Type(ty)) => {
943                if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
944                    // Don't try to infer such impls.
945                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
946                }
947            }
948            Some(GenericArgumentId::Impl(imp)) => {
949                // Don't try to infer such impls.
950                if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
951                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
952                }
953            }
954            Some(GenericArgumentId::Constant(const_value)) => {
955                if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
956                    // Don't try to infer such impls.
957                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
958                }
959            }
960            _ => {}
961        };
962        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
963            self.db,
964            self.inference_id,
965            concrete_trait_id,
966            impl_var_trait_item_mappings,
967        );
968        // impl_type_bounds order is deterimend by the generic params of the function and therefore
969        // is consistent.
970        let solution_set = match self.db.canonic_trait_solutions(
971            canonical_trait,
972            lookup_context,
973            (*self.data.impl_type_bounds).clone(),
974        ) {
975            Ok(solution_set) => solution_set,
976            Err(err) => return Err(self.set_error(err)),
977        };
978        match solution_set {
979            SolutionSet::None => Ok(SolutionSet::None),
980            SolutionSet::Unique(canonical_impl) => {
981                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
982            }
983            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
984        }
985    }
986
987    /// Validate that the given impl is valid based on its negative impls arguments.
988    /// Returns `SolutionSet::Unique(canonical_impl)` if the impl is valid and
989    /// SolutionSet::Ambiguous(...) otherwise.
990    fn validate_neg_impls(
991        &mut self,
992        lookup_context: &ImplLookupContext,
993        canonical_impl: CanonicalImpl,
994    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
995        /// Validates that no solution set is found for the negative impls.
996        fn validate_no_solution_set(
997            inference: &mut Inference<'_>,
998            canonical_impl: CanonicalImpl,
999            lookup_context: &ImplLookupContext,
1000            negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
1001        ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
1002            for concrete_trait_id in negative_impls_concrete_traits {
1003                let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
1004                    inference.set_error(InferenceError::Reported(diag_added))
1005                })?;
1006                for garg in concrete_trait_id.generic_args(inference.db) {
1007                    let GenericArgumentId::Type(ty) = garg else {
1008                        continue;
1009                    };
1010                    let ty = inference.rewrite(ty).no_err();
1011                    // If the negative impl has a generic argument that is not fully
1012                    // concrete we can't tell if we should rule out the candidate impl.
1013                    // For example if we have -TypeEqual<S, T> we can't tell if S and
1014                    // T are going to be assigned the same concrete type.
1015                    // We return `SolutionSet::Ambiguous` here to indicate that more
1016                    // information is needed.
1017                    // Closure can only have one type, even if it's not fully concrete, so can use
1018                    // it and not get ambiguity.
1019                    if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1020                        && !ty.is_fully_concrete(inference.db)
1021                    {
1022                        // TODO(ilya): Try to detect the ambiguity earlier in the
1023                        // inference process.
1024                        return Ok(SolutionSet::Ambiguous(
1025                            Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1026                                impl_id: canonical_impl.0,
1027                                ty,
1028                            },
1029                        ));
1030                    }
1031                }
1032
1033                if !matches!(
1034                    inference.trait_solution_set(
1035                        concrete_trait_id,
1036                        ImplVarTraitItemMappings::default(),
1037                        lookup_context.clone()
1038                    )?,
1039                    SolutionSet::None
1040                ) {
1041                    // If a negative impl has an impl, then we should skip it.
1042                    return Ok(SolutionSet::None);
1043                }
1044            }
1045
1046            Ok(SolutionSet::Unique(canonical_impl))
1047        }
1048        match canonical_impl.0.lookup_intern(self.db) {
1049            ImplLongId::Concrete(concrete_impl) => {
1050                let substitution = concrete_impl
1051                    .substitution(self.db)
1052                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1053                let generic_params = self
1054                    .db
1055                    .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1056                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1057                let concrete_traits = generic_params
1058                    .iter()
1059                    .filter_map(|generic_param| {
1060                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1061                    })
1062                    .map(|generic_param| {
1063                        substitution
1064                            .substitute(self.db, generic_param.clone())
1065                            .and_then(|generic_param| generic_param.concrete_trait)
1066                    });
1067                validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1068            }
1069            ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1070                self,
1071                canonical_impl,
1072                lookup_context,
1073                generated_impl
1074                    .lookup_intern(self.db)
1075                    .generic_params
1076                    .iter()
1077                    .filter_map(|generic_param| {
1078                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1079                    })
1080                    .map(|generic_param| generic_param.concrete_trait),
1081            ),
1082            ImplLongId::GenericParameter(_)
1083            | ImplLongId::ImplVar(_)
1084            | ImplLongId::ImplImpl(_)
1085            | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1086        }
1087    }
1088
1089    // Error handling methods
1090    // ======================
1091
1092    /// Sets an error in the inference state.
1093    /// Does nothing if an error is already set.
1094    /// Returns an `ErrorSet` that can be used in reporting the error.
1095    pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1096        if self.error_status.is_err() {
1097            return ErrorSet;
1098        }
1099        self.error_status = if let InferenceError::Reported(diag_added) = err {
1100            self.consumed_error = Some(diag_added);
1101            Err(InferenceErrorStatus::Consumed)
1102        } else {
1103            self.error = Some(err);
1104            Err(InferenceErrorStatus::Pending)
1105        };
1106        ErrorSet
1107    }
1108
1109    /// Returns whether an error is set (either pending or consumed).
1110    pub fn is_error_set(&self) -> InferenceResult<()> {
1111        if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1112    }
1113
1114    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1115    /// returns None. This should be used with caution. Always prefer to use
1116    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1117    ///
1118    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1119    pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1120        self.consume_error_inner(err_set, skip_diagnostic())
1121    }
1122
1123    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1124    /// does nothing. This should be used with caution. Always prefer to use
1125    /// `report_on_pending_error` if possible.
1126    ///
1127    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1128    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1129    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1130        self.consume_error_inner(err_set, diag_added);
1131    }
1132
1133    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1134    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1135    /// `report_on_pending_error` if possible.
1136    ///
1137    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1138    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1139    fn consume_error_inner(
1140        &mut self,
1141        _err_set: ErrorSet,
1142        diag_added: DiagnosticAdded,
1143    ) -> Option<InferenceError> {
1144        if self.error_status != Err(InferenceErrorStatus::Pending) {
1145            return None;
1146            // panic!("consume_error when there is no pending error");
1147        }
1148        self.error_status = Err(InferenceErrorStatus::Consumed);
1149        self.consumed_error = Some(diag_added);
1150        mem::take(&mut self.error)
1151    }
1152
1153    /// Consumes the pending error, if any, and reports it.
1154    /// Should only be called when an error is set, otherwise it panics.
1155    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1156    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1157    /// stored `DiagnosticAdded`.
1158    pub fn report_on_pending_error(
1159        &mut self,
1160        _err_set: ErrorSet,
1161        diagnostics: &mut SemanticDiagnostics,
1162        stable_ptr: SyntaxStablePtrId,
1163    ) -> DiagnosticAdded {
1164        let Err(state_error) = self.error_status else {
1165            panic!("report_on_pending_error should be called only on error");
1166        };
1167        match state_error {
1168            InferenceErrorStatus::Consumed => self
1169                .consumed_error
1170                .expect("consumed_error is not set although error_status is Err(Consumed)"),
1171            InferenceErrorStatus::Pending => {
1172                let diag_added = match mem::take(&mut self.error)
1173                    .expect("error is not set although error_status is Err(Pending)")
1174                {
1175                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1176                        // If we have other diagnostics, there is no need to TypeNotInferred.
1177
1178                        // Note that `diagnostics` is not empty, so it is safe to return
1179                        // 'DiagnosticAdded' here.
1180                        skip_diagnostic()
1181                    }
1182                    diag => diag.report(diagnostics, stable_ptr),
1183                };
1184
1185                self.error_status = Err(InferenceErrorStatus::Consumed);
1186                self.consumed_error = Some(diag_added);
1187                diag_added
1188            }
1189        }
1190    }
1191
1192    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1193    /// `report`, and consumes the error. Otherwise, does nothing.
1194    pub fn report_modified_if_pending(
1195        &mut self,
1196        err_set: ErrorSet,
1197        report: impl FnOnce() -> DiagnosticAdded,
1198    ) {
1199        if self.error_status == Err(InferenceErrorStatus::Pending) {
1200            self.consume_reported_error(err_set, report());
1201        }
1202    }
1203}
1204
1205impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1206    fn get_db(&self) -> &'a dyn SemanticGroup {
1207        self.db
1208    }
1209}
1210add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1211add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1212add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1213impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1214    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1215        if value.is_var_free(self.db) {
1216            return Ok(RewriteResult::NoChange);
1217        }
1218        value.default_rewrite(self)
1219    }
1220}
1221impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1222    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1223        if value.is_var_free(self.db) {
1224            return Ok(RewriteResult::NoChange);
1225        }
1226        value.default_rewrite(self)
1227    }
1228}
1229impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1230    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1231        match value {
1232            TypeLongId::Var(var) => {
1233                if let Some(type_id) = self.type_assignment.get(&var.id) {
1234                    let mut long_type_id = type_id.lookup_intern(self.db);
1235                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1236                        *self.type_assignment.get_mut(&var.id).unwrap() =
1237                            long_type_id.clone().intern(self.db);
1238                    }
1239                    *value = long_type_id;
1240                    return Ok(RewriteResult::Modified);
1241                }
1242            }
1243            TypeLongId::ImplType(impl_type_id) => {
1244                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1245                    *value = type_id.lookup_intern(self.db);
1246                    self.internal_rewrite(value)?;
1247                    return Ok(RewriteResult::Modified);
1248                }
1249                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1250                let impl_id = impl_type_id.impl_id();
1251                let trait_ty = impl_type_id.ty();
1252                return Ok(match impl_id.lookup_intern(self.db) {
1253                    ImplLongId::GenericParameter(_)
1254                    | ImplLongId::SelfImpl(_)
1255                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1256                    ImplLongId::Concrete(_) => {
1257                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1258                            impl_id, trait_ty, self.db,
1259                        )) {
1260                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1261                            RewriteResult::Modified
1262                        } else {
1263                            impl_type_id_rewrite_result
1264                        }
1265                    }
1266                    ImplLongId::ImplVar(var) => {
1267                        *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1268                        return Ok(RewriteResult::Modified);
1269                    }
1270                    ImplLongId::GeneratedImpl(generated) => {
1271                        *value = self
1272                            .rewrite(
1273                                *generated
1274                                    .lookup_intern(self.db)
1275                                    .impl_items
1276                                    .0
1277                                    .get(&impl_type_id.ty())
1278                                    .unwrap(),
1279                            )
1280                            .no_err()
1281                            .lookup_intern(self.db);
1282                        RewriteResult::Modified
1283                    }
1284                });
1285            }
1286            _ => {}
1287        }
1288        value.default_rewrite(self)
1289    }
1290}
1291impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1292    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1293        match value {
1294            ConstValue::Var(var, _) => {
1295                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1296                    let mut const_value = const_value_id.lookup_intern(self.db);
1297                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1298                        *self.const_assignment.get_mut(&var.id).unwrap() =
1299                            const_value.clone().intern(self.db);
1300                    }
1301                    *value = const_value;
1302                    RewriteResult::Modified
1303                } else {
1304                    RewriteResult::NoChange
1305                });
1306            }
1307            ConstValue::ImplConstant(impl_constant_id) => {
1308                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1309                let impl_id = impl_constant_id.impl_id();
1310                let trait_constant = impl_constant_id.trait_constant_id();
1311                return Ok(match impl_id.lookup_intern(self.db) {
1312                    ImplLongId::GenericParameter(_)
1313                    | ImplLongId::SelfImpl(_)
1314                    | ImplLongId::GeneratedImpl(_)
1315                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1316                    ImplLongId::Concrete(_) => {
1317                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1318                            ImplConstantId::new(impl_id, trait_constant, self.db),
1319                        ) {
1320                            *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1321                            RewriteResult::Modified
1322                        } else {
1323                            impl_constant_id_rewrite_result
1324                        }
1325                    }
1326                    ImplLongId::ImplVar(var) => {
1327                        *value = self
1328                            .rewritten_impl_constant(var, trait_constant)
1329                            .lookup_intern(self.db);
1330                        return Ok(RewriteResult::Modified);
1331                    }
1332                });
1333            }
1334            _ => {}
1335        }
1336        value.default_rewrite(self)
1337    }
1338}
1339impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1340    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1341        match value {
1342            ImplLongId::ImplVar(var) => {
1343                let long_id = var.lookup_intern(self.db);
1344                // Relax the candidates.
1345                let impl_var_id = long_id.id;
1346                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1347                    let mut long_impl_id = impl_id.lookup_intern(self.db);
1348                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1349                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1350                            long_impl_id.clone().intern(self.db);
1351                    }
1352                    *value = long_impl_id;
1353                    return Ok(RewriteResult::Modified);
1354                }
1355            }
1356            ImplLongId::ImplImpl(impl_impl_id) => {
1357                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1358                let impl_id = impl_impl_id.impl_id();
1359                return Ok(match impl_id.lookup_intern(self.db) {
1360                    ImplLongId::GenericParameter(_)
1361                    | ImplLongId::SelfImpl(_)
1362                    | ImplLongId::GeneratedImpl(_)
1363                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1364                    ImplLongId::Concrete(_) => {
1365                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1366                            *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1367                            RewriteResult::Modified
1368                        } else {
1369                            impl_impl_id_rewrite_result
1370                        }
1371                    }
1372                    ImplLongId::ImplVar(var) => {
1373                        if let Ok(concrete_trait_impl) =
1374                            impl_impl_id.concrete_trait_impl_id(self.db)
1375                        {
1376                            *value = self
1377                                .rewritten_impl_impl(var, concrete_trait_impl)
1378                                .lookup_intern(self.db);
1379                            return Ok(RewriteResult::Modified);
1380                        } else {
1381                            impl_impl_id_rewrite_result
1382                        }
1383                    }
1384                });
1385            }
1386
1387            _ => {}
1388        }
1389        if value.is_var_free(self.db) {
1390            return Ok(RewriteResult::NoChange);
1391        }
1392        value.default_rewrite(self)
1393    }
1394}
1395
1396struct InferenceIdReplacer<'a> {
1397    db: &'a dyn SemanticGroup,
1398    from_inference_id: InferenceId,
1399    to_inference_id: InferenceId,
1400}
1401impl<'a> InferenceIdReplacer<'a> {
1402    fn new(
1403        db: &'a dyn SemanticGroup,
1404        from_inference_id: InferenceId,
1405        to_inference_id: InferenceId,
1406    ) -> Self {
1407        Self { db, from_inference_id, to_inference_id }
1408    }
1409}
1410impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1411    fn get_db(&self) -> &'a dyn SemanticGroup {
1412        self.db
1413    }
1414}
1415add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1416add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1417add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1418impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1419    fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1420        if value == &self.from_inference_id {
1421            *value = self.to_inference_id;
1422            Ok(RewriteResult::Modified)
1423        } else {
1424            Ok(RewriteResult::NoChange)
1425        }
1426    }
1427}