1use crate::{
2 device::queue::TempResource,
3 global::Global,
4 hub::Hub,
5 id::CommandEncoderId,
6 init_tracker::MemoryInitKind,
7 ray_tracing::{
8 BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
9 BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage,
10 TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
11 TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
12 },
13 resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas, Trackable},
14 scratch::ScratchBuffer,
15 snatch::SnatchGuard,
16 track::PendingTransition,
17 FastHashSet,
18};
19
20use wgt::{math::align_to, BufferUsages, Features};
21
22use super::CommandBufferMutable;
23use hal::BufferUses;
24use std::{
25 cmp::max,
26 num::NonZeroU64,
27 ops::{Deref, Range},
28 sync::{atomic::Ordering, Arc},
29};
30
31struct TriangleBufferStore<'a> {
32 vertex_buffer: Arc<Buffer>,
33 vertex_transition: Option<PendingTransition<BufferUses>>,
34 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
35 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
36 geometry: BlasTriangleGeometry<'a>,
37 ending_blas: Option<Arc<Blas>>,
38}
39
40struct BlasStore<'a> {
41 blas: Arc<Blas>,
42 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
43 scratch_buffer_offset: u64,
44}
45
46struct UnsafeTlasStore<'a> {
47 tlas: Arc<Tlas>,
48 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
49 scratch_buffer_offset: u64,
50}
51
52struct TlasStore<'a> {
53 internal: UnsafeTlasStore<'a>,
54 range: Range<usize>,
55}
56
57struct TlasBufferStore {
58 buffer: Arc<Buffer>,
59 transition: Option<PendingTransition<BufferUses>>,
60 entry: TlasBuildEntry,
61}
62
63impl Global {
64 pub fn command_encoder_build_acceleration_structures_unsafe_tlas<'a>(
69 &self,
70 command_encoder_id: CommandEncoderId,
71 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
72 tlas_iter: impl Iterator<Item = TlasBuildEntry>,
73 ) -> Result<(), BuildAccelerationStructureError> {
74 profiling::scope!("CommandEncoder::build_acceleration_structures_unsafe_tlas");
75
76 let hub = &self.hub;
77
78 let cmd_buf = hub
79 .command_buffers
80 .get(command_encoder_id.into_command_buffer_id());
81
82 let device = &cmd_buf.device;
83
84 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
85
86 let build_command_index = NonZeroU64::new(
87 device
88 .last_acceleration_structure_build_command_index
89 .fetch_add(1, Ordering::Relaxed),
90 )
91 .unwrap();
92
93 #[cfg(feature = "trace")]
94 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
95 .map(|blas_entry| {
96 let geometries = match blas_entry.geometries {
97 BlasGeometries::TriangleGeometries(triangle_geometries) => {
98 TraceBlasGeometries::TriangleGeometries(
99 triangle_geometries
100 .map(|tg| TraceBlasTriangleGeometry {
101 size: tg.size.clone(),
102 vertex_buffer: tg.vertex_buffer,
103 index_buffer: tg.index_buffer,
104 transform_buffer: tg.transform_buffer,
105 first_vertex: tg.first_vertex,
106 vertex_stride: tg.vertex_stride,
107 first_index: tg.first_index,
108 transform_buffer_offset: tg.transform_buffer_offset,
109 })
110 .collect(),
111 )
112 }
113 };
114 TraceBlasBuildEntry {
115 blas_id: blas_entry.blas_id,
116 geometries,
117 }
118 })
119 .collect();
120
121 #[cfg(feature = "trace")]
122 let trace_tlas: Vec<TlasBuildEntry> = tlas_iter.collect();
123 #[cfg(feature = "trace")]
124 if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
125 list.push(
126 crate::device::trace::Command::BuildAccelerationStructuresUnsafeTlas {
127 blas: trace_blas.clone(),
128 tlas: trace_tlas.clone(),
129 },
130 );
131 if !trace_tlas.is_empty() {
132 log::warn!("a trace of command_encoder_build_acceleration_structures_unsafe_tlas containing a tlas build is not replayable!");
133 }
134 }
135
136 #[cfg(feature = "trace")]
137 let blas_iter = trace_blas.iter().map(|blas_entry| {
138 let geometries = match &blas_entry.geometries {
139 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
140 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
141 size: &tg.size,
142 vertex_buffer: tg.vertex_buffer,
143 index_buffer: tg.index_buffer,
144 transform_buffer: tg.transform_buffer,
145 first_vertex: tg.first_vertex,
146 vertex_stride: tg.vertex_stride,
147 first_index: tg.first_index,
148 transform_buffer_offset: tg.transform_buffer_offset,
149 });
150 BlasGeometries::TriangleGeometries(Box::new(iter))
151 }
152 };
153 BlasBuildEntry {
154 blas_id: blas_entry.blas_id,
155 geometries,
156 }
157 });
158
159 #[cfg(feature = "trace")]
160 let tlas_iter = trace_tlas.iter();
161
162 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
163 let mut buf_storage = Vec::new();
164
165 let mut scratch_buffer_blas_size = 0;
166 let mut blas_storage = Vec::new();
167 let mut cmd_buf_data = cmd_buf.data.lock();
168 let mut cmd_buf_data_guard = cmd_buf_data.record()?;
169 let cmd_buf_data = &mut *cmd_buf_data_guard;
170
171 iter_blas(
172 blas_iter,
173 cmd_buf_data,
174 build_command_index,
175 &mut buf_storage,
176 hub,
177 )?;
178
179 let snatch_guard = device.snatchable_lock.read();
180 iter_buffers(
181 &mut buf_storage,
182 &snatch_guard,
183 &mut input_barriers,
184 cmd_buf_data,
185 &mut scratch_buffer_blas_size,
186 &mut blas_storage,
187 hub,
188 device.alignments.ray_tracing_scratch_buffer_alignment,
189 )?;
190
191 let mut scratch_buffer_tlas_size = 0;
192 let mut tlas_storage = Vec::<UnsafeTlasStore>::new();
193 let mut tlas_buf_storage = Vec::new();
194
195 for entry in tlas_iter {
196 let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
197 let data = cmd_buf_data.trackers.buffers.set_single(
198 &instance_buffer,
199 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
200 );
201 tlas_buf_storage.push(TlasBufferStore {
202 buffer: instance_buffer,
203 transition: data,
204 entry: entry.clone(),
205 });
206 }
207
208 for tlas_buf in &mut tlas_buf_storage {
209 let entry = &tlas_buf.entry;
210 let instance_buffer = {
211 let (instance_buffer, instance_pending) =
212 (&mut tlas_buf.buffer, &mut tlas_buf.transition);
213 let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
214 instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
215
216 if let Some(barrier) = instance_pending
217 .take()
218 .map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
219 {
220 input_barriers.push(barrier);
221 }
222 instance_raw
223 };
224
225 let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
226 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
227
228 cmd_buf_data.tlas_actions.push(TlasAction {
229 tlas: tlas.clone(),
230 kind: crate::ray_tracing::TlasActionKind::Build {
231 build_index: build_command_index,
232 dependencies: Vec::new(),
233 },
234 });
235
236 let scratch_buffer_offset = scratch_buffer_tlas_size;
237 scratch_buffer_tlas_size += align_to(
238 tlas.size_info.build_scratch_size as u32,
239 device.alignments.ray_tracing_scratch_buffer_alignment,
240 ) as u64;
241
242 tlas_storage.push(UnsafeTlasStore {
243 tlas,
244 entries: hal::AccelerationStructureEntries::Instances(
245 hal::AccelerationStructureInstances {
246 buffer: Some(instance_buffer),
247 offset: 0,
248 count: entry.instance_count,
249 },
250 ),
251 scratch_buffer_offset,
252 });
253 }
254
255 let scratch_size =
256 match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
257 None => {
258 cmd_buf_data_guard.mark_successful();
259 return Ok(());
260 }
261 Some(size) => size,
262 };
263
264 let scratch_buffer =
265 ScratchBuffer::new(device, scratch_size).map_err(crate::device::DeviceError::from)?;
266
267 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
268 buffer: scratch_buffer.raw(),
269 usage: hal::StateTransition {
270 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
271 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
272 },
273 };
274
275 let mut tlas_descriptors = Vec::new();
276
277 for UnsafeTlasStore {
278 tlas,
279 entries,
280 scratch_buffer_offset,
281 } in &tlas_storage
282 {
283 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
284 log::info!("only rebuild implemented")
285 }
286 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
287 entries,
288 mode: hal::AccelerationStructureBuildMode::Build,
289 flags: tlas.flags,
290 source_acceleration_structure: None,
291 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
292 scratch_buffer: scratch_buffer.raw(),
293 scratch_buffer_offset: *scratch_buffer_offset,
294 })
295 }
296
297 let blas_present = !blas_storage.is_empty();
298 let tlas_present = !tlas_storage.is_empty();
299
300 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
301
302 let mut descriptors = Vec::new();
303
304 for storage in &blas_storage {
305 descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
306 }
307
308 build_blas(
309 cmd_buf_raw,
310 blas_present,
311 tlas_present,
312 input_barriers,
313 &descriptors,
314 scratch_buffer_barrier,
315 );
316
317 if tlas_present {
318 unsafe {
319 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
320
321 cmd_buf_raw.place_acceleration_structure_barrier(
322 hal::AccelerationStructureBarrier {
323 usage: hal::StateTransition {
324 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
325 to: hal::AccelerationStructureUses::SHADER_INPUT,
326 },
327 },
328 );
329 }
330 }
331
332 cmd_buf_data
333 .temp_resources
334 .push(TempResource::ScratchBuffer(scratch_buffer));
335
336 cmd_buf_data_guard.mark_successful();
337 Ok(())
338 }
339
340 pub fn command_encoder_build_acceleration_structures<'a>(
341 &self,
342 command_encoder_id: CommandEncoderId,
343 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
344 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
345 ) -> Result<(), BuildAccelerationStructureError> {
346 profiling::scope!("CommandEncoder::build_acceleration_structures");
347
348 let hub = &self.hub;
349
350 let cmd_buf = hub
351 .command_buffers
352 .get(command_encoder_id.into_command_buffer_id());
353
354 let device = &cmd_buf.device;
355
356 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
357
358 let build_command_index = NonZeroU64::new(
359 device
360 .last_acceleration_structure_build_command_index
361 .fetch_add(1, Ordering::Relaxed),
362 )
363 .unwrap();
364
365 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
366 .map(|blas_entry| {
367 let geometries = match blas_entry.geometries {
368 BlasGeometries::TriangleGeometries(triangle_geometries) => {
369 TraceBlasGeometries::TriangleGeometries(
370 triangle_geometries
371 .map(|tg| TraceBlasTriangleGeometry {
372 size: tg.size.clone(),
373 vertex_buffer: tg.vertex_buffer,
374 index_buffer: tg.index_buffer,
375 transform_buffer: tg.transform_buffer,
376 first_vertex: tg.first_vertex,
377 vertex_stride: tg.vertex_stride,
378 first_index: tg.first_index,
379 transform_buffer_offset: tg.transform_buffer_offset,
380 })
381 .collect(),
382 )
383 }
384 };
385 TraceBlasBuildEntry {
386 blas_id: blas_entry.blas_id,
387 geometries,
388 }
389 })
390 .collect();
391
392 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
393 .map(|package: TlasPackage| {
394 let instances = package
395 .instances
396 .map(|instance| {
397 instance.map(|instance| TraceTlasInstance {
398 blas_id: instance.blas_id,
399 transform: *instance.transform,
400 custom_index: instance.custom_index,
401 mask: instance.mask,
402 })
403 })
404 .collect();
405 TraceTlasPackage {
406 tlas_id: package.tlas_id,
407 instances,
408 lowest_unmodified: package.lowest_unmodified,
409 }
410 })
411 .collect();
412
413 #[cfg(feature = "trace")]
414 if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
415 list.push(crate::device::trace::Command::BuildAccelerationStructures {
416 blas: trace_blas.clone(),
417 tlas: trace_tlas.clone(),
418 });
419 }
420
421 let blas_iter = trace_blas.iter().map(|blas_entry| {
422 let geometries = match &blas_entry.geometries {
423 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
424 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
425 size: &tg.size,
426 vertex_buffer: tg.vertex_buffer,
427 index_buffer: tg.index_buffer,
428 transform_buffer: tg.transform_buffer,
429 first_vertex: tg.first_vertex,
430 vertex_stride: tg.vertex_stride,
431 first_index: tg.first_index,
432 transform_buffer_offset: tg.transform_buffer_offset,
433 });
434 BlasGeometries::TriangleGeometries(Box::new(iter))
435 }
436 };
437 BlasBuildEntry {
438 blas_id: blas_entry.blas_id,
439 geometries,
440 }
441 });
442
443 let tlas_iter = trace_tlas.iter().map(|tlas_package| {
444 let instances = tlas_package.instances.iter().map(|instance| {
445 instance.as_ref().map(|instance| TlasInstance {
446 blas_id: instance.blas_id,
447 transform: &instance.transform,
448 custom_index: instance.custom_index,
449 mask: instance.mask,
450 })
451 });
452 TlasPackage {
453 tlas_id: tlas_package.tlas_id,
454 instances: Box::new(instances),
455 lowest_unmodified: tlas_package.lowest_unmodified,
456 }
457 });
458
459 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
460 let mut buf_storage = Vec::new();
461
462 let mut scratch_buffer_blas_size = 0;
463 let mut blas_storage = Vec::new();
464 let mut cmd_buf_data = cmd_buf.data.lock();
465 let mut cmd_buf_data_guard = cmd_buf_data.record()?;
466 let cmd_buf_data = &mut *cmd_buf_data_guard;
467
468 iter_blas(
469 blas_iter,
470 cmd_buf_data,
471 build_command_index,
472 &mut buf_storage,
473 hub,
474 )?;
475
476 let snatch_guard = device.snatchable_lock.read();
477 iter_buffers(
478 &mut buf_storage,
479 &snatch_guard,
480 &mut input_barriers,
481 cmd_buf_data,
482 &mut scratch_buffer_blas_size,
483 &mut blas_storage,
484 hub,
485 device.alignments.ray_tracing_scratch_buffer_alignment,
486 )?;
487 let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
488
489 for package in tlas_iter {
490 let tlas = hub.tlas_s.get(package.tlas_id).get()?;
491
492 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
493
494 tlas_lock_store.push((Some(package), tlas))
495 }
496
497 let mut scratch_buffer_tlas_size = 0;
498 let mut tlas_storage = Vec::<TlasStore>::new();
499 let mut instance_buffer_staging_source = Vec::<u8>::new();
500
501 for (package, tlas) in &mut tlas_lock_store {
502 let package = package.take().unwrap();
503
504 let scratch_buffer_offset = scratch_buffer_tlas_size;
505 scratch_buffer_tlas_size += align_to(
506 tlas.size_info.build_scratch_size as u32,
507 device.alignments.ray_tracing_scratch_buffer_alignment,
508 ) as u64;
509
510 let first_byte_index = instance_buffer_staging_source.len();
511
512 let mut dependencies = Vec::new();
513
514 let mut instance_count = 0;
515 for instance in package.instances.flatten() {
516 if instance.custom_index >= (1u32 << 24u32) {
517 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
518 tlas.error_ident(),
519 ));
520 }
521 let blas = hub.blas_s.get(instance.blas_id).get()?;
522
523 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
524
525 instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
526 hal::TlasInstance {
527 transform: *instance.transform,
528 custom_index: instance.custom_index,
529 mask: instance.mask,
530 blas_address: blas.handle,
531 },
532 ));
533
534 instance_count += 1;
535
536 dependencies.push(blas.clone());
537
538 cmd_buf_data.blas_actions.push(BlasAction {
539 blas,
540 kind: crate::ray_tracing::BlasActionKind::Use,
541 });
542 }
543
544 cmd_buf_data.tlas_actions.push(TlasAction {
545 tlas: tlas.clone(),
546 kind: crate::ray_tracing::TlasActionKind::Build {
547 build_index: build_command_index,
548 dependencies,
549 },
550 });
551
552 if instance_count > tlas.max_instance_count {
553 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
554 tlas.error_ident(),
555 instance_count,
556 tlas.max_instance_count,
557 ));
558 }
559
560 tlas_storage.push(TlasStore {
561 internal: UnsafeTlasStore {
562 tlas: tlas.clone(),
563 entries: hal::AccelerationStructureEntries::Instances(
564 hal::AccelerationStructureInstances {
565 buffer: Some(tlas.instance_buffer.as_ref()),
566 offset: 0,
567 count: instance_count,
568 },
569 ),
570 scratch_buffer_offset,
571 },
572 range: first_byte_index..instance_buffer_staging_source.len(),
573 });
574 }
575
576 let scratch_size =
577 match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
578 None => {
580 cmd_buf_data_guard.mark_successful();
581 return Ok(());
582 }
583 Some(size) => size,
584 };
585
586 let scratch_buffer =
587 ScratchBuffer::new(device, scratch_size).map_err(crate::device::DeviceError::from)?;
588
589 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
590 buffer: scratch_buffer.raw(),
591 usage: hal::StateTransition {
592 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
593 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
594 },
595 };
596
597 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
598
599 for &TlasStore {
600 internal:
601 UnsafeTlasStore {
602 ref tlas,
603 ref entries,
604 ref scratch_buffer_offset,
605 },
606 ..
607 } in &tlas_storage
608 {
609 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
610 log::info!("only rebuild implemented")
611 }
612 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
613 entries,
614 mode: hal::AccelerationStructureBuildMode::Build,
615 flags: tlas.flags,
616 source_acceleration_structure: None,
617 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
618 scratch_buffer: scratch_buffer.raw(),
619 scratch_buffer_offset: *scratch_buffer_offset,
620 })
621 }
622
623 let blas_present = !blas_storage.is_empty();
624 let tlas_present = !tlas_storage.is_empty();
625
626 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
627
628 let mut descriptors = Vec::new();
629
630 for storage in &blas_storage {
631 descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
632 }
633
634 build_blas(
635 cmd_buf_raw,
636 blas_present,
637 tlas_present,
638 input_barriers,
639 &descriptors,
640 scratch_buffer_barrier,
641 );
642
643 if tlas_present {
644 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
645 let mut staging_buffer = StagingBuffer::new(
646 device,
647 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
648 )
649 .map_err(crate::device::DeviceError::from)?;
650 staging_buffer.write(&instance_buffer_staging_source);
651 let flushed = staging_buffer.flush();
652 Some(flushed)
653 } else {
654 None
655 };
656
657 unsafe {
658 if let Some(ref staging_buffer) = staging_buffer {
659 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
660 buffer: staging_buffer.raw(),
661 usage: hal::StateTransition {
662 from: BufferUses::MAP_WRITE,
663 to: BufferUses::COPY_SRC,
664 },
665 }]);
666 }
667 }
668
669 let mut instance_buffer_barriers = Vec::new();
670 for &TlasStore {
671 internal: UnsafeTlasStore { ref tlas, .. },
672 ref range,
673 } in &tlas_storage
674 {
675 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
676 None => continue,
677 Some(size) => size,
678 };
679 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
680 buffer: tlas.instance_buffer.as_ref(),
681 usage: hal::StateTransition {
682 from: BufferUses::COPY_DST,
683 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
684 },
685 });
686 unsafe {
687 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
688 buffer: tlas.instance_buffer.as_ref(),
689 usage: hal::StateTransition {
690 from: BufferUses::MAP_READ,
691 to: BufferUses::COPY_DST,
692 },
693 }]);
694 let temp = hal::BufferCopy {
695 src_offset: range.start as u64,
696 dst_offset: 0,
697 size,
698 };
699 cmd_buf_raw.copy_buffer_to_buffer(
700 staging_buffer.as_ref().unwrap().raw(),
703 tlas.instance_buffer.as_ref(),
704 &[temp],
705 );
706 }
707 }
708
709 unsafe {
710 cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
711
712 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
713
714 cmd_buf_raw.place_acceleration_structure_barrier(
715 hal::AccelerationStructureBarrier {
716 usage: hal::StateTransition {
717 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
718 to: hal::AccelerationStructureUses::SHADER_INPUT,
719 },
720 },
721 );
722 }
723
724 if let Some(staging_buffer) = staging_buffer {
725 cmd_buf_data
726 .temp_resources
727 .push(TempResource::StagingBuffer(staging_buffer));
728 }
729 }
730
731 cmd_buf_data
732 .temp_resources
733 .push(TempResource::ScratchBuffer(scratch_buffer));
734
735 cmd_buf_data_guard.mark_successful();
736 Ok(())
737 }
738}
739
740impl CommandBufferMutable {
741 pub(crate) fn validate_blas_actions(&self) -> Result<(), ValidateBlasActionsError> {
743 profiling::scope!("CommandEncoder::[submission]::validate_blas_actions");
744 let mut built = FastHashSet::default();
745 for action in &self.blas_actions {
746 match &action.kind {
747 crate::ray_tracing::BlasActionKind::Build(id) => {
748 built.insert(action.blas.tracker_index());
749 *action.blas.built_index.write() = Some(*id);
750 }
751 crate::ray_tracing::BlasActionKind::Use => {
752 if !built.contains(&action.blas.tracker_index())
753 && (*action.blas.built_index.read()).is_none()
754 {
755 return Err(ValidateBlasActionsError::UsedUnbuilt(
756 action.blas.error_ident(),
757 ));
758 }
759 }
760 }
761 }
762 Ok(())
763 }
764
765 pub(crate) fn validate_tlas_actions(
767 &self,
768 snatch_guard: &SnatchGuard,
769 ) -> Result<(), ValidateTlasActionsError> {
770 profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions");
771 for action in &self.tlas_actions {
772 match &action.kind {
773 crate::ray_tracing::TlasActionKind::Build {
774 build_index,
775 dependencies,
776 } => {
777 *action.tlas.built_index.write() = Some(*build_index);
778 action.tlas.dependencies.write().clone_from(dependencies);
779 }
780 crate::ray_tracing::TlasActionKind::Use => {
781 let tlas_build_index = action.tlas.built_index.read();
782 let dependencies = action.tlas.dependencies.read();
783
784 if (*tlas_build_index).is_none() {
785 return Err(ValidateTlasActionsError::UsedUnbuilt(
786 action.tlas.error_ident(),
787 ));
788 }
789 for blas in dependencies.deref() {
790 let blas_build_index = *blas.built_index.read();
791 if blas_build_index.is_none() {
792 return Err(ValidateTlasActionsError::UsedUnbuiltBlas(
793 action.tlas.error_ident(),
794 blas.error_ident(),
795 ));
796 }
797 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
798 return Err(ValidateTlasActionsError::BlasNewerThenTlas(
799 blas.error_ident(),
800 action.tlas.error_ident(),
801 ));
802 }
803 blas.try_raw(snatch_guard)?;
804 }
805 }
806 }
807 }
808 Ok(())
809 }
810}
811
812fn iter_blas<'a>(
814 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
815 cmd_buf_data: &mut CommandBufferMutable,
816 build_command_index: NonZeroU64,
817 buf_storage: &mut Vec<TriangleBufferStore<'a>>,
818 hub: &Hub,
819) -> Result<(), BuildAccelerationStructureError> {
820 let mut temp_buffer = Vec::new();
821 for entry in blas_iter {
822 let blas = hub.blas_s.get(entry.blas_id).get()?;
823 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
824
825 cmd_buf_data.blas_actions.push(BlasAction {
826 blas: blas.clone(),
827 kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
828 });
829
830 match entry.geometries {
831 BlasGeometries::TriangleGeometries(triangle_geometries) => {
832 for (i, mesh) in triangle_geometries.enumerate() {
833 let size_desc = match &blas.sizes {
834 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
835 };
836 if i >= size_desc.len() {
837 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
838 blas.error_ident(),
839 ));
840 }
841 let size_desc = &size_desc[i];
842
843 if size_desc.flags != mesh.size.flags {
844 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
845 blas.error_ident(),
846 size_desc.flags,
847 mesh.size.flags,
848 ));
849 }
850
851 if size_desc.vertex_count < mesh.size.vertex_count {
852 return Err(
853 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
854 blas.error_ident(),
855 size_desc.vertex_count,
856 mesh.size.vertex_count,
857 ),
858 );
859 }
860
861 if size_desc.vertex_format != mesh.size.vertex_format {
862 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
863 blas.error_ident(),
864 size_desc.vertex_format,
865 mesh.size.vertex_format,
866 ));
867 }
868
869 match (size_desc.index_count, mesh.size.index_count) {
870 (Some(_), None) | (None, Some(_)) => {
871 return Err(
872 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
873 blas.error_ident(),
874 ),
875 )
876 }
877 (Some(create), Some(build)) if create < build => {
878 return Err(
879 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
880 blas.error_ident(),
881 create,
882 build,
883 ),
884 )
885 }
886 _ => {}
887 }
888
889 if size_desc.index_format != mesh.size.index_format {
890 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
891 blas.error_ident(),
892 size_desc.index_format,
893 mesh.size.index_format,
894 ));
895 }
896
897 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
898 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
899 blas.error_ident(),
900 ));
901 }
902 let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
903 let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
904 &vertex_buffer,
905 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
906 );
907 let index_data = if let Some(index_id) = mesh.index_buffer {
908 let index_buffer = hub.buffers.get(index_id).get()?;
909 if mesh.first_index.is_none()
910 || mesh.size.index_count.is_none()
911 || mesh.size.index_count.is_none()
912 {
913 return Err(BuildAccelerationStructureError::MissingAssociatedData(
914 index_buffer.error_ident(),
915 ));
916 }
917 let data = cmd_buf_data.trackers.buffers.set_single(
918 &index_buffer,
919 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
920 );
921 Some((index_buffer, data))
922 } else {
923 None
924 };
925 let transform_data = if let Some(transform_id) = mesh.transform_buffer {
926 let transform_buffer = hub.buffers.get(transform_id).get()?;
927 if mesh.transform_buffer_offset.is_none() {
928 return Err(BuildAccelerationStructureError::MissingAssociatedData(
929 transform_buffer.error_ident(),
930 ));
931 }
932 let data = cmd_buf_data.trackers.buffers.set_single(
933 &transform_buffer,
934 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
935 );
936 Some((transform_buffer, data))
937 } else {
938 None
939 };
940 temp_buffer.push(TriangleBufferStore {
941 vertex_buffer,
942 vertex_transition: vertex_pending,
943 index_buffer_transition: index_data,
944 transform_buffer_transition: transform_data,
945 geometry: mesh,
946 ending_blas: None,
947 });
948 }
949
950 if let Some(last) = temp_buffer.last_mut() {
951 last.ending_blas = Some(blas);
952 buf_storage.append(&mut temp_buffer);
953 }
954 }
955 }
956 }
957 Ok(())
958}
959
960fn iter_buffers<'a, 'b>(
962 buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
963 snatch_guard: &'a SnatchGuard,
964 input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
965 cmd_buf_data: &mut CommandBufferMutable,
966 scratch_buffer_blas_size: &mut u64,
967 blas_storage: &mut Vec<BlasStore<'a>>,
968 hub: &Hub,
969 ray_tracing_scratch_buffer_alignment: u32,
970) -> Result<(), BuildAccelerationStructureError> {
971 let mut triangle_entries =
972 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
973 for buf in buf_storage {
974 let mesh = &buf.geometry;
975 let vertex_buffer = {
976 let vertex_buffer = buf.vertex_buffer.as_ref();
977 let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
978 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
979
980 if let Some(barrier) = buf
981 .vertex_transition
982 .take()
983 .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
984 {
985 input_barriers.push(barrier);
986 }
987 if vertex_buffer.size
988 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
989 {
990 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
991 vertex_buffer.error_ident(),
992 vertex_buffer.size,
993 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
994 ));
995 }
996 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
997 cmd_buf_data.buffer_memory_init_actions.extend(
998 vertex_buffer.initialization_status.read().create_action(
999 &hub.buffers.get(mesh.vertex_buffer).get()?,
1000 vertex_buffer_offset
1001 ..(vertex_buffer_offset
1002 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
1003 MemoryInitKind::NeedsInitializedMemory,
1004 ),
1005 );
1006 vertex_raw
1007 };
1008 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
1009 buf.index_buffer_transition
1010 {
1011 let index_raw = index_buffer.try_raw(snatch_guard)?;
1012 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1013
1014 if let Some(barrier) = index_pending
1015 .take()
1016 .map(|pending| pending.into_hal(index_buffer, snatch_guard))
1017 {
1018 input_barriers.push(barrier);
1019 }
1020 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
1021 let offset = mesh.first_index.unwrap() as u64 * index_stride;
1022 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
1023
1024 if mesh.size.index_count.unwrap() % 3 != 0 {
1025 return Err(BuildAccelerationStructureError::InvalidIndexCount(
1026 index_buffer.error_ident(),
1027 mesh.size.index_count.unwrap(),
1028 ));
1029 }
1030 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
1031 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1032 index_buffer.error_ident(),
1033 index_buffer.size,
1034 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
1035 ));
1036 }
1037
1038 cmd_buf_data.buffer_memory_init_actions.extend(
1039 index_buffer.initialization_status.read().create_action(
1040 index_buffer,
1041 offset..(offset + index_buffer_size),
1042 MemoryInitKind::NeedsInitializedMemory,
1043 ),
1044 );
1045 Some(index_raw)
1046 } else {
1047 None
1048 };
1049 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
1050 buf.transform_buffer_transition
1051 {
1052 if mesh.transform_buffer_offset.is_none() {
1053 return Err(BuildAccelerationStructureError::MissingAssociatedData(
1054 transform_buffer.error_ident(),
1055 ));
1056 }
1057 let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1058 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1059
1060 if let Some(barrier) = transform_pending
1061 .take()
1062 .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
1063 {
1064 input_barriers.push(barrier);
1065 }
1066
1067 let offset = mesh.transform_buffer_offset.unwrap();
1068
1069 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
1070 return Err(
1071 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
1072 transform_buffer.error_ident(),
1073 ),
1074 );
1075 }
1076 if transform_buffer.size < 48 + offset {
1077 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1078 transform_buffer.error_ident(),
1079 transform_buffer.size,
1080 48 + offset,
1081 ));
1082 }
1083 cmd_buf_data.buffer_memory_init_actions.extend(
1084 transform_buffer.initialization_status.read().create_action(
1085 transform_buffer,
1086 offset..(offset + 48),
1087 MemoryInitKind::NeedsInitializedMemory,
1088 ),
1089 );
1090 Some(transform_raw)
1091 } else {
1092 None
1093 };
1094
1095 let triangles = hal::AccelerationStructureTriangles {
1096 vertex_buffer: Some(vertex_buffer),
1097 vertex_format: mesh.size.vertex_format,
1098 first_vertex: mesh.first_vertex,
1099 vertex_count: mesh.size.vertex_count,
1100 vertex_stride: mesh.vertex_stride,
1101 indices: index_buffer.map(|index_buffer| {
1102 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
1103 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
1104 format: mesh.size.index_format.unwrap(),
1105 buffer: Some(index_buffer),
1106 offset: mesh.first_index.unwrap() * index_stride,
1107 count: mesh.size.index_count.unwrap(),
1108 }
1109 }),
1110 transform: transform_buffer.map(|transform_buffer| {
1111 hal::AccelerationStructureTriangleTransform {
1112 buffer: transform_buffer,
1113 offset: mesh.transform_buffer_offset.unwrap() as u32,
1114 }
1115 }),
1116 flags: mesh.size.flags,
1117 };
1118 triangle_entries.push(triangles);
1119 if let Some(blas) = buf.ending_blas.take() {
1120 let scratch_buffer_offset = *scratch_buffer_blas_size;
1121 *scratch_buffer_blas_size += align_to(
1122 blas.size_info.build_scratch_size as u32,
1123 ray_tracing_scratch_buffer_alignment,
1124 ) as u64;
1125
1126 blas_storage.push(BlasStore {
1127 blas,
1128 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
1129 scratch_buffer_offset,
1130 });
1131 triangle_entries = Vec::new();
1132 }
1133 }
1134 Ok(())
1135}
1136
1137fn map_blas<'a>(
1138 storage: &'a BlasStore<'_>,
1139 scratch_buffer: &'a dyn hal::DynBuffer,
1140 snatch_guard: &'a SnatchGuard,
1141) -> Result<
1142 hal::BuildAccelerationStructureDescriptor<
1143 'a,
1144 dyn hal::DynBuffer,
1145 dyn hal::DynAccelerationStructure,
1146 >,
1147 BuildAccelerationStructureError,
1148> {
1149 let BlasStore {
1150 blas,
1151 entries,
1152 scratch_buffer_offset,
1153 } = storage;
1154 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1155 log::info!("only rebuild implemented")
1156 }
1157 Ok(hal::BuildAccelerationStructureDescriptor {
1158 entries,
1159 mode: hal::AccelerationStructureBuildMode::Build,
1160 flags: blas.flags,
1161 source_acceleration_structure: None,
1162 destination_acceleration_structure: blas.try_raw(snatch_guard)?,
1163 scratch_buffer,
1164 scratch_buffer_offset: *scratch_buffer_offset,
1165 })
1166}
1167
1168fn build_blas<'a>(
1169 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1170 blas_present: bool,
1171 tlas_present: bool,
1172 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1173 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1174 'a,
1175 dyn hal::DynBuffer,
1176 dyn hal::DynAccelerationStructure,
1177 >],
1178 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1179) {
1180 unsafe {
1181 cmd_buf_raw.transition_buffers(&input_barriers);
1182 }
1183
1184 if blas_present {
1185 unsafe {
1186 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1187 usage: hal::StateTransition {
1188 from: hal::AccelerationStructureUses::BUILD_INPUT,
1189 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1190 },
1191 });
1192
1193 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1194 }
1195 }
1196
1197 if blas_present && tlas_present {
1198 unsafe {
1199 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1200 }
1201 }
1202
1203 let mut source_usage = hal::AccelerationStructureUses::empty();
1204 let mut destination_usage = hal::AccelerationStructureUses::empty();
1205 if blas_present {
1206 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1207 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1208 }
1209 if tlas_present {
1210 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1211 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1212 }
1213 unsafe {
1214 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1215 usage: hal::StateTransition {
1216 from: source_usage,
1217 to: destination_usage,
1218 },
1219 });
1220 }
1221}