cairo_lang_semantic/expr/inference/
infers.rs

1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13    GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16    ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17    ConcreteTraitTypeId,
18};
19use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
20use crate::types::ImplTypeId;
21use crate::{
22    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
23    FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
24};
25
26/// Functions for embedding generic semantic objects in an existing [Inference] object, by
27/// introducing new variables.
28pub trait InferenceEmbeddings {
29    fn infer_impl(
30        &mut self,
31        uninferred_impl: UninferredImpl,
32        concrete_trait_id: ConcreteTraitId,
33        lookup_context: &ImplLookupContext,
34        stable_ptr: Option<SyntaxStablePtrId>,
35    ) -> InferenceResult<ImplId>;
36    fn infer_impl_def(
37        &mut self,
38        impl_def_id: ImplDefId,
39        concrete_trait_id: ConcreteTraitId,
40        lookup_context: &ImplLookupContext,
41        stable_ptr: Option<SyntaxStablePtrId>,
42    ) -> InferenceResult<ImplId>;
43    fn infer_impl_alias(
44        &mut self,
45        impl_alias_id: ImplAliasId,
46        concrete_trait_id: ConcreteTraitId,
47        lookup_context: &ImplLookupContext,
48        stable_ptr: Option<SyntaxStablePtrId>,
49    ) -> InferenceResult<ImplId>;
50    fn infer_generic_assignment(
51        &mut self,
52        generic_params: &[GenericParam],
53        generic_args: &[GenericArgumentId],
54        expected_generic_args: &[GenericArgumentId],
55        lookup_context: &ImplLookupContext,
56        stable_ptr: Option<SyntaxStablePtrId>,
57    ) -> InferenceResult<Vec<GenericArgumentId>>;
58    fn infer_generic_args(
59        &mut self,
60        generic_params: &[GenericParam],
61        lookup_context: &ImplLookupContext,
62        stable_ptr: Option<SyntaxStablePtrId>,
63    ) -> InferenceResult<Vec<GenericArgumentId>>;
64    fn infer_concrete_trait_by_self(
65        &mut self,
66        trait_function: TraitFunctionId,
67        self_ty: TypeId,
68        lookup_context: &ImplLookupContext,
69        stable_ptr: Option<SyntaxStablePtrId>,
70        inference_error_cb: impl FnOnce(InferenceError),
71    ) -> Option<(ConcreteTraitId, usize)>;
72    fn infer_generic_arg(
73        &mut self,
74        param: &GenericParam,
75        lookup_context: ImplLookupContext,
76        stable_ptr: Option<SyntaxStablePtrId>,
77    ) -> InferenceResult<GenericArgumentId>;
78    fn infer_trait_function(
79        &mut self,
80        concrete_trait_function: ConcreteTraitGenericFunctionId,
81        lookup_context: &ImplLookupContext,
82        stable_ptr: Option<SyntaxStablePtrId>,
83    ) -> InferenceResult<FunctionId>;
84    fn infer_generic_function(
85        &mut self,
86        generic_function: GenericFunctionId,
87        lookup_context: &ImplLookupContext,
88        stable_ptr: Option<SyntaxStablePtrId>,
89    ) -> InferenceResult<FunctionId>;
90    fn infer_trait_generic_function(
91        &mut self,
92        concrete_trait_function: ConcreteTraitGenericFunctionId,
93        lookup_context: &ImplLookupContext,
94        stable_ptr: Option<SyntaxStablePtrId>,
95    ) -> GenericFunctionId;
96    fn infer_trait_type(
97        &mut self,
98        concrete_trait_type: ConcreteTraitTypeId,
99        lookup_context: &ImplLookupContext,
100        stable_ptr: Option<SyntaxStablePtrId>,
101    ) -> TypeId;
102    fn infer_trait_constant(
103        &mut self,
104        concrete_trait_constant: ConcreteTraitConstantId,
105        lookup_context: &ImplLookupContext,
106        stable_ptr: Option<SyntaxStablePtrId>,
107    ) -> ImplConstantId;
108    fn infer_trait_impl(
109        &mut self,
110        concrete_trait_constant: ConcreteTraitImplId,
111        lookup_context: &ImplLookupContext,
112        stable_ptr: Option<SyntaxStablePtrId>,
113    ) -> ImplImplId;
114}
115
116impl InferenceEmbeddings for Inference<'_> {
117    /// Infers all the variables required to make an uninferred impl provide a concrete trait.
118    fn infer_impl(
119        &mut self,
120        uninferred_impl: UninferredImpl,
121        concrete_trait_id: ConcreteTraitId,
122        lookup_context: &ImplLookupContext,
123        stable_ptr: Option<SyntaxStablePtrId>,
124    ) -> InferenceResult<ImplId> {
125        let impl_id = match uninferred_impl {
126            UninferredImpl::Def(impl_def_id) => {
127                self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
128            }
129            UninferredImpl::ImplAlias(impl_alias_id) => {
130                self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
131            }
132            UninferredImpl::ImplImpl(impl_impl_id) => {
133                ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
134            }
135            UninferredImpl::GenericParam(param_id) => {
136                let param = self
137                    .db
138                    .generic_param_semantic(param_id)
139                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
140                let param = extract_matches!(param, GenericParam::Impl);
141                let imp_concrete_trait_id = param.concrete_trait.unwrap();
142                self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
143                ImplLongId::GenericParameter(param_id).intern(self.db)
144            }
145            UninferredImpl::GeneratedImpl(generated_impl) => {
146                let long_id = generated_impl.lookup_intern(self.db);
147
148                // Only making sure the args can be inferred - as they are unused later.
149                self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
150
151                ImplLongId::GeneratedImpl(
152                    GeneratedImplLongId {
153                        concrete_trait: long_id.concrete_trait,
154                        generic_params: long_id.generic_params,
155                        impl_items: long_id.impl_items,
156                    }
157                    .intern(self.db),
158                )
159                .intern(self.db)
160            }
161        };
162        Ok(impl_id)
163    }
164
165    /// Infers all the variables required to make an impl (possibly with free generic params)
166    /// provide a concrete trait.
167    fn infer_impl_def(
168        &mut self,
169        impl_def_id: ImplDefId,
170        concrete_trait_id: ConcreteTraitId,
171        lookup_context: &ImplLookupContext,
172        stable_ptr: Option<SyntaxStablePtrId>,
173    ) -> InferenceResult<ImplId> {
174        let imp_generic_params = self
175            .db
176            .impl_def_generic_params(impl_def_id)
177            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
178        let imp_concrete_trait = self
179            .db
180            .impl_def_concrete_trait(impl_def_id)
181            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
182        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
183            return Err(self.set_error(InferenceError::TraitMismatch {
184                trt0: imp_concrete_trait.trait_id(self.db),
185                trt1: concrete_trait_id.trait_id(self.db),
186            }));
187        }
188
189        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
190        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
191        let generic_args = self.infer_generic_assignment(
192            &imp_generic_params,
193            &long_imp_concrete_trait.generic_args,
194            &long_concrete_trait.generic_args,
195            lookup_context,
196            stable_ptr,
197        )?;
198        Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
199            .intern(self.db))
200    }
201
202    /// Infers all the variables required to make an impl alias (possibly with free generic params)
203    /// provide a concrete trait.
204    fn infer_impl_alias(
205        &mut self,
206        impl_alias_id: ImplAliasId,
207        concrete_trait_id: ConcreteTraitId,
208        lookup_context: &ImplLookupContext,
209        stable_ptr: Option<SyntaxStablePtrId>,
210    ) -> InferenceResult<ImplId> {
211        let impl_alias_generic_params = self
212            .db
213            .impl_alias_generic_params(impl_alias_id)
214            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
215        let impl_id = self
216            .db
217            .impl_alias_resolved_impl(impl_alias_id)
218            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
219        let imp_concrete_trait = impl_id
220            .concrete_trait(self.db)
221            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
222        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
223            return Err(self.set_error(InferenceError::TraitMismatch {
224                trt0: imp_concrete_trait.trait_id(self.db),
225                trt1: concrete_trait_id.trait_id(self.db),
226            }));
227        }
228
229        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
230        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
231        let generic_args = self.infer_generic_assignment(
232            &impl_alias_generic_params,
233            &long_imp_concrete_trait.generic_args,
234            &long_concrete_trait.generic_args,
235            lookup_context,
236            stable_ptr,
237        )?;
238
239        SubstitutionRewriter {
240            db: self.db,
241            substitution: &GenericSubstitution::new(&impl_alias_generic_params, &generic_args),
242        }
243        .rewrite(impl_id)
244        .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
245    }
246
247    /// Chooses and assignment to generic_params s.t. generic_args will be substituted to
248    /// expected_generic_args.
249    /// Returns the generic_params assignment.
250    fn infer_generic_assignment(
251        &mut self,
252        generic_params: &[GenericParam],
253        generic_args: &[GenericArgumentId],
254        expected_generic_args: &[GenericArgumentId],
255        lookup_context: &ImplLookupContext,
256        stable_ptr: Option<SyntaxStablePtrId>,
257    ) -> InferenceResult<Vec<GenericArgumentId>> {
258        let new_generic_args =
259            self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
260        let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
261        let mut rewriter = SubstitutionRewriter { db: self.db, substitution: &substitution };
262        let generic_args = rewriter
263            .rewrite(generic_args.iter().copied().collect_vec())
264            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
265        self.conform_generic_args(&generic_args, expected_generic_args)?;
266        Ok(self.rewrite(new_generic_args).no_err())
267    }
268
269    /// Infers all generic_arguments given the parameters.
270    fn infer_generic_args(
271        &mut self,
272        generic_params: &[GenericParam],
273        lookup_context: &ImplLookupContext,
274        stable_ptr: Option<SyntaxStablePtrId>,
275    ) -> InferenceResult<Vec<GenericArgumentId>> {
276        let mut generic_args = vec![];
277        let mut substitution = GenericSubstitution::default();
278        for generic_param in generic_params {
279            let generic_param = SubstitutionRewriter { db: self.db, substitution: &substitution }
280                .rewrite(generic_param.clone())
281                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
282            let generic_arg =
283                self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
284            generic_args.push(generic_arg);
285            substitution.insert(generic_param.id(), generic_arg);
286        }
287        Ok(generic_args)
288    }
289
290    /// Tries to infer a trait function as a method for `self_ty`.
291    /// Supports snapshot snapshot coercions.
292    ///
293    /// Returns the deduced type and the number of snapshots that need to be added to it.
294    ///
295    /// `inference_error_cb` is called for inference errors, but they are not reported here as
296    /// diagnostics. The caller has to make sure the diagnostics are reported appropriately.
297    fn infer_concrete_trait_by_self(
298        &mut self,
299        trait_function: TraitFunctionId,
300        self_ty: TypeId,
301        lookup_context: &ImplLookupContext,
302        stable_ptr: Option<SyntaxStablePtrId>,
303        inference_error_cb: impl FnOnce(InferenceError),
304    ) -> Option<(ConcreteTraitId, usize)> {
305        let trait_id = trait_function.trait_id(self.db.upcast());
306        let signature = self.db.trait_function_signature(trait_function).ok()?;
307        let first_param = signature.params.into_iter().next()?;
308        require(first_param.name == "self")?;
309
310        let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
311        let trait_generic_args =
312            match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
313                Ok(generic_args) => generic_args,
314                Err(err_set) => {
315                    if let Some(err) = self.consume_error_without_reporting(err_set) {
316                        inference_error_cb(err);
317                    }
318                    return None;
319                }
320            };
321
322        // TODO(yuval): Try to not temporary clone.
323        let mut tmp_inference_data = self.temporary_clone();
324        let mut tmp_inference = tmp_inference_data.inference(self.db);
325        let function_generic_params =
326            tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
327        let function_generic_args =
328            // TODO(yuval): consider getting the substitution from inside `infer_generic_args`
329            // instead of creating it again here.
330            match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
331                Ok(generic_args) => generic_args,
332                Err(err_set) => {
333                    if let Some(err) = self.consume_error_without_reporting(err_set) {
334                        inference_error_cb(err);
335                    }
336                    return None;
337                }
338            };
339
340        let trait_substitution =
341            GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
342        let function_substitution =
343            GenericSubstitution::new(&function_generic_params, &function_generic_args);
344        let substitution = trait_substitution.concat(function_substitution);
345        let mut rewriter = SubstitutionRewriter { db: self.db, substitution: &substitution };
346
347        let fixed_param_ty = rewriter.rewrite(first_param.ty).ok()?;
348        let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
349            Ok(conform) => conform,
350            Err(err_set) => {
351                if let Some(err) = self.consume_error_without_reporting(err_set) {
352                    inference_error_cb(err);
353                }
354                return None;
355            }
356        };
357
358        let generic_args = self.rewrite(trait_generic_args).no_err();
359
360        Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
361    }
362
363    /// Infers a generic argument to be passed as a generic parameter.
364    /// Allocates a new inference variable of the correct kind, and wraps in a generic argument.
365    fn infer_generic_arg(
366        &mut self,
367        param: &GenericParam,
368        lookup_context: ImplLookupContext,
369        stable_ptr: Option<SyntaxStablePtrId>,
370    ) -> InferenceResult<GenericArgumentId> {
371        match param {
372            GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
373            GenericParam::Impl(param) => {
374                let concrete_trait_id = param
375                    .concrete_trait
376                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
377                let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
378                for (trait_ty, ty1) in param.type_constraints.iter() {
379                    let ty0 = self.reduce_impl_ty(ImplTypeId::new(
380                        impl_id,
381                        trait_ty.trait_type(self.db),
382                        self.db,
383                    ))?;
384                    // Conforming the type will always work as the impl is a new inference variable.
385                    self.conform_ty(ty0, *ty1).ok();
386                }
387                Ok(GenericArgumentId::Impl(impl_id))
388            }
389            GenericParam::Const(GenericParamConst { ty, .. }) => {
390                Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
391            }
392            GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
393        }
394    }
395
396    /// Infers the impl to be substituted instead of a trait for a given trait function,
397    /// and the generic arguments to be passed to the function.
398    /// Returns the resulting impl function.
399    fn infer_trait_function(
400        &mut self,
401        concrete_trait_function: ConcreteTraitGenericFunctionId,
402        lookup_context: &ImplLookupContext,
403        stable_ptr: Option<SyntaxStablePtrId>,
404    ) -> InferenceResult<FunctionId> {
405        let generic_function =
406            self.infer_trait_generic_function(concrete_trait_function, lookup_context, stable_ptr);
407        self.infer_generic_function(generic_function, lookup_context, stable_ptr)
408    }
409
410    /// Infers generic arguments to be passed to a generic function.
411    /// Returns the resulting specialized function.
412    fn infer_generic_function(
413        &mut self,
414        generic_function: GenericFunctionId,
415        lookup_context: &ImplLookupContext,
416        stable_ptr: Option<SyntaxStablePtrId>,
417    ) -> InferenceResult<FunctionId> {
418        let generic_params = generic_function
419            .generic_params(self.db)
420            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
421        let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
422        Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
423            .intern(self.db))
424    }
425
426    /// Infers the impl to be substituted instead of a trait for a given trait function.
427    /// Returns the resulting impl generic function.
428    fn infer_trait_generic_function(
429        &mut self,
430        concrete_trait_function: ConcreteTraitGenericFunctionId,
431        lookup_context: &ImplLookupContext,
432        stable_ptr: Option<SyntaxStablePtrId>,
433    ) -> GenericFunctionId {
434        let impl_id = self.new_impl_var(
435            concrete_trait_function.concrete_trait(self.db),
436            stable_ptr,
437            lookup_context.clone(),
438        );
439        GenericFunctionId::Impl(ImplGenericFunctionId {
440            impl_id,
441            function: concrete_trait_function.trait_function(self.db),
442        })
443    }
444
445    /// Infers the impl to be substituted instead of a trait for a given trait type.
446    /// Returns the resulting impl type.
447    fn infer_trait_type(
448        &mut self,
449        concrete_trait_type: ConcreteTraitTypeId,
450        lookup_context: &ImplLookupContext,
451        stable_ptr: Option<SyntaxStablePtrId>,
452    ) -> TypeId {
453        let impl_id = self.new_impl_var(
454            concrete_trait_type.concrete_trait(self.db),
455            stable_ptr,
456            lookup_context.clone(),
457        );
458        TypeLongId::ImplType(ImplTypeId::new(
459            impl_id,
460            concrete_trait_type.trait_type(self.db),
461            self.db,
462        ))
463        .intern(self.db)
464    }
465
466    /// Infers the impl to be substituted instead of a trait for a given trait constant.
467    /// Returns the resulting impl constant.
468    fn infer_trait_constant(
469        &mut self,
470        concrete_trait_constant: ConcreteTraitConstantId,
471        lookup_context: &ImplLookupContext,
472        stable_ptr: Option<SyntaxStablePtrId>,
473    ) -> ImplConstantId {
474        let impl_id = self.new_impl_var(
475            concrete_trait_constant.concrete_trait(self.db),
476            stable_ptr,
477            lookup_context.clone(),
478        );
479
480        ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
481    }
482
483    /// Infers the impl to be substituted instead of a trait for a given trait impl.
484    /// Returns the resulting impl impl.
485    fn infer_trait_impl(
486        &mut self,
487        concrete_trait_impl: ConcreteTraitImplId,
488        lookup_context: &ImplLookupContext,
489        stable_ptr: Option<SyntaxStablePtrId>,
490    ) -> ImplImplId {
491        let impl_id = self.new_impl_var(
492            concrete_trait_impl.concrete_trait(self.db),
493            stable_ptr,
494            lookup_context.clone(),
495        );
496
497        ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
498    }
499}