wgpu_core/command/
ray_tracing.rs

1use crate::{
2    device::queue::TempResource,
3    global::Global,
4    hub::Hub,
5    id::CommandEncoderId,
6    init_tracker::MemoryInitKind,
7    ray_tracing::{
8        BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry,
9        BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage,
10        TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance,
11        TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
12    },
13    resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas, Trackable},
14    scratch::ScratchBuffer,
15    snatch::SnatchGuard,
16    track::PendingTransition,
17    FastHashSet,
18};
19
20use wgt::{math::align_to, BufferUsages, Features};
21
22use super::CommandBufferMutable;
23use hal::BufferUses;
24use std::{
25    cmp::max,
26    num::NonZeroU64,
27    ops::{Deref, Range},
28    sync::{atomic::Ordering, Arc},
29};
30
31struct TriangleBufferStore<'a> {
32    vertex_buffer: Arc<Buffer>,
33    vertex_transition: Option<PendingTransition<BufferUses>>,
34    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
35    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
36    geometry: BlasTriangleGeometry<'a>,
37    ending_blas: Option<Arc<Blas>>,
38}
39
40struct BlasStore<'a> {
41    blas: Arc<Blas>,
42    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
43    scratch_buffer_offset: u64,
44}
45
46struct UnsafeTlasStore<'a> {
47    tlas: Arc<Tlas>,
48    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
49    scratch_buffer_offset: u64,
50}
51
52struct TlasStore<'a> {
53    internal: UnsafeTlasStore<'a>,
54    range: Range<usize>,
55}
56
57struct TlasBufferStore {
58    buffer: Arc<Buffer>,
59    transition: Option<PendingTransition<BufferUses>>,
60    entry: TlasBuildEntry,
61}
62
63impl Global {
64    // Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
65    // making for the two to be implemented differently, the main difference is this function has separate buffers for each
66    // of the TLAS instances while the other has one large buffer
67    // TODO: reconsider this function's usefulness once blas and tlas `as_hal` are added and some time has passed.
68    pub fn command_encoder_build_acceleration_structures_unsafe_tlas<'a>(
69        &self,
70        command_encoder_id: CommandEncoderId,
71        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
72        tlas_iter: impl Iterator<Item = TlasBuildEntry>,
73    ) -> Result<(), BuildAccelerationStructureError> {
74        profiling::scope!("CommandEncoder::build_acceleration_structures_unsafe_tlas");
75
76        let hub = &self.hub;
77
78        let cmd_buf = hub
79            .command_buffers
80            .get(command_encoder_id.into_command_buffer_id());
81
82        let device = &cmd_buf.device;
83
84        device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
85
86        let build_command_index = NonZeroU64::new(
87            device
88                .last_acceleration_structure_build_command_index
89                .fetch_add(1, Ordering::Relaxed),
90        )
91        .unwrap();
92
93        #[cfg(feature = "trace")]
94        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
95            .map(|blas_entry| {
96                let geometries = match blas_entry.geometries {
97                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
98                        TraceBlasGeometries::TriangleGeometries(
99                            triangle_geometries
100                                .map(|tg| TraceBlasTriangleGeometry {
101                                    size: tg.size.clone(),
102                                    vertex_buffer: tg.vertex_buffer,
103                                    index_buffer: tg.index_buffer,
104                                    transform_buffer: tg.transform_buffer,
105                                    first_vertex: tg.first_vertex,
106                                    vertex_stride: tg.vertex_stride,
107                                    first_index: tg.first_index,
108                                    transform_buffer_offset: tg.transform_buffer_offset,
109                                })
110                                .collect(),
111                        )
112                    }
113                };
114                TraceBlasBuildEntry {
115                    blas_id: blas_entry.blas_id,
116                    geometries,
117                }
118            })
119            .collect();
120
121        #[cfg(feature = "trace")]
122        let trace_tlas: Vec<TlasBuildEntry> = tlas_iter.collect();
123        #[cfg(feature = "trace")]
124        if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
125            list.push(
126                crate::device::trace::Command::BuildAccelerationStructuresUnsafeTlas {
127                    blas: trace_blas.clone(),
128                    tlas: trace_tlas.clone(),
129                },
130            );
131            if !trace_tlas.is_empty() {
132                log::warn!("a trace of command_encoder_build_acceleration_structures_unsafe_tlas containing a tlas build is not replayable!");
133            }
134        }
135
136        #[cfg(feature = "trace")]
137        let blas_iter = trace_blas.iter().map(|blas_entry| {
138            let geometries = match &blas_entry.geometries {
139                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
140                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
141                        size: &tg.size,
142                        vertex_buffer: tg.vertex_buffer,
143                        index_buffer: tg.index_buffer,
144                        transform_buffer: tg.transform_buffer,
145                        first_vertex: tg.first_vertex,
146                        vertex_stride: tg.vertex_stride,
147                        first_index: tg.first_index,
148                        transform_buffer_offset: tg.transform_buffer_offset,
149                    });
150                    BlasGeometries::TriangleGeometries(Box::new(iter))
151                }
152            };
153            BlasBuildEntry {
154                blas_id: blas_entry.blas_id,
155                geometries,
156            }
157        });
158
159        #[cfg(feature = "trace")]
160        let tlas_iter = trace_tlas.iter();
161
162        let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
163        let mut buf_storage = Vec::new();
164
165        let mut scratch_buffer_blas_size = 0;
166        let mut blas_storage = Vec::new();
167        let mut cmd_buf_data = cmd_buf.data.lock();
168        let mut cmd_buf_data_guard = cmd_buf_data.record()?;
169        let cmd_buf_data = &mut *cmd_buf_data_guard;
170
171        iter_blas(
172            blas_iter,
173            cmd_buf_data,
174            build_command_index,
175            &mut buf_storage,
176            hub,
177        )?;
178
179        let snatch_guard = device.snatchable_lock.read();
180        iter_buffers(
181            &mut buf_storage,
182            &snatch_guard,
183            &mut input_barriers,
184            cmd_buf_data,
185            &mut scratch_buffer_blas_size,
186            &mut blas_storage,
187            hub,
188            device.alignments.ray_tracing_scratch_buffer_alignment,
189        )?;
190
191        let mut scratch_buffer_tlas_size = 0;
192        let mut tlas_storage = Vec::<UnsafeTlasStore>::new();
193        let mut tlas_buf_storage = Vec::new();
194
195        for entry in tlas_iter {
196            let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
197            let data = cmd_buf_data.trackers.buffers.set_single(
198                &instance_buffer,
199                BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
200            );
201            tlas_buf_storage.push(TlasBufferStore {
202                buffer: instance_buffer,
203                transition: data,
204                entry: entry.clone(),
205            });
206        }
207
208        for tlas_buf in &mut tlas_buf_storage {
209            let entry = &tlas_buf.entry;
210            let instance_buffer = {
211                let (instance_buffer, instance_pending) =
212                    (&mut tlas_buf.buffer, &mut tlas_buf.transition);
213                let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
214                instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
215
216                if let Some(barrier) = instance_pending
217                    .take()
218                    .map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
219                {
220                    input_barriers.push(barrier);
221                }
222                instance_raw
223            };
224
225            let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
226            cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
227
228            cmd_buf_data.tlas_actions.push(TlasAction {
229                tlas: tlas.clone(),
230                kind: crate::ray_tracing::TlasActionKind::Build {
231                    build_index: build_command_index,
232                    dependencies: Vec::new(),
233                },
234            });
235
236            let scratch_buffer_offset = scratch_buffer_tlas_size;
237            scratch_buffer_tlas_size += align_to(
238                tlas.size_info.build_scratch_size as u32,
239                device.alignments.ray_tracing_scratch_buffer_alignment,
240            ) as u64;
241
242            tlas_storage.push(UnsafeTlasStore {
243                tlas,
244                entries: hal::AccelerationStructureEntries::Instances(
245                    hal::AccelerationStructureInstances {
246                        buffer: Some(instance_buffer),
247                        offset: 0,
248                        count: entry.instance_count,
249                    },
250                ),
251                scratch_buffer_offset,
252            });
253        }
254
255        let scratch_size =
256            match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
257                None => {
258                    cmd_buf_data_guard.mark_successful();
259                    return Ok(());
260                }
261                Some(size) => size,
262            };
263
264        let scratch_buffer =
265            ScratchBuffer::new(device, scratch_size).map_err(crate::device::DeviceError::from)?;
266
267        let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
268            buffer: scratch_buffer.raw(),
269            usage: hal::StateTransition {
270                from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
271                to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
272            },
273        };
274
275        let mut tlas_descriptors = Vec::new();
276
277        for UnsafeTlasStore {
278            tlas,
279            entries,
280            scratch_buffer_offset,
281        } in &tlas_storage
282        {
283            if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
284                log::info!("only rebuild implemented")
285            }
286            tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
287                entries,
288                mode: hal::AccelerationStructureBuildMode::Build,
289                flags: tlas.flags,
290                source_acceleration_structure: None,
291                destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
292                scratch_buffer: scratch_buffer.raw(),
293                scratch_buffer_offset: *scratch_buffer_offset,
294            })
295        }
296
297        let blas_present = !blas_storage.is_empty();
298        let tlas_present = !tlas_storage.is_empty();
299
300        let cmd_buf_raw = cmd_buf_data.encoder.open()?;
301
302        let mut descriptors = Vec::new();
303
304        for storage in &blas_storage {
305            descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
306        }
307
308        build_blas(
309            cmd_buf_raw,
310            blas_present,
311            tlas_present,
312            input_barriers,
313            &descriptors,
314            scratch_buffer_barrier,
315        );
316
317        if tlas_present {
318            unsafe {
319                cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
320
321                cmd_buf_raw.place_acceleration_structure_barrier(
322                    hal::AccelerationStructureBarrier {
323                        usage: hal::StateTransition {
324                            from: hal::AccelerationStructureUses::BUILD_OUTPUT,
325                            to: hal::AccelerationStructureUses::SHADER_INPUT,
326                        },
327                    },
328                );
329            }
330        }
331
332        cmd_buf_data
333            .temp_resources
334            .push(TempResource::ScratchBuffer(scratch_buffer));
335
336        cmd_buf_data_guard.mark_successful();
337        Ok(())
338    }
339
340    pub fn command_encoder_build_acceleration_structures<'a>(
341        &self,
342        command_encoder_id: CommandEncoderId,
343        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
344        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
345    ) -> Result<(), BuildAccelerationStructureError> {
346        profiling::scope!("CommandEncoder::build_acceleration_structures");
347
348        let hub = &self.hub;
349
350        let cmd_buf = hub
351            .command_buffers
352            .get(command_encoder_id.into_command_buffer_id());
353
354        let device = &cmd_buf.device;
355
356        device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
357
358        let build_command_index = NonZeroU64::new(
359            device
360                .last_acceleration_structure_build_command_index
361                .fetch_add(1, Ordering::Relaxed),
362        )
363        .unwrap();
364
365        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
366            .map(|blas_entry| {
367                let geometries = match blas_entry.geometries {
368                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
369                        TraceBlasGeometries::TriangleGeometries(
370                            triangle_geometries
371                                .map(|tg| TraceBlasTriangleGeometry {
372                                    size: tg.size.clone(),
373                                    vertex_buffer: tg.vertex_buffer,
374                                    index_buffer: tg.index_buffer,
375                                    transform_buffer: tg.transform_buffer,
376                                    first_vertex: tg.first_vertex,
377                                    vertex_stride: tg.vertex_stride,
378                                    first_index: tg.first_index,
379                                    transform_buffer_offset: tg.transform_buffer_offset,
380                                })
381                                .collect(),
382                        )
383                    }
384                };
385                TraceBlasBuildEntry {
386                    blas_id: blas_entry.blas_id,
387                    geometries,
388                }
389            })
390            .collect();
391
392        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
393            .map(|package: TlasPackage| {
394                let instances = package
395                    .instances
396                    .map(|instance| {
397                        instance.map(|instance| TraceTlasInstance {
398                            blas_id: instance.blas_id,
399                            transform: *instance.transform,
400                            custom_index: instance.custom_index,
401                            mask: instance.mask,
402                        })
403                    })
404                    .collect();
405                TraceTlasPackage {
406                    tlas_id: package.tlas_id,
407                    instances,
408                    lowest_unmodified: package.lowest_unmodified,
409                }
410            })
411            .collect();
412
413        #[cfg(feature = "trace")]
414        if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
415            list.push(crate::device::trace::Command::BuildAccelerationStructures {
416                blas: trace_blas.clone(),
417                tlas: trace_tlas.clone(),
418            });
419        }
420
421        let blas_iter = trace_blas.iter().map(|blas_entry| {
422            let geometries = match &blas_entry.geometries {
423                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
424                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
425                        size: &tg.size,
426                        vertex_buffer: tg.vertex_buffer,
427                        index_buffer: tg.index_buffer,
428                        transform_buffer: tg.transform_buffer,
429                        first_vertex: tg.first_vertex,
430                        vertex_stride: tg.vertex_stride,
431                        first_index: tg.first_index,
432                        transform_buffer_offset: tg.transform_buffer_offset,
433                    });
434                    BlasGeometries::TriangleGeometries(Box::new(iter))
435                }
436            };
437            BlasBuildEntry {
438                blas_id: blas_entry.blas_id,
439                geometries,
440            }
441        });
442
443        let tlas_iter = trace_tlas.iter().map(|tlas_package| {
444            let instances = tlas_package.instances.iter().map(|instance| {
445                instance.as_ref().map(|instance| TlasInstance {
446                    blas_id: instance.blas_id,
447                    transform: &instance.transform,
448                    custom_index: instance.custom_index,
449                    mask: instance.mask,
450                })
451            });
452            TlasPackage {
453                tlas_id: tlas_package.tlas_id,
454                instances: Box::new(instances),
455                lowest_unmodified: tlas_package.lowest_unmodified,
456            }
457        });
458
459        let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
460        let mut buf_storage = Vec::new();
461
462        let mut scratch_buffer_blas_size = 0;
463        let mut blas_storage = Vec::new();
464        let mut cmd_buf_data = cmd_buf.data.lock();
465        let mut cmd_buf_data_guard = cmd_buf_data.record()?;
466        let cmd_buf_data = &mut *cmd_buf_data_guard;
467
468        iter_blas(
469            blas_iter,
470            cmd_buf_data,
471            build_command_index,
472            &mut buf_storage,
473            hub,
474        )?;
475
476        let snatch_guard = device.snatchable_lock.read();
477        iter_buffers(
478            &mut buf_storage,
479            &snatch_guard,
480            &mut input_barriers,
481            cmd_buf_data,
482            &mut scratch_buffer_blas_size,
483            &mut blas_storage,
484            hub,
485            device.alignments.ray_tracing_scratch_buffer_alignment,
486        )?;
487        let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
488
489        for package in tlas_iter {
490            let tlas = hub.tlas_s.get(package.tlas_id).get()?;
491
492            cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
493
494            tlas_lock_store.push((Some(package), tlas))
495        }
496
497        let mut scratch_buffer_tlas_size = 0;
498        let mut tlas_storage = Vec::<TlasStore>::new();
499        let mut instance_buffer_staging_source = Vec::<u8>::new();
500
501        for (package, tlas) in &mut tlas_lock_store {
502            let package = package.take().unwrap();
503
504            let scratch_buffer_offset = scratch_buffer_tlas_size;
505            scratch_buffer_tlas_size += align_to(
506                tlas.size_info.build_scratch_size as u32,
507                device.alignments.ray_tracing_scratch_buffer_alignment,
508            ) as u64;
509
510            let first_byte_index = instance_buffer_staging_source.len();
511
512            let mut dependencies = Vec::new();
513
514            let mut instance_count = 0;
515            for instance in package.instances.flatten() {
516                if instance.custom_index >= (1u32 << 24u32) {
517                    return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
518                        tlas.error_ident(),
519                    ));
520                }
521                let blas = hub.blas_s.get(instance.blas_id).get()?;
522
523                cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
524
525                instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
526                    hal::TlasInstance {
527                        transform: *instance.transform,
528                        custom_index: instance.custom_index,
529                        mask: instance.mask,
530                        blas_address: blas.handle,
531                    },
532                ));
533
534                instance_count += 1;
535
536                dependencies.push(blas.clone());
537
538                cmd_buf_data.blas_actions.push(BlasAction {
539                    blas,
540                    kind: crate::ray_tracing::BlasActionKind::Use,
541                });
542            }
543
544            cmd_buf_data.tlas_actions.push(TlasAction {
545                tlas: tlas.clone(),
546                kind: crate::ray_tracing::TlasActionKind::Build {
547                    build_index: build_command_index,
548                    dependencies,
549                },
550            });
551
552            if instance_count > tlas.max_instance_count {
553                return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
554                    tlas.error_ident(),
555                    instance_count,
556                    tlas.max_instance_count,
557                ));
558            }
559
560            tlas_storage.push(TlasStore {
561                internal: UnsafeTlasStore {
562                    tlas: tlas.clone(),
563                    entries: hal::AccelerationStructureEntries::Instances(
564                        hal::AccelerationStructureInstances {
565                            buffer: Some(tlas.instance_buffer.as_ref()),
566                            offset: 0,
567                            count: instance_count,
568                        },
569                    ),
570                    scratch_buffer_offset,
571                },
572                range: first_byte_index..instance_buffer_staging_source.len(),
573            });
574        }
575
576        let scratch_size =
577            match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
578                // if the size is zero there is nothing to build
579                None => {
580                    cmd_buf_data_guard.mark_successful();
581                    return Ok(());
582                }
583                Some(size) => size,
584            };
585
586        let scratch_buffer =
587            ScratchBuffer::new(device, scratch_size).map_err(crate::device::DeviceError::from)?;
588
589        let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
590            buffer: scratch_buffer.raw(),
591            usage: hal::StateTransition {
592                from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
593                to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
594            },
595        };
596
597        let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
598
599        for &TlasStore {
600            internal:
601                UnsafeTlasStore {
602                    ref tlas,
603                    ref entries,
604                    ref scratch_buffer_offset,
605                },
606            ..
607        } in &tlas_storage
608        {
609            if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
610                log::info!("only rebuild implemented")
611            }
612            tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
613                entries,
614                mode: hal::AccelerationStructureBuildMode::Build,
615                flags: tlas.flags,
616                source_acceleration_structure: None,
617                destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
618                scratch_buffer: scratch_buffer.raw(),
619                scratch_buffer_offset: *scratch_buffer_offset,
620            })
621        }
622
623        let blas_present = !blas_storage.is_empty();
624        let tlas_present = !tlas_storage.is_empty();
625
626        let cmd_buf_raw = cmd_buf_data.encoder.open()?;
627
628        let mut descriptors = Vec::new();
629
630        for storage in &blas_storage {
631            descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
632        }
633
634        build_blas(
635            cmd_buf_raw,
636            blas_present,
637            tlas_present,
638            input_barriers,
639            &descriptors,
640            scratch_buffer_barrier,
641        );
642
643        if tlas_present {
644            let staging_buffer = if !instance_buffer_staging_source.is_empty() {
645                let mut staging_buffer = StagingBuffer::new(
646                    device,
647                    wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
648                )
649                .map_err(crate::device::DeviceError::from)?;
650                staging_buffer.write(&instance_buffer_staging_source);
651                let flushed = staging_buffer.flush();
652                Some(flushed)
653            } else {
654                None
655            };
656
657            unsafe {
658                if let Some(ref staging_buffer) = staging_buffer {
659                    cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
660                        buffer: staging_buffer.raw(),
661                        usage: hal::StateTransition {
662                            from: BufferUses::MAP_WRITE,
663                            to: BufferUses::COPY_SRC,
664                        },
665                    }]);
666                }
667            }
668
669            let mut instance_buffer_barriers = Vec::new();
670            for &TlasStore {
671                internal: UnsafeTlasStore { ref tlas, .. },
672                ref range,
673            } in &tlas_storage
674            {
675                let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
676                    None => continue,
677                    Some(size) => size,
678                };
679                instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
680                    buffer: tlas.instance_buffer.as_ref(),
681                    usage: hal::StateTransition {
682                        from: BufferUses::COPY_DST,
683                        to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
684                    },
685                });
686                unsafe {
687                    cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
688                        buffer: tlas.instance_buffer.as_ref(),
689                        usage: hal::StateTransition {
690                            from: BufferUses::MAP_READ,
691                            to: BufferUses::COPY_DST,
692                        },
693                    }]);
694                    let temp = hal::BufferCopy {
695                        src_offset: range.start as u64,
696                        dst_offset: 0,
697                        size,
698                    };
699                    cmd_buf_raw.copy_buffer_to_buffer(
700                        // the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
701                        // and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
702                        staging_buffer.as_ref().unwrap().raw(),
703                        tlas.instance_buffer.as_ref(),
704                        &[temp],
705                    );
706                }
707            }
708
709            unsafe {
710                cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
711
712                cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
713
714                cmd_buf_raw.place_acceleration_structure_barrier(
715                    hal::AccelerationStructureBarrier {
716                        usage: hal::StateTransition {
717                            from: hal::AccelerationStructureUses::BUILD_OUTPUT,
718                            to: hal::AccelerationStructureUses::SHADER_INPUT,
719                        },
720                    },
721                );
722            }
723
724            if let Some(staging_buffer) = staging_buffer {
725                cmd_buf_data
726                    .temp_resources
727                    .push(TempResource::StagingBuffer(staging_buffer));
728            }
729        }
730
731        cmd_buf_data
732            .temp_resources
733            .push(TempResource::ScratchBuffer(scratch_buffer));
734
735        cmd_buf_data_guard.mark_successful();
736        Ok(())
737    }
738}
739
740impl CommandBufferMutable {
741    // makes sure a blas is build before it is used
742    pub(crate) fn validate_blas_actions(&self) -> Result<(), ValidateBlasActionsError> {
743        profiling::scope!("CommandEncoder::[submission]::validate_blas_actions");
744        let mut built = FastHashSet::default();
745        for action in &self.blas_actions {
746            match &action.kind {
747                crate::ray_tracing::BlasActionKind::Build(id) => {
748                    built.insert(action.blas.tracker_index());
749                    *action.blas.built_index.write() = Some(*id);
750                }
751                crate::ray_tracing::BlasActionKind::Use => {
752                    if !built.contains(&action.blas.tracker_index())
753                        && (*action.blas.built_index.read()).is_none()
754                    {
755                        return Err(ValidateBlasActionsError::UsedUnbuilt(
756                            action.blas.error_ident(),
757                        ));
758                    }
759                }
760            }
761        }
762        Ok(())
763    }
764
765    // makes sure a tlas is built before it is used
766    pub(crate) fn validate_tlas_actions(
767        &self,
768        snatch_guard: &SnatchGuard,
769    ) -> Result<(), ValidateTlasActionsError> {
770        profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions");
771        for action in &self.tlas_actions {
772            match &action.kind {
773                crate::ray_tracing::TlasActionKind::Build {
774                    build_index,
775                    dependencies,
776                } => {
777                    *action.tlas.built_index.write() = Some(*build_index);
778                    action.tlas.dependencies.write().clone_from(dependencies);
779                }
780                crate::ray_tracing::TlasActionKind::Use => {
781                    let tlas_build_index = action.tlas.built_index.read();
782                    let dependencies = action.tlas.dependencies.read();
783
784                    if (*tlas_build_index).is_none() {
785                        return Err(ValidateTlasActionsError::UsedUnbuilt(
786                            action.tlas.error_ident(),
787                        ));
788                    }
789                    for blas in dependencies.deref() {
790                        let blas_build_index = *blas.built_index.read();
791                        if blas_build_index.is_none() {
792                            return Err(ValidateTlasActionsError::UsedUnbuiltBlas(
793                                action.tlas.error_ident(),
794                                blas.error_ident(),
795                            ));
796                        }
797                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
798                            return Err(ValidateTlasActionsError::BlasNewerThenTlas(
799                                blas.error_ident(),
800                                action.tlas.error_ident(),
801                            ));
802                        }
803                        blas.try_raw(snatch_guard)?;
804                    }
805                }
806            }
807        }
808        Ok(())
809    }
810}
811
812///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
813fn iter_blas<'a>(
814    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
815    cmd_buf_data: &mut CommandBufferMutable,
816    build_command_index: NonZeroU64,
817    buf_storage: &mut Vec<TriangleBufferStore<'a>>,
818    hub: &Hub,
819) -> Result<(), BuildAccelerationStructureError> {
820    let mut temp_buffer = Vec::new();
821    for entry in blas_iter {
822        let blas = hub.blas_s.get(entry.blas_id).get()?;
823        cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
824
825        cmd_buf_data.blas_actions.push(BlasAction {
826            blas: blas.clone(),
827            kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
828        });
829
830        match entry.geometries {
831            BlasGeometries::TriangleGeometries(triangle_geometries) => {
832                for (i, mesh) in triangle_geometries.enumerate() {
833                    let size_desc = match &blas.sizes {
834                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
835                    };
836                    if i >= size_desc.len() {
837                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
838                            blas.error_ident(),
839                        ));
840                    }
841                    let size_desc = &size_desc[i];
842
843                    if size_desc.flags != mesh.size.flags {
844                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
845                            blas.error_ident(),
846                            size_desc.flags,
847                            mesh.size.flags,
848                        ));
849                    }
850
851                    if size_desc.vertex_count < mesh.size.vertex_count {
852                        return Err(
853                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
854                                blas.error_ident(),
855                                size_desc.vertex_count,
856                                mesh.size.vertex_count,
857                            ),
858                        );
859                    }
860
861                    if size_desc.vertex_format != mesh.size.vertex_format {
862                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
863                            blas.error_ident(),
864                            size_desc.vertex_format,
865                            mesh.size.vertex_format,
866                        ));
867                    }
868
869                    match (size_desc.index_count, mesh.size.index_count) {
870                        (Some(_), None) | (None, Some(_)) => {
871                            return Err(
872                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
873                                    blas.error_ident(),
874                                ),
875                            )
876                        }
877                        (Some(create), Some(build)) if create < build => {
878                            return Err(
879                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
880                                    blas.error_ident(),
881                                    create,
882                                    build,
883                                ),
884                            )
885                        }
886                        _ => {}
887                    }
888
889                    if size_desc.index_format != mesh.size.index_format {
890                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
891                            blas.error_ident(),
892                            size_desc.index_format,
893                            mesh.size.index_format,
894                        ));
895                    }
896
897                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
898                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
899                            blas.error_ident(),
900                        ));
901                    }
902                    let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
903                    let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
904                        &vertex_buffer,
905                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
906                    );
907                    let index_data = if let Some(index_id) = mesh.index_buffer {
908                        let index_buffer = hub.buffers.get(index_id).get()?;
909                        if mesh.first_index.is_none()
910                            || mesh.size.index_count.is_none()
911                            || mesh.size.index_count.is_none()
912                        {
913                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
914                                index_buffer.error_ident(),
915                            ));
916                        }
917                        let data = cmd_buf_data.trackers.buffers.set_single(
918                            &index_buffer,
919                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
920                        );
921                        Some((index_buffer, data))
922                    } else {
923                        None
924                    };
925                    let transform_data = if let Some(transform_id) = mesh.transform_buffer {
926                        let transform_buffer = hub.buffers.get(transform_id).get()?;
927                        if mesh.transform_buffer_offset.is_none() {
928                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
929                                transform_buffer.error_ident(),
930                            ));
931                        }
932                        let data = cmd_buf_data.trackers.buffers.set_single(
933                            &transform_buffer,
934                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
935                        );
936                        Some((transform_buffer, data))
937                    } else {
938                        None
939                    };
940                    temp_buffer.push(TriangleBufferStore {
941                        vertex_buffer,
942                        vertex_transition: vertex_pending,
943                        index_buffer_transition: index_data,
944                        transform_buffer_transition: transform_data,
945                        geometry: mesh,
946                        ending_blas: None,
947                    });
948                }
949
950                if let Some(last) = temp_buffer.last_mut() {
951                    last.ending_blas = Some(blas);
952                    buf_storage.append(&mut temp_buffer);
953                }
954            }
955        }
956    }
957    Ok(())
958}
959
960/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
961fn iter_buffers<'a, 'b>(
962    buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
963    snatch_guard: &'a SnatchGuard,
964    input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
965    cmd_buf_data: &mut CommandBufferMutable,
966    scratch_buffer_blas_size: &mut u64,
967    blas_storage: &mut Vec<BlasStore<'a>>,
968    hub: &Hub,
969    ray_tracing_scratch_buffer_alignment: u32,
970) -> Result<(), BuildAccelerationStructureError> {
971    let mut triangle_entries =
972        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
973    for buf in buf_storage {
974        let mesh = &buf.geometry;
975        let vertex_buffer = {
976            let vertex_buffer = buf.vertex_buffer.as_ref();
977            let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
978            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
979
980            if let Some(barrier) = buf
981                .vertex_transition
982                .take()
983                .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
984            {
985                input_barriers.push(barrier);
986            }
987            if vertex_buffer.size
988                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
989            {
990                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
991                    vertex_buffer.error_ident(),
992                    vertex_buffer.size,
993                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
994                ));
995            }
996            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
997            cmd_buf_data.buffer_memory_init_actions.extend(
998                vertex_buffer.initialization_status.read().create_action(
999                    &hub.buffers.get(mesh.vertex_buffer).get()?,
1000                    vertex_buffer_offset
1001                        ..(vertex_buffer_offset
1002                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
1003                    MemoryInitKind::NeedsInitializedMemory,
1004                ),
1005            );
1006            vertex_raw
1007        };
1008        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
1009            buf.index_buffer_transition
1010        {
1011            let index_raw = index_buffer.try_raw(snatch_guard)?;
1012            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1013
1014            if let Some(barrier) = index_pending
1015                .take()
1016                .map(|pending| pending.into_hal(index_buffer, snatch_guard))
1017            {
1018                input_barriers.push(barrier);
1019            }
1020            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
1021            let offset = mesh.first_index.unwrap() as u64 * index_stride;
1022            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
1023
1024            if mesh.size.index_count.unwrap() % 3 != 0 {
1025                return Err(BuildAccelerationStructureError::InvalidIndexCount(
1026                    index_buffer.error_ident(),
1027                    mesh.size.index_count.unwrap(),
1028                ));
1029            }
1030            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
1031                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1032                    index_buffer.error_ident(),
1033                    index_buffer.size,
1034                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
1035                ));
1036            }
1037
1038            cmd_buf_data.buffer_memory_init_actions.extend(
1039                index_buffer.initialization_status.read().create_action(
1040                    index_buffer,
1041                    offset..(offset + index_buffer_size),
1042                    MemoryInitKind::NeedsInitializedMemory,
1043                ),
1044            );
1045            Some(index_raw)
1046        } else {
1047            None
1048        };
1049        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
1050            buf.transform_buffer_transition
1051        {
1052            if mesh.transform_buffer_offset.is_none() {
1053                return Err(BuildAccelerationStructureError::MissingAssociatedData(
1054                    transform_buffer.error_ident(),
1055                ));
1056            }
1057            let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1058            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1059
1060            if let Some(barrier) = transform_pending
1061                .take()
1062                .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
1063            {
1064                input_barriers.push(barrier);
1065            }
1066
1067            let offset = mesh.transform_buffer_offset.unwrap();
1068
1069            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
1070                return Err(
1071                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
1072                        transform_buffer.error_ident(),
1073                    ),
1074                );
1075            }
1076            if transform_buffer.size < 48 + offset {
1077                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1078                    transform_buffer.error_ident(),
1079                    transform_buffer.size,
1080                    48 + offset,
1081                ));
1082            }
1083            cmd_buf_data.buffer_memory_init_actions.extend(
1084                transform_buffer.initialization_status.read().create_action(
1085                    transform_buffer,
1086                    offset..(offset + 48),
1087                    MemoryInitKind::NeedsInitializedMemory,
1088                ),
1089            );
1090            Some(transform_raw)
1091        } else {
1092            None
1093        };
1094
1095        let triangles = hal::AccelerationStructureTriangles {
1096            vertex_buffer: Some(vertex_buffer),
1097            vertex_format: mesh.size.vertex_format,
1098            first_vertex: mesh.first_vertex,
1099            vertex_count: mesh.size.vertex_count,
1100            vertex_stride: mesh.vertex_stride,
1101            indices: index_buffer.map(|index_buffer| {
1102                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
1103                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
1104                    format: mesh.size.index_format.unwrap(),
1105                    buffer: Some(index_buffer),
1106                    offset: mesh.first_index.unwrap() * index_stride,
1107                    count: mesh.size.index_count.unwrap(),
1108                }
1109            }),
1110            transform: transform_buffer.map(|transform_buffer| {
1111                hal::AccelerationStructureTriangleTransform {
1112                    buffer: transform_buffer,
1113                    offset: mesh.transform_buffer_offset.unwrap() as u32,
1114                }
1115            }),
1116            flags: mesh.size.flags,
1117        };
1118        triangle_entries.push(triangles);
1119        if let Some(blas) = buf.ending_blas.take() {
1120            let scratch_buffer_offset = *scratch_buffer_blas_size;
1121            *scratch_buffer_blas_size += align_to(
1122                blas.size_info.build_scratch_size as u32,
1123                ray_tracing_scratch_buffer_alignment,
1124            ) as u64;
1125
1126            blas_storage.push(BlasStore {
1127                blas,
1128                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
1129                scratch_buffer_offset,
1130            });
1131            triangle_entries = Vec::new();
1132        }
1133    }
1134    Ok(())
1135}
1136
1137fn map_blas<'a>(
1138    storage: &'a BlasStore<'_>,
1139    scratch_buffer: &'a dyn hal::DynBuffer,
1140    snatch_guard: &'a SnatchGuard,
1141) -> Result<
1142    hal::BuildAccelerationStructureDescriptor<
1143        'a,
1144        dyn hal::DynBuffer,
1145        dyn hal::DynAccelerationStructure,
1146    >,
1147    BuildAccelerationStructureError,
1148> {
1149    let BlasStore {
1150        blas,
1151        entries,
1152        scratch_buffer_offset,
1153    } = storage;
1154    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1155        log::info!("only rebuild implemented")
1156    }
1157    Ok(hal::BuildAccelerationStructureDescriptor {
1158        entries,
1159        mode: hal::AccelerationStructureBuildMode::Build,
1160        flags: blas.flags,
1161        source_acceleration_structure: None,
1162        destination_acceleration_structure: blas.try_raw(snatch_guard)?,
1163        scratch_buffer,
1164        scratch_buffer_offset: *scratch_buffer_offset,
1165    })
1166}
1167
1168fn build_blas<'a>(
1169    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1170    blas_present: bool,
1171    tlas_present: bool,
1172    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1173    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1174        'a,
1175        dyn hal::DynBuffer,
1176        dyn hal::DynAccelerationStructure,
1177    >],
1178    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1179) {
1180    unsafe {
1181        cmd_buf_raw.transition_buffers(&input_barriers);
1182    }
1183
1184    if blas_present {
1185        unsafe {
1186            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1187                usage: hal::StateTransition {
1188                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1189                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1190                },
1191            });
1192
1193            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1194        }
1195    }
1196
1197    if blas_present && tlas_present {
1198        unsafe {
1199            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1200        }
1201    }
1202
1203    let mut source_usage = hal::AccelerationStructureUses::empty();
1204    let mut destination_usage = hal::AccelerationStructureUses::empty();
1205    if blas_present {
1206        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1207        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1208    }
1209    if tlas_present {
1210        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1211        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1212    }
1213    unsafe {
1214        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1215            usage: hal::StateTransition {
1216                from: source_usage,
1217                to: destination_usage,
1218            },
1219        });
1220    }
1221}