cairo_lang_semantic/expr/inference/
infers.rs

1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13    GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16    ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17    ConcreteTraitTypeId,
18};
19use crate::substitution::{GenericSubstitution, SemanticRewriter};
20use crate::types::ImplTypeId;
21use crate::{
22    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
23    FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
24};
25
26/// Functions for embedding generic semantic objects in an existing [Inference] object, by
27/// introducing new variables.
28pub trait InferenceEmbeddings {
29    fn infer_impl(
30        &mut self,
31        uninferred_impl: UninferredImpl,
32        concrete_trait_id: ConcreteTraitId,
33        lookup_context: &ImplLookupContext,
34        stable_ptr: Option<SyntaxStablePtrId>,
35    ) -> InferenceResult<ImplId>;
36    fn infer_impl_def(
37        &mut self,
38        impl_def_id: ImplDefId,
39        concrete_trait_id: ConcreteTraitId,
40        lookup_context: &ImplLookupContext,
41        stable_ptr: Option<SyntaxStablePtrId>,
42    ) -> InferenceResult<ImplId>;
43    fn infer_impl_alias(
44        &mut self,
45        impl_alias_id: ImplAliasId,
46        concrete_trait_id: ConcreteTraitId,
47        lookup_context: &ImplLookupContext,
48        stable_ptr: Option<SyntaxStablePtrId>,
49    ) -> InferenceResult<ImplId>;
50    fn infer_generic_assignment(
51        &mut self,
52        generic_params: &[GenericParam],
53        generic_args: &[GenericArgumentId],
54        expected_generic_args: &[GenericArgumentId],
55        lookup_context: &ImplLookupContext,
56        stable_ptr: Option<SyntaxStablePtrId>,
57    ) -> InferenceResult<Vec<GenericArgumentId>>;
58    fn infer_generic_args(
59        &mut self,
60        generic_params: &[GenericParam],
61        lookup_context: &ImplLookupContext,
62        stable_ptr: Option<SyntaxStablePtrId>,
63    ) -> InferenceResult<Vec<GenericArgumentId>>;
64    fn infer_concrete_trait_by_self(
65        &mut self,
66        trait_function: TraitFunctionId,
67        self_ty: TypeId,
68        lookup_context: &ImplLookupContext,
69        stable_ptr: Option<SyntaxStablePtrId>,
70        inference_error_cb: impl FnOnce(InferenceError),
71    ) -> Option<(ConcreteTraitId, usize)>;
72    fn infer_generic_arg(
73        &mut self,
74        param: &GenericParam,
75        lookup_context: ImplLookupContext,
76        stable_ptr: Option<SyntaxStablePtrId>,
77    ) -> InferenceResult<GenericArgumentId>;
78    fn infer_trait_function(
79        &mut self,
80        concrete_trait_function: ConcreteTraitGenericFunctionId,
81        lookup_context: &ImplLookupContext,
82        stable_ptr: Option<SyntaxStablePtrId>,
83    ) -> InferenceResult<FunctionId>;
84    fn infer_generic_function(
85        &mut self,
86        generic_function: GenericFunctionId,
87        lookup_context: &ImplLookupContext,
88        stable_ptr: Option<SyntaxStablePtrId>,
89    ) -> InferenceResult<FunctionId>;
90    fn infer_trait_generic_function(
91        &mut self,
92        concrete_trait_function: ConcreteTraitGenericFunctionId,
93        lookup_context: &ImplLookupContext,
94        stable_ptr: Option<SyntaxStablePtrId>,
95    ) -> ImplGenericFunctionId;
96    fn infer_trait_type(
97        &mut self,
98        concrete_trait_type: ConcreteTraitTypeId,
99        lookup_context: &ImplLookupContext,
100        stable_ptr: Option<SyntaxStablePtrId>,
101    ) -> TypeId;
102    fn infer_trait_constant(
103        &mut self,
104        concrete_trait_constant: ConcreteTraitConstantId,
105        lookup_context: &ImplLookupContext,
106        stable_ptr: Option<SyntaxStablePtrId>,
107    ) -> ImplConstantId;
108    fn infer_trait_impl(
109        &mut self,
110        concrete_trait_constant: ConcreteTraitImplId,
111        lookup_context: &ImplLookupContext,
112        stable_ptr: Option<SyntaxStablePtrId>,
113    ) -> ImplImplId;
114}
115
116impl InferenceEmbeddings for Inference<'_> {
117    /// Infers all the variables required to make an uninferred impl provide a concrete trait.
118    fn infer_impl(
119        &mut self,
120        uninferred_impl: UninferredImpl,
121        concrete_trait_id: ConcreteTraitId,
122        lookup_context: &ImplLookupContext,
123        stable_ptr: Option<SyntaxStablePtrId>,
124    ) -> InferenceResult<ImplId> {
125        let impl_id = match uninferred_impl {
126            UninferredImpl::Def(impl_def_id) => {
127                self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
128            }
129            UninferredImpl::ImplAlias(impl_alias_id) => {
130                self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
131            }
132            UninferredImpl::ImplImpl(impl_impl_id) => {
133                ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
134            }
135            UninferredImpl::GenericParam(param_id) => {
136                let param = self
137                    .db
138                    .generic_param_semantic(param_id)
139                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
140                let param = extract_matches!(param, GenericParam::Impl);
141                let imp_concrete_trait_id = param.concrete_trait.unwrap();
142                self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
143                ImplLongId::GenericParameter(param_id).intern(self.db)
144            }
145            UninferredImpl::GeneratedImpl(generated_impl) => {
146                let long_id = generated_impl.lookup_intern(self.db);
147
148                // Only making sure the args can be inferred - as they are unused later.
149                self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
150
151                ImplLongId::GeneratedImpl(
152                    GeneratedImplLongId {
153                        concrete_trait: long_id.concrete_trait,
154                        generic_params: long_id.generic_params,
155                        impl_items: long_id.impl_items,
156                    }
157                    .intern(self.db),
158                )
159                .intern(self.db)
160            }
161        };
162        Ok(impl_id)
163    }
164
165    /// Infers all the variables required to make an impl (possibly with free generic params)
166    /// provide a concrete trait.
167    fn infer_impl_def(
168        &mut self,
169        impl_def_id: ImplDefId,
170        concrete_trait_id: ConcreteTraitId,
171        lookup_context: &ImplLookupContext,
172        stable_ptr: Option<SyntaxStablePtrId>,
173    ) -> InferenceResult<ImplId> {
174        let imp_generic_params = self
175            .db
176            .impl_def_generic_params(impl_def_id)
177            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
178        let imp_concrete_trait = self
179            .db
180            .impl_def_concrete_trait(impl_def_id)
181            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
182        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
183            return Err(self.set_error(InferenceError::TraitMismatch {
184                trt0: imp_concrete_trait.trait_id(self.db),
185                trt1: concrete_trait_id.trait_id(self.db),
186            }));
187        }
188
189        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
190        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
191        let generic_args = self.infer_generic_assignment(
192            &imp_generic_params,
193            &long_imp_concrete_trait.generic_args,
194            &long_concrete_trait.generic_args,
195            lookup_context,
196            stable_ptr,
197        )?;
198        Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
199            .intern(self.db))
200    }
201
202    /// Infers all the variables required to make an impl alias (possibly with free generic params)
203    /// provide a concrete trait.
204    fn infer_impl_alias(
205        &mut self,
206        impl_alias_id: ImplAliasId,
207        concrete_trait_id: ConcreteTraitId,
208        lookup_context: &ImplLookupContext,
209        stable_ptr: Option<SyntaxStablePtrId>,
210    ) -> InferenceResult<ImplId> {
211        let impl_alias_generic_params = self
212            .db
213            .impl_alias_generic_params(impl_alias_id)
214            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
215        let impl_id = self
216            .db
217            .impl_alias_resolved_impl(impl_alias_id)
218            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
219        let imp_concrete_trait = impl_id
220            .concrete_trait(self.db)
221            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
222        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
223            return Err(self.set_error(InferenceError::TraitMismatch {
224                trt0: imp_concrete_trait.trait_id(self.db),
225                trt1: concrete_trait_id.trait_id(self.db),
226            }));
227        }
228
229        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
230        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
231        let generic_args = self.infer_generic_assignment(
232            &impl_alias_generic_params,
233            &long_imp_concrete_trait.generic_args,
234            &long_concrete_trait.generic_args,
235            lookup_context,
236            stable_ptr,
237        )?;
238
239        GenericSubstitution::new(&impl_alias_generic_params, &generic_args)
240            .substitute(self.db, impl_id)
241            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
242    }
243
244    /// Chooses and assignment to generic_params s.t. generic_args will be substituted to
245    /// expected_generic_args.
246    /// Returns the generic_params assignment.
247    fn infer_generic_assignment(
248        &mut self,
249        generic_params: &[GenericParam],
250        generic_args: &[GenericArgumentId],
251        expected_generic_args: &[GenericArgumentId],
252        lookup_context: &ImplLookupContext,
253        stable_ptr: Option<SyntaxStablePtrId>,
254    ) -> InferenceResult<Vec<GenericArgumentId>> {
255        let new_generic_args =
256            self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
257        let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
258        let generic_args = substitution
259            .substitute(self.db, generic_args.iter().copied().collect_vec())
260            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
261        self.conform_generic_args(&generic_args, expected_generic_args)?;
262        Ok(self.rewrite(new_generic_args).no_err())
263    }
264
265    /// Infers all generic_arguments given the parameters.
266    fn infer_generic_args(
267        &mut self,
268        generic_params: &[GenericParam],
269        lookup_context: &ImplLookupContext,
270        stable_ptr: Option<SyntaxStablePtrId>,
271    ) -> InferenceResult<Vec<GenericArgumentId>> {
272        let mut generic_args = vec![];
273        let mut substitution = GenericSubstitution::default();
274        for generic_param in generic_params {
275            let generic_param = substitution
276                .substitute(self.db, generic_param.clone())
277                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
278            let generic_arg =
279                self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
280            generic_args.push(generic_arg);
281            substitution.insert(generic_param.id(), generic_arg);
282        }
283        Ok(generic_args)
284    }
285
286    /// Tries to infer a trait function as a method for `self_ty`.
287    /// Supports snapshot snapshot coercions.
288    ///
289    /// Returns the deduced type and the number of snapshots that need to be added to it.
290    ///
291    /// `inference_error_cb` is called for inference errors, but they are not reported here as
292    /// diagnostics. The caller has to make sure the diagnostics are reported appropriately.
293    fn infer_concrete_trait_by_self(
294        &mut self,
295        trait_function: TraitFunctionId,
296        self_ty: TypeId,
297        lookup_context: &ImplLookupContext,
298        stable_ptr: Option<SyntaxStablePtrId>,
299        inference_error_cb: impl FnOnce(InferenceError),
300    ) -> Option<(ConcreteTraitId, usize)> {
301        let trait_id = trait_function.trait_id(self.db.upcast());
302        let signature = self.db.trait_function_signature(trait_function).ok()?;
303        let first_param = signature.params.into_iter().next()?;
304        require(first_param.name == "self")?;
305
306        let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
307        let trait_generic_args =
308            match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
309                Ok(generic_args) => generic_args,
310                Err(err_set) => {
311                    if let Some(err) = self.consume_error_without_reporting(err_set) {
312                        inference_error_cb(err);
313                    }
314                    return None;
315                }
316            };
317
318        // TODO(yuval): Try to not temporary clone.
319        let mut tmp_inference_data = self.temporary_clone();
320        let mut tmp_inference = tmp_inference_data.inference(self.db);
321        let function_generic_params =
322            tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
323        let function_generic_args =
324            // TODO(yuval): consider getting the substitution from inside `infer_generic_args`
325            // instead of creating it again here.
326            match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
327                Ok(generic_args) => generic_args,
328                Err(err_set) => {
329                    if let Some(err) = self.consume_error_without_reporting(err_set) {
330                        inference_error_cb(err);
331                    }
332                    return None;
333                }
334            };
335
336        let trait_substitution =
337            GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
338        let function_substitution =
339            GenericSubstitution::new(&function_generic_params, &function_generic_args);
340        let substitution = trait_substitution.concat(function_substitution);
341
342        let fixed_param_ty = substitution.substitute(self.db, first_param.ty).ok()?;
343        let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
344            Ok(conform) => conform,
345            Err(err_set) => {
346                if let Some(err) = self.consume_error_without_reporting(err_set) {
347                    inference_error_cb(err);
348                }
349                return None;
350            }
351        };
352
353        let generic_args = self.rewrite(trait_generic_args).no_err();
354
355        Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
356    }
357
358    /// Infers a generic argument to be passed as a generic parameter.
359    /// Allocates a new inference variable of the correct kind, and wraps in a generic argument.
360    fn infer_generic_arg(
361        &mut self,
362        param: &GenericParam,
363        lookup_context: ImplLookupContext,
364        stable_ptr: Option<SyntaxStablePtrId>,
365    ) -> InferenceResult<GenericArgumentId> {
366        match param {
367            GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
368            GenericParam::Impl(param) => {
369                let concrete_trait_id = param
370                    .concrete_trait
371                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
372                let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
373                for (trait_ty, ty1) in param.type_constraints.iter() {
374                    let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?;
375                    // Conforming the type will always work as the impl is a new inference variable.
376                    self.conform_ty(ty0, *ty1).ok();
377                }
378                Ok(GenericArgumentId::Impl(impl_id))
379            }
380            GenericParam::Const(GenericParamConst { ty, .. }) => {
381                Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
382            }
383            GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
384        }
385    }
386
387    /// Infers the impl to be substituted instead of a trait for a given trait function,
388    /// and the generic arguments to be passed to the function.
389    /// Returns the resulting impl function.
390    fn infer_trait_function(
391        &mut self,
392        concrete_trait_function: ConcreteTraitGenericFunctionId,
393        lookup_context: &ImplLookupContext,
394        stable_ptr: Option<SyntaxStablePtrId>,
395    ) -> InferenceResult<FunctionId> {
396        let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
397            concrete_trait_function,
398            lookup_context,
399            stable_ptr,
400        ));
401        self.infer_generic_function(generic_function, lookup_context, stable_ptr)
402    }
403
404    /// Infers generic arguments to be passed to a generic function.
405    /// Returns the resulting specialized function.
406    fn infer_generic_function(
407        &mut self,
408        generic_function: GenericFunctionId,
409        lookup_context: &ImplLookupContext,
410        stable_ptr: Option<SyntaxStablePtrId>,
411    ) -> InferenceResult<FunctionId> {
412        let generic_params = generic_function
413            .generic_params(self.db)
414            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
415        let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
416        Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
417            .intern(self.db))
418    }
419
420    /// Infers the impl to be substituted instead of a trait for a given trait function.
421    /// Returns the resulting impl generic function.
422    fn infer_trait_generic_function(
423        &mut self,
424        concrete_trait_function: ConcreteTraitGenericFunctionId,
425        lookup_context: &ImplLookupContext,
426        stable_ptr: Option<SyntaxStablePtrId>,
427    ) -> ImplGenericFunctionId {
428        let impl_id = self.new_impl_var(
429            concrete_trait_function.concrete_trait(self.db),
430            stable_ptr,
431            lookup_context.clone(),
432        );
433        ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
434    }
435
436    /// Infers the impl to be substituted instead of a trait for a given trait type.
437    /// Returns the resulting impl type.
438    fn infer_trait_type(
439        &mut self,
440        concrete_trait_type: ConcreteTraitTypeId,
441        lookup_context: &ImplLookupContext,
442        stable_ptr: Option<SyntaxStablePtrId>,
443    ) -> TypeId {
444        let impl_id = self.new_impl_var(
445            concrete_trait_type.concrete_trait(self.db),
446            stable_ptr,
447            lookup_context.clone(),
448        );
449        TypeLongId::ImplType(ImplTypeId::new(
450            impl_id,
451            concrete_trait_type.trait_type(self.db),
452            self.db,
453        ))
454        .intern(self.db)
455    }
456
457    /// Infers the impl to be substituted instead of a trait for a given trait constant.
458    /// Returns the resulting impl constant.
459    fn infer_trait_constant(
460        &mut self,
461        concrete_trait_constant: ConcreteTraitConstantId,
462        lookup_context: &ImplLookupContext,
463        stable_ptr: Option<SyntaxStablePtrId>,
464    ) -> ImplConstantId {
465        let impl_id = self.new_impl_var(
466            concrete_trait_constant.concrete_trait(self.db),
467            stable_ptr,
468            lookup_context.clone(),
469        );
470
471        ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
472    }
473
474    /// Infers the impl to be substituted instead of a trait for a given trait impl.
475    /// Returns the resulting impl impl.
476    fn infer_trait_impl(
477        &mut self,
478        concrete_trait_impl: ConcreteTraitImplId,
479        lookup_context: &ImplLookupContext,
480        stable_ptr: Option<SyntaxStablePtrId>,
481    ) -> ImplImplId {
482        let impl_id = self.new_impl_var(
483            concrete_trait_impl.concrete_trait(self.db),
484            stable_ptr,
485            lookup_context.clone(),
486        );
487
488        ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
489    }
490}