naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod terminator;
11mod typifier;
12
13pub use constant_evaluator::{
14    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
15};
16pub use emitter::Emitter;
17pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
18pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
19pub use namer::{EntryPointIndex, NameKey, Namer};
20pub use terminator::ensure_block_returns;
21pub use typifier::{ResolveContext, ResolveError, TypeResolution};
22
23impl From<super::StorageFormat> for super::Scalar {
24    fn from(format: super::StorageFormat) -> Self {
25        use super::{ScalarKind as Sk, StorageFormat as Sf};
26        let kind = match format {
27            Sf::R8Unorm => Sk::Float,
28            Sf::R8Snorm => Sk::Float,
29            Sf::R8Uint => Sk::Uint,
30            Sf::R8Sint => Sk::Sint,
31            Sf::R16Uint => Sk::Uint,
32            Sf::R16Sint => Sk::Sint,
33            Sf::R16Float => Sk::Float,
34            Sf::Rg8Unorm => Sk::Float,
35            Sf::Rg8Snorm => Sk::Float,
36            Sf::Rg8Uint => Sk::Uint,
37            Sf::Rg8Sint => Sk::Sint,
38            Sf::R32Uint => Sk::Uint,
39            Sf::R32Sint => Sk::Sint,
40            Sf::R32Float => Sk::Float,
41            Sf::Rg16Uint => Sk::Uint,
42            Sf::Rg16Sint => Sk::Sint,
43            Sf::Rg16Float => Sk::Float,
44            Sf::Rgba8Unorm => Sk::Float,
45            Sf::Rgba8Snorm => Sk::Float,
46            Sf::Rgba8Uint => Sk::Uint,
47            Sf::Rgba8Sint => Sk::Sint,
48            Sf::Bgra8Unorm => Sk::Float,
49            Sf::Rgb10a2Uint => Sk::Uint,
50            Sf::Rgb10a2Unorm => Sk::Float,
51            Sf::Rg11b10Ufloat => Sk::Float,
52            Sf::R64Uint => Sk::Uint,
53            Sf::Rg32Uint => Sk::Uint,
54            Sf::Rg32Sint => Sk::Sint,
55            Sf::Rg32Float => Sk::Float,
56            Sf::Rgba16Uint => Sk::Uint,
57            Sf::Rgba16Sint => Sk::Sint,
58            Sf::Rgba16Float => Sk::Float,
59            Sf::Rgba32Uint => Sk::Uint,
60            Sf::Rgba32Sint => Sk::Sint,
61            Sf::Rgba32Float => Sk::Float,
62            Sf::R16Unorm => Sk::Float,
63            Sf::R16Snorm => Sk::Float,
64            Sf::Rg16Unorm => Sk::Float,
65            Sf::Rg16Snorm => Sk::Float,
66            Sf::Rgba16Unorm => Sk::Float,
67            Sf::Rgba16Snorm => Sk::Float,
68        };
69        let width = match format {
70            Sf::R64Uint => 8,
71            _ => 4,
72        };
73        super::Scalar { kind, width }
74    }
75}
76
77impl super::ScalarKind {
78    pub const fn is_numeric(self) -> bool {
79        match self {
80            crate::ScalarKind::Sint
81            | crate::ScalarKind::Uint
82            | crate::ScalarKind::Float
83            | crate::ScalarKind::AbstractInt
84            | crate::ScalarKind::AbstractFloat => true,
85            crate::ScalarKind::Bool => false,
86        }
87    }
88}
89
90impl super::Scalar {
91    pub const I32: Self = Self {
92        kind: crate::ScalarKind::Sint,
93        width: 4,
94    };
95    pub const U32: Self = Self {
96        kind: crate::ScalarKind::Uint,
97        width: 4,
98    };
99    pub const F32: Self = Self {
100        kind: crate::ScalarKind::Float,
101        width: 4,
102    };
103    pub const F64: Self = Self {
104        kind: crate::ScalarKind::Float,
105        width: 8,
106    };
107    pub const I64: Self = Self {
108        kind: crate::ScalarKind::Sint,
109        width: 8,
110    };
111    pub const U64: Self = Self {
112        kind: crate::ScalarKind::Uint,
113        width: 8,
114    };
115    pub const BOOL: Self = Self {
116        kind: crate::ScalarKind::Bool,
117        width: crate::BOOL_WIDTH,
118    };
119    pub const ABSTRACT_INT: Self = Self {
120        kind: crate::ScalarKind::AbstractInt,
121        width: crate::ABSTRACT_WIDTH,
122    };
123    pub const ABSTRACT_FLOAT: Self = Self {
124        kind: crate::ScalarKind::AbstractFloat,
125        width: crate::ABSTRACT_WIDTH,
126    };
127
128    pub const fn is_abstract(self) -> bool {
129        match self.kind {
130            crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true,
131            crate::ScalarKind::Sint
132            | crate::ScalarKind::Uint
133            | crate::ScalarKind::Float
134            | crate::ScalarKind::Bool => false,
135        }
136    }
137
138    /// Construct a float `Scalar` with the given width.
139    ///
140    /// This is especially common when dealing with
141    /// `TypeInner::Matrix`, where the scalar kind is implicit.
142    pub const fn float(width: crate::Bytes) -> Self {
143        Self {
144            kind: crate::ScalarKind::Float,
145            width,
146        }
147    }
148
149    pub const fn to_inner_scalar(self) -> crate::TypeInner {
150        crate::TypeInner::Scalar(self)
151    }
152
153    pub const fn to_inner_vector(self, size: crate::VectorSize) -> crate::TypeInner {
154        crate::TypeInner::Vector { size, scalar: self }
155    }
156
157    pub const fn to_inner_atomic(self) -> crate::TypeInner {
158        crate::TypeInner::Atomic(self)
159    }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
163pub enum HashableLiteral {
164    F64(u64),
165    F32(u32),
166    U32(u32),
167    I32(i32),
168    U64(u64),
169    I64(i64),
170    Bool(bool),
171    AbstractInt(i64),
172    AbstractFloat(u64),
173}
174
175impl From<crate::Literal> for HashableLiteral {
176    fn from(l: crate::Literal) -> Self {
177        match l {
178            crate::Literal::F64(v) => Self::F64(v.to_bits()),
179            crate::Literal::F32(v) => Self::F32(v.to_bits()),
180            crate::Literal::U32(v) => Self::U32(v),
181            crate::Literal::I32(v) => Self::I32(v),
182            crate::Literal::U64(v) => Self::U64(v),
183            crate::Literal::I64(v) => Self::I64(v),
184            crate::Literal::Bool(v) => Self::Bool(v),
185            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
186            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
187        }
188    }
189}
190
191impl crate::Literal {
192    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
193        match (value, scalar.kind, scalar.width) {
194            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
195            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
196            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
197            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
198            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
199            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
200            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
201            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
202            _ => None,
203        }
204    }
205
206    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
207        Self::new(0, scalar)
208    }
209
210    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
211        Self::new(1, scalar)
212    }
213
214    pub const fn width(&self) -> crate::Bytes {
215        match *self {
216            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
217            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
218            Self::Bool(_) => crate::BOOL_WIDTH,
219            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
220        }
221    }
222    pub const fn scalar(&self) -> crate::Scalar {
223        match *self {
224            Self::F64(_) => crate::Scalar::F64,
225            Self::F32(_) => crate::Scalar::F32,
226            Self::U32(_) => crate::Scalar::U32,
227            Self::I32(_) => crate::Scalar::I32,
228            Self::U64(_) => crate::Scalar::U64,
229            Self::I64(_) => crate::Scalar::I64,
230            Self::Bool(_) => crate::Scalar::BOOL,
231            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
232            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
233        }
234    }
235    pub const fn scalar_kind(&self) -> crate::ScalarKind {
236        self.scalar().kind
237    }
238    pub const fn ty_inner(&self) -> crate::TypeInner {
239        crate::TypeInner::Scalar(self.scalar())
240    }
241}
242
243pub const POINTER_SPAN: u32 = 4;
244
245impl super::TypeInner {
246    /// Return the scalar type of `self`.
247    ///
248    /// If `inner` is a scalar, vector, or matrix type, return
249    /// its scalar type. Otherwise, return `None`.
250    pub const fn scalar(&self) -> Option<super::Scalar> {
251        use crate::TypeInner as Ti;
252        match *self {
253            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
254            Ti::Matrix { scalar, .. } => Some(scalar),
255            _ => None,
256        }
257    }
258
259    pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
260        self.scalar().map(|scalar| scalar.kind)
261    }
262
263    /// Returns the scalar width in bytes
264    pub fn scalar_width(&self) -> Option<u8> {
265        self.scalar().map(|scalar| scalar.width)
266    }
267
268    pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
269        match *self {
270            Self::Pointer { space, .. } => Some(space),
271            Self::ValuePointer { space, .. } => Some(space),
272            _ => None,
273        }
274    }
275
276    pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
277        match *self {
278            crate::TypeInner::Pointer { base, .. } => match types[base].inner {
279                crate::TypeInner::Atomic { .. } => true,
280                _ => false,
281            },
282            _ => false,
283        }
284    }
285
286    /// Get the size of this type.
287    pub fn size(&self, _gctx: GlobalCtx) -> u32 {
288        match *self {
289            Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32,
290            Self::Vector { size, scalar } => size as u32 * scalar.width as u32,
291            // matrices are treated as arrays of aligned columns
292            Self::Matrix {
293                columns,
294                rows,
295                scalar,
296            } => Alignment::from(rows) * scalar.width as u32 * columns as u32,
297            Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
298            Self::Array {
299                base: _,
300                size,
301                stride,
302            } => {
303                let count = match size {
304                    super::ArraySize::Constant(count) => count.get(),
305                    // any struct member or array element needing a size at pipeline-creation time
306                    // must have a creation-fixed footprint
307                    super::ArraySize::Pending(_) => 0,
308                    // A dynamically-sized array has to have at least one element
309                    super::ArraySize::Dynamic => 1,
310                };
311                count * stride
312            }
313            Self::Struct { span, .. } => span,
314            Self::Image { .. }
315            | Self::Sampler { .. }
316            | Self::AccelerationStructure
317            | Self::RayQuery
318            | Self::BindingArray { .. } => 0,
319        }
320    }
321
322    /// Return the canonical form of `self`, or `None` if it's already in
323    /// canonical form.
324    ///
325    /// Certain types have multiple representations in `TypeInner`. This
326    /// function converts all forms of equivalent types to a single
327    /// representative of their class, so that simply applying `Eq` to the
328    /// result indicates whether the types are equivalent, as far as Naga IR is
329    /// concerned.
330    pub fn canonical_form(
331        &self,
332        types: &crate::UniqueArena<crate::Type>,
333    ) -> Option<crate::TypeInner> {
334        use crate::TypeInner as Ti;
335        match *self {
336            Ti::Pointer { base, space } => match types[base].inner {
337                Ti::Scalar(scalar) => Some(Ti::ValuePointer {
338                    size: None,
339                    scalar,
340                    space,
341                }),
342                Ti::Vector { size, scalar } => Some(Ti::ValuePointer {
343                    size: Some(size),
344                    scalar,
345                    space,
346                }),
347                _ => None,
348            },
349            _ => None,
350        }
351    }
352
353    /// Compare `self` and `rhs` as types.
354    ///
355    /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
356    /// `ValuePointer` and `Pointer` types as equivalent.
357    ///
358    /// When you know that one side of the comparison is never a pointer, it's
359    /// fine to not bother with canonicalization, and just compare `TypeInner`
360    /// values with `==`.
361    pub fn equivalent(
362        &self,
363        rhs: &crate::TypeInner,
364        types: &crate::UniqueArena<crate::Type>,
365    ) -> bool {
366        let left = self.canonical_form(types);
367        let right = rhs.canonical_form(types);
368        left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
369    }
370
371    pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
372        use crate::TypeInner as Ti;
373        match *self {
374            Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
375            Ti::Struct { ref members, .. } => members
376                .last()
377                .map(|last| types[last.ty].inner.is_dynamically_sized(types))
378                .unwrap_or(false),
379            _ => false,
380        }
381    }
382
383    pub fn components(&self) -> Option<u32> {
384        Some(match *self {
385            Self::Vector { size, .. } => size as u32,
386            Self::Matrix { columns, .. } => columns as u32,
387            Self::Array {
388                size: crate::ArraySize::Constant(len),
389                ..
390            } => len.get(),
391            Self::Struct { ref members, .. } => members.len() as u32,
392            _ => return None,
393        })
394    }
395
396    pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
397        Some(match *self {
398            Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)),
399            Self::Matrix { rows, scalar, .. } => {
400                TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
401            }
402            Self::Array {
403                base,
404                size: crate::ArraySize::Constant(_),
405                ..
406            } => TypeResolution::Handle(base),
407            Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty),
408            _ => return None,
409        })
410    }
411}
412
413impl super::AddressSpace {
414    pub fn access(self) -> crate::StorageAccess {
415        use crate::StorageAccess as Sa;
416        match self {
417            crate::AddressSpace::Function
418            | crate::AddressSpace::Private
419            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
420            crate::AddressSpace::Uniform => Sa::LOAD,
421            crate::AddressSpace::Storage { access } => access,
422            crate::AddressSpace::Handle => Sa::LOAD,
423            crate::AddressSpace::PushConstant => Sa::LOAD,
424        }
425    }
426}
427
428impl super::MathFunction {
429    pub const fn argument_count(&self) -> usize {
430        match *self {
431            // comparison
432            Self::Abs => 1,
433            Self::Min => 2,
434            Self::Max => 2,
435            Self::Clamp => 3,
436            Self::Saturate => 1,
437            // trigonometry
438            Self::Cos => 1,
439            Self::Cosh => 1,
440            Self::Sin => 1,
441            Self::Sinh => 1,
442            Self::Tan => 1,
443            Self::Tanh => 1,
444            Self::Acos => 1,
445            Self::Asin => 1,
446            Self::Atan => 1,
447            Self::Atan2 => 2,
448            Self::Asinh => 1,
449            Self::Acosh => 1,
450            Self::Atanh => 1,
451            Self::Radians => 1,
452            Self::Degrees => 1,
453            // decomposition
454            Self::Ceil => 1,
455            Self::Floor => 1,
456            Self::Round => 1,
457            Self::Fract => 1,
458            Self::Trunc => 1,
459            Self::Modf => 1,
460            Self::Frexp => 1,
461            Self::Ldexp => 2,
462            // exponent
463            Self::Exp => 1,
464            Self::Exp2 => 1,
465            Self::Log => 1,
466            Self::Log2 => 1,
467            Self::Pow => 2,
468            // geometry
469            Self::Dot => 2,
470            Self::Outer => 2,
471            Self::Cross => 2,
472            Self::Distance => 2,
473            Self::Length => 1,
474            Self::Normalize => 1,
475            Self::FaceForward => 3,
476            Self::Reflect => 2,
477            Self::Refract => 3,
478            // computational
479            Self::Sign => 1,
480            Self::Fma => 3,
481            Self::Mix => 3,
482            Self::Step => 2,
483            Self::SmoothStep => 3,
484            Self::Sqrt => 1,
485            Self::InverseSqrt => 1,
486            Self::Inverse => 1,
487            Self::Transpose => 1,
488            Self::Determinant => 1,
489            Self::QuantizeToF16 => 1,
490            // bits
491            Self::CountTrailingZeros => 1,
492            Self::CountLeadingZeros => 1,
493            Self::CountOneBits => 1,
494            Self::ReverseBits => 1,
495            Self::ExtractBits => 3,
496            Self::InsertBits => 4,
497            Self::FirstTrailingBit => 1,
498            Self::FirstLeadingBit => 1,
499            // data packing
500            Self::Pack4x8snorm => 1,
501            Self::Pack4x8unorm => 1,
502            Self::Pack2x16snorm => 1,
503            Self::Pack2x16unorm => 1,
504            Self::Pack2x16float => 1,
505            Self::Pack4xI8 => 1,
506            Self::Pack4xU8 => 1,
507            // data unpacking
508            Self::Unpack4x8snorm => 1,
509            Self::Unpack4x8unorm => 1,
510            Self::Unpack2x16snorm => 1,
511            Self::Unpack2x16unorm => 1,
512            Self::Unpack2x16float => 1,
513            Self::Unpack4xI8 => 1,
514            Self::Unpack4xU8 => 1,
515        }
516    }
517}
518
519impl crate::Expression {
520    /// Returns true if the expression is considered emitted at the start of a function.
521    pub const fn needs_pre_emit(&self) -> bool {
522        match *self {
523            Self::Literal(_)
524            | Self::Constant(_)
525            | Self::Override(_)
526            | Self::ZeroValue(_)
527            | Self::FunctionArgument(_)
528            | Self::GlobalVariable(_)
529            | Self::LocalVariable(_) => true,
530            _ => false,
531        }
532    }
533
534    /// Return true if this expression is a dynamic array/vector/matrix index,
535    /// for [`Access`].
536    ///
537    /// This method returns true if this expression is a dynamically computed
538    /// index, and as such can only be used to index matrices when they appear
539    /// behind a pointer. See the documentation for [`Access`] for details.
540    ///
541    /// Note, this does not check the _type_ of the given expression. It's up to
542    /// the caller to establish that the `Access` expression is well-typed
543    /// through other means, like [`ResolveContext`].
544    ///
545    /// [`Access`]: crate::Expression::Access
546    /// [`ResolveContext`]: crate::proc::ResolveContext
547    pub const fn is_dynamic_index(&self) -> bool {
548        match *self {
549            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
550            _ => true,
551        }
552    }
553}
554
555impl crate::Function {
556    /// Return the global variable being accessed by the expression `pointer`.
557    ///
558    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
559    /// expressions that ultimately access some part of a `GlobalVariable`,
560    /// return a handle for that global.
561    ///
562    /// If the expression does not ultimately access a global variable, return
563    /// `None`.
564    pub fn originating_global(
565        &self,
566        mut pointer: crate::Handle<crate::Expression>,
567    ) -> Option<crate::Handle<crate::GlobalVariable>> {
568        loop {
569            pointer = match self.expressions[pointer] {
570                crate::Expression::Access { base, .. } => base,
571                crate::Expression::AccessIndex { base, .. } => base,
572                crate::Expression::GlobalVariable(handle) => return Some(handle),
573                crate::Expression::LocalVariable(_) => return None,
574                crate::Expression::FunctionArgument(_) => return None,
575                // There are no other expressions that produce pointer values.
576                _ => unreachable!(),
577            }
578        }
579    }
580}
581
582impl crate::SampleLevel {
583    pub const fn implicit_derivatives(&self) -> bool {
584        match *self {
585            Self::Auto | Self::Bias(_) => true,
586            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
587        }
588    }
589}
590
591impl crate::Binding {
592    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
593        match *self {
594            crate::Binding::BuiltIn(built_in) => Some(built_in),
595            Self::Location { .. } => None,
596        }
597    }
598}
599
600impl super::SwizzleComponent {
601    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
602
603    pub const fn index(&self) -> u32 {
604        match *self {
605            Self::X => 0,
606            Self::Y => 1,
607            Self::Z => 2,
608            Self::W => 3,
609        }
610    }
611    pub const fn from_index(idx: u32) -> Self {
612        match idx {
613            0 => Self::X,
614            1 => Self::Y,
615            2 => Self::Z,
616            _ => Self::W,
617        }
618    }
619}
620
621impl super::ImageClass {
622    pub const fn is_multisampled(self) -> bool {
623        match self {
624            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
625            crate::ImageClass::Storage { .. } => false,
626        }
627    }
628
629    pub const fn is_mipmapped(self) -> bool {
630        match self {
631            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
632            crate::ImageClass::Storage { .. } => false,
633        }
634    }
635
636    pub const fn is_depth(self) -> bool {
637        matches!(self, crate::ImageClass::Depth { .. })
638    }
639}
640
641impl crate::Module {
642    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
643        GlobalCtx {
644            types: &self.types,
645            constants: &self.constants,
646            overrides: &self.overrides,
647            global_expressions: &self.global_expressions,
648        }
649    }
650}
651
652#[derive(Debug)]
653pub(super) enum U32EvalError {
654    NonConst,
655    Negative,
656}
657
658#[derive(Clone, Copy)]
659pub struct GlobalCtx<'a> {
660    pub types: &'a crate::UniqueArena<crate::Type>,
661    pub constants: &'a crate::Arena<crate::Constant>,
662    pub overrides: &'a crate::Arena<crate::Override>,
663    pub global_expressions: &'a crate::Arena<crate::Expression>,
664}
665
666impl GlobalCtx<'_> {
667    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
668    #[allow(dead_code)]
669    pub(super) fn eval_expr_to_u32(
670        &self,
671        handle: crate::Handle<crate::Expression>,
672    ) -> Result<u32, U32EvalError> {
673        self.eval_expr_to_u32_from(handle, self.global_expressions)
674    }
675
676    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
677    pub(super) fn eval_expr_to_u32_from(
678        &self,
679        handle: crate::Handle<crate::Expression>,
680        arena: &crate::Arena<crate::Expression>,
681    ) -> Result<u32, U32EvalError> {
682        match self.eval_expr_to_literal_from(handle, arena) {
683            Some(crate::Literal::U32(value)) => Ok(value),
684            Some(crate::Literal::I32(value)) => {
685                value.try_into().map_err(|_| U32EvalError::Negative)
686            }
687            _ => Err(U32EvalError::NonConst),
688        }
689    }
690
691    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
692    #[allow(dead_code)]
693    pub(super) fn eval_expr_to_bool_from(
694        &self,
695        handle: crate::Handle<crate::Expression>,
696        arena: &crate::Arena<crate::Expression>,
697    ) -> Option<bool> {
698        match self.eval_expr_to_literal_from(handle, arena) {
699            Some(crate::Literal::Bool(value)) => Some(value),
700            _ => None,
701        }
702    }
703
704    #[allow(dead_code)]
705    pub(crate) fn eval_expr_to_literal(
706        &self,
707        handle: crate::Handle<crate::Expression>,
708    ) -> Option<crate::Literal> {
709        self.eval_expr_to_literal_from(handle, self.global_expressions)
710    }
711
712    fn eval_expr_to_literal_from(
713        &self,
714        handle: crate::Handle<crate::Expression>,
715        arena: &crate::Arena<crate::Expression>,
716    ) -> Option<crate::Literal> {
717        fn get(
718            gctx: GlobalCtx,
719            handle: crate::Handle<crate::Expression>,
720            arena: &crate::Arena<crate::Expression>,
721        ) -> Option<crate::Literal> {
722            match arena[handle] {
723                crate::Expression::Literal(literal) => Some(literal),
724                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
725                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
726                    _ => None,
727                },
728                _ => None,
729            }
730        }
731        match arena[handle] {
732            crate::Expression::Constant(c) => {
733                get(*self, self.constants[c].init, self.global_expressions)
734            }
735            _ => get(*self, handle, arena),
736        }
737    }
738}
739
740/// Return an iterator over the individual components assembled by a
741/// `Compose` expression.
742///
743/// Given `ty` and `components` from an `Expression::Compose`, return an
744/// iterator over the components of the resulting value.
745///
746/// Normally, this would just be an iterator over `components`. However,
747/// `Compose` expressions can concatenate vectors, in which case the i'th
748/// value being composed is not generally the i'th element of `components`.
749/// This function consults `ty` to decide if this concatenation is occurring,
750/// and returns an iterator that produces the components of the result of
751/// the `Compose` expression in either case.
752pub fn flatten_compose<'arenas>(
753    ty: crate::Handle<crate::Type>,
754    components: &'arenas [crate::Handle<crate::Expression>],
755    expressions: &'arenas crate::Arena<crate::Expression>,
756    types: &'arenas crate::UniqueArena<crate::Type>,
757) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
758    // Returning `impl Iterator` is a bit tricky. We may or may not
759    // want to flatten the components, but we have to settle on a
760    // single concrete type to return. This function returns a single
761    // iterator chain that handles both the flattening and
762    // non-flattening cases.
763    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
764        (size as usize, true)
765    } else {
766        (components.len(), false)
767    };
768
769    /// Flatten `Compose` expressions if `is_vector` is true.
770    fn flatten_compose<'c>(
771        component: &'c crate::Handle<crate::Expression>,
772        is_vector: bool,
773        expressions: &'c crate::Arena<crate::Expression>,
774    ) -> &'c [crate::Handle<crate::Expression>] {
775        if is_vector {
776            if let crate::Expression::Compose {
777                ty: _,
778                components: ref subcomponents,
779            } = expressions[*component]
780            {
781                return subcomponents;
782            }
783        }
784        std::slice::from_ref(component)
785    }
786
787    /// Flatten `Splat` expressions if `is_vector` is true.
788    fn flatten_splat<'c>(
789        component: &'c crate::Handle<crate::Expression>,
790        is_vector: bool,
791        expressions: &'c crate::Arena<crate::Expression>,
792    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
793        let mut expr = *component;
794        let mut count = 1;
795        if is_vector {
796            if let crate::Expression::Splat { size, value } = expressions[expr] {
797                expr = value;
798                count = size as usize;
799            }
800        }
801        std::iter::repeat(expr).take(count)
802    }
803
804    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
805    // flatten up to two levels of `Compose` expressions.
806    //
807    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
808    // `Splat` expressions. Fortunately, the operand of a `Splat` must
809    // be a scalar, so we can stop there.
810    components
811        .iter()
812        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
813        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
814        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
815        .take(size)
816}
817
818#[test]
819fn test_matrix_size() {
820    let module = crate::Module::default();
821    assert_eq!(
822        crate::TypeInner::Matrix {
823            columns: crate::VectorSize::Tri,
824            rows: crate::VectorSize::Tri,
825            scalar: crate::Scalar::F32,
826        }
827        .size(module.to_ctx()),
828        48,
829    );
830}