spirv_reflect/
lib.rs

1#[macro_use]
2extern crate bitflags;
3extern crate num_traits;
4extern crate spirv_headers;
5#[macro_use]
6extern crate serde_derive;
7
8use num_traits::cast::FromPrimitive;
9
10pub mod convert;
11pub mod ffi;
12pub mod types;
13
14pub(crate) fn ffi_to_string(ffi: *const ::std::os::raw::c_char) -> String {
15    if ffi.is_null() {
16        String::new()
17    } else {
18        unsafe { std::ffi::CStr::from_ptr(ffi).to_string_lossy().into_owned() }
19    }
20}
21
22impl Default for ffi::SpvReflectShaderModule {
23    fn default() -> Self {
24        unsafe { std::mem::zeroed() }
25    }
26}
27
28impl Default for ffi::SpvReflectDescriptorSet {
29    fn default() -> Self {
30        unsafe { std::mem::zeroed() }
31    }
32}
33
34#[derive(Default, Clone)]
35pub struct ShaderModule {
36    module: Option<ffi::SpvReflectShaderModule>,
37}
38
39impl ShaderModule {
40    pub fn load_u8_data(spv_data: &[u8]) -> Result<ShaderModule, &'static str> {
41        Ok(create_shader_module(spv_data)?)
42    }
43
44    pub fn load_u32_data(spv_data: &[u32]) -> Result<ShaderModule, &'static str> {
45        let u8_data: &[u8] = unsafe {
46            std::slice::from_raw_parts(
47                spv_data.as_ptr() as *const u8,
48                spv_data.len() * std::mem::size_of::<u32>(),
49            )
50        };
51        Ok(create_shader_module(u8_data)?)
52    }
53
54    pub fn get_code(&self) -> Vec<u32> {
55        match self.module {
56            Some(ref module) => {
57                let code_size = unsafe { ffi::spvReflectGetCodeSize(module) as usize };
58                let code_slice = unsafe {
59                    std::slice::from_raw_parts(ffi::spvReflectGetCode(module), code_size / 4)
60                };
61                code_slice.to_owned()
62            }
63            None => Vec::new(),
64        }
65    }
66
67    pub fn get_generator(&self) -> types::ReflectGenerator {
68        match self.module {
69            Some(ref module) => convert::ffi_to_generator(module.generator),
70            None => types::ReflectGenerator::Unknown,
71        }
72    }
73
74    pub fn get_shader_stage(&self) -> types::ReflectShaderStageFlags {
75        match self.module {
76            Some(ref module) => convert::ffi_to_shader_stage_flags(module.shader_stage),
77            None => types::ReflectShaderStageFlags::UNDEFINED,
78        }
79    }
80
81    pub fn get_source_language(&self) -> spirv_headers::SourceLanguage {
82        match self.module {
83            Some(ref module) => {
84                match spirv_headers::SourceLanguage::from_u32(module.source_language as u32) {
85                    Some(language) => language,
86                    None => spirv_headers::SourceLanguage::Unknown,
87                }
88            }
89            None => spirv_headers::SourceLanguage::Unknown,
90        }
91    }
92
93    pub fn get_source_language_version(&self) -> u32 {
94        match self.module {
95            Some(ref module) => module.source_language_version,
96            None => 0,
97        }
98    }
99
100    pub fn get_source_file(&self) -> String {
101        match self.module {
102            Some(ref module) => ffi_to_string(module.source_file),
103            None => String::new(),
104        }
105    }
106
107    pub fn get_source_text(&self) -> String {
108        match self.module {
109            Some(ref module) => ffi_to_string(module.source_source),
110            None => String::new(),
111        }
112    }
113
114    pub fn get_spirv_execution_model(&self) -> spirv_headers::ExecutionModel {
115        match self.module {
116            Some(ref module) => {
117                match spirv_headers::ExecutionModel::from_u32(module.spirv_execution_model as u32) {
118                    Some(model) => model,
119                    None => spirv_headers::ExecutionModel::Vertex,
120                }
121            }
122            None => spirv_headers::ExecutionModel::Vertex,
123        }
124    }
125
126    pub fn enumerate_input_variables(
127        &self,
128        entry_point: Option<&str>,
129    ) -> Result<Vec<types::ReflectInterfaceVariable>, &'static str> {
130        if let Some(ref module) = self.module {
131            let mut count: u32 = 0;
132            let result = unsafe {
133                match entry_point {
134                    Some(entry_point) => {
135                        let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
136                        ffi::spvReflectEnumerateEntryPointInputVariables(
137                            module,
138                            entry_point_cstr.as_ptr(),
139                            &mut count,
140                            ::std::ptr::null_mut(),
141                        )
142                    }
143                    None => ffi::spvReflectEnumerateInputVariables(
144                        module,
145                        &mut count,
146                        ::std::ptr::null_mut(),
147                    ),
148                }
149            };
150            if result == ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS && count > 0 {
151                let mut ffi_vars: Vec<*mut ffi::SpvReflectInterfaceVariable> =
152                    vec![::std::ptr::null_mut(); count as usize];
153                let result = unsafe {
154                    let mut out_count: u32 = count;
155                    match entry_point {
156                        Some(entry_point) => {
157                            let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
158                            ffi::spvReflectEnumerateEntryPointInputVariables(
159                                module,
160                                entry_point_cstr.as_ptr(),
161                                &mut out_count,
162                                ffi_vars.as_mut_ptr(),
163                            )
164                        }
165                        None => ffi::spvReflectEnumerateInputVariables(
166                            module,
167                            &mut out_count,
168                            ffi_vars.as_mut_ptr(),
169                        ),
170                    }
171                };
172                match result {
173                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => {
174                        let vars: Vec<types::ReflectInterfaceVariable> = ffi_vars
175                            .iter()
176                            .map(|&var| convert::ffi_to_interface_variable(var))
177                            .collect();
178                        Ok(vars)
179                    }
180                    _ => Err(convert::result_to_string(result)),
181                }
182            } else {
183                Ok(Vec::new())
184            }
185        } else {
186            Ok(Vec::new())
187        }
188    }
189
190    pub fn enumerate_output_variables(
191        &self,
192        entry_point: Option<&str>,
193    ) -> Result<Vec<types::ReflectInterfaceVariable>, &'static str> {
194        if let Some(ref module) = self.module {
195            let mut count: u32 = 0;
196            let result = unsafe {
197                match entry_point {
198                    Some(entry_point) => {
199                        let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
200                        ffi::spvReflectEnumerateEntryPointOutputVariables(
201                            module,
202                            entry_point_cstr.as_ptr(),
203                            &mut count,
204                            ::std::ptr::null_mut(),
205                        )
206                    }
207                    None => ffi::spvReflectEnumerateOutputVariables(
208                        module,
209                        &mut count,
210                        ::std::ptr::null_mut(),
211                    ),
212                }
213            };
214            if result == ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS && count > 0 {
215                let mut ffi_vars: Vec<*mut ffi::SpvReflectInterfaceVariable> =
216                    vec![::std::ptr::null_mut(); count as usize];
217                let result = unsafe {
218                    let mut out_count: u32 = count;
219                    match entry_point {
220                        Some(entry_point) => {
221                            let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
222                            ffi::spvReflectEnumerateEntryPointOutputVariables(
223                                module,
224                                entry_point_cstr.as_ptr(),
225                                &mut out_count,
226                                ffi_vars.as_mut_ptr(),
227                            )
228                        }
229                        None => ffi::spvReflectEnumerateOutputVariables(
230                            module,
231                            &mut out_count,
232                            ffi_vars.as_mut_ptr(),
233                        ),
234                    }
235                };
236                match result {
237                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => {
238                        let vars: Vec<types::ReflectInterfaceVariable> = ffi_vars
239                            .iter()
240                            .map(|&var| convert::ffi_to_interface_variable(var))
241                            .collect();
242                        Ok(vars)
243                    }
244                    _ => Err(convert::result_to_string(result)),
245                }
246            } else {
247                Ok(Vec::new())
248            }
249        } else {
250            Ok(Vec::new())
251        }
252    }
253
254    pub fn enumerate_descriptor_bindings(
255        &self,
256        entry_point: Option<&str>,
257    ) -> Result<Vec<types::ReflectDescriptorBinding>, &'static str> {
258        if let Some(ref module) = self.module {
259            let mut count: u32 = 0;
260            let result = unsafe {
261                match entry_point {
262                    Some(entry_point) => {
263                        let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
264                        ffi::spvReflectEnumerateEntryPointDescriptorBindings(
265                            module,
266                            entry_point_cstr.as_ptr(),
267                            &mut count,
268                            ::std::ptr::null_mut(),
269                        )
270                    }
271                    None => ffi::spvReflectEnumerateDescriptorBindings(
272                        module,
273                        &mut count,
274                        ::std::ptr::null_mut(),
275                    ),
276                }
277            };
278            if result == ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS && count > 0 {
279                let mut ffi_bindings: Vec<*mut ffi::SpvReflectDescriptorBinding> =
280                    vec![::std::ptr::null_mut(); count as usize];
281                let result = unsafe {
282                    let mut out_count: u32 = count;
283                    match entry_point {
284                        Some(entry_point) => {
285                            let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
286                            ffi::spvReflectEnumerateEntryPointDescriptorBindings(
287                                module,
288                                entry_point_cstr.as_ptr(),
289                                &mut out_count,
290                                ffi_bindings.as_mut_ptr(),
291                            )
292                        }
293                        None => ffi::spvReflectEnumerateDescriptorBindings(
294                            module,
295                            &mut out_count,
296                            ffi_bindings.as_mut_ptr(),
297                        ),
298                    }
299                };
300                match result {
301                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => {
302                        let bindings: Vec<types::ReflectDescriptorBinding> = ffi_bindings
303                            .iter()
304                            .map(|&binding| convert::ffi_to_descriptor_binding(binding))
305                            .collect();
306                        Ok(bindings)
307                    }
308                    _ => Err(convert::result_to_string(result)),
309                }
310            } else {
311                Ok(Vec::new())
312            }
313        } else {
314            Ok(Vec::new())
315        }
316    }
317
318    pub fn enumerate_descriptor_sets(
319        &self,
320        entry_point: Option<&str>,
321    ) -> Result<Vec<types::ReflectDescriptorSet>, &'static str> {
322        if let Some(ref module) = self.module {
323            let mut count: u32 = 0;
324            let result = unsafe {
325                match entry_point {
326                    Some(entry_point) => {
327                        let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
328                        ffi::spvReflectEnumerateEntryPointDescriptorSets(
329                            module,
330                            entry_point_cstr.as_ptr(),
331                            &mut count,
332                            ::std::ptr::null_mut(),
333                        )
334                    }
335                    None => ffi::spvReflectEnumerateDescriptorSets(
336                        module,
337                        &mut count,
338                        ::std::ptr::null_mut(),
339                    ),
340                }
341            };
342            if result == ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS && count > 0 {
343                let mut ffi_sets: Vec<*mut ffi::SpvReflectDescriptorSet> =
344                    vec![::std::ptr::null_mut(); count as usize];
345                let result = unsafe {
346                    let mut out_count: u32 = count;
347                    match entry_point {
348                        Some(entry_point) => {
349                            let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
350                            ffi::spvReflectEnumerateEntryPointDescriptorSets(
351                                module,
352                                entry_point_cstr.as_ptr(),
353                                &mut out_count,
354                                ffi_sets.as_mut_ptr(),
355                            )
356                        }
357                        None => ffi::spvReflectEnumerateDescriptorSets(
358                            module,
359                            &mut out_count,
360                            ffi_sets.as_mut_ptr(),
361                        ),
362                    }
363                };
364                match result {
365                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(ffi_sets
366                        .iter()
367                        .map(|&set| convert::ffi_to_descriptor_set(set))
368                        .collect()),
369                    _ => Err(convert::result_to_string(result)),
370                }
371            } else {
372                Ok(Vec::new())
373            }
374        } else {
375            Ok(Vec::new())
376        }
377    }
378
379    pub fn enumerate_push_constant_blocks(
380        &self,
381        entry_point: Option<&str>,
382    ) -> Result<Vec<types::ReflectBlockVariable>, &'static str> {
383        if let Some(ref module) = self.module {
384            let mut count: u32 = 0;
385            let result = unsafe {
386                match entry_point {
387                    Some(entry_point) => {
388                        let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
389                        ffi::spvReflectEnumerateEntryPointPushConstantBlocks(
390                            module,
391                            entry_point_cstr.as_ptr(),
392                            &mut count,
393                            ::std::ptr::null_mut(),
394                        )
395                    }
396                    None => ffi::spvReflectEnumeratePushConstantBlocks(
397                        module,
398                        &mut count,
399                        ::std::ptr::null_mut(),
400                    ),
401                }
402            };
403            if result == ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS && count > 0 {
404                let mut ffi_blocks: Vec<*mut ffi::SpvReflectBlockVariable> =
405                    vec![::std::ptr::null_mut(); count as usize];
406                let result = unsafe {
407                    let mut out_count: u32 = count;
408                    match entry_point {
409                        Some(entry_point) => {
410                            let entry_point_cstr = std::ffi::CString::new(entry_point).unwrap();
411                            ffi::spvReflectEnumerateEntryPointPushConstantBlocks(
412                                module,
413                                entry_point_cstr.as_ptr(),
414                                &mut out_count,
415                                ffi_blocks.as_mut_ptr(),
416                            )
417                        }
418                        None => ffi::spvReflectEnumeratePushConstantBlocks(
419                            module,
420                            &mut out_count,
421                            ffi_blocks.as_mut_ptr(),
422                        ),
423                    }
424                };
425                match result {
426                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => {
427                        let blocks: Vec<types::ReflectBlockVariable> = ffi_blocks
428                            .iter()
429                            .map(|&block| convert::ffi_to_block_variable(unsafe { &*block }))
430                            .collect();
431                        Ok(blocks)
432                    }
433                    _ => Err(convert::result_to_string(result)),
434                }
435            } else {
436                Ok(Vec::new())
437            }
438        } else {
439            Ok(Vec::new())
440        }
441    }
442
443    pub fn enumerate_entry_points(&self) -> Result<Vec<types::ReflectEntryPoint>, &'static str> {
444        if let Some(ref module) = self.module {
445            let ffi_entry_points = unsafe {
446                std::slice::from_raw_parts(module.entry_points, module.entry_point_count as usize)
447            };
448            let entry_points: Vec<types::ReflectEntryPoint> = ffi_entry_points
449                .iter()
450                .map(|entry_point| convert::ffi_to_entry_point(entry_point))
451                .collect();
452            Ok(entry_points)
453        } else {
454            Ok(Vec::new())
455        }
456    }
457
458    pub fn get_entry_point_name(&self) -> String {
459        match self.module {
460            Some(ref module) => ffi_to_string(module.entry_point_name),
461            None => String::new(),
462        }
463    }
464
465    pub fn change_descriptor_binding_numbers(
466        &mut self,
467        binding: &types::descriptor::ReflectDescriptorBinding,
468        new_binding: u32,
469        new_set: Option<u32>,
470    ) -> Result<(), &'static str> {
471        match self.module {
472            Some(ref mut module) => {
473                let new_set = new_set.unwrap_or(ffi::SPV_REFLECT_SET_NUMBER_DONT_CHANGE as u32);
474                let result = unsafe {
475                    ffi::spvReflectChangeDescriptorBindingNumbers(
476                        module as *mut ffi::SpvReflectShaderModule,
477                        binding.internal_data,
478                        new_binding,
479                        new_set,
480                    )
481                };
482                match result {
483                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(()),
484                    _ => Err(convert::result_to_string(result)),
485                }
486            }
487            None => Ok(()),
488        }
489    }
490
491    pub fn change_descriptor_set_number(
492        &mut self,
493        set: &types::descriptor::ReflectDescriptorSet,
494        new_set: u32,
495    ) -> Result<(), &'static str> {
496        match self.module {
497            Some(ref mut module) => {
498                let result = unsafe {
499                    ffi::spvReflectChangeDescriptorSetNumber(
500                        module as *mut ffi::SpvReflectShaderModule,
501                        set.internal_data,
502                        new_set,
503                    )
504                };
505                match result {
506                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(()),
507                    _ => Err(convert::result_to_string(result)),
508                }
509            }
510            None => Ok(()),
511        }
512    }
513
514    pub fn change_input_variable_location(
515        &mut self,
516        variable: &types::variable::ReflectInterfaceVariable,
517        new_location: u32,
518    ) -> Result<(), &'static str> {
519        match self.module {
520            Some(ref mut module) => {
521                let result = unsafe {
522                    ffi::spvReflectChangeInputVariableLocation(
523                        module as *mut ffi::SpvReflectShaderModule,
524                        variable.internal_data,
525                        new_location,
526                    )
527                };
528                match result {
529                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(()),
530                    _ => Err(convert::result_to_string(result)),
531                }
532            }
533            None => Ok(()),
534        }
535    }
536
537    pub fn change_output_variable_location(
538        &mut self,
539        variable: &types::variable::ReflectInterfaceVariable,
540        new_location: u32,
541    ) -> Result<(), &'static str> {
542        match self.module {
543            Some(ref mut module) => {
544                let result = unsafe {
545                    ffi::spvReflectChangeOutputVariableLocation(
546                        module as *mut ffi::SpvReflectShaderModule,
547                        variable.internal_data,
548                        new_location,
549                    )
550                };
551                match result {
552                    ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(()),
553                    _ => Err(convert::result_to_string(result)),
554                }
555            }
556            None => Ok(()),
557        }
558    }
559}
560
561impl Drop for ShaderModule {
562    fn drop(&mut self) {
563        if let Some(ref mut module) = self.module {
564            unsafe {
565                ffi::spvReflectDestroyShaderModule(module);
566            }
567        }
568    }
569}
570
571/*
572impl From<&[u8]> for ShaderModule {
573    fn from(spv_data: &[u8]) -> Result<ShaderModule, &str> {
574        create_shader_module(spv_data)?
575    }
576}
577*/
578
579/*impl<'a, T: AsRef<[u8]>> From<T> for ShaderModule {
580    fn from(v: T) -> Result<ShaderModule, &'static str> {
581        Ok(create_shader_module(v.as_ref())?)
582    }
583}*/
584
585pub fn create_shader_module(spv_data: &[u8]) -> Result<ShaderModule, &'static str> {
586    let mut module: ffi::SpvReflectShaderModule = unsafe { std::mem::zeroed() };
587    let result: ffi::SpvReflectResult = unsafe {
588        ffi::spvReflectCreateShaderModule(
589            spv_data.len(),
590            spv_data.as_ptr() as *const std::os::raw::c_void,
591            &mut module,
592        )
593    };
594    match result {
595        ffi::SpvReflectResult_SPV_REFLECT_RESULT_SUCCESS => Ok(ShaderModule {
596            module: Some(module),
597        }),
598        _ => Err(convert::result_to_string(result)),
599    }
600}