1use std::mem::size_of;
2
3use thiserror::Error;
4use wgt::AdapterInfo;
5
6pub const HEADER_LENGTH: usize = size_of::<PipelineCacheHeader>();
7
8#[derive(Debug, PartialEq, Eq, Clone, Error)]
9#[non_exhaustive]
10pub enum PipelineCacheValidationError {
11 #[error("The pipeline cache data was truncated")]
12 Truncated,
13 #[error("The pipeline cache data was longer than recorded")]
14 Extended,
16 #[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
17 Corrupted,
18 #[error("The pipeline cacha data was out of date and so cannot be safely used")]
19 Outdated,
20 #[error("The cache data was created for a different device")]
21 DeviceMismatch,
22 #[error("Pipeline cacha data was created for a future version of wgpu")]
23 Unsupported,
24}
25
26impl PipelineCacheValidationError {
27 pub fn was_avoidable(&self) -> bool {
30 match self {
31 PipelineCacheValidationError::DeviceMismatch => true,
32 PipelineCacheValidationError::Truncated
33 | PipelineCacheValidationError::Unsupported
34 | PipelineCacheValidationError::Extended
35 | PipelineCacheValidationError::Outdated
37 | PipelineCacheValidationError::Corrupted => false,
38 }
39 }
40}
41
42pub fn validate_pipeline_cache<'d>(
44 cache_data: &'d [u8],
45 adapter: &AdapterInfo,
46 validation_key: [u8; 16],
47) -> Result<&'d [u8], PipelineCacheValidationError> {
48 let adapter_key = adapter_key(adapter)?;
49 let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
50 return Err(PipelineCacheValidationError::Truncated);
51 };
52 if header.magic != MAGIC {
53 return Err(PipelineCacheValidationError::Corrupted);
54 }
55 if header.header_version != HEADER_VERSION {
56 return Err(PipelineCacheValidationError::Outdated);
57 }
58 if header.cache_abi != ABI {
59 return Err(PipelineCacheValidationError::Outdated);
60 }
61 if header.backend != adapter.backend as u8 {
62 return Err(PipelineCacheValidationError::DeviceMismatch);
63 }
64 if header.adapter_key != adapter_key {
65 return Err(PipelineCacheValidationError::DeviceMismatch);
66 }
67 if header.validation_key != validation_key {
68 return Err(PipelineCacheValidationError::Outdated);
72 }
73 let data_size: usize = header
74 .data_size
75 .try_into()
76 .map_err(|_| PipelineCacheValidationError::Corrupted)?;
79 if remaining_data.len() < data_size {
80 return Err(PipelineCacheValidationError::Truncated);
81 }
82 if remaining_data.len() > data_size {
83 return Err(PipelineCacheValidationError::Extended);
84 }
85 if header.hash_space != HASH_SPACE_VALUE {
86 return Err(PipelineCacheValidationError::Corrupted);
87 }
88 Ok(remaining_data)
89}
90
91pub fn add_cache_header(
92 in_region: &mut [u8],
93 data: &[u8],
94 adapter: &AdapterInfo,
95 validation_key: [u8; 16],
96) {
97 assert_eq!(in_region.len(), HEADER_LENGTH);
98 let header = PipelineCacheHeader {
99 adapter_key: adapter_key(adapter)
100 .expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
101 backend: adapter.backend as u8,
102 cache_abi: ABI,
103 magic: MAGIC,
104 header_version: HEADER_VERSION,
105 validation_key,
106 hash_space: HASH_SPACE_VALUE,
107 data_size: data
108 .len()
109 .try_into()
110 .expect("Cache larger than u64::MAX bytes"),
111 };
112 header.write(in_region);
113}
114
115const MAGIC: [u8; 8] = *b"WGPUPLCH";
116const HEADER_VERSION: u32 = 1;
117const ABI: u32 = size_of::<*const ()>() as u32;
118
119const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
129
130#[repr(C)]
131#[derive(PartialEq, Eq)]
132struct PipelineCacheHeader {
133 magic: [u8; 8],
136 header_version: u32,
143 cache_abi: u32,
147 backend: u8,
149 adapter_key: [u8; 15],
154 validation_key: [u8; 16],
158 data_size: u64,
160 hash_space: u64,
167}
168
169impl PipelineCacheHeader {
170 fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
171 let mut reader = Reader {
172 data,
173 total_read: 0,
174 };
175 let magic = reader.read_array()?;
176 let header_version = reader.read_u32()?;
177 let cache_abi = reader.read_u32()?;
178 let backend = reader.read_byte()?;
179 let adapter_key = reader.read_array()?;
180 let validation_key = reader.read_array()?;
181 let data_size = reader.read_u64()?;
182 let data_hash = reader.read_u64()?;
183
184 assert_eq!(reader.total_read, size_of::<PipelineCacheHeader>());
185
186 Some((
187 PipelineCacheHeader {
188 magic,
189 header_version,
190 cache_abi,
191 backend,
192 adapter_key,
193 validation_key,
194 data_size,
195 hash_space: data_hash,
196 },
197 reader.data,
198 ))
199 }
200
201 fn write(&self, into: &mut [u8]) -> Option<()> {
202 let mut writer = Writer { data: into };
203 writer.write_array(&self.magic)?;
204 writer.write_u32(self.header_version)?;
205 writer.write_u32(self.cache_abi)?;
206 writer.write_byte(self.backend)?;
207 writer.write_array(&self.adapter_key)?;
208 writer.write_array(&self.validation_key)?;
209 writer.write_u64(self.data_size)?;
210 writer.write_u64(self.hash_space)?;
211
212 assert_eq!(writer.data.len(), 0);
213 Some(())
214 }
215}
216
217fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
218 match adapter.backend {
219 wgt::Backend::Vulkan => {
220 let v: [u8; 4] = adapter.vendor.to_be_bytes();
223 let d: [u8; 4] = adapter.device.to_be_bytes();
224 let adapter = [
225 255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
226 ];
227 Ok(adapter)
228 }
229 _ => Err(PipelineCacheValidationError::Unsupported),
230 }
231}
232
233struct Reader<'a> {
234 data: &'a [u8],
235 total_read: usize,
236}
237
238impl<'a> Reader<'a> {
239 fn read_byte(&mut self) -> Option<u8> {
240 let res = *self.data.first()?;
241 self.total_read += 1;
242 self.data = &self.data[1..];
243 Some(res)
244 }
245 fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
246 if N > self.data.len() {
248 return None;
249 }
250 let (start, data) = self.data.split_at(N);
251 self.total_read += N;
252 self.data = data;
253 Some(start.try_into().expect("off-by-one-error in array size"))
254 }
255
256 fn read_u32(&mut self) -> Option<u32> {
260 self.read_array().map(u32::from_be_bytes)
261 }
262 fn read_u64(&mut self) -> Option<u64> {
263 self.read_array().map(u64::from_be_bytes)
264 }
265}
266
267struct Writer<'a> {
268 data: &'a mut [u8],
269}
270
271impl<'a> Writer<'a> {
272 fn write_byte(&mut self, byte: u8) -> Option<()> {
273 self.write_array(&[byte])
274 }
275 fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
276 if N > self.data.len() {
278 return None;
279 }
280 let data = std::mem::take(&mut self.data);
281 let (start, data) = data.split_at_mut(N);
282 self.data = data;
283 start.copy_from_slice(array);
284 Some(())
285 }
286
287 fn write_u32(&mut self, value: u32) -> Option<()> {
291 self.write_array(&value.to_be_bytes())
292 }
293 fn write_u64(&mut self, value: u64) -> Option<()> {
294 self.write_array(&value.to_be_bytes())
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use wgt::AdapterInfo;
301
302 use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
303
304 use super::ABI;
305
306 const _: [(); HEADER_LENGTH] = [(); 64];
308
309 const ADAPTER: AdapterInfo = AdapterInfo {
310 name: String::new(),
311 vendor: 0x0002_FEED,
312 device: 0xFEFE_FEFE,
313 device_type: wgt::DeviceType::Other,
314 driver: String::new(),
315 driver_info: String::new(),
316 backend: wgt::Backend::Vulkan,
317 };
318
319 const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
321 #[test]
322 fn written_header() {
323 let mut result = [0; HEADER_LENGTH];
324 super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
325 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
326 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
335 let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
336
337 assert_eq!(result.as_slice(), expected.as_slice());
338 }
339
340 #[test]
341 fn valid_data() {
342 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
343 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
352 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
353 let expected: &[u8] = &[];
354 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
355 assert_eq!(validation_result, Ok(expected));
356 }
357 #[test]
358 fn invalid_magic() {
359 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
360 *b"NOT_WGPU", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
369 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
370 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
371 assert_eq!(validation_result, Err(E::Corrupted));
372 }
373
374 #[test]
375 fn wrong_version() {
376 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
377 *b"WGPUPLCH", [0, 0, 0, 2, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
386 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
387 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
388 assert_eq!(validation_result, Err(E::Outdated));
389 }
390 #[test]
391 fn wrong_abi() {
392 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
393 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, 14], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
403 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
404 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
405 assert_eq!(validation_result, Err(E::Outdated));
406 }
407
408 #[test]
409 fn wrong_backend() {
410 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
411 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [2, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
420 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
421 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
422 assert_eq!(validation_result, Err(E::DeviceMismatch));
423 }
424 #[test]
425 fn wrong_adapter() {
426 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
427 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0x00], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
436 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
437 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
438 assert_eq!(validation_result, Err(E::DeviceMismatch));
439 }
440 #[test]
441 fn wrong_validation() {
442 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
443 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_00000000u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
452 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
453 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
454 assert_eq!(validation_result, Err(E::Outdated));
455 }
456 #[test]
457 fn too_little_data() {
458 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
459 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
468 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
469 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
470 assert_eq!(validation_result, Err(E::Truncated));
471 }
472 #[test]
473 fn not_no_data() {
474 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
475 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 100u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
484 let cache = cache
485 .into_iter()
486 .flatten()
487 .chain(std::iter::repeat(0u8).take(100))
488 .collect::<Vec<u8>>();
489 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
490 let expected: &[u8] = &[0; 100];
491 assert_eq!(validation_result, Ok(expected));
492 }
493 #[test]
494 fn too_much_data() {
495 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
496 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
505 let cache = cache
506 .into_iter()
507 .flatten()
508 .chain(std::iter::repeat(0u8).take(200))
509 .collect::<Vec<u8>>();
510 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
511 assert_eq!(validation_result, Err(E::Extended));
512 }
513 #[test]
514 fn wrong_hash() {
515 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
516 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0x00000000_00000000u64.to_be_bytes(), ];
525 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
526 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
527 assert_eq!(validation_result, Err(E::Corrupted));
528 }
529}