1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14 arena::{Handle, HandleSet},
15 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
16 FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::{check_literal_value, LiteralError};
28pub use expression::{ConstExpressionError, ExpressionError};
29pub use function::{CallError, FunctionError, LocalVariableError};
30pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
31pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError};
32
33use self::handles::InvalidHandleError;
34
35bitflags::bitflags! {
36 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
50 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
51 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
52 pub struct ValidationFlags: u8 {
53 const EXPRESSIONS = 0x1;
55 const BLOCKS = 0x2;
57 const CONTROL_FLOW_UNIFORMITY = 0x4;
59 const STRUCT_LAYOUTS = 0x8;
61 const CONSTANTS = 0x10;
63 const BINDINGS = 0x20;
65 }
66}
67
68impl Default for ValidationFlags {
69 fn default() -> Self {
70 Self::all()
71 }
72}
73
74bitflags::bitflags! {
75 #[must_use]
77 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
78 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
79 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
80 pub struct Capabilities: u32 {
81 const PUSH_CONSTANT = 1 << 0;
85 const FLOAT64 = 1 << 1;
87 const PRIMITIVE_INDEX = 1 << 2;
91 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
93 const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
95 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 5;
97 const CLIP_DISTANCE = 1 << 6;
101 const CULL_DISTANCE = 1 << 7;
105 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 8;
107 const MULTIVIEW = 1 << 9;
111 const EARLY_DEPTH_TEST = 1 << 10;
113 const MULTISAMPLED_SHADING = 1 << 11;
118 const RAY_QUERY = 1 << 12;
120 const DUAL_SOURCE_BLENDING = 1 << 13;
122 const CUBE_ARRAY_TEXTURES = 1 << 14;
124 const SHADER_INT64 = 1 << 15;
126 const SUBGROUP = 1 << 16;
130 const SUBGROUP_BARRIER = 1 << 17;
132 const SUBGROUP_VERTEX_STAGE = 1 << 18;
134 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 19;
144 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 20;
146 const SHADER_FLOAT32_ATOMIC = 1 << 21;
155 const TEXTURE_ATOMIC = 1 << 22;
157 const TEXTURE_INT64_ATOMIC = 1 << 23;
159 }
160}
161
162impl Default for Capabilities {
163 fn default() -> Self {
164 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
165 }
166}
167
168bitflags::bitflags! {
169 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
171 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
172 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
173 pub struct SubgroupOperationSet: u8 {
174 const BASIC = 1 << 0;
176 const VOTE = 1 << 1;
178 const ARITHMETIC = 1 << 2;
180 const BALLOT = 1 << 3;
182 const SHUFFLE = 1 << 4;
184 const SHUFFLE_RELATIVE = 1 << 5;
186 }
194}
195
196impl super::SubgroupOperation {
197 const fn required_operations(&self) -> SubgroupOperationSet {
198 use SubgroupOperationSet as S;
199 match *self {
200 Self::All | Self::Any => S::VOTE,
201 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
202 S::ARITHMETIC
203 }
204 }
205 }
206}
207
208impl super::GatherMode {
209 const fn required_operations(&self) -> SubgroupOperationSet {
210 use SubgroupOperationSet as S;
211 match *self {
212 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
213 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
214 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
215 }
216 }
217}
218
219bitflags::bitflags! {
220 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
222 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
223 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
224 pub struct ShaderStages: u8 {
225 const VERTEX = 0x1;
226 const FRAGMENT = 0x2;
227 const COMPUTE = 0x4;
228 }
229}
230
231#[derive(Debug, Clone)]
232#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
233#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
234pub struct ModuleInfo {
235 type_flags: Vec<TypeFlags>,
236 functions: Vec<FunctionInfo>,
237 entry_points: Vec<FunctionInfo>,
238 const_expression_types: Box<[TypeResolution]>,
239}
240
241impl ops::Index<Handle<crate::Type>> for ModuleInfo {
242 type Output = TypeFlags;
243 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
244 &self.type_flags[handle.index()]
245 }
246}
247
248impl ops::Index<Handle<crate::Function>> for ModuleInfo {
249 type Output = FunctionInfo;
250 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
251 &self.functions[handle.index()]
252 }
253}
254
255impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
256 type Output = TypeResolution;
257 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
258 &self.const_expression_types[handle.index()]
259 }
260}
261
262#[derive(Debug)]
263pub struct Validator {
264 flags: ValidationFlags,
265 capabilities: Capabilities,
266 subgroup_stages: ShaderStages,
267 subgroup_operations: SubgroupOperationSet,
268 types: Vec<r#type::TypeInfo>,
269 layouter: Layouter,
270 location_mask: BitSet,
271 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
272 #[allow(dead_code)]
273 switch_values: FastHashSet<crate::SwitchValue>,
274 valid_expression_list: Vec<Handle<crate::Expression>>,
275 valid_expression_set: HandleSet<crate::Expression>,
276 override_ids: FastHashSet<u16>,
277 allow_overrides: bool,
278
279 needs_visit: HandleSet<crate::Expression>,
298}
299
300#[derive(Clone, Debug, thiserror::Error)]
301#[cfg_attr(test, derive(PartialEq))]
302pub enum ConstantError {
303 #[error("Initializer must be a const-expression")]
304 InitializerExprType,
305 #[error("The type doesn't match the constant")]
306 InvalidType,
307 #[error("The type is not constructible")]
308 NonConstructibleType,
309}
310
311#[derive(Clone, Debug, thiserror::Error)]
312#[cfg_attr(test, derive(PartialEq))]
313pub enum OverrideError {
314 #[error("Override name and ID are missing")]
315 MissingNameAndID,
316 #[error("Override ID must be unique")]
317 DuplicateID,
318 #[error("Initializer must be a const-expression or override-expression")]
319 InitializerExprType,
320 #[error("The type doesn't match the override")]
321 InvalidType,
322 #[error("The type is not constructible")]
323 NonConstructibleType,
324 #[error("The type is not a scalar")]
325 TypeNotScalar,
326 #[error("Override declarations are not allowed")]
327 NotAllowed,
328}
329
330#[derive(Clone, Debug, thiserror::Error)]
331#[cfg_attr(test, derive(PartialEq))]
332pub enum ValidationError {
333 #[error(transparent)]
334 InvalidHandle(#[from] InvalidHandleError),
335 #[error(transparent)]
336 Layouter(#[from] LayoutError),
337 #[error("Type {handle:?} '{name}' is invalid")]
338 Type {
339 handle: Handle<crate::Type>,
340 name: String,
341 source: TypeError,
342 },
343 #[error("Constant expression {handle:?} is invalid")]
344 ConstExpression {
345 handle: Handle<crate::Expression>,
346 source: ConstExpressionError,
347 },
348 #[error("Array size expression {handle:?} is not strictly positive")]
349 ArraySizeError { handle: Handle<crate::Expression> },
350 #[error("Constant {handle:?} '{name}' is invalid")]
351 Constant {
352 handle: Handle<crate::Constant>,
353 name: String,
354 source: ConstantError,
355 },
356 #[error("Override {handle:?} '{name}' is invalid")]
357 Override {
358 handle: Handle<crate::Override>,
359 name: String,
360 source: OverrideError,
361 },
362 #[error("Global variable {handle:?} '{name}' is invalid")]
363 GlobalVariable {
364 handle: Handle<crate::GlobalVariable>,
365 name: String,
366 source: GlobalVariableError,
367 },
368 #[error("Function {handle:?} '{name}' is invalid")]
369 Function {
370 handle: Handle<crate::Function>,
371 name: String,
372 source: FunctionError,
373 },
374 #[error("Entry point {name} at {stage:?} is invalid")]
375 EntryPoint {
376 stage: crate::ShaderStage,
377 name: String,
378 source: EntryPointError,
379 },
380 #[error("Module is corrupted")]
381 Corrupted,
382}
383
384impl crate::TypeInner {
385 const fn is_sized(&self) -> bool {
386 match *self {
387 Self::Scalar { .. }
388 | Self::Vector { .. }
389 | Self::Matrix { .. }
390 | Self::Array {
391 size: crate::ArraySize::Constant(_),
392 ..
393 }
394 | Self::Atomic { .. }
395 | Self::Pointer { .. }
396 | Self::ValuePointer { .. }
397 | Self::Struct { .. } => true,
398 Self::Array { .. }
399 | Self::Image { .. }
400 | Self::Sampler { .. }
401 | Self::AccelerationStructure
402 | Self::RayQuery
403 | Self::BindingArray { .. } => false,
404 }
405 }
406
407 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
409 match *self {
410 Self::Scalar(crate::Scalar {
411 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
412 ..
413 }) => Some(crate::ImageDimension::D1),
414 Self::Vector {
415 size: crate::VectorSize::Bi,
416 scalar:
417 crate::Scalar {
418 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
419 ..
420 },
421 } => Some(crate::ImageDimension::D2),
422 Self::Vector {
423 size: crate::VectorSize::Tri,
424 scalar:
425 crate::Scalar {
426 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
427 ..
428 },
429 } => Some(crate::ImageDimension::D3),
430 _ => None,
431 }
432 }
433}
434
435impl Validator {
436 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
438 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
439 use SubgroupOperationSet as S;
440 S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
441 } else {
442 SubgroupOperationSet::empty()
443 };
444 let subgroup_stages = {
445 let mut stages = ShaderStages::empty();
446 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
447 stages |= ShaderStages::VERTEX;
448 }
449 if capabilities.contains(Capabilities::SUBGROUP) {
450 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
451 }
452 stages
453 };
454
455 Validator {
456 flags,
457 capabilities,
458 subgroup_stages,
459 subgroup_operations,
460 types: Vec::new(),
461 layouter: Layouter::default(),
462 location_mask: BitSet::new(),
463 ep_resource_bindings: FastHashSet::default(),
464 switch_values: FastHashSet::default(),
465 valid_expression_list: Vec::new(),
466 valid_expression_set: HandleSet::new(),
467 override_ids: FastHashSet::default(),
468 allow_overrides: true,
469 needs_visit: HandleSet::new(),
470 }
471 }
472
473 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
474 self.subgroup_stages = stages;
475 self
476 }
477
478 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
479 self.subgroup_operations = operations;
480 self
481 }
482
483 pub fn reset(&mut self) {
485 self.types.clear();
486 self.layouter.clear();
487 self.location_mask.clear();
488 self.ep_resource_bindings.clear();
489 self.switch_values.clear();
490 self.valid_expression_list.clear();
491 self.valid_expression_set.clear();
492 self.override_ids.clear();
493 }
494
495 fn validate_constant(
496 &self,
497 handle: Handle<crate::Constant>,
498 gctx: crate::proc::GlobalCtx,
499 mod_info: &ModuleInfo,
500 global_expr_kind: &ExpressionKindTracker,
501 ) -> Result<(), ConstantError> {
502 let con = &gctx.constants[handle];
503
504 let type_info = &self.types[con.ty.index()];
505 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
506 return Err(ConstantError::NonConstructibleType);
507 }
508
509 if !global_expr_kind.is_const(con.init) {
510 return Err(ConstantError::InitializerExprType);
511 }
512
513 let decl_ty = &gctx.types[con.ty].inner;
514 let init_ty = mod_info[con.init].inner_with(gctx.types);
515 if !decl_ty.equivalent(init_ty, gctx.types) {
516 return Err(ConstantError::InvalidType);
517 }
518
519 Ok(())
520 }
521
522 fn validate_override(
523 &mut self,
524 handle: Handle<crate::Override>,
525 gctx: crate::proc::GlobalCtx,
526 mod_info: &ModuleInfo,
527 ) -> Result<(), OverrideError> {
528 if !self.allow_overrides {
529 return Err(OverrideError::NotAllowed);
530 }
531
532 let o = &gctx.overrides[handle];
533
534 if o.name.is_none() && o.id.is_none() {
535 return Err(OverrideError::MissingNameAndID);
536 }
537
538 if let Some(id) = o.id {
539 if !self.override_ids.insert(id) {
540 return Err(OverrideError::DuplicateID);
541 }
542 }
543
544 let type_info = &self.types[o.ty.index()];
545 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
546 return Err(OverrideError::NonConstructibleType);
547 }
548
549 let decl_ty = &gctx.types[o.ty].inner;
550 match decl_ty {
551 &crate::TypeInner::Scalar(
552 crate::Scalar::BOOL
553 | crate::Scalar::I32
554 | crate::Scalar::U32
555 | crate::Scalar::F32
556 | crate::Scalar::F64,
557 ) => {}
558 _ => return Err(OverrideError::TypeNotScalar),
559 }
560
561 if let Some(init) = o.init {
562 let init_ty = mod_info[init].inner_with(gctx.types);
563 if !decl_ty.equivalent(init_ty, gctx.types) {
564 return Err(OverrideError::InvalidType);
565 }
566 }
567
568 Ok(())
569 }
570
571 pub fn validate(
573 &mut self,
574 module: &crate::Module,
575 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
576 self.allow_overrides = true;
577 self.validate_impl(module)
578 }
579
580 pub fn validate_no_overrides(
584 &mut self,
585 module: &crate::Module,
586 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
587 self.allow_overrides = false;
588 self.validate_impl(module)
589 }
590
591 fn validate_impl(
592 &mut self,
593 module: &crate::Module,
594 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
595 self.reset();
596 self.reset_types(module.types.len());
597
598 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
599
600 self.layouter.update(module.to_ctx()).map_err(|e| {
601 let handle = e.ty;
602 ValidationError::from(e).with_span_handle(handle, &module.types)
603 })?;
604
605 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
607 kind: crate::ScalarKind::Bool,
608 width: 0,
609 }));
610
611 let mut mod_info = ModuleInfo {
612 type_flags: Vec::with_capacity(module.types.len()),
613 functions: Vec::with_capacity(module.functions.len()),
614 entry_points: Vec::with_capacity(module.entry_points.len()),
615 const_expression_types: vec![placeholder; module.global_expressions.len()]
616 .into_boxed_slice(),
617 };
618
619 for (handle, ty) in module.types.iter() {
620 let ty_info = self
621 .validate_type(handle, module.to_ctx())
622 .map_err(|source| {
623 ValidationError::Type {
624 handle,
625 name: ty.name.clone().unwrap_or_default(),
626 source,
627 }
628 .with_span_handle(handle, &module.types)
629 })?;
630 if !self.allow_overrides {
631 if let crate::TypeInner::Array {
632 size: crate::ArraySize::Pending(_),
633 ..
634 } = ty.inner
635 {
636 return Err((ValidationError::Type {
637 handle,
638 name: ty.name.clone().unwrap_or_default(),
639 source: TypeError::UnresolvedOverride(handle),
640 })
641 .with_span_handle(handle, &module.types));
642 }
643 }
644 mod_info.type_flags.push(ty_info.flags);
645 self.types[handle.index()] = ty_info;
646 }
647
648 {
649 let t = crate::Arena::new();
650 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
651 for (handle, _) in module.global_expressions.iter() {
652 mod_info
653 .process_const_expression(handle, &resolve_context, module.to_ctx())
654 .map_err(|source| {
655 ValidationError::ConstExpression { handle, source }
656 .with_span_handle(handle, &module.global_expressions)
657 })?
658 }
659 }
660
661 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
662
663 if self.flags.contains(ValidationFlags::CONSTANTS) {
664 for (handle, _) in module.global_expressions.iter() {
665 self.validate_const_expression(
666 handle,
667 module.to_ctx(),
668 &mod_info,
669 &global_expr_kind,
670 )
671 .map_err(|source| {
672 ValidationError::ConstExpression { handle, source }
673 .with_span_handle(handle, &module.global_expressions)
674 })?
675 }
676
677 for (handle, constant) in module.constants.iter() {
678 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
679 .map_err(|source| {
680 ValidationError::Constant {
681 handle,
682 name: constant.name.clone().unwrap_or_default(),
683 source,
684 }
685 .with_span_handle(handle, &module.constants)
686 })?
687 }
688
689 for (handle, override_) in module.overrides.iter() {
690 self.validate_override(handle, module.to_ctx(), &mod_info)
691 .map_err(|source| {
692 ValidationError::Override {
693 handle,
694 name: override_.name.clone().unwrap_or_default(),
695 source,
696 }
697 .with_span_handle(handle, &module.overrides)
698 })?
699 }
700 }
701
702 for (var_handle, var) in module.global_variables.iter() {
703 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
704 .map_err(|source| {
705 ValidationError::GlobalVariable {
706 handle: var_handle,
707 name: var.name.clone().unwrap_or_default(),
708 source,
709 }
710 .with_span_handle(var_handle, &module.global_variables)
711 })?;
712 }
713
714 for (handle, fun) in module.functions.iter() {
715 match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
716 Ok(info) => mod_info.functions.push(info),
717 Err(error) => {
718 return Err(error.and_then(|source| {
719 ValidationError::Function {
720 handle,
721 name: fun.name.clone().unwrap_or_default(),
722 source,
723 }
724 .with_span_handle(handle, &module.functions)
725 }))
726 }
727 }
728 }
729
730 let mut ep_map = FastHashSet::default();
731 for ep in module.entry_points.iter() {
732 if !ep_map.insert((ep.stage, &ep.name)) {
733 return Err(ValidationError::EntryPoint {
734 stage: ep.stage,
735 name: ep.name.clone(),
736 source: EntryPointError::Conflict,
737 }
738 .with_span()); }
740
741 match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
742 Ok(info) => mod_info.entry_points.push(info),
743 Err(error) => {
744 return Err(error.and_then(|source| {
745 ValidationError::EntryPoint {
746 stage: ep.stage,
747 name: ep.name.clone(),
748 source,
749 }
750 .with_span()
751 }));
752 }
753 }
754 }
755
756 Ok(mod_info)
757 }
758}
759
760fn validate_atomic_compare_exchange_struct(
761 types: &crate::UniqueArena<crate::Type>,
762 members: &[crate::StructMember],
763 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
764) -> bool {
765 members.len() == 2
766 && members[0].name.as_deref() == Some("old_value")
767 && scalar_predicate(&types[members[0].ty].inner)
768 && members[1].name.as_deref() == Some("exchanged")
769 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
770}