#![warn(
missing_debug_implementations,
missing_copy_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
unused_qualifications
)]
#[cfg(feature = "shader-compiler")]
mod shaderc;
#[cfg(feature = "spirv-reflection")]
#[allow(dead_code)]
mod reflect;
#[cfg(feature = "shader-compiler")]
pub use self::shaderc::*;
#[cfg(feature = "spirv-reflection")]
pub use self::reflect::SpirvReflection;
use gfx_hal::{pso::ShaderStageFlags, Backend};
use std::collections::HashMap;
pub trait Shader {
fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, failure::Error>;
fn entry(&self) -> &str;
fn stage(&self) -> ShaderStageFlags;
unsafe fn module<B>(
&self,
factory: &rendy_factory::Factory<B>,
) -> Result<B::ShaderModule, failure::Error>
where
B: Backend,
{
gfx_hal::Device::create_shader_module(factory.device().raw(), &self.spirv()?)
.map_err(Into::into)
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SpirvShader {
#[cfg_attr(feature = "serde", serde(with = "serde_spirv"))]
spirv: Vec<u32>,
stage: ShaderStageFlags,
entry: String,
}
#[cfg(feature = "serde")]
mod serde_spirv {
pub fn serialize<S>(data: &Vec<u32>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(rendy_util::cast_slice(&data))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes: &[u8] = serde::Deserialize::deserialize(deserializer)?;
gfx_hal::pso::read_spirv(std::io::Cursor::new(bytes)).map_err(serde::de::Error::custom)
}
}
impl SpirvShader {
pub fn new(spirv: Vec<u32>, stage: ShaderStageFlags, entrypoint: &str) -> Self {
assert!(!spirv.is_empty());
Self {
spirv,
stage,
entry: entrypoint.to_string(),
}
}
pub fn from_bytes(
spirv: &[u8],
stage: ShaderStageFlags,
entrypoint: &str,
) -> std::io::Result<Self> {
Ok(Self::new(
gfx_hal::pso::read_spirv(std::io::Cursor::new(spirv))?,
stage,
entrypoint,
))
}
}
impl Shader for SpirvShader {
fn spirv(&self) -> Result<std::borrow::Cow<'_, [u32]>, failure::Error> {
Ok(std::borrow::Cow::Borrowed(&self.spirv))
}
fn entry(&self) -> &str {
&self.entry
}
fn stage(&self) -> ShaderStageFlags {
self.stage
}
}
#[derive(derivative::Derivative, Debug)]
#[derivative(Default(bound = ""))]
pub struct ShaderSet<B: Backend> {
shaders: HashMap<ShaderStageFlags, ShaderStorage<B>>,
}
impl<B: Backend> ShaderSet<B> {
pub fn load(
&mut self,
factory: &rendy_factory::Factory<B>,
) -> Result<&mut Self, failure::Error> {
for (_, v) in self.shaders.iter_mut() {
unsafe { v.compile(factory)? }
}
Ok(self)
}
pub fn raw<'a>(&'a self) -> Result<(gfx_hal::pso::GraphicsShaderSet<'a, B>), failure::Error> {
Ok(gfx_hal::pso::GraphicsShaderSet {
vertex: self
.shaders
.get(&ShaderStageFlags::VERTEX)
.expect("ShaderSet doesn't contain vertex shader")
.get_entry_point()?
.unwrap(),
fragment: match self.shaders.get(&ShaderStageFlags::FRAGMENT) {
Some(fragment) => fragment.get_entry_point()?,
None => None,
},
domain: match self.shaders.get(&ShaderStageFlags::DOMAIN) {
Some(domain) => domain.get_entry_point()?,
None => None,
},
hull: match self.shaders.get(&ShaderStageFlags::HULL) {
Some(hull) => hull.get_entry_point()?,
None => None,
},
geometry: match self.shaders.get(&ShaderStageFlags::GEOMETRY) {
Some(geometry) => geometry.get_entry_point()?,
None => None,
},
})
}
pub fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
for (_, shader) in self.shaders.iter_mut() {
shader.dispose(factory);
}
}
}
#[derive(Debug, Default, Clone)]
#[allow(missing_copy_implementations)]
pub struct SpecConstantSet {
pub vertex: Option<gfx_hal::pso::Specialization<'static>>,
pub fragment: Option<gfx_hal::pso::Specialization<'static>>,
pub geometry: Option<gfx_hal::pso::Specialization<'static>>,
pub hull: Option<gfx_hal::pso::Specialization<'static>>,
pub domain: Option<gfx_hal::pso::Specialization<'static>>,
pub compute: Option<gfx_hal::pso::Specialization<'static>>,
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ShaderSetBuilder {
vertex: Option<(Vec<u32>, String)>,
fragment: Option<(Vec<u32>, String)>,
geometry: Option<(Vec<u32>, String)>,
hull: Option<(Vec<u32>, String)>,
domain: Option<(Vec<u32>, String)>,
compute: Option<(Vec<u32>, String)>,
}
impl ShaderSetBuilder {
pub fn build<B: Backend>(
&self,
factory: &rendy_factory::Factory<B>,
spec_constants: SpecConstantSet,
) -> Result<ShaderSet<B>, failure::Error> {
let mut set = ShaderSet::<B>::default();
if self.vertex.is_none() && self.compute.is_none() {
failure::bail!("A vertex or compute shader must be provided");
}
type ShaderTy = (
Vec<u32>,
String,
Option<gfx_hal::pso::Specialization<'static>>,
);
let create_storage =
move |stage, shader: ShaderTy, factory| -> Result<ShaderStorage<B>, failure::Error> {
let mut storage = ShaderStorage {
stage: stage,
spirv: shader.0,
module: None,
entrypoint: shader.1.clone(),
specialization: shader.2,
};
unsafe {
storage.compile(factory)?;
}
Ok(storage)
};
if let Some(shader) = self.vertex.clone() {
set.shaders.insert(
ShaderStageFlags::VERTEX,
create_storage(
ShaderStageFlags::VERTEX,
(shader.0, shader.1, spec_constants.vertex),
factory,
)?,
);
}
if let Some(shader) = self.fragment.clone() {
set.shaders.insert(
ShaderStageFlags::FRAGMENT,
create_storage(
ShaderStageFlags::FRAGMENT,
(shader.0, shader.1, spec_constants.fragment),
factory,
)?,
);
}
if let Some(shader) = self.compute.clone() {
set.shaders.insert(
ShaderStageFlags::COMPUTE,
create_storage(
ShaderStageFlags::COMPUTE,
(shader.0, shader.1, spec_constants.compute),
factory,
)?,
);
}
if let Some(shader) = self.domain.clone() {
set.shaders.insert(
ShaderStageFlags::DOMAIN,
create_storage(
ShaderStageFlags::DOMAIN,
(shader.0, shader.1, spec_constants.domain),
factory,
)?,
);
}
if let Some(shader) = self.hull.clone() {
set.shaders.insert(
ShaderStageFlags::HULL,
create_storage(
ShaderStageFlags::HULL,
(shader.0, shader.1, spec_constants.hull),
factory,
)?,
);
}
if let Some(shader) = self.geometry.clone() {
set.shaders.insert(
ShaderStageFlags::GEOMETRY,
create_storage(
ShaderStageFlags::GEOMETRY,
(shader.0, shader.1, spec_constants.geometry),
factory,
)?,
);
}
Ok(set)
}
#[inline(always)]
pub fn with_vertex<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.vertex = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_fragment<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.fragment = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_geometry<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.geometry = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_hull<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.hull = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_domain<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.domain = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_compute<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.compute = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[cfg(feature = "spirv-reflection")]
pub fn reflect(&self) -> Result<SpirvReflection, failure::Error> {
if self.vertex.is_none() && self.compute.is_none() {
failure::bail!("A vertex or compute shader must be provided");
}
let mut reflections = Vec::new();
if let Some(vertex) = self.vertex.as_ref() {
reflections.push(SpirvReflection::reflect(&vertex.0, None)?);
}
if let Some(fragment) = self.fragment.as_ref() {
reflections.push(SpirvReflection::reflect(&fragment.0, None)?);
}
if let Some(hull) = self.hull.as_ref() {
reflections.push(SpirvReflection::reflect(&hull.0, None)?);
}
if let Some(domain) = self.domain.as_ref() {
reflections.push(SpirvReflection::reflect(&domain.0, None)?);
}
if let Some(compute) = self.compute.as_ref() {
reflections.push(SpirvReflection::reflect(&compute.0, None)?);
}
if let Some(geometry) = self.geometry.as_ref() {
reflections.push(SpirvReflection::reflect(&geometry.0, None)?);
}
reflect::merge(&reflections)?.compile_cache()
}
}
#[derive(Debug)]
pub struct ShaderStorage<B: Backend> {
stage: ShaderStageFlags,
spirv: Vec<u32>,
module: Option<B::ShaderModule>,
entrypoint: String,
specialization: Option<gfx_hal::pso::Specialization<'static>>,
}
impl<B: Backend> ShaderStorage<B> {
pub fn get_entry_point<'a>(
&'a self,
) -> Result<Option<gfx_hal::pso::EntryPoint<'a, B>>, failure::Error> {
Ok(Some(gfx_hal::pso::EntryPoint {
entry: &self.entrypoint,
module: self.module.as_ref().unwrap(),
specialization: self
.specialization
.clone()
.unwrap_or(gfx_hal::pso::Specialization::default()),
}))
}
pub unsafe fn compile(
&mut self,
factory: &rendy_factory::Factory<B>,
) -> Result<(), failure::Error> {
self.module = Some(gfx_hal::Device::create_shader_module(
factory.device().raw(),
&self.spirv,
)?);
Ok(())
}
fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
use gfx_hal::device::Device;
if let Some(module) = self.module.take() {
unsafe { factory.destroy_shader_module(module) };
}
self.module = None;
}
}
impl<B: Backend> Drop for ShaderStorage<B> {
fn drop(&mut self) {
if self.module.is_some() {
panic!("This shader storage class needs to be manually dropped with dispose() first");
}
}
}