cranelift_codegen_meta/cdsl/
typevar.rs

1use std::cell::RefCell;
2use std::collections::BTreeSet;
3use std::fmt;
4use std::hash;
5use std::ops;
6use std::rc::Rc;
7
8use crate::cdsl::types::{LaneType, ValueType};
9
10const MAX_LANES: u16 = 256;
11const MAX_BITS: u16 = 128;
12const MAX_FLOAT_BITS: u16 = 128;
13
14/// Type variables can be used in place of concrete types when defining
15/// instructions. This makes the instructions *polymorphic*.
16///
17/// A type variable is restricted to vary over a subset of the value types.
18/// This subset is specified by a set of flags that control the permitted base
19/// types and whether the type variable can assume scalar or vector types, or
20/// both.
21#[derive(Debug)]
22pub(crate) struct TypeVarContent {
23    /// Short name of type variable used in instruction descriptions.
24    pub name: String,
25
26    /// Documentation string.
27    pub doc: String,
28
29    /// Type set associated to the type variable.
30    /// This field must remain private; use `get_typeset()` or `get_raw_typeset()` to get the
31    /// information you want.
32    type_set: TypeSet,
33
34    pub base: Option<TypeVarParent>,
35}
36
37#[derive(Clone, Debug)]
38pub(crate) struct TypeVar {
39    content: Rc<RefCell<TypeVarContent>>,
40}
41
42impl TypeVar {
43    pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self {
44        Self {
45            content: Rc::new(RefCell::new(TypeVarContent {
46                name: name.into(),
47                doc: doc.into(),
48                type_set,
49                base: None,
50            })),
51        }
52    }
53
54    pub fn new_singleton(value_type: ValueType) -> Self {
55        let (name, doc) = (value_type.to_string(), value_type.doc());
56        let mut builder = TypeSetBuilder::new();
57
58        let (scalar_type, num_lanes) = match value_type {
59            ValueType::Lane(lane_type) => (lane_type, 1),
60            ValueType::Vector(vec_type) => {
61                (vec_type.lane_type(), vec_type.lane_count() as RangeBound)
62            }
63            ValueType::DynamicVector(vec_type) => (
64                vec_type.lane_type(),
65                vec_type.minimum_lane_count() as RangeBound,
66            ),
67        };
68
69        builder = builder.simd_lanes(num_lanes..num_lanes);
70
71        // Only generate dynamic types for multiple lanes.
72        if num_lanes > 1 {
73            builder = builder.dynamic_simd_lanes(num_lanes..num_lanes);
74        }
75
76        let builder = match scalar_type {
77            LaneType::Int(int_type) => {
78                let bits = int_type as RangeBound;
79                builder.ints(bits..bits)
80            }
81            LaneType::Float(float_type) => {
82                let bits = float_type as RangeBound;
83                builder.floats(bits..bits)
84            }
85        };
86        TypeVar::new(name, doc, builder.build())
87    }
88
89    /// Get a fresh copy of self, named after `name`. Can only be called on non-derived typevars.
90    pub fn copy_from(other: &TypeVar, name: String) -> TypeVar {
91        assert!(
92            other.base.is_none(),
93            "copy_from() can only be called on non-derived type variables"
94        );
95        TypeVar {
96            content: Rc::new(RefCell::new(TypeVarContent {
97                name,
98                doc: "".into(),
99                type_set: other.type_set.clone(),
100                base: None,
101            })),
102        }
103    }
104
105    /// Returns the typeset for this TV. If the TV is derived, computes it recursively from the
106    /// derived function and the base's typeset.
107    /// Note this can't be done non-lazily in the constructor, because the TypeSet of the base may
108    /// change over time.
109    pub fn get_typeset(&self) -> TypeSet {
110        match &self.base {
111            Some(base) => base.type_var.get_typeset().image(base.derived_func),
112            None => self.type_set.clone(),
113        }
114    }
115
116    /// Returns this typevar's type set, assuming this type var has no parent.
117    pub fn get_raw_typeset(&self) -> &TypeSet {
118        assert_eq!(self.type_set, self.get_typeset());
119        &self.type_set
120    }
121
122    /// If the associated typeset has a single type return it. Otherwise return None.
123    pub fn singleton_type(&self) -> Option<ValueType> {
124        let type_set = self.get_typeset();
125        if type_set.size() == 1 {
126            Some(type_set.get_singleton())
127        } else {
128            None
129        }
130    }
131
132    /// Get the free type variable controlling this one.
133    pub fn free_typevar(&self) -> Option<TypeVar> {
134        match &self.base {
135            Some(base) => base.type_var.free_typevar(),
136            None => {
137                match self.singleton_type() {
138                    // A singleton type isn't a proper free variable.
139                    Some(_) => None,
140                    None => Some(self.clone()),
141                }
142            }
143        }
144    }
145
146    /// Create a type variable that is a function of another.
147    pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar {
148        let ts = self.get_typeset();
149
150        // Safety checks to avoid over/underflows.
151        match derived_func {
152            DerivedFunc::HalfWidth => {
153                assert!(
154                    ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
155                    "can't halve all integer types"
156                );
157                assert!(
158                    ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
159                    "can't halve all float types"
160                );
161            }
162            DerivedFunc::DoubleWidth => {
163                assert!(
164                    ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
165                    "can't double all integer types"
166                );
167                assert!(
168                    ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
169                    "can't double all float types"
170                );
171            }
172            DerivedFunc::SplitLanes => {
173                assert!(
174                    ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
175                    "can't halve all integer types"
176                );
177                assert!(
178                    ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
179                    "can't halve all float types"
180                );
181                assert!(
182                    *ts.lanes.iter().max().unwrap() < MAX_LANES,
183                    "can't double 256 lanes"
184                );
185            }
186            DerivedFunc::MergeLanes => {
187                assert!(
188                    ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
189                    "can't double all integer types"
190                );
191                assert!(
192                    ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
193                    "can't double all float types"
194                );
195                assert!(
196                    *ts.lanes.iter().min().unwrap() > 1,
197                    "can't halve a scalar type"
198                );
199            }
200            DerivedFunc::Narrower => {
201                assert_eq!(
202                    *ts.lanes.iter().max().unwrap(),
203                    1,
204                    "The `narrower` constraint does not apply to vectors"
205                );
206                assert!(
207                    (!ts.ints.is_empty() || !ts.floats.is_empty()) && ts.dynamic_lanes.is_empty(),
208                    "The `narrower` constraint only applies to scalar ints or floats"
209                );
210            }
211            DerivedFunc::Wider => {
212                assert_eq!(
213                    *ts.lanes.iter().max().unwrap(),
214                    1,
215                    "The `wider` constraint does not apply to vectors"
216                );
217                assert!(
218                    (!ts.ints.is_empty() || !ts.floats.is_empty()) && ts.dynamic_lanes.is_empty(),
219                    "The `wider` constraint only applies to scalar ints or floats"
220                );
221            }
222            DerivedFunc::LaneOf | DerivedFunc::AsTruthy | DerivedFunc::DynamicToVector => {
223                /* no particular assertions */
224            }
225        }
226
227        TypeVar {
228            content: Rc::new(RefCell::new(TypeVarContent {
229                name: format!("{}({})", derived_func.name(), self.name),
230                doc: "".into(),
231                type_set: ts,
232                base: Some(TypeVarParent {
233                    type_var: self.clone(),
234                    derived_func,
235                }),
236            })),
237        }
238    }
239
240    pub fn lane_of(&self) -> TypeVar {
241        self.derived(DerivedFunc::LaneOf)
242    }
243    pub fn as_truthy(&self) -> TypeVar {
244        self.derived(DerivedFunc::AsTruthy)
245    }
246    pub fn half_width(&self) -> TypeVar {
247        self.derived(DerivedFunc::HalfWidth)
248    }
249    pub fn double_width(&self) -> TypeVar {
250        self.derived(DerivedFunc::DoubleWidth)
251    }
252    pub fn split_lanes(&self) -> TypeVar {
253        self.derived(DerivedFunc::SplitLanes)
254    }
255    pub fn merge_lanes(&self) -> TypeVar {
256        self.derived(DerivedFunc::MergeLanes)
257    }
258    pub fn dynamic_to_vector(&self) -> TypeVar {
259        self.derived(DerivedFunc::DynamicToVector)
260    }
261
262    /// Make a new [TypeVar] that includes all types narrower than self.
263    pub fn narrower(&self) -> TypeVar {
264        self.derived(DerivedFunc::Narrower)
265    }
266
267    /// Make a new [TypeVar] that includes all types wider than self.
268    pub fn wider(&self) -> TypeVar {
269        self.derived(DerivedFunc::Wider)
270    }
271}
272
273impl From<&TypeVar> for TypeVar {
274    fn from(type_var: &TypeVar) -> Self {
275        type_var.clone()
276    }
277}
278impl From<ValueType> for TypeVar {
279    fn from(value_type: ValueType) -> Self {
280        TypeVar::new_singleton(value_type)
281    }
282}
283
284// Hash TypeVars by pointers.
285// There might be a better way to do this, but since TypeVar's content (namely TypeSet) can be
286// mutated, it makes sense to use pointer equality/hashing here.
287impl hash::Hash for TypeVar {
288    fn hash<H: hash::Hasher>(&self, h: &mut H) {
289        match &self.base {
290            Some(base) => {
291                base.type_var.hash(h);
292                base.derived_func.hash(h);
293            }
294            None => {
295                (&**self as *const TypeVarContent).hash(h);
296            }
297        }
298    }
299}
300
301impl PartialEq for TypeVar {
302    fn eq(&self, other: &TypeVar) -> bool {
303        match (&self.base, &other.base) {
304            (Some(base1), Some(base2)) => {
305                base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func
306            }
307            (None, None) => Rc::ptr_eq(&self.content, &other.content),
308            _ => false,
309        }
310    }
311}
312
313// Allow TypeVar as map keys, based on pointer equality (see also above PartialEq impl).
314impl Eq for TypeVar {}
315
316impl ops::Deref for TypeVar {
317    type Target = TypeVarContent;
318    fn deref(&self) -> &Self::Target {
319        unsafe { self.content.as_ptr().as_ref().unwrap() }
320    }
321}
322
323#[derive(Clone, Copy, Debug, Hash, PartialEq)]
324pub(crate) enum DerivedFunc {
325    LaneOf,
326    AsTruthy,
327    HalfWidth,
328    DoubleWidth,
329    SplitLanes,
330    MergeLanes,
331    DynamicToVector,
332    Narrower,
333    Wider,
334}
335
336impl DerivedFunc {
337    pub fn name(self) -> &'static str {
338        match self {
339            DerivedFunc::LaneOf => "lane_of",
340            DerivedFunc::AsTruthy => "as_truthy",
341            DerivedFunc::HalfWidth => "half_width",
342            DerivedFunc::DoubleWidth => "double_width",
343            DerivedFunc::SplitLanes => "split_lanes",
344            DerivedFunc::MergeLanes => "merge_lanes",
345            DerivedFunc::DynamicToVector => "dynamic_to_vector",
346            DerivedFunc::Narrower => "narrower",
347            DerivedFunc::Wider => "wider",
348        }
349    }
350}
351
352#[derive(Debug, Hash)]
353pub(crate) struct TypeVarParent {
354    pub type_var: TypeVar,
355    pub derived_func: DerivedFunc,
356}
357
358/// A set of types.
359///
360/// We don't allow arbitrary subsets of types, but use a parametrized approach
361/// instead.
362///
363/// Objects of this class can be used as dictionary keys.
364///
365/// Parametrized type sets are specified in terms of ranges:
366/// - The permitted range of vector lanes, where 1 indicates a scalar type.
367/// - The permitted range of integer types.
368/// - The permitted range of floating point types, and
369/// - The permitted range of boolean types.
370///
371/// The ranges are inclusive from smallest bit-width to largest bit-width.
372
373type RangeBound = u16;
374type Range = ops::Range<RangeBound>;
375type NumSet = BTreeSet<RangeBound>;
376
377macro_rules! num_set {
378    ($($expr:expr),*) => {
379        NumSet::from_iter(vec![$($expr),*])
380    };
381}
382
383#[derive(Clone, PartialEq, Eq, Hash)]
384pub(crate) struct TypeSet {
385    pub lanes: NumSet,
386    pub dynamic_lanes: NumSet,
387    pub ints: NumSet,
388    pub floats: NumSet,
389}
390
391impl TypeSet {
392    fn new(lanes: NumSet, dynamic_lanes: NumSet, ints: NumSet, floats: NumSet) -> Self {
393        Self {
394            lanes,
395            dynamic_lanes,
396            ints,
397            floats,
398        }
399    }
400
401    /// Return the number of concrete types represented by this typeset.
402    pub fn size(&self) -> usize {
403        self.lanes.len() * (self.ints.len() + self.floats.len())
404            + self.dynamic_lanes.len() * (self.ints.len() + self.floats.len())
405    }
406
407    /// Return the image of self across the derived function func.
408    fn image(&self, derived_func: DerivedFunc) -> TypeSet {
409        match derived_func {
410            DerivedFunc::LaneOf => self.lane_of(),
411            DerivedFunc::AsTruthy => self.as_truthy(),
412            DerivedFunc::HalfWidth => self.half_width(),
413            DerivedFunc::DoubleWidth => self.double_width(),
414            DerivedFunc::SplitLanes => self.half_width().double_vector(),
415            DerivedFunc::MergeLanes => self.double_width().half_vector(),
416            DerivedFunc::DynamicToVector => self.dynamic_to_vector(),
417            DerivedFunc::Narrower => self.clone(),
418            DerivedFunc::Wider => self.clone(),
419        }
420    }
421
422    /// Return a TypeSet describing the image of self across lane_of.
423    fn lane_of(&self) -> TypeSet {
424        let mut copy = self.clone();
425        copy.lanes = num_set![1];
426        copy
427    }
428
429    /// Return a TypeSet describing the image of self across as_truthy.
430    fn as_truthy(&self) -> TypeSet {
431        let mut copy = self.clone();
432
433        // If this type set represents a scalar, `as_truthy` produces an I8, otherwise it returns a
434        // vector of the same number of lanes, whose elements are integers of the same width. For
435        // example, F32X4 gets turned into I32X4, while I32 gets turned into I8.
436        if self.lanes.len() == 1 && self.lanes.contains(&1) {
437            copy.ints = NumSet::from([8]);
438        } else {
439            copy.ints.extend(&self.floats)
440        }
441
442        copy.floats = NumSet::new();
443        copy
444    }
445
446    /// Return a TypeSet describing the image of self across halfwidth.
447    fn half_width(&self) -> TypeSet {
448        let mut copy = self.clone();
449        copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2));
450        copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 16).map(|&x| x / 2));
451        copy
452    }
453
454    /// Return a TypeSet describing the image of self across doublewidth.
455    fn double_width(&self) -> TypeSet {
456        let mut copy = self.clone();
457        copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2));
458        copy.floats = NumSet::from_iter(
459            self.floats
460                .iter()
461                .filter(|&&x| x < MAX_FLOAT_BITS)
462                .map(|&x| x * 2),
463        );
464        copy
465    }
466
467    /// Return a TypeSet describing the image of self across halfvector.
468    fn half_vector(&self) -> TypeSet {
469        let mut copy = self.clone();
470        copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2));
471        copy
472    }
473
474    /// Return a TypeSet describing the image of self across doublevector.
475    fn double_vector(&self) -> TypeSet {
476        let mut copy = self.clone();
477        copy.lanes = NumSet::from_iter(
478            self.lanes
479                .iter()
480                .filter(|&&x| x < MAX_LANES)
481                .map(|&x| x * 2),
482        );
483        copy
484    }
485
486    fn dynamic_to_vector(&self) -> TypeSet {
487        let mut copy = self.clone();
488        copy.lanes = NumSet::from_iter(
489            self.dynamic_lanes
490                .iter()
491                .filter(|&&x| x < MAX_LANES)
492                .copied(),
493        );
494        copy.dynamic_lanes = NumSet::new();
495        copy
496    }
497
498    fn concrete_types(&self) -> Vec<ValueType> {
499        let mut ret = Vec::new();
500        for &num_lanes in &self.lanes {
501            for &bits in &self.ints {
502                ret.push(LaneType::int_from_bits(bits).by(num_lanes));
503            }
504            for &bits in &self.floats {
505                ret.push(LaneType::float_from_bits(bits).by(num_lanes));
506            }
507        }
508        for &num_lanes in &self.dynamic_lanes {
509            for &bits in &self.ints {
510                ret.push(LaneType::int_from_bits(bits).to_dynamic(num_lanes));
511            }
512            for &bits in &self.floats {
513                ret.push(LaneType::float_from_bits(bits).to_dynamic(num_lanes));
514            }
515        }
516        ret
517    }
518
519    /// Return the singleton type represented by self. Can only call on typesets containing 1 type.
520    fn get_singleton(&self) -> ValueType {
521        let mut types = self.concrete_types();
522        assert_eq!(types.len(), 1);
523        types.remove(0)
524    }
525}
526
527impl fmt::Debug for TypeSet {
528    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
529        write!(fmt, "TypeSet(")?;
530
531        let mut subsets = Vec::new();
532        if !self.lanes.is_empty() {
533            subsets.push(format!(
534                "lanes={{{}}}",
535                Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ")
536            ));
537        }
538        if !self.dynamic_lanes.is_empty() {
539            subsets.push(format!(
540                "dynamic_lanes={{{}}}",
541                Vec::from_iter(self.dynamic_lanes.iter().map(|x| x.to_string())).join(", ")
542            ));
543        }
544        if !self.ints.is_empty() {
545            subsets.push(format!(
546                "ints={{{}}}",
547                Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ")
548            ));
549        }
550        if !self.floats.is_empty() {
551            subsets.push(format!(
552                "floats={{{}}}",
553                Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ")
554            ));
555        }
556
557        write!(fmt, "{})", subsets.join(", "))?;
558        Ok(())
559    }
560}
561
562pub(crate) struct TypeSetBuilder {
563    ints: Interval,
564    floats: Interval,
565    includes_scalars: bool,
566    simd_lanes: Interval,
567    dynamic_simd_lanes: Interval,
568}
569
570impl TypeSetBuilder {
571    pub fn new() -> Self {
572        Self {
573            ints: Interval::None,
574            floats: Interval::None,
575            includes_scalars: true,
576            simd_lanes: Interval::None,
577            dynamic_simd_lanes: Interval::None,
578        }
579    }
580
581    pub fn ints(mut self, interval: impl Into<Interval>) -> Self {
582        assert!(self.ints == Interval::None);
583        self.ints = interval.into();
584        self
585    }
586    pub fn floats(mut self, interval: impl Into<Interval>) -> Self {
587        assert!(self.floats == Interval::None);
588        self.floats = interval.into();
589        self
590    }
591    pub fn includes_scalars(mut self, includes_scalars: bool) -> Self {
592        self.includes_scalars = includes_scalars;
593        self
594    }
595    pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
596        assert!(self.simd_lanes == Interval::None);
597        self.simd_lanes = interval.into();
598        self
599    }
600    pub fn dynamic_simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
601        assert!(self.dynamic_simd_lanes == Interval::None);
602        self.dynamic_simd_lanes = interval.into();
603        self
604    }
605
606    pub fn build(self) -> TypeSet {
607        let min_lanes = if self.includes_scalars { 1 } else { 2 };
608
609        TypeSet::new(
610            range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))),
611            range_to_set(self.dynamic_simd_lanes.to_range(2..MAX_LANES, None)),
612            range_to_set(self.ints.to_range(8..MAX_BITS, None)),
613            range_to_set(self.floats.to_range(16..MAX_FLOAT_BITS, None)),
614        )
615    }
616}
617
618#[derive(PartialEq)]
619pub(crate) enum Interval {
620    None,
621    All,
622    Range(Range),
623}
624
625impl Interval {
626    fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> {
627        match self {
628            Interval::None => default.map(|default_val| default_val..default_val),
629
630            Interval::All => Some(full_range),
631
632            Interval::Range(range) => {
633                let (low, high) = (range.start, range.end);
634                assert!(low.is_power_of_two());
635                assert!(high.is_power_of_two());
636                assert!(low <= high);
637                assert!(low >= full_range.start);
638                assert!(high <= full_range.end);
639                Some(low..high)
640            }
641        }
642    }
643}
644
645impl From<Range> for Interval {
646    fn from(range: Range) -> Self {
647        Interval::Range(range)
648    }
649}
650
651/// Generates a set with all the powers of two included in the range.
652fn range_to_set(range: Option<Range>) -> NumSet {
653    let mut set = NumSet::new();
654
655    let (low, high) = match range {
656        Some(range) => (range.start, range.end),
657        None => return set,
658    };
659
660    assert!(low.is_power_of_two());
661    assert!(high.is_power_of_two());
662    assert!(low <= high);
663
664    for i in low.trailing_zeros()..=high.trailing_zeros() {
665        assert!(1 << i <= RangeBound::max_value());
666        set.insert(1 << i);
667    }
668    set
669}
670
671#[test]
672fn test_typevar_builder() {
673    let type_set = TypeSetBuilder::new().ints(Interval::All).build();
674    assert_eq!(type_set.lanes, num_set![1]);
675    assert!(type_set.floats.is_empty());
676    assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
677
678    let type_set = TypeSetBuilder::new().floats(Interval::All).build();
679    assert_eq!(type_set.lanes, num_set![1]);
680    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
681    assert!(type_set.ints.is_empty());
682
683    let type_set = TypeSetBuilder::new()
684        .floats(Interval::All)
685        .simd_lanes(Interval::All)
686        .includes_scalars(false)
687        .build();
688    assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
689    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
690    assert!(type_set.ints.is_empty());
691
692    let type_set = TypeSetBuilder::new()
693        .floats(Interval::All)
694        .simd_lanes(Interval::All)
695        .includes_scalars(true)
696        .build();
697    assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]);
698    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
699    assert!(type_set.ints.is_empty());
700
701    let type_set = TypeSetBuilder::new()
702        .floats(Interval::All)
703        .simd_lanes(Interval::All)
704        .includes_scalars(false)
705        .build();
706    assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
707    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
708    assert!(type_set.dynamic_lanes.is_empty());
709    assert!(type_set.ints.is_empty());
710
711    let type_set = TypeSetBuilder::new()
712        .ints(Interval::All)
713        .floats(Interval::All)
714        .dynamic_simd_lanes(Interval::All)
715        .includes_scalars(false)
716        .build();
717    assert_eq!(
718        type_set.dynamic_lanes,
719        num_set![2, 4, 8, 16, 32, 64, 128, 256]
720    );
721    assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
722    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
723    assert_eq!(type_set.lanes, num_set![1]);
724
725    let type_set = TypeSetBuilder::new()
726        .floats(Interval::All)
727        .dynamic_simd_lanes(Interval::All)
728        .includes_scalars(false)
729        .build();
730    assert_eq!(
731        type_set.dynamic_lanes,
732        num_set![2, 4, 8, 16, 32, 64, 128, 256]
733    );
734    assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
735    assert_eq!(type_set.lanes, num_set![1]);
736    assert!(type_set.ints.is_empty());
737
738    let type_set = TypeSetBuilder::new().ints(16..64).build();
739    assert_eq!(type_set.lanes, num_set![1]);
740    assert_eq!(type_set.ints, num_set![16, 32, 64]);
741    assert!(type_set.floats.is_empty());
742}
743
744#[test]
745fn test_dynamic_to_vector() {
746    // We don't generate single lane dynamic types, so the maximum number of
747    // lanes we support is 128, as MAX_BITS is 256.
748    assert_eq!(
749        TypeSetBuilder::new()
750            .dynamic_simd_lanes(Interval::All)
751            .ints(Interval::All)
752            .build()
753            .dynamic_to_vector(),
754        TypeSetBuilder::new()
755            .simd_lanes(2..128)
756            .ints(Interval::All)
757            .build()
758    );
759    assert_eq!(
760        TypeSetBuilder::new()
761            .dynamic_simd_lanes(Interval::All)
762            .floats(Interval::All)
763            .build()
764            .dynamic_to_vector(),
765        TypeSetBuilder::new()
766            .simd_lanes(2..128)
767            .floats(Interval::All)
768            .build()
769    );
770}
771
772#[test]
773#[should_panic]
774fn test_typevar_builder_too_high_bound_panic() {
775    TypeSetBuilder::new().ints(16..2 * MAX_BITS).build();
776}
777
778#[test]
779#[should_panic]
780fn test_typevar_builder_inverted_bounds_panic() {
781    TypeSetBuilder::new().ints(32..16).build();
782}
783
784#[test]
785fn test_as_truthy() {
786    let a = TypeSetBuilder::new()
787        .simd_lanes(2..8)
788        .ints(8..8)
789        .floats(32..32)
790        .build();
791    assert_eq!(
792        a.lane_of(),
793        TypeSetBuilder::new().ints(8..8).floats(32..32).build()
794    );
795
796    let mut a_as_truthy = TypeSetBuilder::new().simd_lanes(2..8).build();
797    a_as_truthy.ints = num_set![8, 32];
798    assert_eq!(a.as_truthy(), a_as_truthy);
799
800    let a = TypeSetBuilder::new().ints(8..32).floats(32..64).build();
801    let a_as_truthy = TypeSetBuilder::new().ints(8..8).build();
802    assert_eq!(a.as_truthy(), a_as_truthy);
803}
804
805#[test]
806fn test_forward_images() {
807    let empty_set = TypeSetBuilder::new().build();
808
809    // Half vector.
810    assert_eq!(
811        TypeSetBuilder::new()
812            .simd_lanes(1..32)
813            .build()
814            .half_vector(),
815        TypeSetBuilder::new().simd_lanes(1..16).build()
816    );
817
818    // Double vector.
819    assert_eq!(
820        TypeSetBuilder::new()
821            .simd_lanes(1..32)
822            .build()
823            .double_vector(),
824        TypeSetBuilder::new().simd_lanes(2..64).build()
825    );
826    assert_eq!(
827        TypeSetBuilder::new()
828            .simd_lanes(128..256)
829            .build()
830            .double_vector(),
831        TypeSetBuilder::new().simd_lanes(256..256).build()
832    );
833
834    // Half width.
835    assert_eq!(
836        TypeSetBuilder::new().ints(8..32).build().half_width(),
837        TypeSetBuilder::new().ints(8..16).build()
838    );
839    assert_eq!(
840        TypeSetBuilder::new().floats(16..16).build().half_width(),
841        empty_set
842    );
843    assert_eq!(
844        TypeSetBuilder::new().floats(32..128).build().half_width(),
845        TypeSetBuilder::new().floats(16..64).build()
846    );
847
848    // Double width.
849    assert_eq!(
850        TypeSetBuilder::new().ints(8..32).build().double_width(),
851        TypeSetBuilder::new().ints(16..64).build()
852    );
853    assert_eq!(
854        TypeSetBuilder::new().ints(32..64).build().double_width(),
855        TypeSetBuilder::new().ints(64..128).build()
856    );
857    assert_eq!(
858        TypeSetBuilder::new().floats(32..32).build().double_width(),
859        TypeSetBuilder::new().floats(64..64).build()
860    );
861    assert_eq!(
862        TypeSetBuilder::new().floats(16..64).build().double_width(),
863        TypeSetBuilder::new().floats(32..128).build()
864    );
865}
866
867#[test]
868#[should_panic]
869fn test_typeset_singleton_panic_nonsingleton_types() {
870    TypeSetBuilder::new()
871        .ints(8..8)
872        .floats(32..32)
873        .build()
874        .get_singleton();
875}
876
877#[test]
878#[should_panic]
879fn test_typeset_singleton_panic_nonsingleton_lanes() {
880    TypeSetBuilder::new()
881        .simd_lanes(1..2)
882        .floats(32..32)
883        .build()
884        .get_singleton();
885}
886
887#[test]
888fn test_typeset_singleton() {
889    use crate::shared::types as shared_types;
890    assert_eq!(
891        TypeSetBuilder::new().ints(16..16).build().get_singleton(),
892        ValueType::Lane(shared_types::Int::I16.into())
893    );
894    assert_eq!(
895        TypeSetBuilder::new().floats(64..64).build().get_singleton(),
896        ValueType::Lane(shared_types::Float::F64.into())
897    );
898    assert_eq!(
899        TypeSetBuilder::new()
900            .simd_lanes(4..4)
901            .ints(32..32)
902            .build()
903            .get_singleton(),
904        LaneType::from(shared_types::Int::I32).by(4)
905    );
906}
907
908#[test]
909fn test_typevar_functions() {
910    let x = TypeVar::new(
911        "x",
912        "i16 and up",
913        TypeSetBuilder::new().ints(16..64).build(),
914    );
915    assert_eq!(x.half_width().name, "half_width(x)");
916    assert_eq!(
917        x.half_width().double_width().name,
918        "double_width(half_width(x))"
919    );
920
921    let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build());
922    assert_eq!(x.double_width().name, "double_width(x)");
923}
924
925#[test]
926fn test_typevar_singleton() {
927    use crate::cdsl::types::VectorType;
928    use crate::shared::types as shared_types;
929
930    // Test i32.
931    let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32)));
932    assert_eq!(typevar.name, "i32");
933    assert_eq!(typevar.type_set.ints, num_set![32]);
934    assert!(typevar.type_set.floats.is_empty());
935    assert_eq!(typevar.type_set.lanes, num_set![1]);
936
937    // Test f32x4.
938    let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new(
939        LaneType::Float(shared_types::Float::F32),
940        4,
941    )));
942    assert_eq!(typevar.name, "f32x4");
943    assert!(typevar.type_set.ints.is_empty());
944    assert_eq!(typevar.type_set.floats, num_set![32]);
945    assert_eq!(typevar.type_set.lanes, num_set![4]);
946}