wgpu_core/device/
ray_tracing.rs

1use std::mem::ManuallyDrop;
2use std::sync::Arc;
3
4use crate::api_log;
5#[cfg(feature = "trace")]
6use crate::device::trace;
7use crate::lock::rank;
8use crate::resource::{Fallible, TrackingData};
9use crate::snatch::Snatchable;
10use crate::{
11    device::{Device, DeviceError},
12    global::Global,
13    id::{self, BlasId, TlasId},
14    lock::RwLock,
15    ray_tracing::{CreateBlasError, CreateTlasError},
16    resource, LabelHelpers,
17};
18use hal::AccelerationStructureTriangleIndices;
19use wgt::Features;
20
21impl Device {
22    fn create_blas(
23        self: &Arc<Self>,
24        blas_desc: &resource::BlasDescriptor,
25        sizes: wgt::BlasGeometrySizeDescriptors,
26    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
27        self.check_is_valid()?;
28        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
29
30        let size_info = match &sizes {
31            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
32                let mut entries =
33                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
34                        descriptors.len(),
35                    );
36                for desc in descriptors {
37                    if desc.index_count.is_some() != desc.index_format.is_some() {
38                        return Err(CreateBlasError::MissingIndexData);
39                    }
40                    let indices =
41                        desc.index_count
42                            .map(|count| AccelerationStructureTriangleIndices::<
43                                dyn hal::DynBuffer,
44                            > {
45                                format: desc.index_format.unwrap(),
46                                buffer: None,
47                                offset: 0,
48                                count,
49                            });
50                    if !self
51                        .features
52                        .allowed_vertex_formats_for_blas()
53                        .contains(&desc.vertex_format)
54                    {
55                        return Err(CreateBlasError::InvalidVertexFormat(
56                            desc.vertex_format,
57                            self.features.allowed_vertex_formats_for_blas(),
58                        ));
59                    }
60                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
61                        vertex_buffer: None,
62                        vertex_format: desc.vertex_format,
63                        first_vertex: 0,
64                        vertex_count: desc.vertex_count,
65                        vertex_stride: 0,
66                        indices,
67                        transform: None,
68                        flags: desc.flags,
69                    });
70                }
71                unsafe {
72                    self.raw().get_acceleration_structure_build_sizes(
73                        &hal::GetAccelerationStructureBuildSizesDescriptor {
74                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
75                            flags: blas_desc.flags,
76                        },
77                    )
78                }
79            }
80        };
81
82        let raw = unsafe {
83            self.raw()
84                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
85                    label: blas_desc.label.as_deref(),
86                    size: size_info.acceleration_structure_size,
87                    format: hal::AccelerationStructureFormat::BottomLevel,
88                })
89        }
90        .map_err(DeviceError::from_hal)?;
91
92        let handle = unsafe {
93            self.raw()
94                .get_acceleration_structure_device_address(raw.as_ref())
95        };
96
97        Ok(Arc::new(resource::Blas {
98            raw: Snatchable::new(raw),
99            device: self.clone(),
100            size_info,
101            sizes,
102            flags: blas_desc.flags,
103            update_mode: blas_desc.update_mode,
104            handle,
105            label: blas_desc.label.to_string(),
106            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
107            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
108        }))
109    }
110
111    fn create_tlas(
112        self: &Arc<Self>,
113        desc: &resource::TlasDescriptor,
114    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
115        self.check_is_valid()?;
116        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
117
118        let size_info = unsafe {
119            self.raw().get_acceleration_structure_build_sizes(
120                &hal::GetAccelerationStructureBuildSizesDescriptor {
121                    entries: &hal::AccelerationStructureEntries::Instances(
122                        hal::AccelerationStructureInstances {
123                            buffer: None,
124                            offset: 0,
125                            count: desc.max_instances,
126                        },
127                    ),
128                    flags: desc.flags,
129                },
130            )
131        };
132
133        let raw = unsafe {
134            self.raw()
135                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
136                    label: desc.label.as_deref(),
137                    size: size_info.acceleration_structure_size,
138                    format: hal::AccelerationStructureFormat::TopLevel,
139                })
140        }
141        .map_err(DeviceError::from_hal)?;
142
143        let instance_buffer_size =
144            self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
145        let instance_buffer = unsafe {
146            self.raw().create_buffer(&hal::BufferDescriptor {
147                label: Some("(wgpu-core) instances_buffer"),
148                size: instance_buffer_size as u64,
149                usage: hal::BufferUses::COPY_DST
150                    | hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
151                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
152            })
153        }
154        .map_err(DeviceError::from_hal)?;
155
156        Ok(Arc::new(resource::Tlas {
157            raw: Snatchable::new(raw),
158            device: self.clone(),
159            size_info,
160            flags: desc.flags,
161            update_mode: desc.update_mode,
162            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
163            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
164            instance_buffer: ManuallyDrop::new(instance_buffer),
165            label: desc.label.to_string(),
166            max_instance_count: desc.max_instances,
167            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
168        }))
169    }
170}
171
172impl Global {
173    pub fn device_create_blas(
174        &self,
175        device_id: id::DeviceId,
176        desc: &resource::BlasDescriptor,
177        sizes: wgt::BlasGeometrySizeDescriptors,
178        id_in: Option<BlasId>,
179    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
180        profiling::scope!("Device::create_blas");
181
182        let fid = self.hub.blas_s.prepare(id_in);
183
184        let error = 'error: {
185            let device = self.hub.devices.get(device_id);
186
187            #[cfg(feature = "trace")]
188            if let Some(trace) = device.trace.lock().as_mut() {
189                trace.add(trace::Action::CreateBlas {
190                    id: fid.id(),
191                    desc: desc.clone(),
192                    sizes: sizes.clone(),
193                });
194            }
195
196            let blas = match device.create_blas(desc, sizes) {
197                Ok(blas) => blas,
198                Err(e) => break 'error e,
199            };
200            let handle = blas.handle;
201
202            let id = fid.assign(Fallible::Valid(blas));
203            api_log!("Device::create_blas -> {id:?}");
204
205            return (id, Some(handle), None);
206        };
207
208        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
209        (id, None, Some(error))
210    }
211
212    pub fn device_create_tlas(
213        &self,
214        device_id: id::DeviceId,
215        desc: &resource::TlasDescriptor,
216        id_in: Option<TlasId>,
217    ) -> (TlasId, Option<CreateTlasError>) {
218        profiling::scope!("Device::create_tlas");
219
220        let fid = self.hub.tlas_s.prepare(id_in);
221
222        let error = 'error: {
223            let device = self.hub.devices.get(device_id);
224
225            #[cfg(feature = "trace")]
226            if let Some(trace) = device.trace.lock().as_mut() {
227                trace.add(trace::Action::CreateTlas {
228                    id: fid.id(),
229                    desc: desc.clone(),
230                });
231            }
232
233            let tlas = match device.create_tlas(desc) {
234                Ok(tlas) => tlas,
235                Err(e) => break 'error e,
236            };
237
238            let id = fid.assign(Fallible::Valid(tlas));
239            api_log!("Device::create_tlas -> {id:?}");
240
241            return (id, None);
242        };
243
244        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
245        (id, Some(error))
246    }
247
248    pub fn blas_drop(&self, blas_id: BlasId) {
249        profiling::scope!("Blas::drop");
250        api_log!("Blas::drop {blas_id:?}");
251
252        let _blas = self.hub.blas_s.remove(blas_id);
253
254        #[cfg(feature = "trace")]
255        if let Ok(blas) = _blas.get() {
256            if let Some(t) = blas.device.trace.lock().as_mut() {
257                t.add(trace::Action::DestroyBlas(blas_id));
258            }
259        }
260    }
261
262    pub fn tlas_drop(&self, tlas_id: TlasId) {
263        profiling::scope!("Tlas::drop");
264        api_log!("Tlas::drop {tlas_id:?}");
265
266        let _tlas = self.hub.tlas_s.remove(tlas_id);
267
268        #[cfg(feature = "trace")]
269        if let Ok(tlas) = _tlas.get() {
270            if let Some(t) = tlas.device.trace.lock().as_mut() {
271                t.add(trace::Action::DestroyTlas(tlas_id));
272            }
273        }
274    }
275}