1use std::cell::RefCell;
2use std::collections::BTreeSet;
3use std::fmt;
4use std::hash;
5use std::ops;
6use std::rc::Rc;
7
8use crate::cdsl::types::{LaneType, ValueType};
9
10const MAX_LANES: u16 = 256;
11const MAX_BITS: u16 = 128;
12const MAX_FLOAT_BITS: u16 = 128;
13
14#[derive(Debug)]
22pub(crate) struct TypeVarContent {
23 pub name: String,
25
26 pub doc: String,
28
29 type_set: TypeSet,
33
34 pub base: Option<TypeVarParent>,
35}
36
37#[derive(Clone, Debug)]
38pub(crate) struct TypeVar {
39 content: Rc<RefCell<TypeVarContent>>,
40}
41
42impl TypeVar {
43 pub fn new(name: impl Into<String>, doc: impl Into<String>, type_set: TypeSet) -> Self {
44 Self {
45 content: Rc::new(RefCell::new(TypeVarContent {
46 name: name.into(),
47 doc: doc.into(),
48 type_set,
49 base: None,
50 })),
51 }
52 }
53
54 pub fn new_singleton(value_type: ValueType) -> Self {
55 let (name, doc) = (value_type.to_string(), value_type.doc());
56 let mut builder = TypeSetBuilder::new();
57
58 let (scalar_type, num_lanes) = match value_type {
59 ValueType::Lane(lane_type) => (lane_type, 1),
60 ValueType::Vector(vec_type) => {
61 (vec_type.lane_type(), vec_type.lane_count() as RangeBound)
62 }
63 ValueType::DynamicVector(vec_type) => (
64 vec_type.lane_type(),
65 vec_type.minimum_lane_count() as RangeBound,
66 ),
67 };
68
69 builder = builder.simd_lanes(num_lanes..num_lanes);
70
71 if num_lanes > 1 {
73 builder = builder.dynamic_simd_lanes(num_lanes..num_lanes);
74 }
75
76 let builder = match scalar_type {
77 LaneType::Int(int_type) => {
78 let bits = int_type as RangeBound;
79 builder.ints(bits..bits)
80 }
81 LaneType::Float(float_type) => {
82 let bits = float_type as RangeBound;
83 builder.floats(bits..bits)
84 }
85 };
86 TypeVar::new(name, doc, builder.build())
87 }
88
89 pub fn copy_from(other: &TypeVar, name: String) -> TypeVar {
91 assert!(
92 other.base.is_none(),
93 "copy_from() can only be called on non-derived type variables"
94 );
95 TypeVar {
96 content: Rc::new(RefCell::new(TypeVarContent {
97 name,
98 doc: "".into(),
99 type_set: other.type_set.clone(),
100 base: None,
101 })),
102 }
103 }
104
105 pub fn get_typeset(&self) -> TypeSet {
110 match &self.base {
111 Some(base) => base.type_var.get_typeset().image(base.derived_func),
112 None => self.type_set.clone(),
113 }
114 }
115
116 pub fn get_raw_typeset(&self) -> &TypeSet {
118 assert_eq!(self.type_set, self.get_typeset());
119 &self.type_set
120 }
121
122 pub fn singleton_type(&self) -> Option<ValueType> {
124 let type_set = self.get_typeset();
125 if type_set.size() == 1 {
126 Some(type_set.get_singleton())
127 } else {
128 None
129 }
130 }
131
132 pub fn free_typevar(&self) -> Option<TypeVar> {
134 match &self.base {
135 Some(base) => base.type_var.free_typevar(),
136 None => {
137 match self.singleton_type() {
138 Some(_) => None,
140 None => Some(self.clone()),
141 }
142 }
143 }
144 }
145
146 pub fn derived(&self, derived_func: DerivedFunc) -> TypeVar {
148 let ts = self.get_typeset();
149
150 match derived_func {
152 DerivedFunc::HalfWidth => {
153 assert!(
154 ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
155 "can't halve all integer types"
156 );
157 assert!(
158 ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
159 "can't halve all float types"
160 );
161 }
162 DerivedFunc::DoubleWidth => {
163 assert!(
164 ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
165 "can't double all integer types"
166 );
167 assert!(
168 ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
169 "can't double all float types"
170 );
171 }
172 DerivedFunc::SplitLanes => {
173 assert!(
174 ts.ints.is_empty() || *ts.ints.iter().min().unwrap() > 8,
175 "can't halve all integer types"
176 );
177 assert!(
178 ts.floats.is_empty() || *ts.floats.iter().min().unwrap() > 16,
179 "can't halve all float types"
180 );
181 assert!(
182 *ts.lanes.iter().max().unwrap() < MAX_LANES,
183 "can't double 256 lanes"
184 );
185 }
186 DerivedFunc::MergeLanes => {
187 assert!(
188 ts.ints.is_empty() || *ts.ints.iter().max().unwrap() < MAX_BITS,
189 "can't double all integer types"
190 );
191 assert!(
192 ts.floats.is_empty() || *ts.floats.iter().max().unwrap() < MAX_FLOAT_BITS,
193 "can't double all float types"
194 );
195 assert!(
196 *ts.lanes.iter().min().unwrap() > 1,
197 "can't halve a scalar type"
198 );
199 }
200 DerivedFunc::Narrower => {
201 assert_eq!(
202 *ts.lanes.iter().max().unwrap(),
203 1,
204 "The `narrower` constraint does not apply to vectors"
205 );
206 assert!(
207 (!ts.ints.is_empty() || !ts.floats.is_empty()) && ts.dynamic_lanes.is_empty(),
208 "The `narrower` constraint only applies to scalar ints or floats"
209 );
210 }
211 DerivedFunc::Wider => {
212 assert_eq!(
213 *ts.lanes.iter().max().unwrap(),
214 1,
215 "The `wider` constraint does not apply to vectors"
216 );
217 assert!(
218 (!ts.ints.is_empty() || !ts.floats.is_empty()) && ts.dynamic_lanes.is_empty(),
219 "The `wider` constraint only applies to scalar ints or floats"
220 );
221 }
222 DerivedFunc::LaneOf | DerivedFunc::AsTruthy | DerivedFunc::DynamicToVector => {
223 }
225 }
226
227 TypeVar {
228 content: Rc::new(RefCell::new(TypeVarContent {
229 name: format!("{}({})", derived_func.name(), self.name),
230 doc: "".into(),
231 type_set: ts,
232 base: Some(TypeVarParent {
233 type_var: self.clone(),
234 derived_func,
235 }),
236 })),
237 }
238 }
239
240 pub fn lane_of(&self) -> TypeVar {
241 self.derived(DerivedFunc::LaneOf)
242 }
243 pub fn as_truthy(&self) -> TypeVar {
244 self.derived(DerivedFunc::AsTruthy)
245 }
246 pub fn half_width(&self) -> TypeVar {
247 self.derived(DerivedFunc::HalfWidth)
248 }
249 pub fn double_width(&self) -> TypeVar {
250 self.derived(DerivedFunc::DoubleWidth)
251 }
252 pub fn split_lanes(&self) -> TypeVar {
253 self.derived(DerivedFunc::SplitLanes)
254 }
255 pub fn merge_lanes(&self) -> TypeVar {
256 self.derived(DerivedFunc::MergeLanes)
257 }
258 pub fn dynamic_to_vector(&self) -> TypeVar {
259 self.derived(DerivedFunc::DynamicToVector)
260 }
261
262 pub fn narrower(&self) -> TypeVar {
264 self.derived(DerivedFunc::Narrower)
265 }
266
267 pub fn wider(&self) -> TypeVar {
269 self.derived(DerivedFunc::Wider)
270 }
271}
272
273impl From<&TypeVar> for TypeVar {
274 fn from(type_var: &TypeVar) -> Self {
275 type_var.clone()
276 }
277}
278impl From<ValueType> for TypeVar {
279 fn from(value_type: ValueType) -> Self {
280 TypeVar::new_singleton(value_type)
281 }
282}
283
284impl hash::Hash for TypeVar {
288 fn hash<H: hash::Hasher>(&self, h: &mut H) {
289 match &self.base {
290 Some(base) => {
291 base.type_var.hash(h);
292 base.derived_func.hash(h);
293 }
294 None => {
295 (&**self as *const TypeVarContent).hash(h);
296 }
297 }
298 }
299}
300
301impl PartialEq for TypeVar {
302 fn eq(&self, other: &TypeVar) -> bool {
303 match (&self.base, &other.base) {
304 (Some(base1), Some(base2)) => {
305 base1.type_var.eq(&base2.type_var) && base1.derived_func == base2.derived_func
306 }
307 (None, None) => Rc::ptr_eq(&self.content, &other.content),
308 _ => false,
309 }
310 }
311}
312
313impl Eq for TypeVar {}
315
316impl ops::Deref for TypeVar {
317 type Target = TypeVarContent;
318 fn deref(&self) -> &Self::Target {
319 unsafe { self.content.as_ptr().as_ref().unwrap() }
320 }
321}
322
323#[derive(Clone, Copy, Debug, Hash, PartialEq)]
324pub(crate) enum DerivedFunc {
325 LaneOf,
326 AsTruthy,
327 HalfWidth,
328 DoubleWidth,
329 SplitLanes,
330 MergeLanes,
331 DynamicToVector,
332 Narrower,
333 Wider,
334}
335
336impl DerivedFunc {
337 pub fn name(self) -> &'static str {
338 match self {
339 DerivedFunc::LaneOf => "lane_of",
340 DerivedFunc::AsTruthy => "as_truthy",
341 DerivedFunc::HalfWidth => "half_width",
342 DerivedFunc::DoubleWidth => "double_width",
343 DerivedFunc::SplitLanes => "split_lanes",
344 DerivedFunc::MergeLanes => "merge_lanes",
345 DerivedFunc::DynamicToVector => "dynamic_to_vector",
346 DerivedFunc::Narrower => "narrower",
347 DerivedFunc::Wider => "wider",
348 }
349 }
350}
351
352#[derive(Debug, Hash)]
353pub(crate) struct TypeVarParent {
354 pub type_var: TypeVar,
355 pub derived_func: DerivedFunc,
356}
357
358type RangeBound = u16;
374type Range = ops::Range<RangeBound>;
375type NumSet = BTreeSet<RangeBound>;
376
377macro_rules! num_set {
378 ($($expr:expr),*) => {
379 NumSet::from_iter(vec![$($expr),*])
380 };
381}
382
383#[derive(Clone, PartialEq, Eq, Hash)]
384pub(crate) struct TypeSet {
385 pub lanes: NumSet,
386 pub dynamic_lanes: NumSet,
387 pub ints: NumSet,
388 pub floats: NumSet,
389}
390
391impl TypeSet {
392 fn new(lanes: NumSet, dynamic_lanes: NumSet, ints: NumSet, floats: NumSet) -> Self {
393 Self {
394 lanes,
395 dynamic_lanes,
396 ints,
397 floats,
398 }
399 }
400
401 pub fn size(&self) -> usize {
403 self.lanes.len() * (self.ints.len() + self.floats.len())
404 + self.dynamic_lanes.len() * (self.ints.len() + self.floats.len())
405 }
406
407 fn image(&self, derived_func: DerivedFunc) -> TypeSet {
409 match derived_func {
410 DerivedFunc::LaneOf => self.lane_of(),
411 DerivedFunc::AsTruthy => self.as_truthy(),
412 DerivedFunc::HalfWidth => self.half_width(),
413 DerivedFunc::DoubleWidth => self.double_width(),
414 DerivedFunc::SplitLanes => self.half_width().double_vector(),
415 DerivedFunc::MergeLanes => self.double_width().half_vector(),
416 DerivedFunc::DynamicToVector => self.dynamic_to_vector(),
417 DerivedFunc::Narrower => self.clone(),
418 DerivedFunc::Wider => self.clone(),
419 }
420 }
421
422 fn lane_of(&self) -> TypeSet {
424 let mut copy = self.clone();
425 copy.lanes = num_set![1];
426 copy
427 }
428
429 fn as_truthy(&self) -> TypeSet {
431 let mut copy = self.clone();
432
433 if self.lanes.len() == 1 && self.lanes.contains(&1) {
437 copy.ints = NumSet::from([8]);
438 } else {
439 copy.ints.extend(&self.floats)
440 }
441
442 copy.floats = NumSet::new();
443 copy
444 }
445
446 fn half_width(&self) -> TypeSet {
448 let mut copy = self.clone();
449 copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x > 8).map(|&x| x / 2));
450 copy.floats = NumSet::from_iter(self.floats.iter().filter(|&&x| x > 16).map(|&x| x / 2));
451 copy
452 }
453
454 fn double_width(&self) -> TypeSet {
456 let mut copy = self.clone();
457 copy.ints = NumSet::from_iter(self.ints.iter().filter(|&&x| x < MAX_BITS).map(|&x| x * 2));
458 copy.floats = NumSet::from_iter(
459 self.floats
460 .iter()
461 .filter(|&&x| x < MAX_FLOAT_BITS)
462 .map(|&x| x * 2),
463 );
464 copy
465 }
466
467 fn half_vector(&self) -> TypeSet {
469 let mut copy = self.clone();
470 copy.lanes = NumSet::from_iter(self.lanes.iter().filter(|&&x| x > 1).map(|&x| x / 2));
471 copy
472 }
473
474 fn double_vector(&self) -> TypeSet {
476 let mut copy = self.clone();
477 copy.lanes = NumSet::from_iter(
478 self.lanes
479 .iter()
480 .filter(|&&x| x < MAX_LANES)
481 .map(|&x| x * 2),
482 );
483 copy
484 }
485
486 fn dynamic_to_vector(&self) -> TypeSet {
487 let mut copy = self.clone();
488 copy.lanes = NumSet::from_iter(
489 self.dynamic_lanes
490 .iter()
491 .filter(|&&x| x < MAX_LANES)
492 .copied(),
493 );
494 copy.dynamic_lanes = NumSet::new();
495 copy
496 }
497
498 fn concrete_types(&self) -> Vec<ValueType> {
499 let mut ret = Vec::new();
500 for &num_lanes in &self.lanes {
501 for &bits in &self.ints {
502 ret.push(LaneType::int_from_bits(bits).by(num_lanes));
503 }
504 for &bits in &self.floats {
505 ret.push(LaneType::float_from_bits(bits).by(num_lanes));
506 }
507 }
508 for &num_lanes in &self.dynamic_lanes {
509 for &bits in &self.ints {
510 ret.push(LaneType::int_from_bits(bits).to_dynamic(num_lanes));
511 }
512 for &bits in &self.floats {
513 ret.push(LaneType::float_from_bits(bits).to_dynamic(num_lanes));
514 }
515 }
516 ret
517 }
518
519 fn get_singleton(&self) -> ValueType {
521 let mut types = self.concrete_types();
522 assert_eq!(types.len(), 1);
523 types.remove(0)
524 }
525}
526
527impl fmt::Debug for TypeSet {
528 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
529 write!(fmt, "TypeSet(")?;
530
531 let mut subsets = Vec::new();
532 if !self.lanes.is_empty() {
533 subsets.push(format!(
534 "lanes={{{}}}",
535 Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ")
536 ));
537 }
538 if !self.dynamic_lanes.is_empty() {
539 subsets.push(format!(
540 "dynamic_lanes={{{}}}",
541 Vec::from_iter(self.dynamic_lanes.iter().map(|x| x.to_string())).join(", ")
542 ));
543 }
544 if !self.ints.is_empty() {
545 subsets.push(format!(
546 "ints={{{}}}",
547 Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ")
548 ));
549 }
550 if !self.floats.is_empty() {
551 subsets.push(format!(
552 "floats={{{}}}",
553 Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ")
554 ));
555 }
556
557 write!(fmt, "{})", subsets.join(", "))?;
558 Ok(())
559 }
560}
561
562pub(crate) struct TypeSetBuilder {
563 ints: Interval,
564 floats: Interval,
565 includes_scalars: bool,
566 simd_lanes: Interval,
567 dynamic_simd_lanes: Interval,
568}
569
570impl TypeSetBuilder {
571 pub fn new() -> Self {
572 Self {
573 ints: Interval::None,
574 floats: Interval::None,
575 includes_scalars: true,
576 simd_lanes: Interval::None,
577 dynamic_simd_lanes: Interval::None,
578 }
579 }
580
581 pub fn ints(mut self, interval: impl Into<Interval>) -> Self {
582 assert!(self.ints == Interval::None);
583 self.ints = interval.into();
584 self
585 }
586 pub fn floats(mut self, interval: impl Into<Interval>) -> Self {
587 assert!(self.floats == Interval::None);
588 self.floats = interval.into();
589 self
590 }
591 pub fn includes_scalars(mut self, includes_scalars: bool) -> Self {
592 self.includes_scalars = includes_scalars;
593 self
594 }
595 pub fn simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
596 assert!(self.simd_lanes == Interval::None);
597 self.simd_lanes = interval.into();
598 self
599 }
600 pub fn dynamic_simd_lanes(mut self, interval: impl Into<Interval>) -> Self {
601 assert!(self.dynamic_simd_lanes == Interval::None);
602 self.dynamic_simd_lanes = interval.into();
603 self
604 }
605
606 pub fn build(self) -> TypeSet {
607 let min_lanes = if self.includes_scalars { 1 } else { 2 };
608
609 TypeSet::new(
610 range_to_set(self.simd_lanes.to_range(min_lanes..MAX_LANES, Some(1))),
611 range_to_set(self.dynamic_simd_lanes.to_range(2..MAX_LANES, None)),
612 range_to_set(self.ints.to_range(8..MAX_BITS, None)),
613 range_to_set(self.floats.to_range(16..MAX_FLOAT_BITS, None)),
614 )
615 }
616}
617
618#[derive(PartialEq)]
619pub(crate) enum Interval {
620 None,
621 All,
622 Range(Range),
623}
624
625impl Interval {
626 fn to_range(&self, full_range: Range, default: Option<RangeBound>) -> Option<Range> {
627 match self {
628 Interval::None => default.map(|default_val| default_val..default_val),
629
630 Interval::All => Some(full_range),
631
632 Interval::Range(range) => {
633 let (low, high) = (range.start, range.end);
634 assert!(low.is_power_of_two());
635 assert!(high.is_power_of_two());
636 assert!(low <= high);
637 assert!(low >= full_range.start);
638 assert!(high <= full_range.end);
639 Some(low..high)
640 }
641 }
642 }
643}
644
645impl From<Range> for Interval {
646 fn from(range: Range) -> Self {
647 Interval::Range(range)
648 }
649}
650
651fn range_to_set(range: Option<Range>) -> NumSet {
653 let mut set = NumSet::new();
654
655 let (low, high) = match range {
656 Some(range) => (range.start, range.end),
657 None => return set,
658 };
659
660 assert!(low.is_power_of_two());
661 assert!(high.is_power_of_two());
662 assert!(low <= high);
663
664 for i in low.trailing_zeros()..=high.trailing_zeros() {
665 assert!(1 << i <= RangeBound::max_value());
666 set.insert(1 << i);
667 }
668 set
669}
670
671#[test]
672fn test_typevar_builder() {
673 let type_set = TypeSetBuilder::new().ints(Interval::All).build();
674 assert_eq!(type_set.lanes, num_set![1]);
675 assert!(type_set.floats.is_empty());
676 assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
677
678 let type_set = TypeSetBuilder::new().floats(Interval::All).build();
679 assert_eq!(type_set.lanes, num_set![1]);
680 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
681 assert!(type_set.ints.is_empty());
682
683 let type_set = TypeSetBuilder::new()
684 .floats(Interval::All)
685 .simd_lanes(Interval::All)
686 .includes_scalars(false)
687 .build();
688 assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
689 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
690 assert!(type_set.ints.is_empty());
691
692 let type_set = TypeSetBuilder::new()
693 .floats(Interval::All)
694 .simd_lanes(Interval::All)
695 .includes_scalars(true)
696 .build();
697 assert_eq!(type_set.lanes, num_set![1, 2, 4, 8, 16, 32, 64, 128, 256]);
698 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
699 assert!(type_set.ints.is_empty());
700
701 let type_set = TypeSetBuilder::new()
702 .floats(Interval::All)
703 .simd_lanes(Interval::All)
704 .includes_scalars(false)
705 .build();
706 assert_eq!(type_set.lanes, num_set![2, 4, 8, 16, 32, 64, 128, 256]);
707 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
708 assert!(type_set.dynamic_lanes.is_empty());
709 assert!(type_set.ints.is_empty());
710
711 let type_set = TypeSetBuilder::new()
712 .ints(Interval::All)
713 .floats(Interval::All)
714 .dynamic_simd_lanes(Interval::All)
715 .includes_scalars(false)
716 .build();
717 assert_eq!(
718 type_set.dynamic_lanes,
719 num_set![2, 4, 8, 16, 32, 64, 128, 256]
720 );
721 assert_eq!(type_set.ints, num_set![8, 16, 32, 64, 128]);
722 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
723 assert_eq!(type_set.lanes, num_set![1]);
724
725 let type_set = TypeSetBuilder::new()
726 .floats(Interval::All)
727 .dynamic_simd_lanes(Interval::All)
728 .includes_scalars(false)
729 .build();
730 assert_eq!(
731 type_set.dynamic_lanes,
732 num_set![2, 4, 8, 16, 32, 64, 128, 256]
733 );
734 assert_eq!(type_set.floats, num_set![16, 32, 64, 128]);
735 assert_eq!(type_set.lanes, num_set![1]);
736 assert!(type_set.ints.is_empty());
737
738 let type_set = TypeSetBuilder::new().ints(16..64).build();
739 assert_eq!(type_set.lanes, num_set![1]);
740 assert_eq!(type_set.ints, num_set![16, 32, 64]);
741 assert!(type_set.floats.is_empty());
742}
743
744#[test]
745fn test_dynamic_to_vector() {
746 assert_eq!(
749 TypeSetBuilder::new()
750 .dynamic_simd_lanes(Interval::All)
751 .ints(Interval::All)
752 .build()
753 .dynamic_to_vector(),
754 TypeSetBuilder::new()
755 .simd_lanes(2..128)
756 .ints(Interval::All)
757 .build()
758 );
759 assert_eq!(
760 TypeSetBuilder::new()
761 .dynamic_simd_lanes(Interval::All)
762 .floats(Interval::All)
763 .build()
764 .dynamic_to_vector(),
765 TypeSetBuilder::new()
766 .simd_lanes(2..128)
767 .floats(Interval::All)
768 .build()
769 );
770}
771
772#[test]
773#[should_panic]
774fn test_typevar_builder_too_high_bound_panic() {
775 TypeSetBuilder::new().ints(16..2 * MAX_BITS).build();
776}
777
778#[test]
779#[should_panic]
780fn test_typevar_builder_inverted_bounds_panic() {
781 TypeSetBuilder::new().ints(32..16).build();
782}
783
784#[test]
785fn test_as_truthy() {
786 let a = TypeSetBuilder::new()
787 .simd_lanes(2..8)
788 .ints(8..8)
789 .floats(32..32)
790 .build();
791 assert_eq!(
792 a.lane_of(),
793 TypeSetBuilder::new().ints(8..8).floats(32..32).build()
794 );
795
796 let mut a_as_truthy = TypeSetBuilder::new().simd_lanes(2..8).build();
797 a_as_truthy.ints = num_set![8, 32];
798 assert_eq!(a.as_truthy(), a_as_truthy);
799
800 let a = TypeSetBuilder::new().ints(8..32).floats(32..64).build();
801 let a_as_truthy = TypeSetBuilder::new().ints(8..8).build();
802 assert_eq!(a.as_truthy(), a_as_truthy);
803}
804
805#[test]
806fn test_forward_images() {
807 let empty_set = TypeSetBuilder::new().build();
808
809 assert_eq!(
811 TypeSetBuilder::new()
812 .simd_lanes(1..32)
813 .build()
814 .half_vector(),
815 TypeSetBuilder::new().simd_lanes(1..16).build()
816 );
817
818 assert_eq!(
820 TypeSetBuilder::new()
821 .simd_lanes(1..32)
822 .build()
823 .double_vector(),
824 TypeSetBuilder::new().simd_lanes(2..64).build()
825 );
826 assert_eq!(
827 TypeSetBuilder::new()
828 .simd_lanes(128..256)
829 .build()
830 .double_vector(),
831 TypeSetBuilder::new().simd_lanes(256..256).build()
832 );
833
834 assert_eq!(
836 TypeSetBuilder::new().ints(8..32).build().half_width(),
837 TypeSetBuilder::new().ints(8..16).build()
838 );
839 assert_eq!(
840 TypeSetBuilder::new().floats(16..16).build().half_width(),
841 empty_set
842 );
843 assert_eq!(
844 TypeSetBuilder::new().floats(32..128).build().half_width(),
845 TypeSetBuilder::new().floats(16..64).build()
846 );
847
848 assert_eq!(
850 TypeSetBuilder::new().ints(8..32).build().double_width(),
851 TypeSetBuilder::new().ints(16..64).build()
852 );
853 assert_eq!(
854 TypeSetBuilder::new().ints(32..64).build().double_width(),
855 TypeSetBuilder::new().ints(64..128).build()
856 );
857 assert_eq!(
858 TypeSetBuilder::new().floats(32..32).build().double_width(),
859 TypeSetBuilder::new().floats(64..64).build()
860 );
861 assert_eq!(
862 TypeSetBuilder::new().floats(16..64).build().double_width(),
863 TypeSetBuilder::new().floats(32..128).build()
864 );
865}
866
867#[test]
868#[should_panic]
869fn test_typeset_singleton_panic_nonsingleton_types() {
870 TypeSetBuilder::new()
871 .ints(8..8)
872 .floats(32..32)
873 .build()
874 .get_singleton();
875}
876
877#[test]
878#[should_panic]
879fn test_typeset_singleton_panic_nonsingleton_lanes() {
880 TypeSetBuilder::new()
881 .simd_lanes(1..2)
882 .floats(32..32)
883 .build()
884 .get_singleton();
885}
886
887#[test]
888fn test_typeset_singleton() {
889 use crate::shared::types as shared_types;
890 assert_eq!(
891 TypeSetBuilder::new().ints(16..16).build().get_singleton(),
892 ValueType::Lane(shared_types::Int::I16.into())
893 );
894 assert_eq!(
895 TypeSetBuilder::new().floats(64..64).build().get_singleton(),
896 ValueType::Lane(shared_types::Float::F64.into())
897 );
898 assert_eq!(
899 TypeSetBuilder::new()
900 .simd_lanes(4..4)
901 .ints(32..32)
902 .build()
903 .get_singleton(),
904 LaneType::from(shared_types::Int::I32).by(4)
905 );
906}
907
908#[test]
909fn test_typevar_functions() {
910 let x = TypeVar::new(
911 "x",
912 "i16 and up",
913 TypeSetBuilder::new().ints(16..64).build(),
914 );
915 assert_eq!(x.half_width().name, "half_width(x)");
916 assert_eq!(
917 x.half_width().double_width().name,
918 "double_width(half_width(x))"
919 );
920
921 let x = TypeVar::new("x", "up to i32", TypeSetBuilder::new().ints(8..32).build());
922 assert_eq!(x.double_width().name, "double_width(x)");
923}
924
925#[test]
926fn test_typevar_singleton() {
927 use crate::cdsl::types::VectorType;
928 use crate::shared::types as shared_types;
929
930 let typevar = TypeVar::new_singleton(ValueType::Lane(LaneType::Int(shared_types::Int::I32)));
932 assert_eq!(typevar.name, "i32");
933 assert_eq!(typevar.type_set.ints, num_set![32]);
934 assert!(typevar.type_set.floats.is_empty());
935 assert_eq!(typevar.type_set.lanes, num_set![1]);
936
937 let typevar = TypeVar::new_singleton(ValueType::Vector(VectorType::new(
939 LaneType::Float(shared_types::Float::F32),
940 4,
941 )));
942 assert_eq!(typevar.name, "f32x4");
943 assert!(typevar.type_set.ints.is_empty());
944 assert_eq!(typevar.type_set.floats, num_set![32]);
945 assert_eq!(typevar.type_set.lanes, num_set![4]);
946}