1#![warn(
4 missing_debug_implementations,
5 missing_copy_implementations,
6 missing_docs,
7 trivial_casts,
8 trivial_numeric_casts,
9 unused_extern_crates,
10 unused_import_braces,
11 unused_qualifications
12)]
13
14#[cfg(feature = "shader-compiler")]
15mod shaderc;
16
17#[cfg(feature = "spirv-reflection")]
18#[allow(dead_code)]
19mod reflect;
20
21#[cfg(feature = "shader-compiler")]
22pub use self::shaderc::*;
23
24#[cfg(feature = "spirv-reflection")]
25pub use self::reflect::{ReflectError, ReflectTypeError, RetrievalKind, SpirvReflection};
26
27use rendy_core::hal::{pso::ShaderStageFlags, Backend};
28use std::collections::HashMap;
29
30#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum ShaderError {}
33
34impl std::error::Error for ShaderError {}
35impl std::fmt::Display for ShaderError {
36 fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match *self {}
38 }
39}
40
41pub trait Shader {
45 type Error: std::fmt::Debug;
47
48 fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, <Self as Shader>::Error>;
50
51 fn entry(&self) -> &str;
53
54 fn stage(&self) -> ShaderStageFlags;
56
57 unsafe fn module<B>(
62 &self,
63 factory: &rendy_factory::Factory<B>,
64 ) -> Result<B::ShaderModule, rendy_core::hal::device::ShaderError>
65 where
66 B: Backend,
67 {
68 rendy_core::hal::device::Device::create_shader_module(
69 factory.device().raw(),
70 &self.spirv().map_err(|e| {
71 rendy_core::hal::device::ShaderError::CompilationFailed(format!("{:?}", e))
72 })?,
73 )
74 }
75}
76
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
80pub struct SpirvShader {
81 #[cfg_attr(feature = "serde", serde(with = "serde_spirv"))]
82 spirv: Vec<u32>,
83 stage: ShaderStageFlags,
84 entry: String,
85}
86
87#[cfg(feature = "serde")]
88mod serde_spirv {
89 pub fn serialize<S>(data: &Vec<u32>, serializer: S) -> Result<S::Ok, S::Error>
90 where
91 S: serde::Serializer,
92 {
93 serializer.serialize_bytes(rendy_core::cast_slice(&data))
94 }
95
96 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
97 where
98 D: serde::Deserializer<'de>,
99 {
100 let bytes: &[u8] = serde::Deserialize::deserialize(deserializer)?;
102 rendy_core::hal::pso::read_spirv(std::io::Cursor::new(bytes))
103 .map_err(serde::de::Error::custom)
104 }
105}
106
107impl SpirvShader {
108 pub fn new(spirv: Vec<u32>, stage: ShaderStageFlags, entrypoint: &str) -> Self {
110 assert!(!spirv.is_empty());
111 Self {
112 spirv,
113 stage,
114 entry: entrypoint.to_string(),
115 }
116 }
117
118 pub fn from_bytes(
121 spirv: &[u8],
122 stage: ShaderStageFlags,
123 entrypoint: &str,
124 ) -> std::io::Result<Self> {
125 Ok(Self::new(
126 rendy_core::hal::pso::read_spirv(std::io::Cursor::new(spirv))?,
127 stage,
128 entrypoint,
129 ))
130 }
131}
132
133impl Shader for SpirvShader {
134 type Error = ShaderError;
135
136 fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, ShaderError> {
137 Ok(std::borrow::Cow::Borrowed(&self.spirv))
138 }
139
140 fn entry(&self) -> &str {
141 &self.entry
142 }
143
144 fn stage(&self) -> ShaderStageFlags {
145 self.stage
146 }
147}
148
149#[derive(Debug)]
151pub struct ShaderSet<B: Backend> {
152 shaders: HashMap<ShaderStageFlags, ShaderStorage<B>>,
153}
154
155impl<B> Default for ShaderSet<B>
156where
157 B: Backend,
158{
159 fn default() -> Self {
160 ShaderSet {
161 shaders: HashMap::default(),
162 }
163 }
164}
165
166impl<B: Backend> ShaderSet<B> {
167 pub fn load(
169 &mut self,
170 factory: &rendy_factory::Factory<B>,
171 ) -> Result<&mut Self, rendy_core::hal::device::ShaderError> {
172 for (_, v) in self.shaders.iter_mut() {
173 unsafe { v.compile(factory)? }
174 }
175
176 Ok(self)
177 }
178
179 pub fn raw<'a>(
181 &'a self,
182 ) -> Result<rendy_core::hal::pso::GraphicsShaderSet<'a, B>, ShaderError> {
183 Ok(rendy_core::hal::pso::GraphicsShaderSet {
184 vertex: self
185 .shaders
186 .get(&ShaderStageFlags::VERTEX)
187 .expect("ShaderSet doesn't contain vertex shader")
188 .get_entry_point()?
189 .unwrap(),
190 fragment: match self.shaders.get(&ShaderStageFlags::FRAGMENT) {
191 Some(fragment) => fragment.get_entry_point()?,
192 None => None,
193 },
194 domain: match self.shaders.get(&ShaderStageFlags::DOMAIN) {
195 Some(domain) => domain.get_entry_point()?,
196 None => None,
197 },
198 hull: match self.shaders.get(&ShaderStageFlags::HULL) {
199 Some(hull) => hull.get_entry_point()?,
200 None => None,
201 },
202 geometry: match self.shaders.get(&ShaderStageFlags::GEOMETRY) {
203 Some(geometry) => geometry.get_entry_point()?,
204 None => None,
205 },
206 })
207 }
208
209 pub fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
211 for (_, shader) in self.shaders.iter_mut() {
212 shader.dispose(factory);
213 }
214 }
215}
216
217#[derive(Debug, Default, Clone)]
219#[allow(missing_copy_implementations)]
220pub struct SpecConstantSet {
221 pub vertex: Option<rendy_core::hal::pso::Specialization<'static>>,
223 pub fragment: Option<rendy_core::hal::pso::Specialization<'static>>,
225 pub geometry: Option<rendy_core::hal::pso::Specialization<'static>>,
227 pub hull: Option<rendy_core::hal::pso::Specialization<'static>>,
229 pub domain: Option<rendy_core::hal::pso::Specialization<'static>>,
231 pub compute: Option<rendy_core::hal::pso::Specialization<'static>>,
233}
234
235#[derive(Clone, Debug, Default)]
238#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
239pub struct ShaderSetBuilder {
240 vertex: Option<(Vec<u32>, String)>,
241 fragment: Option<(Vec<u32>, String)>,
242 geometry: Option<(Vec<u32>, String)>,
243 hull: Option<(Vec<u32>, String)>,
244 domain: Option<(Vec<u32>, String)>,
245 compute: Option<(Vec<u32>, String)>,
246}
247
248impl ShaderSetBuilder {
249 pub fn build<B: Backend>(
257 &self,
258 factory: &rendy_factory::Factory<B>,
259 spec_constants: SpecConstantSet,
260 ) -> Result<ShaderSet<B>, rendy_core::hal::device::ShaderError> {
261 let mut set = ShaderSet::<B>::default();
262
263 if self.vertex.is_none() && self.compute.is_none() {
264 let msg = "A vertex or compute shader must be provided".to_string();
265 return Err(rendy_core::hal::device::ShaderError::InterfaceMismatch(msg));
266 }
267 type ShaderTy = (
268 Vec<u32>,
269 String,
270 Option<rendy_core::hal::pso::Specialization<'static>>,
271 );
272
273 let create_storage =
274 move |stage,
275 shader: ShaderTy,
276 factory|
277 -> Result<ShaderStorage<B>, rendy_core::hal::device::ShaderError> {
278 let mut storage = ShaderStorage {
279 stage: stage,
280 spirv: shader.0,
281 module: None,
282 entrypoint: shader.1.clone(),
283 specialization: shader.2,
284 };
285 unsafe {
286 storage.compile(factory)?;
287 }
288 Ok(storage)
289 };
290
291 if let Some(shader) = self.vertex.clone() {
292 set.shaders.insert(
293 ShaderStageFlags::VERTEX,
294 create_storage(
295 ShaderStageFlags::VERTEX,
296 (shader.0, shader.1, spec_constants.vertex),
297 factory,
298 )?,
299 );
300 }
301
302 if let Some(shader) = self.fragment.clone() {
303 set.shaders.insert(
304 ShaderStageFlags::FRAGMENT,
305 create_storage(
306 ShaderStageFlags::FRAGMENT,
307 (shader.0, shader.1, spec_constants.fragment),
308 factory,
309 )?,
310 );
311 }
312
313 if let Some(shader) = self.compute.clone() {
314 set.shaders.insert(
315 ShaderStageFlags::COMPUTE,
316 create_storage(
317 ShaderStageFlags::COMPUTE,
318 (shader.0, shader.1, spec_constants.compute),
319 factory,
320 )?,
321 );
322 }
323
324 if let Some(shader) = self.domain.clone() {
325 set.shaders.insert(
326 ShaderStageFlags::DOMAIN,
327 create_storage(
328 ShaderStageFlags::DOMAIN,
329 (shader.0, shader.1, spec_constants.domain),
330 factory,
331 )?,
332 );
333 }
334
335 if let Some(shader) = self.hull.clone() {
336 set.shaders.insert(
337 ShaderStageFlags::HULL,
338 create_storage(
339 ShaderStageFlags::HULL,
340 (shader.0, shader.1, spec_constants.hull),
341 factory,
342 )?,
343 );
344 }
345
346 if let Some(shader) = self.geometry.clone() {
347 set.shaders.insert(
348 ShaderStageFlags::GEOMETRY,
349 create_storage(
350 ShaderStageFlags::GEOMETRY,
351 (shader.0, shader.1, spec_constants.geometry),
352 factory,
353 )?,
354 );
355 }
356
357 Ok(set)
358 }
359
360 #[inline(always)]
362 pub fn with_vertex<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
363 let data = shader.spirv()?;
364 self.vertex = Some((data.to_vec(), shader.entry().to_string()));
365 Ok(self)
366 }
367
368 #[inline(always)]
370 pub fn with_fragment<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
371 let data = shader.spirv()?;
372 self.fragment = Some((data.to_vec(), shader.entry().to_string()));
373 Ok(self)
374 }
375
376 #[inline(always)]
378 pub fn with_geometry<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
379 let data = shader.spirv()?;
380 self.geometry = Some((data.to_vec(), shader.entry().to_string()));
381 Ok(self)
382 }
383
384 #[inline(always)]
386 pub fn with_hull<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
387 let data = shader.spirv()?;
388 self.hull = Some((data.to_vec(), shader.entry().to_string()));
389 Ok(self)
390 }
391
392 #[inline(always)]
394 pub fn with_domain<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
395 let data = shader.spirv()?;
396 self.domain = Some((data.to_vec(), shader.entry().to_string()));
397 Ok(self)
398 }
399
400 #[inline(always)]
403 pub fn with_compute<S: Shader>(mut self, shader: &S) -> Result<Self, S::Error> {
404 let data = shader.spirv()?;
405 self.compute = Some((data.to_vec(), shader.entry().to_string()));
406 Ok(self)
407 }
408
409 #[cfg(feature = "spirv-reflection")]
410 pub fn reflect(&self) -> Result<SpirvReflection, ReflectError> {
413 if self.vertex.is_none() && self.compute.is_none() {
414 return Err(ReflectError::NoVertComputeProvided);
415 }
416
417 let mut reflections = Vec::new();
419 if let Some(vertex) = self.vertex.as_ref() {
420 reflections.push(SpirvReflection::reflect(&vertex.0, None)?);
421 }
422 if let Some(fragment) = self.fragment.as_ref() {
423 reflections.push(SpirvReflection::reflect(&fragment.0, None)?);
424 }
425 if let Some(hull) = self.hull.as_ref() {
426 reflections.push(SpirvReflection::reflect(&hull.0, None)?);
427 }
428 if let Some(domain) = self.domain.as_ref() {
429 reflections.push(SpirvReflection::reflect(&domain.0, None)?);
430 }
431 if let Some(compute) = self.compute.as_ref() {
432 reflections.push(SpirvReflection::reflect(&compute.0, None)?);
433 }
434 if let Some(geometry) = self.geometry.as_ref() {
435 reflections.push(SpirvReflection::reflect(&geometry.0, None)?);
436 }
437
438 reflect::merge(&reflections)?.compile_cache()
439 }
440}
441
442#[derive(Debug)]
444pub struct ShaderStorage<B: Backend> {
445 stage: ShaderStageFlags,
446 spirv: Vec<u32>,
447 module: Option<B::ShaderModule>,
448 entrypoint: String,
449 specialization: Option<rendy_core::hal::pso::Specialization<'static>>,
450}
451impl<B: Backend> ShaderStorage<B> {
452 pub fn get_entry_point<'a>(
454 &'a self,
455 ) -> Result<Option<rendy_core::hal::pso::EntryPoint<'a, B>>, ShaderError> {
456 Ok(Some(rendy_core::hal::pso::EntryPoint {
457 entry: &self.entrypoint,
458 module: self.module.as_ref().unwrap(),
459 specialization: self
460 .specialization
461 .clone()
462 .unwrap_or(rendy_core::hal::pso::Specialization::default()),
463 }))
464 }
465
466 pub unsafe fn compile(
468 &mut self,
469 factory: &rendy_factory::Factory<B>,
470 ) -> Result<(), rendy_core::hal::device::ShaderError> {
471 self.module = Some(rendy_core::hal::device::Device::create_shader_module(
472 factory.device().raw(),
473 &self.spirv,
474 )?);
475
476 Ok(())
477 }
478
479 fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
480 use rendy_core::hal::device::Device;
481
482 if let Some(module) = self.module.take() {
483 unsafe { factory.destroy_shader_module(module) };
484 }
485 self.module = None;
486 }
487}
488
489impl<B: Backend> Drop for ShaderStorage<B> {
490 fn drop(&mut self) {
491 if self.module.is_some() {
492 panic!("This shader storage class needs to be manually dropped with dispose() first");
493 }
494 }
495}