cairo_lang_semantic/expr/inference/
solver.rs1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::ids::LanguageElementId;
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::LookupIntern;
8use itertools::Itertools;
9
10use super::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, MapperError, ResultNoErrEx};
11use super::conform::InferenceConform;
12use super::infers::InferenceEmbeddings;
13use super::{
14 ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
15 InferenceVar, LocalImplVarId,
16};
17use crate::db::SemanticGroup;
18use crate::items::constant::ImplConstantId;
19use crate::items::imp::{
20 ImplId, ImplImplId, ImplLookupContext, UninferredImpl, find_candidates_at_context,
21 find_closure_generated_candidate,
22};
23use crate::substitution::SemanticRewriter;
24use crate::types::{ImplTypeById, ImplTypeId};
25use crate::{ConcreteTraitId, GenericArgumentId, TypeId, TypeLongId};
26
27#[derive(Clone, PartialEq, Eq, Debug)]
29pub enum SolutionSet<T> {
30 None,
31 Unique(T),
32 Ambiguous(Ambiguity),
33}
34
35#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
37pub enum Ambiguity {
38 MultipleImplsFound {
39 concrete_trait_id: ConcreteTraitId,
40 impls: Vec<ImplId>,
41 },
42 FreeVariable {
43 impl_id: ImplId,
44 #[dont_rewrite]
45 var: InferenceVar,
46 },
47 WillNotInfer(ConcreteTraitId),
48 NegativeImplWithUnresolvedGenericArgs {
49 impl_id: ImplId,
50 ty: TypeId,
51 },
52}
53impl Ambiguity {
54 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
55 match self {
56 Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
57 let impls_str =
58 impls.iter().map(|imp| format!("`{}`", imp.format(db.upcast()))).join(", ");
59 format!(
60 "Trait `{:?}` has multiple implementations, in: {impls_str}",
61 concrete_trait_id.debug(db)
62 )
63 }
64 Ambiguity::FreeVariable { impl_id, var: _ } => {
65 format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
66 }
67 Ambiguity::WillNotInfer(concrete_trait_id) => {
68 format!(
69 "Cannot infer trait {:?}. First generic argument must be known.",
70 concrete_trait_id.debug(db)
71 )
72 }
73 Ambiguity::NegativeImplWithUnresolvedGenericArgs { impl_id, ty } => format!(
74 "Cannot infer negative impl in `{}` as it contains the unresolved type `{}`",
75 impl_id.format(db),
76 ty.format(db)
77 ),
78 }
79 }
80}
81
82pub fn canonic_trait_solutions(
85 db: &dyn SemanticGroup,
86 canonical_trait: CanonicalTrait,
87 lookup_context: ImplLookupContext,
88 impl_type_bounds: BTreeMap<ImplTypeById, TypeId>,
89) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
90 let mut concrete_trait_id = canonical_trait.id;
91 let impl_type_bounds = Arc::new(impl_type_bounds);
92 if !concrete_trait_id.is_fully_concrete(db) {
95 let mut solver =
96 Solver::new(db, canonical_trait, lookup_context.clone(), impl_type_bounds.clone());
97 match solver.solution_set(db) {
98 SolutionSet::None => {}
99 SolutionSet::Unique(imp) => {
100 concrete_trait_id =
101 imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
102 }
103 SolutionSet::Ambiguous(ambiguity) => {
104 return Ok(SolutionSet::Ambiguous(ambiguity));
105 }
106 }
107 }
108 let mut solver = Solver::new(
110 db,
111 CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
112 lookup_context,
113 impl_type_bounds,
114 );
115
116 Ok(solver.solution_set(db))
117}
118
119pub fn canonic_trait_solutions_cycle(
121 _db: &dyn SemanticGroup,
122 _cycle: &salsa::Cycle,
123 _canonical_trait: &CanonicalTrait,
124 _lookup_context: &ImplLookupContext,
125 _impl_type_bounds: &BTreeMap<ImplTypeById, TypeId>,
126) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
127 Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
128}
129
130pub fn enrich_lookup_context(
132 db: &dyn SemanticGroup,
133 concrete_trait_id: ConcreteTraitId,
134 lookup_context: &mut ImplLookupContext,
135) {
136 lookup_context.insert_module(concrete_trait_id.trait_id(db).module_file_id(db.upcast()).0);
137 let generic_args = concrete_trait_id.generic_args(db);
138 for generic_arg in &generic_args {
140 if let GenericArgumentId::Type(ty) = generic_arg {
141 match ty.lookup_intern(db) {
142 TypeLongId::Concrete(concrete) => {
143 lookup_context
144 .insert_module(concrete.generic_type(db).module_file_id(db.upcast()).0);
145 }
146 TypeLongId::Coupon(function_id) => {
147 if let Some(module_file_id) =
148 function_id.get_concrete(db).generic_function.module_file_id(db)
149 {
150 lookup_context.insert_module(module_file_id.0);
151 }
152 }
153 TypeLongId::ImplType(impl_type_id) => {
154 lookup_context.insert_impl(impl_type_id.impl_id());
155 }
156 _ => (),
157 }
158 }
159 }
160}
161
162#[derive(Debug)]
164pub struct Solver {
165 pub canonical_trait: CanonicalTrait,
166 pub lookup_context: ImplLookupContext,
167 candidate_solvers: Vec<CandidateSolver>,
168}
169impl Solver {
170 fn new(
171 db: &dyn SemanticGroup,
172 canonical_trait: CanonicalTrait,
173 lookup_context: ImplLookupContext,
174 impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
175 ) -> Self {
176 let filter = canonical_trait.id.filter(db);
177 let mut candidates =
178 find_candidates_at_context(db, &lookup_context, &filter).unwrap_or_default();
179 find_closure_generated_candidate(db, canonical_trait.id)
180 .map(|candidate| candidates.insert(candidate));
181 let candidate_solvers = candidates
182 .into_iter()
183 .filter_map(|candidate| {
184 CandidateSolver::new(
185 db,
186 &canonical_trait,
187 candidate,
188 &lookup_context,
189 impl_type_bounds.clone(),
190 )
191 .ok()
192 })
193 .collect();
194
195 Self { canonical_trait, lookup_context, candidate_solvers }
196 }
197
198 pub fn solution_set(&mut self, db: &dyn SemanticGroup) -> SolutionSet<CanonicalImpl> {
199 let mut unique_solution: Option<CanonicalImpl> = None;
200 for candidate_solver in &mut self.candidate_solvers {
201 let Ok(candidate_solution_set) = candidate_solver.solution_set(db) else {
202 continue;
203 };
204
205 let candidate_solution = match candidate_solution_set {
206 SolutionSet::None => continue,
207 SolutionSet::Unique(candidate_solution) => candidate_solution,
208 SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
209 };
210 if let Some(unique_solution) = unique_solution {
211 if unique_solution.0 != candidate_solution.0 {
215 return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
216 concrete_trait_id: self.canonical_trait.id,
217 impls: vec![unique_solution.0, candidate_solution.0],
218 });
219 }
220 }
221 unique_solution = Some(candidate_solution);
222 }
223 unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
224 }
225}
226
227#[derive(Debug)]
229pub struct CandidateSolver {
230 pub candidate: UninferredImpl,
231 inference_data: InferenceData,
232 canonical_embedding: CanonicalMapping,
233 candidate_impl: ImplId,
234 pub lookup_context: ImplLookupContext,
235}
236impl CandidateSolver {
237 fn new(
238 db: &dyn SemanticGroup,
239 canonical_trait: &CanonicalTrait,
240 candidate: UninferredImpl,
241 lookup_context: &ImplLookupContext,
242 impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
243 ) -> InferenceResult<CandidateSolver> {
244 let mut inference_data: InferenceData = InferenceData::new(InferenceId::Canonical);
245 let mut inference = inference_data.inference(db);
246 inference.data.impl_type_bounds = impl_type_bounds.clone();
247 let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
248
249 if let UninferredImpl::GeneratedImpl(imp) = candidate {
252 inference.conform_traits(imp.lookup_intern(db).concrete_trait, canonical_trait.id)?;
253 }
254
255 let mut lookup_context = lookup_context.clone();
257 lookup_context.insert_lookup_scope(db, &candidate);
258 let candidate_impl =
260 inference.infer_impl(candidate, canonical_trait.id, &lookup_context, None)?;
261 for (trait_type, ty) in canonical_trait.mappings.types.iter() {
262 let mapped_ty =
263 inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
264
265 inference.conform_ty(mapped_ty, *ty)?;
267 }
268 for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
269 let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
270 candidate_impl,
271 *trait_const,
272 db,
273 ))?;
274 inference.conform_const(mapped_const_id, *const_id)?;
276 }
277
278 for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
279 let mapped_impl_id =
280 inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
281 inference.conform_impl(mapped_impl_id, *impl_id)?;
283 }
284
285 Ok(CandidateSolver {
286 candidate,
287 inference_data,
288 canonical_embedding,
289 candidate_impl,
290 lookup_context,
291 })
292 }
293 fn solution_set(
294 &mut self,
295 db: &dyn SemanticGroup,
296 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
297 let mut inference = self.inference_data.inference(db);
298 let solution_set = inference.solution_set()?;
299 Ok(match solution_set {
300 SolutionSet::None => SolutionSet::None,
301 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
302 SolutionSet::Unique(_) => {
303 let candidate_impl = inference.rewrite(self.candidate_impl).no_err();
304 match CanonicalImpl::canonicalize(db, candidate_impl, &self.canonical_embedding) {
305 Ok(canonical_impl) => {
306 inference.validate_neg_impls(&self.lookup_context, canonical_impl)?
307 }
308 Err(MapperError(var)) => {
309 return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
310 impl_id: candidate_impl,
311 var,
312 }));
313 }
314 }
315 }
316 })
317 }
318}