1#[cfg(feature = "spirv_cross")]
2use spirv_cross::spirv;
3use std::{io, slice};
4
5pub type FastHashMap<K, V> =
7 std::collections::HashMap<K, V, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
8pub type FastHashSet<K> =
9 std::collections::HashSet<K, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
10
11#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[repr(u8)]
14pub enum ShaderStage {
15 Vertex,
16 Hull,
17 Domain,
18 Geometry,
19 Fragment,
20 Compute,
21 Task,
22 Mesh,
23}
24
25impl ShaderStage {
26 pub fn to_flag(self) -> hal::pso::ShaderStageFlags {
27 use hal::pso::ShaderStageFlags as Ssf;
28 match self {
29 ShaderStage::Vertex => Ssf::VERTEX,
30 ShaderStage::Hull => Ssf::HULL,
31 ShaderStage::Domain => Ssf::DOMAIN,
32 ShaderStage::Geometry => Ssf::GEOMETRY,
33 ShaderStage::Fragment => Ssf::FRAGMENT,
34 ShaderStage::Compute => Ssf::COMPUTE,
35 ShaderStage::Task => Ssf::TASK,
36 ShaderStage::Mesh => Ssf::MESH,
37 }
38 }
39}
40
41pub fn read_spirv<R: io::Read + io::Seek>(mut x: R) -> io::Result<Vec<u32>> {
61 let size = x.seek(io::SeekFrom::End(0))?;
62 if size % 4 != 0 {
63 return Err(io::Error::new(
64 io::ErrorKind::InvalidData,
65 "input length not divisible by 4",
66 ));
67 }
68 if size > usize::MAX as u64 {
69 return Err(io::Error::new(io::ErrorKind::InvalidData, "input too long"));
70 }
71 let words = (size / 4) as usize;
72 let mut result = Vec::<u32>::with_capacity(words);
73 x.seek(io::SeekFrom::Start(0))?;
74 unsafe {
75 x.read_exact(slice::from_raw_parts_mut(
78 result.as_mut_ptr() as *mut u8,
79 words * 4,
80 ))?;
81 result.set_len(words);
82 }
83 const MAGIC_NUMBER: u32 = 0x07230203;
84 if result.len() > 0 && result[0] == MAGIC_NUMBER.swap_bytes() {
85 for word in &mut result {
86 *word = word.swap_bytes();
87 }
88 }
89 if result.len() == 0 || result[0] != MAGIC_NUMBER {
90 return Err(io::Error::new(
91 io::ErrorKind::InvalidData,
92 "input missing SPIR-V magic number",
93 ));
94 }
95 Ok(result)
96}
97
98#[cfg(feature = "spirv_cross")]
99pub fn spirv_cross_specialize_ast<T>(
100 ast: &mut spirv::Ast<T>,
101 specialization: &hal::pso::Specialization,
102) -> Result<(), String>
103where
104 T: spirv::Target,
105 spirv::Ast<T>: spirv::Compile<T> + spirv::Parse<T>,
106{
107 let spec_constants = ast
108 .get_specialization_constants()
109 .map_err(|err| match err {
110 spirv_cross::ErrorCode::CompilationError(msg) => msg,
111 spirv_cross::ErrorCode::Unhandled => "Unexpected specialization constant error".into(),
112 })?;
113
114 for spec_constant in spec_constants {
115 if let Some(constant) = specialization
116 .constants
117 .iter()
118 .find(|c| c.id == spec_constant.constant_id)
119 {
120 let value = specialization.data
122 [constant.range.start as usize..constant.range.end as usize]
123 .iter()
124 .rev()
125 .fold(0u64, |u, &b| (u << 8) + b as u64);
126
127 ast.set_scalar_constant(spec_constant.id, value)
128 .map_err(|err| match err {
129 spirv_cross::ErrorCode::CompilationError(msg) => msg,
130 spirv_cross::ErrorCode::Unhandled => {
131 "Unexpected specialization constant error".into()
132 }
133 })?;
134 }
135 }
136
137 Ok(())
138}