1use std::mem::ManuallyDrop;
2use std::sync::Arc;
3
4use crate::api_log;
5#[cfg(feature = "trace")]
6use crate::device::trace;
7use crate::lock::rank;
8use crate::resource::{Fallible, TrackingData};
9use crate::snatch::Snatchable;
10use crate::{
11 device::{Device, DeviceError},
12 global::Global,
13 id::{self, BlasId, TlasId},
14 lock::RwLock,
15 ray_tracing::{CreateBlasError, CreateTlasError},
16 resource, LabelHelpers,
17};
18use hal::AccelerationStructureTriangleIndices;
19use wgt::Features;
20
21impl Device {
22 fn create_blas(
23 self: &Arc<Self>,
24 blas_desc: &resource::BlasDescriptor,
25 sizes: wgt::BlasGeometrySizeDescriptors,
26 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
27 self.check_is_valid()?;
28 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
29
30 let size_info = match &sizes {
31 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
32 let mut entries =
33 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
34 descriptors.len(),
35 );
36 for desc in descriptors {
37 if desc.index_count.is_some() != desc.index_format.is_some() {
38 return Err(CreateBlasError::MissingIndexData);
39 }
40 let indices =
41 desc.index_count
42 .map(|count| AccelerationStructureTriangleIndices::<
43 dyn hal::DynBuffer,
44 > {
45 format: desc.index_format.unwrap(),
46 buffer: None,
47 offset: 0,
48 count,
49 });
50 if !self
51 .features
52 .allowed_vertex_formats_for_blas()
53 .contains(&desc.vertex_format)
54 {
55 return Err(CreateBlasError::InvalidVertexFormat(
56 desc.vertex_format,
57 self.features.allowed_vertex_formats_for_blas(),
58 ));
59 }
60 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
61 vertex_buffer: None,
62 vertex_format: desc.vertex_format,
63 first_vertex: 0,
64 vertex_count: desc.vertex_count,
65 vertex_stride: 0,
66 indices,
67 transform: None,
68 flags: desc.flags,
69 });
70 }
71 unsafe {
72 self.raw().get_acceleration_structure_build_sizes(
73 &hal::GetAccelerationStructureBuildSizesDescriptor {
74 entries: &hal::AccelerationStructureEntries::Triangles(entries),
75 flags: blas_desc.flags,
76 },
77 )
78 }
79 }
80 };
81
82 let raw = unsafe {
83 self.raw()
84 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
85 label: blas_desc.label.as_deref(),
86 size: size_info.acceleration_structure_size,
87 format: hal::AccelerationStructureFormat::BottomLevel,
88 })
89 }
90 .map_err(DeviceError::from_hal)?;
91
92 let handle = unsafe {
93 self.raw()
94 .get_acceleration_structure_device_address(raw.as_ref())
95 };
96
97 Ok(Arc::new(resource::Blas {
98 raw: Snatchable::new(raw),
99 device: self.clone(),
100 size_info,
101 sizes,
102 flags: blas_desc.flags,
103 update_mode: blas_desc.update_mode,
104 handle,
105 label: blas_desc.label.to_string(),
106 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
107 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
108 }))
109 }
110
111 fn create_tlas(
112 self: &Arc<Self>,
113 desc: &resource::TlasDescriptor,
114 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
115 self.check_is_valid()?;
116 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
117
118 let size_info = unsafe {
119 self.raw().get_acceleration_structure_build_sizes(
120 &hal::GetAccelerationStructureBuildSizesDescriptor {
121 entries: &hal::AccelerationStructureEntries::Instances(
122 hal::AccelerationStructureInstances {
123 buffer: None,
124 offset: 0,
125 count: desc.max_instances,
126 },
127 ),
128 flags: desc.flags,
129 },
130 )
131 };
132
133 let raw = unsafe {
134 self.raw()
135 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
136 label: desc.label.as_deref(),
137 size: size_info.acceleration_structure_size,
138 format: hal::AccelerationStructureFormat::TopLevel,
139 })
140 }
141 .map_err(DeviceError::from_hal)?;
142
143 let instance_buffer_size =
144 self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
145 let instance_buffer = unsafe {
146 self.raw().create_buffer(&hal::BufferDescriptor {
147 label: Some("(wgpu-core) instances_buffer"),
148 size: instance_buffer_size as u64,
149 usage: hal::BufferUses::COPY_DST
150 | hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
151 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
152 })
153 }
154 .map_err(DeviceError::from_hal)?;
155
156 Ok(Arc::new(resource::Tlas {
157 raw: Snatchable::new(raw),
158 device: self.clone(),
159 size_info,
160 flags: desc.flags,
161 update_mode: desc.update_mode,
162 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
163 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
164 instance_buffer: ManuallyDrop::new(instance_buffer),
165 label: desc.label.to_string(),
166 max_instance_count: desc.max_instances,
167 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
168 }))
169 }
170}
171
172impl Global {
173 pub fn device_create_blas(
174 &self,
175 device_id: id::DeviceId,
176 desc: &resource::BlasDescriptor,
177 sizes: wgt::BlasGeometrySizeDescriptors,
178 id_in: Option<BlasId>,
179 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
180 profiling::scope!("Device::create_blas");
181
182 let fid = self.hub.blas_s.prepare(id_in);
183
184 let error = 'error: {
185 let device = self.hub.devices.get(device_id);
186
187 #[cfg(feature = "trace")]
188 if let Some(trace) = device.trace.lock().as_mut() {
189 trace.add(trace::Action::CreateBlas {
190 id: fid.id(),
191 desc: desc.clone(),
192 sizes: sizes.clone(),
193 });
194 }
195
196 let blas = match device.create_blas(desc, sizes) {
197 Ok(blas) => blas,
198 Err(e) => break 'error e,
199 };
200 let handle = blas.handle;
201
202 let id = fid.assign(Fallible::Valid(blas));
203 api_log!("Device::create_blas -> {id:?}");
204
205 return (id, Some(handle), None);
206 };
207
208 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
209 (id, None, Some(error))
210 }
211
212 pub fn device_create_tlas(
213 &self,
214 device_id: id::DeviceId,
215 desc: &resource::TlasDescriptor,
216 id_in: Option<TlasId>,
217 ) -> (TlasId, Option<CreateTlasError>) {
218 profiling::scope!("Device::create_tlas");
219
220 let fid = self.hub.tlas_s.prepare(id_in);
221
222 let error = 'error: {
223 let device = self.hub.devices.get(device_id);
224
225 #[cfg(feature = "trace")]
226 if let Some(trace) = device.trace.lock().as_mut() {
227 trace.add(trace::Action::CreateTlas {
228 id: fid.id(),
229 desc: desc.clone(),
230 });
231 }
232
233 let tlas = match device.create_tlas(desc) {
234 Ok(tlas) => tlas,
235 Err(e) => break 'error e,
236 };
237
238 let id = fid.assign(Fallible::Valid(tlas));
239 api_log!("Device::create_tlas -> {id:?}");
240
241 return (id, None);
242 };
243
244 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
245 (id, Some(error))
246 }
247
248 pub fn blas_drop(&self, blas_id: BlasId) {
249 profiling::scope!("Blas::drop");
250 api_log!("Blas::drop {blas_id:?}");
251
252 let _blas = self.hub.blas_s.remove(blas_id);
253
254 #[cfg(feature = "trace")]
255 if let Ok(blas) = _blas.get() {
256 if let Some(t) = blas.device.trace.lock().as_mut() {
257 t.add(trace::Action::DestroyBlas(blas_id));
258 }
259 }
260 }
261
262 pub fn tlas_drop(&self, tlas_id: TlasId) {
263 profiling::scope!("Tlas::drop");
264 api_log!("Tlas::drop {tlas_id:?}");
265
266 let _tlas = self.hub.tlas_s.remove(tlas_id);
267
268 #[cfg(feature = "trace")]
269 if let Ok(tlas) = _tlas.get() {
270 if let Some(t) = tlas.device.trace.lock().as_mut() {
271 t.add(trace::Action::DestroyTlas(tlas_id));
272 }
273 }
274 }
275}