1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13 GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16 ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17 ConcreteTraitTypeId,
18};
19use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
20use crate::types::ImplTypeId;
21use crate::{
22 ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
23 FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
24};
25
26pub trait InferenceEmbeddings {
29 fn infer_impl(
30 &mut self,
31 uninferred_impl: UninferredImpl,
32 concrete_trait_id: ConcreteTraitId,
33 lookup_context: &ImplLookupContext,
34 stable_ptr: Option<SyntaxStablePtrId>,
35 ) -> InferenceResult<ImplId>;
36 fn infer_impl_def(
37 &mut self,
38 impl_def_id: ImplDefId,
39 concrete_trait_id: ConcreteTraitId,
40 lookup_context: &ImplLookupContext,
41 stable_ptr: Option<SyntaxStablePtrId>,
42 ) -> InferenceResult<ImplId>;
43 fn infer_impl_alias(
44 &mut self,
45 impl_alias_id: ImplAliasId,
46 concrete_trait_id: ConcreteTraitId,
47 lookup_context: &ImplLookupContext,
48 stable_ptr: Option<SyntaxStablePtrId>,
49 ) -> InferenceResult<ImplId>;
50 fn infer_generic_assignment(
51 &mut self,
52 generic_params: &[GenericParam],
53 generic_args: &[GenericArgumentId],
54 expected_generic_args: &[GenericArgumentId],
55 lookup_context: &ImplLookupContext,
56 stable_ptr: Option<SyntaxStablePtrId>,
57 ) -> InferenceResult<Vec<GenericArgumentId>>;
58 fn infer_generic_args(
59 &mut self,
60 generic_params: &[GenericParam],
61 lookup_context: &ImplLookupContext,
62 stable_ptr: Option<SyntaxStablePtrId>,
63 ) -> InferenceResult<Vec<GenericArgumentId>>;
64 fn infer_concrete_trait_by_self(
65 &mut self,
66 trait_function: TraitFunctionId,
67 self_ty: TypeId,
68 lookup_context: &ImplLookupContext,
69 stable_ptr: Option<SyntaxStablePtrId>,
70 inference_error_cb: impl FnOnce(InferenceError),
71 ) -> Option<(ConcreteTraitId, usize)>;
72 fn infer_generic_arg(
73 &mut self,
74 param: &GenericParam,
75 lookup_context: ImplLookupContext,
76 stable_ptr: Option<SyntaxStablePtrId>,
77 ) -> InferenceResult<GenericArgumentId>;
78 fn infer_trait_function(
79 &mut self,
80 concrete_trait_function: ConcreteTraitGenericFunctionId,
81 lookup_context: &ImplLookupContext,
82 stable_ptr: Option<SyntaxStablePtrId>,
83 ) -> InferenceResult<FunctionId>;
84 fn infer_generic_function(
85 &mut self,
86 generic_function: GenericFunctionId,
87 lookup_context: &ImplLookupContext,
88 stable_ptr: Option<SyntaxStablePtrId>,
89 ) -> InferenceResult<FunctionId>;
90 fn infer_trait_generic_function(
91 &mut self,
92 concrete_trait_function: ConcreteTraitGenericFunctionId,
93 lookup_context: &ImplLookupContext,
94 stable_ptr: Option<SyntaxStablePtrId>,
95 ) -> GenericFunctionId;
96 fn infer_trait_type(
97 &mut self,
98 concrete_trait_type: ConcreteTraitTypeId,
99 lookup_context: &ImplLookupContext,
100 stable_ptr: Option<SyntaxStablePtrId>,
101 ) -> TypeId;
102 fn infer_trait_constant(
103 &mut self,
104 concrete_trait_constant: ConcreteTraitConstantId,
105 lookup_context: &ImplLookupContext,
106 stable_ptr: Option<SyntaxStablePtrId>,
107 ) -> ImplConstantId;
108 fn infer_trait_impl(
109 &mut self,
110 concrete_trait_constant: ConcreteTraitImplId,
111 lookup_context: &ImplLookupContext,
112 stable_ptr: Option<SyntaxStablePtrId>,
113 ) -> ImplImplId;
114}
115
116impl InferenceEmbeddings for Inference<'_> {
117 fn infer_impl(
119 &mut self,
120 uninferred_impl: UninferredImpl,
121 concrete_trait_id: ConcreteTraitId,
122 lookup_context: &ImplLookupContext,
123 stable_ptr: Option<SyntaxStablePtrId>,
124 ) -> InferenceResult<ImplId> {
125 let impl_id = match uninferred_impl {
126 UninferredImpl::Def(impl_def_id) => {
127 self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
128 }
129 UninferredImpl::ImplAlias(impl_alias_id) => {
130 self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
131 }
132 UninferredImpl::ImplImpl(impl_impl_id) => {
133 ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
134 }
135 UninferredImpl::GenericParam(param_id) => {
136 let param = self
137 .db
138 .generic_param_semantic(param_id)
139 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
140 let param = extract_matches!(param, GenericParam::Impl);
141 let imp_concrete_trait_id = param.concrete_trait.unwrap();
142 self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
143 ImplLongId::GenericParameter(param_id).intern(self.db)
144 }
145 UninferredImpl::GeneratedImpl(generated_impl) => {
146 let long_id = generated_impl.lookup_intern(self.db);
147
148 self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
150
151 ImplLongId::GeneratedImpl(
152 GeneratedImplLongId {
153 concrete_trait: long_id.concrete_trait,
154 generic_params: long_id.generic_params,
155 impl_items: long_id.impl_items,
156 }
157 .intern(self.db),
158 )
159 .intern(self.db)
160 }
161 };
162 Ok(impl_id)
163 }
164
165 fn infer_impl_def(
168 &mut self,
169 impl_def_id: ImplDefId,
170 concrete_trait_id: ConcreteTraitId,
171 lookup_context: &ImplLookupContext,
172 stable_ptr: Option<SyntaxStablePtrId>,
173 ) -> InferenceResult<ImplId> {
174 let imp_generic_params = self
175 .db
176 .impl_def_generic_params(impl_def_id)
177 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
178 let imp_concrete_trait = self
179 .db
180 .impl_def_concrete_trait(impl_def_id)
181 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
182 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
183 return Err(self.set_error(InferenceError::TraitMismatch {
184 trt0: imp_concrete_trait.trait_id(self.db),
185 trt1: concrete_trait_id.trait_id(self.db),
186 }));
187 }
188
189 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
190 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
191 let generic_args = self.infer_generic_assignment(
192 &imp_generic_params,
193 &long_imp_concrete_trait.generic_args,
194 &long_concrete_trait.generic_args,
195 lookup_context,
196 stable_ptr,
197 )?;
198 Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
199 .intern(self.db))
200 }
201
202 fn infer_impl_alias(
205 &mut self,
206 impl_alias_id: ImplAliasId,
207 concrete_trait_id: ConcreteTraitId,
208 lookup_context: &ImplLookupContext,
209 stable_ptr: Option<SyntaxStablePtrId>,
210 ) -> InferenceResult<ImplId> {
211 let impl_alias_generic_params = self
212 .db
213 .impl_alias_generic_params(impl_alias_id)
214 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
215 let impl_id = self
216 .db
217 .impl_alias_resolved_impl(impl_alias_id)
218 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
219 let imp_concrete_trait = impl_id
220 .concrete_trait(self.db)
221 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
222 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
223 return Err(self.set_error(InferenceError::TraitMismatch {
224 trt0: imp_concrete_trait.trait_id(self.db),
225 trt1: concrete_trait_id.trait_id(self.db),
226 }));
227 }
228
229 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
230 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
231 let generic_args = self.infer_generic_assignment(
232 &impl_alias_generic_params,
233 &long_imp_concrete_trait.generic_args,
234 &long_concrete_trait.generic_args,
235 lookup_context,
236 stable_ptr,
237 )?;
238
239 SubstitutionRewriter {
240 db: self.db,
241 substitution: &GenericSubstitution::new(&impl_alias_generic_params, &generic_args),
242 }
243 .rewrite(impl_id)
244 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
245 }
246
247 fn infer_generic_assignment(
251 &mut self,
252 generic_params: &[GenericParam],
253 generic_args: &[GenericArgumentId],
254 expected_generic_args: &[GenericArgumentId],
255 lookup_context: &ImplLookupContext,
256 stable_ptr: Option<SyntaxStablePtrId>,
257 ) -> InferenceResult<Vec<GenericArgumentId>> {
258 let new_generic_args =
259 self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
260 let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
261 let mut rewriter = SubstitutionRewriter { db: self.db, substitution: &substitution };
262 let generic_args = rewriter
263 .rewrite(generic_args.iter().copied().collect_vec())
264 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
265 self.conform_generic_args(&generic_args, expected_generic_args)?;
266 Ok(self.rewrite(new_generic_args).no_err())
267 }
268
269 fn infer_generic_args(
271 &mut self,
272 generic_params: &[GenericParam],
273 lookup_context: &ImplLookupContext,
274 stable_ptr: Option<SyntaxStablePtrId>,
275 ) -> InferenceResult<Vec<GenericArgumentId>> {
276 let mut generic_args = vec![];
277 let mut substitution = GenericSubstitution::default();
278 for generic_param in generic_params {
279 let generic_param = SubstitutionRewriter { db: self.db, substitution: &substitution }
280 .rewrite(generic_param.clone())
281 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
282 let generic_arg =
283 self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
284 generic_args.push(generic_arg);
285 substitution.insert(generic_param.id(), generic_arg);
286 }
287 Ok(generic_args)
288 }
289
290 fn infer_concrete_trait_by_self(
298 &mut self,
299 trait_function: TraitFunctionId,
300 self_ty: TypeId,
301 lookup_context: &ImplLookupContext,
302 stable_ptr: Option<SyntaxStablePtrId>,
303 inference_error_cb: impl FnOnce(InferenceError),
304 ) -> Option<(ConcreteTraitId, usize)> {
305 let trait_id = trait_function.trait_id(self.db.upcast());
306 let signature = self.db.trait_function_signature(trait_function).ok()?;
307 let first_param = signature.params.into_iter().next()?;
308 require(first_param.name == "self")?;
309
310 let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
311 let trait_generic_args =
312 match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
313 Ok(generic_args) => generic_args,
314 Err(err_set) => {
315 if let Some(err) = self.consume_error_without_reporting(err_set) {
316 inference_error_cb(err);
317 }
318 return None;
319 }
320 };
321
322 let mut tmp_inference_data = self.temporary_clone();
324 let mut tmp_inference = tmp_inference_data.inference(self.db);
325 let function_generic_params =
326 tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
327 let function_generic_args =
328 match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
331 Ok(generic_args) => generic_args,
332 Err(err_set) => {
333 if let Some(err) = self.consume_error_without_reporting(err_set) {
334 inference_error_cb(err);
335 }
336 return None;
337 }
338 };
339
340 let trait_substitution =
341 GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
342 let function_substitution =
343 GenericSubstitution::new(&function_generic_params, &function_generic_args);
344 let substitution = trait_substitution.concat(function_substitution);
345 let mut rewriter = SubstitutionRewriter { db: self.db, substitution: &substitution };
346
347 let fixed_param_ty = rewriter.rewrite(first_param.ty).ok()?;
348 let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
349 Ok(conform) => conform,
350 Err(err_set) => {
351 if let Some(err) = self.consume_error_without_reporting(err_set) {
352 inference_error_cb(err);
353 }
354 return None;
355 }
356 };
357
358 let generic_args = self.rewrite(trait_generic_args).no_err();
359
360 Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
361 }
362
363 fn infer_generic_arg(
366 &mut self,
367 param: &GenericParam,
368 lookup_context: ImplLookupContext,
369 stable_ptr: Option<SyntaxStablePtrId>,
370 ) -> InferenceResult<GenericArgumentId> {
371 match param {
372 GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
373 GenericParam::Impl(param) => {
374 let concrete_trait_id = param
375 .concrete_trait
376 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
377 let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
378 for (trait_ty, ty1) in param.type_constraints.iter() {
379 let ty0 = self.reduce_impl_ty(ImplTypeId::new(
380 impl_id,
381 trait_ty.trait_type(self.db),
382 self.db,
383 ))?;
384 self.conform_ty(ty0, *ty1).ok();
386 }
387 Ok(GenericArgumentId::Impl(impl_id))
388 }
389 GenericParam::Const(GenericParamConst { ty, .. }) => {
390 Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
391 }
392 GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
393 }
394 }
395
396 fn infer_trait_function(
400 &mut self,
401 concrete_trait_function: ConcreteTraitGenericFunctionId,
402 lookup_context: &ImplLookupContext,
403 stable_ptr: Option<SyntaxStablePtrId>,
404 ) -> InferenceResult<FunctionId> {
405 let generic_function =
406 self.infer_trait_generic_function(concrete_trait_function, lookup_context, stable_ptr);
407 self.infer_generic_function(generic_function, lookup_context, stable_ptr)
408 }
409
410 fn infer_generic_function(
413 &mut self,
414 generic_function: GenericFunctionId,
415 lookup_context: &ImplLookupContext,
416 stable_ptr: Option<SyntaxStablePtrId>,
417 ) -> InferenceResult<FunctionId> {
418 let generic_params = generic_function
419 .generic_params(self.db)
420 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
421 let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
422 Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
423 .intern(self.db))
424 }
425
426 fn infer_trait_generic_function(
429 &mut self,
430 concrete_trait_function: ConcreteTraitGenericFunctionId,
431 lookup_context: &ImplLookupContext,
432 stable_ptr: Option<SyntaxStablePtrId>,
433 ) -> GenericFunctionId {
434 let impl_id = self.new_impl_var(
435 concrete_trait_function.concrete_trait(self.db),
436 stable_ptr,
437 lookup_context.clone(),
438 );
439 GenericFunctionId::Impl(ImplGenericFunctionId {
440 impl_id,
441 function: concrete_trait_function.trait_function(self.db),
442 })
443 }
444
445 fn infer_trait_type(
448 &mut self,
449 concrete_trait_type: ConcreteTraitTypeId,
450 lookup_context: &ImplLookupContext,
451 stable_ptr: Option<SyntaxStablePtrId>,
452 ) -> TypeId {
453 let impl_id = self.new_impl_var(
454 concrete_trait_type.concrete_trait(self.db),
455 stable_ptr,
456 lookup_context.clone(),
457 );
458 TypeLongId::ImplType(ImplTypeId::new(
459 impl_id,
460 concrete_trait_type.trait_type(self.db),
461 self.db,
462 ))
463 .intern(self.db)
464 }
465
466 fn infer_trait_constant(
469 &mut self,
470 concrete_trait_constant: ConcreteTraitConstantId,
471 lookup_context: &ImplLookupContext,
472 stable_ptr: Option<SyntaxStablePtrId>,
473 ) -> ImplConstantId {
474 let impl_id = self.new_impl_var(
475 concrete_trait_constant.concrete_trait(self.db),
476 stable_ptr,
477 lookup_context.clone(),
478 );
479
480 ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
481 }
482
483 fn infer_trait_impl(
486 &mut self,
487 concrete_trait_impl: ConcreteTraitImplId,
488 lookup_context: &ImplLookupContext,
489 stable_ptr: Option<SyntaxStablePtrId>,
490 ) -> ImplImplId {
491 let impl_id = self.new_impl_var(
492 concrete_trait_impl.concrete_trait(self.db),
493 stable_ptr,
494 lookup_context.clone(),
495 );
496
497 ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
498 }
499}