1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13 GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16 ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17 ConcreteTraitTypeId,
18};
19use crate::substitution::{GenericSubstitution, SemanticRewriter};
20use crate::types::ImplTypeId;
21use crate::{
22 ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
23 FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
24};
25
26pub trait InferenceEmbeddings {
29 fn infer_impl(
30 &mut self,
31 uninferred_impl: UninferredImpl,
32 concrete_trait_id: ConcreteTraitId,
33 lookup_context: &ImplLookupContext,
34 stable_ptr: Option<SyntaxStablePtrId>,
35 ) -> InferenceResult<ImplId>;
36 fn infer_impl_def(
37 &mut self,
38 impl_def_id: ImplDefId,
39 concrete_trait_id: ConcreteTraitId,
40 lookup_context: &ImplLookupContext,
41 stable_ptr: Option<SyntaxStablePtrId>,
42 ) -> InferenceResult<ImplId>;
43 fn infer_impl_alias(
44 &mut self,
45 impl_alias_id: ImplAliasId,
46 concrete_trait_id: ConcreteTraitId,
47 lookup_context: &ImplLookupContext,
48 stable_ptr: Option<SyntaxStablePtrId>,
49 ) -> InferenceResult<ImplId>;
50 fn infer_generic_assignment(
51 &mut self,
52 generic_params: &[GenericParam],
53 generic_args: &[GenericArgumentId],
54 expected_generic_args: &[GenericArgumentId],
55 lookup_context: &ImplLookupContext,
56 stable_ptr: Option<SyntaxStablePtrId>,
57 ) -> InferenceResult<Vec<GenericArgumentId>>;
58 fn infer_generic_args(
59 &mut self,
60 generic_params: &[GenericParam],
61 lookup_context: &ImplLookupContext,
62 stable_ptr: Option<SyntaxStablePtrId>,
63 ) -> InferenceResult<Vec<GenericArgumentId>>;
64 fn infer_concrete_trait_by_self(
65 &mut self,
66 trait_function: TraitFunctionId,
67 self_ty: TypeId,
68 lookup_context: &ImplLookupContext,
69 stable_ptr: Option<SyntaxStablePtrId>,
70 inference_error_cb: impl FnOnce(InferenceError),
71 ) -> Option<(ConcreteTraitId, usize)>;
72 fn infer_generic_arg(
73 &mut self,
74 param: &GenericParam,
75 lookup_context: ImplLookupContext,
76 stable_ptr: Option<SyntaxStablePtrId>,
77 ) -> InferenceResult<GenericArgumentId>;
78 fn infer_trait_function(
79 &mut self,
80 concrete_trait_function: ConcreteTraitGenericFunctionId,
81 lookup_context: &ImplLookupContext,
82 stable_ptr: Option<SyntaxStablePtrId>,
83 ) -> InferenceResult<FunctionId>;
84 fn infer_generic_function(
85 &mut self,
86 generic_function: GenericFunctionId,
87 lookup_context: &ImplLookupContext,
88 stable_ptr: Option<SyntaxStablePtrId>,
89 ) -> InferenceResult<FunctionId>;
90 fn infer_trait_generic_function(
91 &mut self,
92 concrete_trait_function: ConcreteTraitGenericFunctionId,
93 lookup_context: &ImplLookupContext,
94 stable_ptr: Option<SyntaxStablePtrId>,
95 ) -> ImplGenericFunctionId;
96 fn infer_trait_type(
97 &mut self,
98 concrete_trait_type: ConcreteTraitTypeId,
99 lookup_context: &ImplLookupContext,
100 stable_ptr: Option<SyntaxStablePtrId>,
101 ) -> TypeId;
102 fn infer_trait_constant(
103 &mut self,
104 concrete_trait_constant: ConcreteTraitConstantId,
105 lookup_context: &ImplLookupContext,
106 stable_ptr: Option<SyntaxStablePtrId>,
107 ) -> ImplConstantId;
108 fn infer_trait_impl(
109 &mut self,
110 concrete_trait_constant: ConcreteTraitImplId,
111 lookup_context: &ImplLookupContext,
112 stable_ptr: Option<SyntaxStablePtrId>,
113 ) -> ImplImplId;
114}
115
116impl InferenceEmbeddings for Inference<'_> {
117 fn infer_impl(
119 &mut self,
120 uninferred_impl: UninferredImpl,
121 concrete_trait_id: ConcreteTraitId,
122 lookup_context: &ImplLookupContext,
123 stable_ptr: Option<SyntaxStablePtrId>,
124 ) -> InferenceResult<ImplId> {
125 let impl_id = match uninferred_impl {
126 UninferredImpl::Def(impl_def_id) => {
127 self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
128 }
129 UninferredImpl::ImplAlias(impl_alias_id) => {
130 self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
131 }
132 UninferredImpl::ImplImpl(impl_impl_id) => {
133 ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
134 }
135 UninferredImpl::GenericParam(param_id) => {
136 let param = self
137 .db
138 .generic_param_semantic(param_id)
139 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
140 let param = extract_matches!(param, GenericParam::Impl);
141 let imp_concrete_trait_id = param.concrete_trait.unwrap();
142 self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
143 ImplLongId::GenericParameter(param_id).intern(self.db)
144 }
145 UninferredImpl::GeneratedImpl(generated_impl) => {
146 let long_id = generated_impl.lookup_intern(self.db);
147
148 self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
150
151 ImplLongId::GeneratedImpl(
152 GeneratedImplLongId {
153 concrete_trait: long_id.concrete_trait,
154 generic_params: long_id.generic_params,
155 impl_items: long_id.impl_items,
156 }
157 .intern(self.db),
158 )
159 .intern(self.db)
160 }
161 };
162 Ok(impl_id)
163 }
164
165 fn infer_impl_def(
168 &mut self,
169 impl_def_id: ImplDefId,
170 concrete_trait_id: ConcreteTraitId,
171 lookup_context: &ImplLookupContext,
172 stable_ptr: Option<SyntaxStablePtrId>,
173 ) -> InferenceResult<ImplId> {
174 let imp_generic_params = self
175 .db
176 .impl_def_generic_params(impl_def_id)
177 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
178 let imp_concrete_trait = self
179 .db
180 .impl_def_concrete_trait(impl_def_id)
181 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
182 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
183 return Err(self.set_error(InferenceError::TraitMismatch {
184 trt0: imp_concrete_trait.trait_id(self.db),
185 trt1: concrete_trait_id.trait_id(self.db),
186 }));
187 }
188
189 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
190 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
191 let generic_args = self.infer_generic_assignment(
192 &imp_generic_params,
193 &long_imp_concrete_trait.generic_args,
194 &long_concrete_trait.generic_args,
195 lookup_context,
196 stable_ptr,
197 )?;
198 Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
199 .intern(self.db))
200 }
201
202 fn infer_impl_alias(
205 &mut self,
206 impl_alias_id: ImplAliasId,
207 concrete_trait_id: ConcreteTraitId,
208 lookup_context: &ImplLookupContext,
209 stable_ptr: Option<SyntaxStablePtrId>,
210 ) -> InferenceResult<ImplId> {
211 let impl_alias_generic_params = self
212 .db
213 .impl_alias_generic_params(impl_alias_id)
214 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
215 let impl_id = self
216 .db
217 .impl_alias_resolved_impl(impl_alias_id)
218 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
219 let imp_concrete_trait = impl_id
220 .concrete_trait(self.db)
221 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
222 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
223 return Err(self.set_error(InferenceError::TraitMismatch {
224 trt0: imp_concrete_trait.trait_id(self.db),
225 trt1: concrete_trait_id.trait_id(self.db),
226 }));
227 }
228
229 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
230 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
231 let generic_args = self.infer_generic_assignment(
232 &impl_alias_generic_params,
233 &long_imp_concrete_trait.generic_args,
234 &long_concrete_trait.generic_args,
235 lookup_context,
236 stable_ptr,
237 )?;
238
239 GenericSubstitution::new(&impl_alias_generic_params, &generic_args)
240 .substitute(self.db, impl_id)
241 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
242 }
243
244 fn infer_generic_assignment(
248 &mut self,
249 generic_params: &[GenericParam],
250 generic_args: &[GenericArgumentId],
251 expected_generic_args: &[GenericArgumentId],
252 lookup_context: &ImplLookupContext,
253 stable_ptr: Option<SyntaxStablePtrId>,
254 ) -> InferenceResult<Vec<GenericArgumentId>> {
255 let new_generic_args =
256 self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
257 let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
258 let generic_args = substitution
259 .substitute(self.db, generic_args.iter().copied().collect_vec())
260 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
261 self.conform_generic_args(&generic_args, expected_generic_args)?;
262 Ok(self.rewrite(new_generic_args).no_err())
263 }
264
265 fn infer_generic_args(
267 &mut self,
268 generic_params: &[GenericParam],
269 lookup_context: &ImplLookupContext,
270 stable_ptr: Option<SyntaxStablePtrId>,
271 ) -> InferenceResult<Vec<GenericArgumentId>> {
272 let mut generic_args = vec![];
273 let mut substitution = GenericSubstitution::default();
274 for generic_param in generic_params {
275 let generic_param = substitution
276 .substitute(self.db, generic_param.clone())
277 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
278 let generic_arg =
279 self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
280 generic_args.push(generic_arg);
281 substitution.insert(generic_param.id(), generic_arg);
282 }
283 Ok(generic_args)
284 }
285
286 fn infer_concrete_trait_by_self(
294 &mut self,
295 trait_function: TraitFunctionId,
296 self_ty: TypeId,
297 lookup_context: &ImplLookupContext,
298 stable_ptr: Option<SyntaxStablePtrId>,
299 inference_error_cb: impl FnOnce(InferenceError),
300 ) -> Option<(ConcreteTraitId, usize)> {
301 let trait_id = trait_function.trait_id(self.db.upcast());
302 let signature = self.db.trait_function_signature(trait_function).ok()?;
303 let first_param = signature.params.into_iter().next()?;
304 require(first_param.name == "self")?;
305
306 let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
307 let trait_generic_args =
308 match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
309 Ok(generic_args) => generic_args,
310 Err(err_set) => {
311 if let Some(err) = self.consume_error_without_reporting(err_set) {
312 inference_error_cb(err);
313 }
314 return None;
315 }
316 };
317
318 let mut tmp_inference_data = self.temporary_clone();
320 let mut tmp_inference = tmp_inference_data.inference(self.db);
321 let function_generic_params =
322 tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
323 let function_generic_args =
324 match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
327 Ok(generic_args) => generic_args,
328 Err(err_set) => {
329 if let Some(err) = self.consume_error_without_reporting(err_set) {
330 inference_error_cb(err);
331 }
332 return None;
333 }
334 };
335
336 let trait_substitution =
337 GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
338 let function_substitution =
339 GenericSubstitution::new(&function_generic_params, &function_generic_args);
340 let substitution = trait_substitution.concat(function_substitution);
341
342 let fixed_param_ty = substitution.substitute(self.db, first_param.ty).ok()?;
343 let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
344 Ok(conform) => conform,
345 Err(err_set) => {
346 if let Some(err) = self.consume_error_without_reporting(err_set) {
347 inference_error_cb(err);
348 }
349 return None;
350 }
351 };
352
353 let generic_args = self.rewrite(trait_generic_args).no_err();
354
355 Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
356 }
357
358 fn infer_generic_arg(
361 &mut self,
362 param: &GenericParam,
363 lookup_context: ImplLookupContext,
364 stable_ptr: Option<SyntaxStablePtrId>,
365 ) -> InferenceResult<GenericArgumentId> {
366 match param {
367 GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
368 GenericParam::Impl(param) => {
369 let concrete_trait_id = param
370 .concrete_trait
371 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
372 let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
373 for (trait_ty, ty1) in param.type_constraints.iter() {
374 let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?;
375 self.conform_ty(ty0, *ty1).ok();
377 }
378 Ok(GenericArgumentId::Impl(impl_id))
379 }
380 GenericParam::Const(GenericParamConst { ty, .. }) => {
381 Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
382 }
383 GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
384 }
385 }
386
387 fn infer_trait_function(
391 &mut self,
392 concrete_trait_function: ConcreteTraitGenericFunctionId,
393 lookup_context: &ImplLookupContext,
394 stable_ptr: Option<SyntaxStablePtrId>,
395 ) -> InferenceResult<FunctionId> {
396 let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
397 concrete_trait_function,
398 lookup_context,
399 stable_ptr,
400 ));
401 self.infer_generic_function(generic_function, lookup_context, stable_ptr)
402 }
403
404 fn infer_generic_function(
407 &mut self,
408 generic_function: GenericFunctionId,
409 lookup_context: &ImplLookupContext,
410 stable_ptr: Option<SyntaxStablePtrId>,
411 ) -> InferenceResult<FunctionId> {
412 let generic_params = generic_function
413 .generic_params(self.db)
414 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
415 let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
416 Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
417 .intern(self.db))
418 }
419
420 fn infer_trait_generic_function(
423 &mut self,
424 concrete_trait_function: ConcreteTraitGenericFunctionId,
425 lookup_context: &ImplLookupContext,
426 stable_ptr: Option<SyntaxStablePtrId>,
427 ) -> ImplGenericFunctionId {
428 let impl_id = self.new_impl_var(
429 concrete_trait_function.concrete_trait(self.db),
430 stable_ptr,
431 lookup_context.clone(),
432 );
433 ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
434 }
435
436 fn infer_trait_type(
439 &mut self,
440 concrete_trait_type: ConcreteTraitTypeId,
441 lookup_context: &ImplLookupContext,
442 stable_ptr: Option<SyntaxStablePtrId>,
443 ) -> TypeId {
444 let impl_id = self.new_impl_var(
445 concrete_trait_type.concrete_trait(self.db),
446 stable_ptr,
447 lookup_context.clone(),
448 );
449 TypeLongId::ImplType(ImplTypeId::new(
450 impl_id,
451 concrete_trait_type.trait_type(self.db),
452 self.db,
453 ))
454 .intern(self.db)
455 }
456
457 fn infer_trait_constant(
460 &mut self,
461 concrete_trait_constant: ConcreteTraitConstantId,
462 lookup_context: &ImplLookupContext,
463 stable_ptr: Option<SyntaxStablePtrId>,
464 ) -> ImplConstantId {
465 let impl_id = self.new_impl_var(
466 concrete_trait_constant.concrete_trait(self.db),
467 stable_ptr,
468 lookup_context.clone(),
469 );
470
471 ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
472 }
473
474 fn infer_trait_impl(
477 &mut self,
478 concrete_trait_impl: ConcreteTraitImplId,
479 lookup_context: &ImplLookupContext,
480 stable_ptr: Option<SyntaxStablePtrId>,
481 ) -> ImplImplId {
482 let impl_id = self.new_impl_var(
483 concrete_trait_impl.concrete_trait(self.db),
484 stable_ptr,
485 lookup_context.clone(),
486 );
487
488 ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
489 }
490}