gfx_auxil/
lib.rs

1#[cfg(feature = "spirv_cross")]
2use spirv_cross::spirv;
3use std::{io, slice};
4
5/// Fast hash map used internally.
6pub type FastHashMap<K, V> =
7    std::collections::HashMap<K, V, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
8pub type FastHashSet<K> =
9    std::collections::HashSet<K, std::hash::BuildHasherDefault<fxhash::FxHasher>>;
10
11#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[repr(u8)]
14pub enum ShaderStage {
15    Vertex,
16    Hull,
17    Domain,
18    Geometry,
19    Fragment,
20    Compute,
21    Task,
22    Mesh,
23}
24
25impl ShaderStage {
26    pub fn to_flag(self) -> hal::pso::ShaderStageFlags {
27        use hal::pso::ShaderStageFlags as Ssf;
28        match self {
29            ShaderStage::Vertex => Ssf::VERTEX,
30            ShaderStage::Hull => Ssf::HULL,
31            ShaderStage::Domain => Ssf::DOMAIN,
32            ShaderStage::Geometry => Ssf::GEOMETRY,
33            ShaderStage::Fragment => Ssf::FRAGMENT,
34            ShaderStage::Compute => Ssf::COMPUTE,
35            ShaderStage::Task => Ssf::TASK,
36            ShaderStage::Mesh => Ssf::MESH,
37        }
38    }
39}
40
41/// Safely read SPIR-V
42///
43/// Converts to native endianness and returns correctly aligned storage without unnecessary
44/// copying. Returns an `InvalidData` error if the input is trivially not SPIR-V.
45///
46/// This function can also be used to convert an already in-memory `&[u8]` to a valid `Vec<u32>`,
47/// but prefer working with `&[u32]` from the start whenever possible.
48///
49/// # Examples
50/// ```no_run
51/// let mut file = std::fs::File::open("/path/to/shader.spv").unwrap();
52/// let words = gfx_auxil::read_spirv(&mut file).unwrap();
53/// ```
54/// ```
55/// const SPIRV: &[u8] = &[
56///     0x03, 0x02, 0x23, 0x07, // ...
57/// ];
58/// let words = gfx_auxil::read_spirv(std::io::Cursor::new(&SPIRV[..])).unwrap();
59/// ```
60pub fn read_spirv<R: io::Read + io::Seek>(mut x: R) -> io::Result<Vec<u32>> {
61    let size = x.seek(io::SeekFrom::End(0))?;
62    if size % 4 != 0 {
63        return Err(io::Error::new(
64            io::ErrorKind::InvalidData,
65            "input length not divisible by 4",
66        ));
67    }
68    if size > usize::MAX as u64 {
69        return Err(io::Error::new(io::ErrorKind::InvalidData, "input too long"));
70    }
71    let words = (size / 4) as usize;
72    let mut result = Vec::<u32>::with_capacity(words);
73    x.seek(io::SeekFrom::Start(0))?;
74    unsafe {
75        // Writing all bytes through a pointer with less strict alignment when our type has no
76        // invalid bitpatterns is safe.
77        x.read_exact(slice::from_raw_parts_mut(
78            result.as_mut_ptr() as *mut u8,
79            words * 4,
80        ))?;
81        result.set_len(words);
82    }
83    const MAGIC_NUMBER: u32 = 0x07230203;
84    if result.len() > 0 && result[0] == MAGIC_NUMBER.swap_bytes() {
85        for word in &mut result {
86            *word = word.swap_bytes();
87        }
88    }
89    if result.len() == 0 || result[0] != MAGIC_NUMBER {
90        return Err(io::Error::new(
91            io::ErrorKind::InvalidData,
92            "input missing SPIR-V magic number",
93        ));
94    }
95    Ok(result)
96}
97
98#[cfg(feature = "spirv_cross")]
99pub fn spirv_cross_specialize_ast<T>(
100    ast: &mut spirv::Ast<T>,
101    specialization: &hal::pso::Specialization,
102) -> Result<(), String>
103where
104    T: spirv::Target,
105    spirv::Ast<T>: spirv::Compile<T> + spirv::Parse<T>,
106{
107    let spec_constants = ast
108        .get_specialization_constants()
109        .map_err(|err| match err {
110            spirv_cross::ErrorCode::CompilationError(msg) => msg,
111            spirv_cross::ErrorCode::Unhandled => "Unexpected specialization constant error".into(),
112        })?;
113
114    for spec_constant in spec_constants {
115        if let Some(constant) = specialization
116            .constants
117            .iter()
118            .find(|c| c.id == spec_constant.constant_id)
119        {
120            // Override specialization constant values
121            let value = specialization.data
122                [constant.range.start as usize..constant.range.end as usize]
123                .iter()
124                .rev()
125                .fold(0u64, |u, &b| (u << 8) + b as u64);
126
127            ast.set_scalar_constant(spec_constant.id, value)
128                .map_err(|err| match err {
129                    spirv_cross::ErrorCode::CompilationError(msg) => msg,
130                    spirv_cross::ErrorCode::Unhandled => {
131                        "Unexpected specialization constant error".into()
132                    }
133                })?;
134        }
135    }
136
137    Ok(())
138}