cairo_lang_semantic/expr/inference/
conform.rs

1use std::hash::Hash;
2
3use cairo_lang_defs::ids::{TraitConstantId, TraitTypeId};
4use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
5use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
6use cairo_lang_utils::{Intern, LookupIntern};
7use itertools::zip_eq;
8
9use super::canonic::{NoError, ResultNoErrEx};
10use super::{
11    ErrorSet, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceError, InferenceResult,
12    InferenceVar, LocalTypeVarId, TypeVar,
13};
14use crate::corelib::never_ty;
15use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
16use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
17use crate::items::imp::{ImplId, ImplImplId, ImplLongId, ImplLookupContext};
18use crate::items::trt::ConcreteTraitImplId;
19use crate::substitution::SemanticRewriter;
20use crate::types::{ClosureTypeLongId, ImplTypeId, peel_snapshots};
21use crate::{
22    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId,
23    FunctionId, FunctionLongId, GenericArgumentId, TypeId, TypeLongId,
24};
25
26/// Functions for conforming semantic objects with each other.
27pub trait InferenceConform {
28    fn conform_ty(&mut self, ty0: TypeId, ty1: TypeId) -> InferenceResult<TypeId>;
29    fn conform_ty_ex(
30        &mut self,
31        ty0: TypeId,
32        ty1: TypeId,
33        ty0_is_self: bool,
34    ) -> InferenceResult<(TypeId, usize)>;
35    fn conform_const(
36        &mut self,
37        ty0: ConstValueId,
38        ty1: ConstValueId,
39    ) -> InferenceResult<ConstValueId>;
40    fn maybe_peel_snapshots(&mut self, ty0_is_self: bool, ty1: TypeId) -> (usize, TypeLongId);
41    fn conform_generic_args(
42        &mut self,
43        gargs0: &[GenericArgumentId],
44        gargs1: &[GenericArgumentId],
45    ) -> InferenceResult<Vec<GenericArgumentId>>;
46    fn conform_generic_arg(
47        &mut self,
48        garg0: GenericArgumentId,
49        garg1: GenericArgumentId,
50    ) -> InferenceResult<GenericArgumentId>;
51    fn conform_impl(&mut self, impl0: ImplId, impl1: ImplId) -> InferenceResult<ImplId>;
52    fn conform_traits(
53        &mut self,
54        trt0: ConcreteTraitId,
55        trt1: ConcreteTraitId,
56    ) -> InferenceResult<ConcreteTraitId>;
57    fn conform_generic_function(
58        &mut self,
59        trt0: GenericFunctionId,
60        trt1: GenericFunctionId,
61    ) -> InferenceResult<GenericFunctionId>;
62    fn ty_contains_var(&mut self, ty: TypeId, var: InferenceVar) -> bool;
63    fn generic_args_contain_var(
64        &mut self,
65        generic_args: &[GenericArgumentId],
66        var: InferenceVar,
67    ) -> bool;
68    fn impl_contains_var(&mut self, impl_id: ImplId, var: InferenceVar) -> bool;
69    fn function_contains_var(&mut self, function_id: FunctionId, var: InferenceVar) -> bool;
70}
71
72impl InferenceConform for Inference<'_> {
73    /// Conforms ty0 to ty1. Should be called when ty0 should be coerced to ty1. Not symmetric.
74    /// Returns the reduced type for ty0, or an error if the type is no coercible.
75    fn conform_ty(&mut self, ty0: TypeId, ty1: TypeId) -> InferenceResult<TypeId> {
76        Ok(self.conform_ty_ex(ty0, ty1, false)?.0)
77    }
78
79    /// Same as conform_ty but supports adding snapshots to ty0 if `ty0_is_self` is true.
80    /// Returns the reduced type for ty0 and the number of snapshots that needs to be added
81    /// for the types to conform.
82    fn conform_ty_ex(
83        &mut self,
84        ty0: TypeId,
85        ty1: TypeId,
86        ty0_is_self: bool,
87    ) -> InferenceResult<(TypeId, usize)> {
88        let ty0 = self.rewrite(ty0).no_err();
89        let ty1 = self.rewrite(ty1).no_err();
90        if ty0 == never_ty(self.db) || ty0.is_missing(self.db) {
91            return Ok((ty1, 0));
92        }
93        if ty0 == ty1 {
94            return Ok((ty0, 0));
95        }
96        let long_ty1 = ty1.lookup_intern(self.db);
97        match long_ty1 {
98            TypeLongId::Var(var) => return Ok((self.assign_ty(var, ty0)?, 0)),
99            TypeLongId::Missing(_) => return Ok((ty1, 0)),
100            TypeLongId::Snapshot(inner_ty) => {
101                if ty0_is_self {
102                    if inner_ty == ty0 {
103                        return Ok((ty1, 1));
104                    }
105                    if !matches!(ty0.lookup_intern(self.db), TypeLongId::Snapshot(_)) {
106                        if let TypeLongId::Var(var) = inner_ty.lookup_intern(self.db) {
107                            return Ok((self.assign_ty(var, ty0)?, 1));
108                        }
109                    }
110                }
111            }
112            TypeLongId::ImplType(impl_type) => {
113                if let Some(ty) = self.impl_type_bounds.get(&impl_type.into()) {
114                    return self.conform_ty_ex(ty0, *ty, ty0_is_self);
115                }
116            }
117            _ => {}
118        }
119        let long_ty0 = ty0.lookup_intern(self.db);
120
121        match long_ty0 {
122            TypeLongId::Concrete(concrete0) => {
123                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
124                let TypeLongId::Concrete(concrete1) = long_ty1 else {
125                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
126                };
127                if concrete0.generic_type(self.db) != concrete1.generic_type(self.db) {
128                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
129                }
130                let gargs0 = concrete0.generic_args(self.db);
131                let gargs1 = concrete1.generic_args(self.db);
132                let gargs = self.conform_generic_args(&gargs0, &gargs1)?;
133                let long_ty = TypeLongId::Concrete(ConcreteTypeId::new(
134                    self.db,
135                    concrete0.generic_type(self.db),
136                    gargs,
137                ));
138                Ok((long_ty.intern(self.db), n_snapshots))
139            }
140            TypeLongId::Tuple(tys0) => {
141                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
142                let TypeLongId::Tuple(tys1) = long_ty1 else {
143                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
144                };
145                if tys0.len() != tys1.len() {
146                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
147                }
148                let tys = zip_eq(tys0, tys1)
149                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
150                    .collect::<Result<Vec<_>, _>>()?;
151                Ok((TypeLongId::Tuple(tys).intern(self.db), n_snapshots))
152            }
153            TypeLongId::Closure(closure0) => {
154                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
155                let TypeLongId::Closure(closure1) = long_ty1 else {
156                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
157                };
158                if closure0.wrapper_location != closure1.wrapper_location {
159                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
160                }
161                let param_tys = zip_eq(closure0.param_tys, closure1.param_tys)
162                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
163                    .collect::<Result<Vec<_>, _>>()?;
164                let captured_types = zip_eq(closure0.captured_types, closure1.captured_types)
165                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
166                    .collect::<Result<Vec<_>, _>>()?;
167                let ret_ty = self.conform_ty(closure0.ret_ty, closure1.ret_ty)?;
168                Ok((
169                    TypeLongId::Closure(ClosureTypeLongId {
170                        param_tys,
171                        ret_ty,
172                        captured_types,
173                        wrapper_location: closure0.wrapper_location,
174                        parent_function: closure0.parent_function,
175                    })
176                    .intern(self.db),
177                    n_snapshots,
178                ))
179            }
180            TypeLongId::FixedSizeArray { type_id, size } => {
181                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
182                let TypeLongId::FixedSizeArray { type_id: type_id1, size: size1 } = long_ty1 else {
183                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
184                };
185                let size = self.conform_const(size, size1)?;
186                let ty = self.conform_ty(type_id, type_id1)?;
187                Ok((TypeLongId::FixedSizeArray { type_id: ty, size }.intern(self.db), n_snapshots))
188            }
189            TypeLongId::Snapshot(inner_ty0) => {
190                let TypeLongId::Snapshot(inner_ty1) = long_ty1 else {
191                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
192                };
193                let (ty, n_snapshots) = self.conform_ty_ex(inner_ty0, inner_ty1, ty0_is_self)?;
194                Ok((TypeLongId::Snapshot(ty).intern(self.db), n_snapshots))
195            }
196            TypeLongId::GenericParameter(_) => {
197                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
198            }
199            TypeLongId::TraitType(_) => {
200                // This should never happen as the trait type should be implized when conformed, but
201                // don't panic in case of a bug.
202                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
203            }
204            TypeLongId::Var(var) => Ok((self.assign_ty(var, ty1)?, 0)),
205            TypeLongId::ImplType(impl_type) => {
206                if let Some(ty) = self.impl_type_bounds.get(&impl_type.into()) {
207                    return self.conform_ty_ex(*ty, ty1, ty0_is_self);
208                }
209                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
210            }
211            TypeLongId::Missing(_) => Ok((ty0, 0)),
212            TypeLongId::Coupon(function_id0) => {
213                let TypeLongId::Coupon(function_id1) = long_ty1 else {
214                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
215                };
216
217                let func0 = function_id0.lookup_intern(self.db).function;
218                let func1 = function_id1.lookup_intern(self.db).function;
219
220                let generic_function =
221                    self.conform_generic_function(func0.generic_function, func1.generic_function)?;
222
223                if func0.generic_args.len() != func1.generic_args.len() {
224                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
225                }
226
227                let generic_args =
228                    self.conform_generic_args(&func0.generic_args, &func1.generic_args)?;
229
230                Ok((
231                    TypeLongId::Coupon(
232                        FunctionLongId {
233                            function: ConcreteFunction { generic_function, generic_args },
234                        }
235                        .intern(self.db),
236                    )
237                    .intern(self.db),
238                    0,
239                ))
240            }
241        }
242    }
243
244    /// Conforms id0 to id1. Should be called when id0 should be coerced to id1. Not symmetric.
245    /// Returns the reduced const for id0, or an error if the const is no coercible.
246    fn conform_const(
247        &mut self,
248        id0: ConstValueId,
249        id1: ConstValueId,
250    ) -> InferenceResult<ConstValueId> {
251        let id0 = self.rewrite(id0).no_err();
252        let id1 = self.rewrite(id1).no_err();
253        self.conform_ty(id0.ty(self.db).unwrap(), id1.ty(self.db).unwrap())?;
254        if id0 == id1 {
255            return Ok(id0);
256        }
257        let const_value0 = id0.lookup_intern(self.db);
258        if matches!(const_value0, ConstValue::Missing(_)) {
259            return Ok(id1);
260        }
261        match id1.lookup_intern(self.db) {
262            ConstValue::Missing(_) => return Ok(id1),
263            ConstValue::Var(var, _) => return self.assign_const(var, id0),
264            _ => {}
265        }
266        match const_value0 {
267            ConstValue::Var(var, _) => Ok(self.assign_const(var, id1)?),
268            ConstValue::ImplConstant(_) => {
269                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
270            }
271            _ => {
272                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
273            }
274        }
275    }
276
277    // Conditionally peels snapshots.
278    fn maybe_peel_snapshots(&mut self, ty0_is_self: bool, ty1: TypeId) -> (usize, TypeLongId) {
279        let (n_snapshots, long_ty1) = if ty0_is_self {
280            peel_snapshots(self.db, ty1)
281        } else {
282            (0, ty1.lookup_intern(self.db))
283        };
284        (n_snapshots, long_ty1)
285    }
286
287    /// Conforms generics args. See `conform_ty()`.
288    fn conform_generic_args(
289        &mut self,
290        gargs0: &[GenericArgumentId],
291        gargs1: &[GenericArgumentId],
292    ) -> InferenceResult<Vec<GenericArgumentId>> {
293        zip_eq(gargs0, gargs1)
294            .map(|(garg0, garg1)| self.conform_generic_arg(*garg0, *garg1))
295            .collect::<Result<Vec<_>, _>>()
296    }
297
298    /// Conforms a generics arg. See `conform_ty()`.
299    fn conform_generic_arg(
300        &mut self,
301        garg0: GenericArgumentId,
302        garg1: GenericArgumentId,
303    ) -> InferenceResult<GenericArgumentId> {
304        if garg0 == garg1 {
305            return Ok(garg0);
306        }
307        match garg0 {
308            GenericArgumentId::Type(gty0) => {
309                let GenericArgumentId::Type(gty1) = garg1 else {
310                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
311                };
312                Ok(GenericArgumentId::Type(self.conform_ty(gty0, gty1)?))
313            }
314            GenericArgumentId::Constant(gc0) => {
315                let GenericArgumentId::Constant(gc1) = garg1 else {
316                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
317                };
318
319                Ok(GenericArgumentId::Constant(self.conform_const(gc0, gc1)?))
320            }
321            GenericArgumentId::Impl(impl0) => {
322                let GenericArgumentId::Impl(impl1) = garg1 else {
323                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
324                };
325                Ok(GenericArgumentId::Impl(self.conform_impl(impl0, impl1)?))
326            }
327            GenericArgumentId::NegImpl => match garg1 {
328                GenericArgumentId::NegImpl => Ok(GenericArgumentId::NegImpl),
329                GenericArgumentId::Constant(_)
330                | GenericArgumentId::Type(_)
331                | GenericArgumentId::Impl(_) => {
332                    Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }))
333                }
334            },
335        }
336    }
337
338    /// Conforms an impl. See `conform_ty()`.
339    fn conform_impl(&mut self, impl0: ImplId, impl1: ImplId) -> InferenceResult<ImplId> {
340        let impl0 = self.rewrite(impl0).no_err();
341        let impl1 = self.rewrite(impl1).no_err();
342        let long_impl1 = impl1.lookup_intern(self.db);
343        if impl0 == impl1 {
344            return Ok(impl0);
345        }
346        if let ImplLongId::ImplVar(var) = long_impl1 {
347            let impl_concrete_trait = self
348                .db
349                .impl_concrete_trait(impl0)
350                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
351            self.conform_traits(var.lookup_intern(self.db).concrete_trait_id, impl_concrete_trait)?;
352            let impl_id = self.rewrite(impl0).no_err();
353            return self.assign_impl(var, impl_id);
354        }
355        match impl0.lookup_intern(self.db) {
356            ImplLongId::ImplVar(var) => {
357                let impl_concrete_trait = self
358                    .db
359                    .impl_concrete_trait(impl1)
360                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
361                self.conform_traits(
362                    var.lookup_intern(self.db).concrete_trait_id,
363                    impl_concrete_trait,
364                )?;
365                let impl_id = self.rewrite(impl1).no_err();
366                self.assign_impl(var, impl_id)
367            }
368            ImplLongId::Concrete(concrete0) => {
369                let ImplLongId::Concrete(concrete1) = long_impl1 else {
370                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
371                };
372                let concrete0 = concrete0.lookup_intern(self.db);
373                let concrete1 = concrete1.lookup_intern(self.db);
374                if concrete0.impl_def_id != concrete1.impl_def_id {
375                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
376                }
377                let gargs0 = concrete0.generic_args;
378                let gargs1 = concrete1.generic_args;
379                let generic_args = self.conform_generic_args(&gargs0, &gargs1)?;
380                Ok(ImplLongId::Concrete(
381                    ConcreteImplLongId { impl_def_id: concrete0.impl_def_id, generic_args }
382                        .intern(self.db),
383                )
384                .intern(self.db))
385            }
386            ImplLongId::GenericParameter(_)
387            | ImplLongId::ImplImpl(_)
388            | ImplLongId::TraitImpl(_)
389            | ImplLongId::GeneratedImpl(_) => {
390                Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }))
391            }
392        }
393    }
394
395    /// Conforms generics traits. See `conform_ty()`.
396    fn conform_traits(
397        &mut self,
398        trt0: ConcreteTraitId,
399        trt1: ConcreteTraitId,
400    ) -> InferenceResult<ConcreteTraitId> {
401        let trt0 = trt0.lookup_intern(self.db);
402        let trt1 = trt1.lookup_intern(self.db);
403        if trt0.trait_id != trt1.trait_id {
404            return Err(self.set_error(InferenceError::TraitMismatch {
405                trt0: trt0.trait_id,
406                trt1: trt1.trait_id,
407            }));
408        }
409        let generic_args = self.conform_generic_args(&trt0.generic_args, &trt1.generic_args)?;
410        Ok(ConcreteTraitLongId { trait_id: trt0.trait_id, generic_args }.intern(self.db))
411    }
412
413    fn conform_generic_function(
414        &mut self,
415        func0: GenericFunctionId,
416        func1: GenericFunctionId,
417    ) -> InferenceResult<GenericFunctionId> {
418        if let (GenericFunctionId::Impl(id0), GenericFunctionId::Impl(id1)) = (func0, func1) {
419            if id0.function != id1.function {
420                return Err(
421                    self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 })
422                );
423            }
424            let function = id0.function;
425            let impl_id = self.conform_impl(id0.impl_id, id1.impl_id)?;
426            return Ok(GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }));
427        }
428
429        if func0 != func1 {
430            return Err(self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 }));
431        }
432        Ok(func0)
433    }
434
435    /// Checks if a type tree contains a certain [InferenceVar] somewhere. Used to avoid inference
436    /// cycles.
437    fn ty_contains_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
438        let ty = self.rewrite(ty).no_err();
439        self.internal_ty_contains_var(ty, var)
440    }
441
442    /// Checks if a slice of generics arguments contain a certain [InferenceVar] somewhere. Used to
443    /// avoid inference cycles.
444    fn generic_args_contain_var(
445        &mut self,
446        generic_args: &[GenericArgumentId],
447        var: InferenceVar,
448    ) -> bool {
449        for garg in generic_args {
450            if match *garg {
451                GenericArgumentId::Type(ty) => self.internal_ty_contains_var(ty, var),
452                GenericArgumentId::Constant(_) => false,
453                GenericArgumentId::Impl(impl_id) => self.impl_contains_var(impl_id, var),
454                GenericArgumentId::NegImpl => false,
455            } {
456                return true;
457            }
458        }
459        false
460    }
461
462    /// Checks if an impl contains a certain [InferenceVar] somewhere. Used to avoid inference
463    /// cycles.
464    fn impl_contains_var(&mut self, impl_id: ImplId, var: InferenceVar) -> bool {
465        match impl_id.lookup_intern(self.db) {
466            ImplLongId::Concrete(concrete_impl_id) => self.generic_args_contain_var(
467                &concrete_impl_id.lookup_intern(self.db).generic_args,
468                var,
469            ),
470            ImplLongId::GenericParameter(_) | ImplLongId::TraitImpl(_) => false,
471            ImplLongId::ImplVar(new_var) => {
472                let new_var_long_id = new_var.lookup_intern(self.db);
473                let new_var_local_id = new_var_long_id.id;
474                if InferenceVar::Impl(new_var_local_id) == var {
475                    return true;
476                }
477                if let Some(impl_id) = self.impl_assignment(new_var_local_id) {
478                    return self.impl_contains_var(impl_id, var);
479                }
480                self.generic_args_contain_var(
481                    &new_var_long_id.concrete_trait_id.generic_args(self.db),
482                    var,
483                )
484            }
485            ImplLongId::ImplImpl(impl_impl) => self.impl_contains_var(impl_impl.impl_id(), var),
486            ImplLongId::GeneratedImpl(generated_impl) => self.generic_args_contain_var(
487                &generated_impl.concrete_trait(self.db).generic_args(self.db),
488                var,
489            ),
490        }
491    }
492
493    /// Checks if a function contains a certain [InferenceVar] in its generic arguments or in the
494    /// generic arguments of the impl containing the function (in case the function is an impl
495    /// function).
496    ///
497    /// Used to avoid inference cycles.
498    fn function_contains_var(&mut self, function_id: FunctionId, var: InferenceVar) -> bool {
499        let function = function_id.get_concrete(self.db);
500        let generic_args = function.generic_args;
501        // Look in the generic arguments of the function and in the impl generic arguments.
502        self.generic_args_contain_var(&generic_args, var)
503            || matches!(function.generic_function,
504                GenericFunctionId::Impl(impl_generic_function_id)
505                if self.impl_contains_var(impl_generic_function_id.impl_id, var)
506            )
507    }
508}
509
510impl Inference<'_> {
511    /// Reduces an impl type to a concrete type.
512    pub fn reduce_impl_ty(&mut self, impl_type_id: ImplTypeId) -> InferenceResult<TypeId> {
513        let impl_id = impl_type_id.impl_id();
514        let trait_ty = impl_type_id.ty();
515        if let ImplLongId::ImplVar(var) = impl_id.lookup_intern(self.db) {
516            Ok(self.rewritten_impl_type(var, trait_ty))
517        } else if let Ok(ty) =
518            self.db.impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
519        {
520            Ok(ty)
521        } else {
522            Err(self.set_impl_reduction_error(impl_id))
523        }
524    }
525
526    /// Reduces an impl constant to a concrete const.
527    pub fn reduce_impl_constant(
528        &mut self,
529        impl_const_id: ImplConstantId,
530    ) -> InferenceResult<ConstValueId> {
531        let impl_id = impl_const_id.impl_id();
532        let trait_constant = impl_const_id.trait_constant_id();
533        if let ImplLongId::ImplVar(var) = impl_id.lookup_intern(self.db) {
534            Ok(self.rewritten_impl_constant(var, trait_constant))
535        } else if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
536            ImplConstantId::new(impl_id, trait_constant, self.db),
537        ) {
538            Ok(constant)
539        } else {
540            Err(self.set_impl_reduction_error(impl_id))
541        }
542    }
543
544    /// Reduces an impl impl to a concrete impl.
545    pub fn reduce_impl_impl(&mut self, impl_impl_id: ImplImplId) -> InferenceResult<ImplId> {
546        let impl_id = impl_impl_id.impl_id();
547        let concrete_trait_impl = impl_impl_id
548            .concrete_trait_impl_id(self.db)
549            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
550
551        if let ImplLongId::ImplVar(var) = impl_id.lookup_intern(self.db) {
552            Ok(self.rewritten_impl_impl(var, concrete_trait_impl))
553        } else if let Ok(imp) = self.db.impl_impl_concrete_implized(ImplImplId::new(
554            impl_id,
555            impl_impl_id.trait_impl_id(),
556            self.db,
557        )) {
558            Ok(imp)
559        } else {
560            Err(self.set_impl_reduction_error(impl_id))
561        }
562    }
563
564    /// Returns the type of an impl var's type item.
565    /// The type may be a variable itself, but it may previously exist, so may be more specific due
566    /// to rewriting.
567    pub fn rewritten_impl_type(&mut self, id: ImplVarId, trait_type_id: TraitTypeId) -> TypeId {
568        self.rewritten_impl_item(
569            id,
570            trait_type_id,
571            |m| &mut m.types,
572            |inference, stable_ptr| inference.new_type_var(stable_ptr),
573        )
574    }
575
576    /// Returns the constant value of an impl var's constant item.
577    /// The constant may be a variable itself, but it may previously exist, so may be more specific
578    /// due to rewriting.
579    pub fn rewritten_impl_constant(
580        &mut self,
581        id: ImplVarId,
582        trait_constant: TraitConstantId,
583    ) -> ConstValueId {
584        self.rewritten_impl_item(
585            id,
586            trait_constant,
587            |m| &mut m.constants,
588            |inference, stable_ptr| {
589                inference.new_const_var(
590                    stable_ptr,
591                    inference.db.trait_constant_type(trait_constant).unwrap(),
592                )
593            },
594        )
595    }
596
597    /// Returns the inner_impl value of an impl var's impl item.
598    /// The inner_impl may be a variable itself, but it may previously exist, so may be more
599    /// specific due to rewriting.
600    pub fn rewritten_impl_impl(
601        &mut self,
602        id: ImplVarId,
603        concrete_trait_impl: ConcreteTraitImplId,
604    ) -> ImplId {
605        self.rewritten_impl_item(
606            id,
607            concrete_trait_impl.trait_impl(self.db),
608            |m| &mut m.impls,
609            |inference, stable_ptr| {
610                inference.new_impl_var(
611                    inference.db.concrete_trait_impl_concrete_trait(concrete_trait_impl).unwrap(),
612                    stable_ptr,
613                    ImplLookupContext::default(),
614                )
615            },
616        )
617    }
618
619    /// Helper function for getting an impl vars item ids.
620    /// These ids are likely to be variables, but may have more specific information due to
621    /// rewriting.
622    fn rewritten_impl_item<K: Hash + PartialEq + Eq, V: Copy>(
623        &mut self,
624        id: ImplVarId,
625        key: K,
626        get_map: impl Fn(&mut ImplVarTraitItemMappings) -> &mut OrderedHashMap<K, V>,
627        new_var: impl FnOnce(&mut Self, Option<SyntaxStablePtrId>) -> V,
628    ) -> V
629    where
630        Self: SemanticRewriter<V, NoError>,
631    {
632        let var_id = id.id(self.db);
633        if let Some(value) = self
634            .data
635            .impl_vars_trait_item_mappings
636            .get_mut(&var_id)
637            .and_then(|mappings| get_map(mappings).get(&key))
638        {
639            // Copy the value to allow usage of `self`.
640            let value = *value;
641            // If the value already exists, rewrite it before returning.
642            self.rewrite(value).no_err()
643        } else {
644            let value =
645                new_var(self, self.data.stable_ptrs.get(&InferenceVar::Impl(var_id)).cloned());
646            get_map(self.data.impl_vars_trait_item_mappings.entry(var_id).or_default())
647                .insert(key, value);
648            value
649        }
650    }
651
652    /// Sets an error for an impl reduction failure.
653    fn set_impl_reduction_error(&mut self, impl_id: ImplId) -> ErrorSet {
654        self.set_error(
655            impl_id
656                .concrete_trait(self.db)
657                .map(InferenceError::NoImplsFound)
658                .unwrap_or_else(InferenceError::Reported),
659        )
660    }
661
662    /// Conforms a type to a type. Returning the reduced types on failure.
663    /// Useful for immediately reporting a diagnostic based on the compared types.
664    pub fn conform_ty_for_diag(
665        &mut self,
666        ty0: TypeId,
667        ty1: TypeId,
668    ) -> Result<(), (ErrorSet, TypeId, TypeId)> {
669        match self.conform_ty(ty0, ty1) {
670            Ok(_ty) => Ok(()),
671            Err(err) => Err((err, self.rewrite(ty0).no_err(), self.rewrite(ty1).no_err())),
672        }
673    }
674
675    /// helper function for ty_contains_var
676    /// Assumes ty was already rewritten.
677    #[doc(hidden)]
678    fn internal_ty_contains_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
679        match ty.lookup_intern(self.db) {
680            TypeLongId::Concrete(concrete) => {
681                let generic_args = concrete.generic_args(self.db);
682                self.generic_args_contain_var(&generic_args, var)
683            }
684            TypeLongId::Tuple(tys) => {
685                tys.into_iter().any(|ty| self.internal_ty_contains_var(ty, var))
686            }
687            TypeLongId::Snapshot(ty) => self.internal_ty_contains_var(ty, var),
688            TypeLongId::Var(new_var) => {
689                if InferenceVar::Type(new_var.id) == var {
690                    return true;
691                }
692                if let Some(ty) = self.type_assignment.get(&new_var.id) {
693                    return self.internal_ty_contains_var(*ty, var);
694                }
695                false
696            }
697            TypeLongId::ImplType(id) => self.impl_contains_var(id.impl_id(), var),
698            TypeLongId::TraitType(_) | TypeLongId::GenericParameter(_) | TypeLongId::Missing(_) => {
699                false
700            }
701            TypeLongId::Coupon(function_id) => self.function_contains_var(function_id, var),
702            TypeLongId::FixedSizeArray { type_id, .. } => {
703                self.internal_ty_contains_var(type_id, var)
704            }
705            TypeLongId::Closure(closure) => {
706                closure.param_tys.into_iter().any(|ty| self.internal_ty_contains_var(ty, var))
707                    || self.internal_ty_contains_var(closure.ret_ty, var)
708            }
709        }
710    }
711
712    /// Creates a var for each constrained impl_type and conforms the types.
713    pub fn conform_generic_params_type_constraints(&mut self, constraints: &Vec<(TypeId, TypeId)>) {
714        let mut impl_type_bounds = Default::default();
715        for (ty0, ty1) in constraints {
716            let ty0 = if let TypeLongId::ImplType(impl_type) = ty0.lookup_intern(self.db) {
717                self.impl_type_assignment(impl_type, &mut impl_type_bounds)
718            } else {
719                *ty0
720            };
721            let ty1 = if let TypeLongId::ImplType(impl_type) = ty1.lookup_intern(self.db) {
722                self.impl_type_assignment(impl_type, &mut impl_type_bounds)
723            } else {
724                *ty1
725            };
726            self.conform_ty(ty0, ty1).ok();
727        }
728        self.set_impl_type_bounds(impl_type_bounds);
729    }
730
731    /// An helper function for getting for an impl type assignment.
732    /// Creates a new type var if the impl type is not yet assigned.
733    fn impl_type_assignment(
734        &mut self,
735        impl_type: ImplTypeId,
736        impl_type_bounds: &mut OrderedHashMap<ImplTypeId, TypeId>,
737    ) -> TypeId {
738        match impl_type_bounds.entry(impl_type) {
739            Entry::Occupied(entry) => *entry.get(),
740            Entry::Vacant(entry) => {
741                let inference_id = self.data.inference_id;
742                let id = LocalTypeVarId(self.data.type_vars.len());
743                let var = TypeVar { inference_id, id };
744                let ty = TypeLongId::Var(var).intern(self.db);
745                entry.insert(ty);
746                self.type_vars.push(var);
747                ty
748            }
749        }
750    }
751}