cairo_lang_semantic/expr/inference/
solver.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::ids::LanguageElementId;
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use itertools::Itertools;
9
10use super::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, MapperError, ResultNoErrEx};
11use super::conform::InferenceConform;
12use super::infers::InferenceEmbeddings;
13use super::{
14    ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
15    InferenceVar, LocalImplVarId,
16};
17use crate::db::SemanticGroup;
18use crate::items::constant::ImplConstantId;
19use crate::items::imp::{
20    ImplId, ImplImplId, ImplLookupContext, UninferredImpl, find_candidates_at_context,
21    find_closure_generated_candidate,
22};
23use crate::substitution::SemanticRewriter;
24use crate::types::{ImplTypeById, ImplTypeId};
25use crate::{ConcreteTraitId, GenericArgumentId, TypeId, TypeLongId};
26
27/// A generic solution set for an inference constraint system.
28#[derive(Clone, PartialEq, Eq, Debug)]
29pub enum SolutionSet<T> {
30    None,
31    Unique(T),
32    Ambiguous(Ambiguity),
33}
34
35/// Describes the kinds of inference ambiguities.
36#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
37pub enum Ambiguity {
38    MultipleImplsFound {
39        concrete_trait_id: ConcreteTraitId,
40        impls: Vec<ImplId>,
41    },
42    FreeVariable {
43        impl_id: ImplId,
44        #[dont_rewrite]
45        var: InferenceVar,
46    },
47    WillNotInfer(ConcreteTraitId),
48    NegativeImplWithUnresolvedGenericArgs {
49        impl_id: ImplId,
50        ty: TypeId,
51    },
52}
53impl Ambiguity {
54    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
55        match self {
56            Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
57                let impls_str =
58                    impls.iter().map(|imp| format!("`{}`", imp.format(db.upcast()))).join(", ");
59                format!(
60                    "Trait `{:?}` has multiple implementations, in: {impls_str}",
61                    concrete_trait_id.debug(db)
62                )
63            }
64            Ambiguity::FreeVariable { impl_id, var: _ } => {
65                format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
66            }
67            Ambiguity::WillNotInfer(concrete_trait_id) => {
68                format!(
69                    "Cannot infer trait {:?}. First generic argument must be known.",
70                    concrete_trait_id.debug(db)
71                )
72            }
73            Ambiguity::NegativeImplWithUnresolvedGenericArgs { impl_id, ty } => format!(
74                "Cannot infer negative impl in `{}` as it contains the unresolved type `{}`",
75                impl_id.format(db),
76                ty.format(db)
77            ),
78        }
79    }
80}
81
82/// Query implementation of [SemanticGroup::canonic_trait_solutions].
83/// Assumes the lookup context is already enriched by [enrich_lookup_context].
84pub fn canonic_trait_solutions(
85    db: &dyn SemanticGroup,
86    canonical_trait: CanonicalTrait,
87    lookup_context: ImplLookupContext,
88    impl_type_bounds: BTreeMap<ImplTypeById, TypeId>,
89) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
90    let mut concrete_trait_id = canonical_trait.id;
91    let impl_type_bounds = Arc::new(impl_type_bounds);
92    // If the trait is not fully concrete, we might be able to use the trait's items to find a
93    // more concrete trait.
94    if !concrete_trait_id.is_fully_concrete(db) {
95        let mut solver =
96            Solver::new(db, canonical_trait, lookup_context.clone(), impl_type_bounds.clone());
97        match solver.solution_set(db) {
98            SolutionSet::None => {}
99            SolutionSet::Unique(imp) => {
100                concrete_trait_id =
101                    imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
102            }
103            SolutionSet::Ambiguous(ambiguity) => {
104                return Ok(SolutionSet::Ambiguous(ambiguity));
105            }
106        }
107    }
108    // Solve the trait without the trait items, so we'd be able to find conflicting impls.
109    let mut solver = Solver::new(
110        db,
111        CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
112        lookup_context,
113        impl_type_bounds,
114    );
115
116    Ok(solver.solution_set(db))
117}
118
119/// Cycle handling for [canonic_trait_solutions].
120pub fn canonic_trait_solutions_cycle(
121    _db: &dyn SemanticGroup,
122    _cycle: &salsa::Cycle,
123    _canonical_trait: &CanonicalTrait,
124    _lookup_context: &ImplLookupContext,
125    _impl_type_bounds: &BTreeMap<ImplTypeById, TypeId>,
126) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
127    Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
128}
129
130/// Adds the defining module of the trait and the generic arguments to the lookup context.
131pub fn enrich_lookup_context(
132    db: &dyn SemanticGroup,
133    concrete_trait_id: ConcreteTraitId,
134    lookup_context: &mut ImplLookupContext,
135) {
136    lookup_context.insert_module(concrete_trait_id.trait_id(db).module_file_id(db.upcast()).0);
137    let generic_args = concrete_trait_id.generic_args(db);
138    // Add the defining module of the generic args to the lookup.
139    for generic_arg in &generic_args {
140        if let GenericArgumentId::Type(ty) = generic_arg {
141            match ty.lookup_intern(db) {
142                TypeLongId::Concrete(concrete) => {
143                    lookup_context
144                        .insert_module(concrete.generic_type(db).module_file_id(db.upcast()).0);
145                }
146                TypeLongId::Coupon(function_id) => {
147                    if let Some(module_file_id) =
148                        function_id.get_concrete(db).generic_function.module_file_id(db)
149                    {
150                        lookup_context.insert_module(module_file_id.0);
151                    }
152                }
153                TypeLongId::ImplType(impl_type_id) => {
154                    lookup_context.insert_impl(impl_type_id.impl_id());
155                }
156                _ => (),
157            }
158        }
159    }
160}
161
162/// A canonical trait solver.
163#[derive(Debug)]
164pub struct Solver {
165    pub canonical_trait: CanonicalTrait,
166    pub lookup_context: ImplLookupContext,
167    candidate_solvers: Vec<CandidateSolver>,
168}
169impl Solver {
170    fn new(
171        db: &dyn SemanticGroup,
172        canonical_trait: CanonicalTrait,
173        lookup_context: ImplLookupContext,
174        impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
175    ) -> Self {
176        let filter = canonical_trait.id.filter(db);
177        let mut candidates =
178            find_candidates_at_context(db, &lookup_context, &filter).unwrap_or_default();
179        find_closure_generated_candidate(db, canonical_trait.id)
180            .map(|candidate| candidates.insert(candidate));
181        let candidate_solvers = candidates
182            .into_iter()
183            .filter_map(|candidate| {
184                CandidateSolver::new(
185                    db,
186                    &canonical_trait,
187                    candidate,
188                    &lookup_context,
189                    impl_type_bounds.clone(),
190                )
191                .ok()
192            })
193            .collect();
194
195        Self { canonical_trait, lookup_context, candidate_solvers }
196    }
197
198    pub fn solution_set(&mut self, db: &dyn SemanticGroup) -> SolutionSet<CanonicalImpl> {
199        let mut unique_solution: Option<CanonicalImpl> = None;
200        for candidate_solver in &mut self.candidate_solvers {
201            let Ok(candidate_solution_set) = candidate_solver.solution_set(db) else {
202                continue;
203            };
204
205            let candidate_solution = match candidate_solution_set {
206                SolutionSet::None => continue,
207                SolutionSet::Unique(candidate_solution) => candidate_solution,
208                SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
209            };
210            if let Some(unique_solution) = unique_solution {
211                // There might be multiple unique solutions from different candidates that are
212                // solved to the same impl id (e.g. finding it near the trait, and
213                // through an impl alias). This is valid.
214                if unique_solution.0 != candidate_solution.0 {
215                    return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
216                        concrete_trait_id: self.canonical_trait.id,
217                        impls: vec![unique_solution.0, candidate_solution.0],
218                    });
219                }
220            }
221            unique_solution = Some(candidate_solution);
222        }
223        unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
224    }
225}
226
227/// A solver for a candidate to a canonical trait.
228#[derive(Debug)]
229pub struct CandidateSolver {
230    pub candidate: UninferredImpl,
231    inference_data: InferenceData,
232    canonical_embedding: CanonicalMapping,
233    candidate_impl: ImplId,
234    pub lookup_context: ImplLookupContext,
235}
236impl CandidateSolver {
237    fn new(
238        db: &dyn SemanticGroup,
239        canonical_trait: &CanonicalTrait,
240        candidate: UninferredImpl,
241        lookup_context: &ImplLookupContext,
242        impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
243    ) -> InferenceResult<CandidateSolver> {
244        let mut inference_data: InferenceData = InferenceData::new(InferenceId::Canonical);
245        let mut inference = inference_data.inference(db);
246        inference.data.impl_type_bounds = impl_type_bounds.clone();
247        let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
248
249        // If the closure params are not var free, we cannot infer the negative impl.
250        // We use the canonical trait concretize the closure params.
251        if let UninferredImpl::GeneratedImpl(imp) = candidate {
252            inference.conform_traits(imp.lookup_intern(db).concrete_trait, canonical_trait.id)?;
253        }
254
255        // Add the defining module of the candidate to the lookup.
256        let mut lookup_context = lookup_context.clone();
257        lookup_context.insert_lookup_scope(db, &candidate);
258        // Instantiate the candidate in the inference table.
259        let candidate_impl =
260            inference.infer_impl(candidate, canonical_trait.id, &lookup_context, None)?;
261        for (trait_type, ty) in canonical_trait.mappings.types.iter() {
262            let mapped_ty =
263                inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
264
265            // Conform the candidate's type to the trait's type.
266            inference.conform_ty(mapped_ty, *ty)?;
267        }
268        for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
269            let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
270                candidate_impl,
271                *trait_const,
272                db,
273            ))?;
274            // Conform the candidate's constant to the trait's constant.
275            inference.conform_const(mapped_const_id, *const_id)?;
276        }
277
278        for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
279            let mapped_impl_id =
280                inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
281            // Conform the candidate's impl to the trait's impl.
282            inference.conform_impl(mapped_impl_id, *impl_id)?;
283        }
284
285        Ok(CandidateSolver {
286            candidate,
287            inference_data,
288            canonical_embedding,
289            candidate_impl,
290            lookup_context,
291        })
292    }
293    fn solution_set(
294        &mut self,
295        db: &dyn SemanticGroup,
296    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
297        let mut inference = self.inference_data.inference(db);
298        let solution_set = inference.solution_set()?;
299        Ok(match solution_set {
300            SolutionSet::None => SolutionSet::None,
301            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
302            SolutionSet::Unique(_) => {
303                let candidate_impl = inference.rewrite(self.candidate_impl).no_err();
304                match CanonicalImpl::canonicalize(db, candidate_impl, &self.canonical_embedding) {
305                    Ok(canonical_impl) => {
306                        inference.validate_neg_impls(&self.lookup_context, canonical_impl)?
307                    }
308                    Err(MapperError(var)) => {
309                        return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
310                            impl_id: candidate_impl,
311                            var,
312                        }));
313                    }
314                }
315            }
316        })
317    }
318}