rendy_shader/
lib.rs

1//! Shader compilation.
2
3#![warn(
4    missing_debug_implementations,
5    missing_copy_implementations,
6    missing_docs,
7    trivial_casts,
8    trivial_numeric_casts,
9    unused_extern_crates,
10    unused_import_braces,
11    unused_qualifications
12)]
13
14#[cfg(feature = "shader-compiler")]
15mod shaderc;
16
17#[cfg(feature = "spirv-reflection")]
18#[allow(dead_code)]
19mod reflect;
20
21#[cfg(feature = "shader-compiler")]
22pub use self::shaderc::*;
23
24#[cfg(feature = "spirv-reflection")]
25pub use self::reflect::{ReflectError, ReflectTypeError, RetrievalKind, SpirvReflection};
26
27use rendy_core::hal::{pso::ShaderStageFlags, Backend};
28use std::collections::HashMap;
29
30/// Error type returned by this module.
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum ShaderError {}
33
34impl std::error::Error for ShaderError {}
35impl std::fmt::Display for ShaderError {
36    fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match *self {}
38    }
39}
40
41/// Interface to create shader modules from shaders.
42/// Implemented for static shaders via [`compile_to_spirv!`] macro.
43///
44pub trait Shader {
45    /// The error type returned by the spirv function of this shader.
46    type Error: std::fmt::Debug;
47
48    /// Get spirv bytecode.
49    fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, <Self as Shader>::Error>;
50
51    /// Get the entry point of the shader.
52    fn entry(&self) -> &str;
53
54    /// Get the rendy_core::hal representation of this shaders kind/stage.
55    fn stage(&self) -> ShaderStageFlags;
56
57    /// Create shader module.
58    ///
59    /// Spir-V bytecode must adhere valid usage on this Vulkan spec page:
60    /// https://www.khronos.org/registry/vulkan/specs/1.1-extensions/man/html/VkShaderModuleCreateInfo.html
61    unsafe fn module<B>(
62        &self,
63        factory: &rendy_factory::Factory<B>,
64    ) -> Result<B::ShaderModule, rendy_core::hal::device::ShaderError>
65    where
66        B: Backend,
67    {
68        rendy_core::hal::device::Device::create_shader_module(
69            factory.device().raw(),
70            &self.spirv().map_err(|e| {
71                rendy_core::hal::device::ShaderError::CompilationFailed(format!("{:?}", e))
72            })?,
73        )
74    }
75}
76
77/// Spir-V shader.
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
80pub struct SpirvShader {
81    #[cfg_attr(feature = "serde", serde(with = "serde_spirv"))]
82    spirv: Vec<u32>,
83    stage: ShaderStageFlags,
84    entry: String,
85}
86
87#[cfg(feature = "serde")]
88mod serde_spirv {
89    pub fn serialize<S>(data: &Vec<u32>, serializer: S) -> Result<S::Ok, S::Error>
90    where
91        S: serde::Serializer,
92    {
93        serializer.serialize_bytes(rendy_core::cast_slice(&data))
94    }
95
96    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
97    where
98        D: serde::Deserializer<'de>,
99    {
100        // Via the serde::Deserialize impl for &[u8].
101        let bytes: &[u8] = serde::Deserialize::deserialize(deserializer)?;
102        rendy_core::hal::pso::read_spirv(std::io::Cursor::new(bytes))
103            .map_err(serde::de::Error::custom)
104    }
105}
106
107impl SpirvShader {
108    /// Create Spir-V shader from bytes.
109    pub fn new(spirv: Vec<u32>, stage: ShaderStageFlags, entrypoint: &str) -> Self {
110        assert!(!spirv.is_empty());
111        Self {
112            spirv,
113            stage,
114            entry: entrypoint.to_string(),
115        }
116    }
117
118    /// Create Spir-V shader from bytecode stored as bytes.
119    /// Errors when passed byte array length is not a multiple of 4.
120    pub fn from_bytes(
121        spirv: &[u8],
122        stage: ShaderStageFlags,
123        entrypoint: &str,
124    ) -> std::io::Result<Self> {
125        Ok(Self::new(
126            rendy_core::hal::pso::read_spirv(std::io::Cursor::new(spirv))?,
127            stage,
128            entrypoint,
129        ))
130    }
131}
132
133impl Shader for SpirvShader {
134    type Error = ShaderError;
135
136    fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, ShaderError> {
137        Ok(std::borrow::Cow::Borrowed(&self.spirv))
138    }
139
140    fn entry(&self) -> &str {
141        &self.entry
142    }
143
144    fn stage(&self) -> ShaderStageFlags {
145        self.stage
146    }
147}
148
149/// A `ShaderSet` object represents a merged collection of `ShaderStorage` structures, which reflects merged information for all shaders in the set.
150#[derive(Debug)]
151pub struct ShaderSet<B: Backend> {
152    shaders: HashMap<ShaderStageFlags, ShaderStorage<B>>,
153}
154
155impl<B> Default for ShaderSet<B>
156where
157    B: Backend,
158{
159    fn default() -> Self {
160        ShaderSet {
161            shaders: HashMap::default(),
162        }
163    }
164}
165
166impl<B: Backend> ShaderSet<B> {
167    /// This function compiles and loads all shaders into B::ShaderModule objects which must be dropped later with `dispose`
168    pub fn load(
169        &mut self,
170        factory: &rendy_factory::Factory<B>,
171    ) -> Result<&mut Self, rendy_core::hal::device::ShaderError> {
172        for (_, v) in self.shaders.iter_mut() {
173            unsafe { v.compile(factory)? }
174        }
175
176        Ok(self)
177    }
178
179    /// Returns the `GraphicsShaderSet` structure to provide all the runtime information needed to use the shaders in this set in rendy_core::hal.
180    pub fn raw<'a>(
181        &'a self,
182    ) -> Result<rendy_core::hal::pso::GraphicsShaderSet<'a, B>, ShaderError> {
183        Ok(rendy_core::hal::pso::GraphicsShaderSet {
184            vertex: self
185                .shaders
186                .get(&ShaderStageFlags::VERTEX)
187                .expect("ShaderSet doesn't contain vertex shader")
188                .get_entry_point()?
189                .unwrap(),
190            fragment: match self.shaders.get(&ShaderStageFlags::FRAGMENT) {
191                Some(fragment) => fragment.get_entry_point()?,
192                None => None,
193            },
194            domain: match self.shaders.get(&ShaderStageFlags::DOMAIN) {
195                Some(domain) => domain.get_entry_point()?,
196                None => None,
197            },
198            hull: match self.shaders.get(&ShaderStageFlags::HULL) {
199                Some(hull) => hull.get_entry_point()?,
200                None => None,
201            },
202            geometry: match self.shaders.get(&ShaderStageFlags::GEOMETRY) {
203                Some(geometry) => geometry.get_entry_point()?,
204                None => None,
205            },
206        })
207    }
208
209    /// Must be called to perform a drop of the Backend ShaderModule object otherwise the shader will never be destroyed in memory.
210    pub fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
211        for (_, shader) in self.shaders.iter_mut() {
212            shader.dispose(factory);
213        }
214    }
215}
216
217/// A set of Specialization constants for a certain shader set.
218#[derive(Debug, Default, Clone)]
219#[allow(missing_copy_implementations)]
220pub struct SpecConstantSet {
221    /// Vertex specialization
222    pub vertex: Option<rendy_core::hal::pso::Specialization<'static>>,
223    /// Fragment specialization
224    pub fragment: Option<rendy_core::hal::pso::Specialization<'static>>,
225    /// Geometry specialization
226    pub geometry: Option<rendy_core::hal::pso::Specialization<'static>>,
227    /// Hull specialization
228    pub hull: Option<rendy_core::hal::pso::Specialization<'static>>,
229    /// Domain specialization
230    pub domain: Option<rendy_core::hal::pso::Specialization<'static>>,
231    /// Compute specialization
232    pub compute: Option<rendy_core::hal::pso::Specialization<'static>>,
233}
234
235/// Builder class which is used to begin the reflection and shader set construction process for a shader set. Provides all the functionality needed to
236/// build a shader set with provided shaders and then reflect appropriate gfx-hal and generic shader information.
237#[derive(Clone, Debug, Default)]
238#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
239pub struct ShaderSetBuilder {
240    vertex: Option<(Vec<u32>, String)>,
241    fragment: Option<(Vec<u32>, String)>,
242    geometry: Option<(Vec<u32>, String)>,
243    hull: Option<(Vec<u32>, String)>,
244    domain: Option<(Vec<u32>, String)>,
245    compute: Option<(Vec<u32>, String)>,
246}
247
248impl ShaderSetBuilder {
249    /// Builds the Backend-specific shader modules using the provided SPIRV code provided to the builder.
250    /// This function is called during the creation of a render pass.
251    ///
252    /// # Parameters
253    ///
254    /// `factory`   - factory to create shader modules.
255    ///
256    pub fn build<B: Backend>(
257        &self,
258        factory: &rendy_factory::Factory<B>,
259        spec_constants: SpecConstantSet,
260    ) -> Result<ShaderSet<B>, rendy_core::hal::device::ShaderError> {
261        let mut set = ShaderSet::<B>::default();
262
263        if self.vertex.is_none() && self.compute.is_none() {
264            let msg = "A vertex or compute shader must be provided".to_string();
265            return Err(rendy_core::hal::device::ShaderError::InterfaceMismatch(msg));
266        }
267        type ShaderTy = (
268            Vec<u32>,
269            String,
270            Option<rendy_core::hal::pso::Specialization<'static>>,
271        );
272
273        let create_storage =
274            move |stage,
275                  shader: ShaderTy,
276                  factory|
277                  -> Result<ShaderStorage<B>, rendy_core::hal::device::ShaderError> {
278                let mut storage = ShaderStorage {
279                    stage: stage,
280                    spirv: shader.0,
281                    module: None,
282                    entrypoint: shader.1.clone(),
283                    specialization: shader.2,
284                };
285                unsafe {
286                    storage.compile(factory)?;
287                }
288                Ok(storage)
289            };
290
291        if let Some(shader) = self.vertex.clone() {
292            set.shaders.insert(
293                ShaderStageFlags::VERTEX,
294                create_storage(
295                    ShaderStageFlags::VERTEX,
296                    (shader.0, shader.1, spec_constants.vertex),
297                    factory,
298                )?,
299            );
300        }
301
302        if let Some(shader) = self.fragment.clone() {
303            set.shaders.insert(
304                ShaderStageFlags::FRAGMENT,
305                create_storage(
306                    ShaderStageFlags::FRAGMENT,
307                    (shader.0, shader.1, spec_constants.fragment),
308                    factory,
309                )?,
310            );
311        }
312
313        if let Some(shader) = self.compute.clone() {
314            set.shaders.insert(
315                ShaderStageFlags::COMPUTE,
316                create_storage(
317                    ShaderStageFlags::COMPUTE,
318                    (shader.0, shader.1, spec_constants.compute),
319                    factory,
320                )?,
321            );
322        }
323
324        if let Some(shader) = self.domain.clone() {
325            set.shaders.insert(
326                ShaderStageFlags::DOMAIN,
327                create_storage(
328                    ShaderStageFlags::DOMAIN,
329                    (shader.0, shader.1, spec_constants.domain),
330                    factory,
331                )?,
332            );
333        }
334
335        if let Some(shader) = self.hull.clone() {
336            set.shaders.insert(
337                ShaderStageFlags::HULL,
338                create_storage(
339                    ShaderStageFlags::HULL,
340                    (shader.0, shader.1, spec_constants.hull),
341                    factory,
342                )?,
343            );
344        }
345
346        if let Some(shader) = self.geometry.clone() {
347            set.shaders.insert(
348                ShaderStageFlags::GEOMETRY,
349                create_storage(
350                    ShaderStageFlags::GEOMETRY,
351                    (shader.0, shader.1, spec_constants.geometry),
352                    factory,
353                )?,
354            );
355        }
356
357        Ok(set)
358    }
359
360    /// Add a vertex shader to this shader set
361    #[inline(always)]
362    pub fn with_vertex<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
363        let data = shader.spirv()?;
364        self.vertex = Some((data.to_vec(), shader.entry().to_string()));
365        Ok(self)
366    }
367
368    /// Add a fragment shader to this shader set
369    #[inline(always)]
370    pub fn with_fragment<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
371        let data = shader.spirv()?;
372        self.fragment = Some((data.to_vec(), shader.entry().to_string()));
373        Ok(self)
374    }
375
376    /// Add a geometry shader to this shader set
377    #[inline(always)]
378    pub fn with_geometry<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
379        let data = shader.spirv()?;
380        self.geometry = Some((data.to_vec(), shader.entry().to_string()));
381        Ok(self)
382    }
383
384    /// Add a hull shader to this shader set
385    #[inline(always)]
386    pub fn with_hull<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
387        let data = shader.spirv()?;
388        self.hull = Some((data.to_vec(), shader.entry().to_string()));
389        Ok(self)
390    }
391
392    /// Add a domain shader to this shader set
393    #[inline(always)]
394    pub fn with_domain<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
395        let data = shader.spirv()?;
396        self.domain = Some((data.to_vec(), shader.entry().to_string()));
397        Ok(self)
398    }
399
400    /// Add a compute shader to this shader set.
401    /// Note a compute or vertex shader must always exist in a shader set.
402    #[inline(always)]
403    pub fn with_compute<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
404        let data = shader.spirv()?;
405        self.compute = Some((data.to_vec(), shader.entry().to_string()));
406        Ok(self)
407    }
408
409    #[cfg(feature = "spirv-reflection")]
410    /// This function processes all shaders provided to the builder and computes and stores full reflection information on the shader.
411    /// This includes names, attributes, descriptor sets and push constants used by the shaders, as well as compiling local caches for performance.
412    pub fn reflect(&self) -> Result<SpirvReflection, ReflectError> {
413        if self.vertex.is_none() && self.compute.is_none() {
414            return Err(ReflectError::NoVertComputeProvided);
415        }
416
417        // We need to combine and merge all the reflections into a single SpirvReflection instance
418        let mut reflections = Vec::new();
419        if let Some(vertex) = self.vertex.as_ref() {
420            reflections.push(SpirvReflection::reflect(&vertex.0, None)?);
421        }
422        if let Some(fragment) = self.fragment.as_ref() {
423            reflections.push(SpirvReflection::reflect(&fragment.0, None)?);
424        }
425        if let Some(hull) = self.hull.as_ref() {
426            reflections.push(SpirvReflection::reflect(&hull.0, None)?);
427        }
428        if let Some(domain) = self.domain.as_ref() {
429            reflections.push(SpirvReflection::reflect(&domain.0, None)?);
430        }
431        if let Some(compute) = self.compute.as_ref() {
432            reflections.push(SpirvReflection::reflect(&compute.0, None)?);
433        }
434        if let Some(geometry) = self.geometry.as_ref() {
435            reflections.push(SpirvReflection::reflect(&geometry.0, None)?);
436        }
437
438        reflect::merge(&reflections)?.compile_cache()
439    }
440}
441
442/// Contains reflection and runtime nformation for a given compiled Shader Module.
443#[derive(Debug)]
444pub struct ShaderStorage<B: Backend> {
445    stage: ShaderStageFlags,
446    spirv: Vec<u32>,
447    module: Option<B::ShaderModule>,
448    entrypoint: String,
449    specialization: Option<rendy_core::hal::pso::Specialization<'static>>,
450}
451impl<B: Backend> ShaderStorage<B> {
452    /// Builds the `EntryPoint` structure used by rendy_core::hal to determine how to utilize this shader
453    pub fn get_entry_point<'a>(
454        &'a self,
455    ) -> Result<Option<rendy_core::hal::pso::EntryPoint<'a, B>>, ShaderError> {
456        Ok(Some(rendy_core::hal::pso::EntryPoint {
457            entry: &self.entrypoint,
458            module: self.module.as_ref().unwrap(),
459            specialization: self
460                .specialization
461                .clone()
462                .unwrap_or(rendy_core::hal::pso::Specialization::default()),
463        }))
464    }
465
466    /// Compile the SPIRV code with the backend and store the reference to the module inside this structure.
467    pub unsafe fn compile(
468        &mut self,
469        factory: &rendy_factory::Factory<B>,
470    ) -> Result<(), rendy_core::hal::device::ShaderError> {
471        self.module = Some(rendy_core::hal::device::Device::create_shader_module(
472            factory.device().raw(),
473            &self.spirv,
474        )?);
475
476        Ok(())
477    }
478
479    fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
480        use rendy_core::hal::device::Device;
481
482        if let Some(module) = self.module.take() {
483            unsafe { factory.destroy_shader_module(module) };
484        }
485        self.module = None;
486    }
487}
488
489impl<B: Backend> Drop for ShaderStorage<B> {
490    fn drop(&mut self) {
491        if self.module.is_some() {
492            panic!("This shader storage class needs to be manually dropped with dispose() first");
493        }
494    }
495}