1use crate::{
2 binding_model::{
3 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4 },
5 command::{
6 bind::Binder,
7 compute_command::ArcComputeCommand,
8 end_pipeline_statistics_query,
9 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10 validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
11 BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr, PassErrorScope,
12 PassTimestampWrites, QueryUseError, StateChange,
13 },
14 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
15 global::Global,
16 hal_label, id,
17 init_tracker::{BufferInitTrackerAction, MemoryInitKind},
18 pipeline::ComputePipeline,
19 resource::{
20 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
21 MissingBufferUsageError, ParentDevice,
22 },
23 snatch::SnatchGuard,
24 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
25 Label,
26};
27
28use thiserror::Error;
29use wgt::{BufferAddress, DynamicOffset};
30
31use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
32use crate::ray_tracing::TlasAction;
33use std::{fmt, mem::size_of, str, sync::Arc};
34
35pub struct ComputePass {
36 base: Option<BasePass<ArcComputeCommand>>,
41
42 parent: Option<Arc<CommandBuffer>>,
46
47 timestamp_writes: Option<ArcPassTimestampWrites>,
48
49 current_bind_groups: BindGroupStateChange,
51 current_pipeline: StateChange<id::ComputePipelineId>,
52}
53
54impl ComputePass {
55 fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
57 let ArcComputePassDescriptor {
58 label,
59 timestamp_writes,
60 } = desc;
61
62 Self {
63 base: Some(BasePass::new(label)),
64 parent,
65 timestamp_writes,
66
67 current_bind_groups: BindGroupStateChange::new(),
68 current_pipeline: StateChange::new(),
69 }
70 }
71
72 #[inline]
73 pub fn label(&self) -> Option<&str> {
74 self.base.as_ref().and_then(|base| base.label.as_deref())
75 }
76
77 fn base_mut<'a>(
78 &'a mut self,
79 scope: PassErrorScope,
80 ) -> Result<&'a mut BasePass<ArcComputeCommand>, ComputePassError> {
81 self.base
82 .as_mut()
83 .ok_or(ComputePassErrorInner::PassEnded)
84 .map_pass_err(scope)
85 }
86}
87
88impl fmt::Debug for ComputePass {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self.parent {
91 Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
92 None => write!(f, "ComputePass {{ parent: None }}"),
93 }
94 }
95}
96
97#[derive(Clone, Debug, Default)]
98pub struct ComputePassDescriptor<'a> {
99 pub label: Label<'a>,
100 pub timestamp_writes: Option<&'a PassTimestampWrites>,
102}
103
104struct ArcComputePassDescriptor<'a> {
105 pub label: &'a Label<'a>,
106 pub timestamp_writes: Option<ArcPassTimestampWrites>,
108}
109
110#[derive(Clone, Debug, Error)]
111#[non_exhaustive]
112pub enum DispatchError {
113 #[error("Compute pipeline must be set")]
114 MissingPipeline,
115 #[error(transparent)]
116 IncompatibleBindGroup(#[from] Box<BinderError>),
117 #[error(
118 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
119 )]
120 InvalidGroupSize { current: [u32; 3], limit: u32 },
121 #[error(transparent)]
122 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
123}
124
125#[derive(Clone, Debug, Error)]
127pub enum ComputePassErrorInner {
128 #[error(transparent)]
129 Device(#[from] DeviceError),
130 #[error(transparent)]
131 Encoder(#[from] CommandEncoderError),
132 #[error("Parent encoder is invalid")]
133 InvalidParentEncoder,
134 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
135 BindGroupIndexOutOfRange { index: u32, max: u32 },
136 #[error(transparent)]
137 DestroyedResource(#[from] DestroyedResourceError),
138 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
139 UnalignedIndirectBufferOffset(BufferAddress),
140 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
141 IndirectBufferOverrun {
142 offset: u64,
143 end_offset: u64,
144 buffer_size: u64,
145 },
146 #[error(transparent)]
147 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
148 #[error(transparent)]
149 MissingBufferUsage(#[from] MissingBufferUsageError),
150 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
151 InvalidPopDebugGroup,
152 #[error(transparent)]
153 Dispatch(#[from] DispatchError),
154 #[error(transparent)]
155 Bind(#[from] BindError),
156 #[error(transparent)]
157 PushConstants(#[from] PushConstantUploadError),
158 #[error("Push constant offset must be aligned to 4 bytes")]
159 PushConstantOffsetAlignment,
160 #[error("Push constant size must be aligned to 4 bytes")]
161 PushConstantSizeAlignment,
162 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
163 PushConstantOutOfMemory,
164 #[error(transparent)]
165 QueryUse(#[from] QueryUseError),
166 #[error(transparent)]
167 MissingFeatures(#[from] MissingFeatures),
168 #[error(transparent)]
169 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
170 #[error("The compute pass has already been ended and no further commands can be recorded")]
171 PassEnded,
172 #[error(transparent)]
173 InvalidResource(#[from] InvalidResourceError),
174}
175
176#[derive(Clone, Debug, Error)]
178#[error("{scope}")]
179pub struct ComputePassError {
180 pub scope: PassErrorScope,
181 #[source]
182 pub(super) inner: ComputePassErrorInner,
183}
184
185impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
186where
187 E: Into<ComputePassErrorInner>,
188{
189 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
190 self.map_err(|inner| ComputePassError {
191 scope,
192 inner: inner.into(),
193 })
194 }
195}
196
197struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
198 binder: Binder,
199 pipeline: Option<Arc<ComputePipeline>>,
200 scope: UsageScope<'scope>,
201 debug_scope_depth: u32,
202
203 snatch_guard: SnatchGuard<'snatch_guard>,
204
205 device: &'cmd_buf Arc<Device>,
206
207 raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
208
209 tracker: &'cmd_buf mut Tracker,
210 buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
211 texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
212 tlas_actions: &'cmd_buf mut Vec<TlasAction>,
213
214 temp_offsets: Vec<u32>,
215 dynamic_offset_count: usize,
216 string_offset: usize,
217 active_query: Option<(Arc<resource::QuerySet>, u32)>,
218
219 push_constants: Vec<u32>,
220
221 intermediate_trackers: Tracker,
222
223 pending_discard_init_fixups: SurfacesInDiscardState,
226}
227
228impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
229 State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
230{
231 fn is_ready(&self) -> Result<(), DispatchError> {
232 if let Some(pipeline) = self.pipeline.as_ref() {
233 self.binder.check_compatibility(pipeline.as_ref())?;
234 self.binder.check_late_buffer_bindings()?;
235 Ok(())
236 } else {
237 Err(DispatchError::MissingPipeline)
238 }
239 }
240
241 fn flush_states(
244 &mut self,
245 indirect_buffer: Option<TrackerIndex>,
246 ) -> Result<(), ResourceUsageCompatibilityError> {
247 for bind_group in self.binder.list_active() {
248 unsafe { self.scope.merge_bind_group(&bind_group.used)? };
249 }
252
253 for bind_group in self.binder.list_active() {
254 unsafe {
255 self.intermediate_trackers
256 .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
257 }
258 }
259
260 unsafe {
262 self.intermediate_trackers
263 .buffers
264 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
265 }
266
267 CommandBuffer::drain_barriers(
268 self.raw_encoder,
269 &mut self.intermediate_trackers,
270 &self.snatch_guard,
271 );
272 Ok(())
273 }
274}
275
276impl Global {
279 pub fn command_encoder_create_compute_pass(
288 &self,
289 encoder_id: id::CommandEncoderId,
290 desc: &ComputePassDescriptor<'_>,
291 ) -> (ComputePass, Option<CommandEncoderError>) {
292 let hub = &self.hub;
293
294 let mut arc_desc = ArcComputePassDescriptor {
295 label: &desc.label,
296 timestamp_writes: None, };
298
299 let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
300
301 let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
302
303 match cmd_buf.data.lock().lock_encoder() {
304 Ok(_) => {}
305 Err(e) => return make_err(e, arc_desc),
306 };
307
308 arc_desc.timestamp_writes = match desc
309 .timestamp_writes
310 .map(|tw| {
311 Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
312 })
313 .transpose()
314 {
315 Ok(ok) => ok,
316 Err(e) => return make_err(e, arc_desc),
317 };
318
319 (ComputePass::new(Some(cmd_buf), arc_desc), None)
320 }
321
322 #[doc(hidden)]
325 #[cfg(any(feature = "serde", feature = "replay"))]
326 pub fn compute_pass_end_with_unresolved_commands(
327 &self,
328 encoder_id: id::CommandEncoderId,
329 base: BasePass<super::ComputeCommand>,
330 timestamp_writes: Option<&PassTimestampWrites>,
331 ) -> Result<(), ComputePassError> {
332 let pass_scope = PassErrorScope::Pass;
333
334 #[cfg(feature = "trace")]
335 {
336 let cmd_buf = self
337 .hub
338 .command_buffers
339 .get(encoder_id.into_command_buffer_id());
340 let mut cmd_buf_data = cmd_buf.data.lock();
341 let cmd_buf_data = cmd_buf_data.get_inner().map_pass_err(pass_scope)?;
342
343 if let Some(ref mut list) = cmd_buf_data.commands {
344 list.push(crate::device::trace::Command::RunComputePass {
345 base: BasePass {
346 label: base.label.clone(),
347 commands: base.commands.clone(),
348 dynamic_offsets: base.dynamic_offsets.clone(),
349 string_data: base.string_data.clone(),
350 push_constant_data: base.push_constant_data.clone(),
351 },
352 timestamp_writes: timestamp_writes.cloned(),
353 });
354 }
355 }
356
357 let BasePass {
358 label,
359 commands,
360 dynamic_offsets,
361 string_data,
362 push_constant_data,
363 } = base;
364
365 let (mut compute_pass, encoder_error) = self.command_encoder_create_compute_pass(
366 encoder_id,
367 &ComputePassDescriptor {
368 label: label.as_deref().map(std::borrow::Cow::Borrowed),
369 timestamp_writes,
370 },
371 );
372 if let Some(err) = encoder_error {
373 return Err(ComputePassError {
374 scope: pass_scope,
375 inner: err.into(),
376 });
377 };
378
379 compute_pass.base = Some(BasePass {
380 label,
381 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)?,
382 dynamic_offsets,
383 string_data,
384 push_constant_data,
385 });
386
387 self.compute_pass_end(&mut compute_pass)
388 }
389
390 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
391 profiling::scope!("CommandEncoder::run_compute_pass");
392 let pass_scope = PassErrorScope::Pass;
393
394 let cmd_buf = pass
395 .parent
396 .as_ref()
397 .ok_or(ComputePassErrorInner::InvalidParentEncoder)
398 .map_pass_err(pass_scope)?;
399
400 let base = pass
401 .base
402 .take()
403 .ok_or(ComputePassErrorInner::PassEnded)
404 .map_pass_err(pass_scope)?;
405
406 let device = &cmd_buf.device;
407 device.check_is_valid().map_pass_err(pass_scope)?;
408
409 let mut cmd_buf_data = cmd_buf.data.lock();
410 let mut cmd_buf_data_guard = cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
411 let cmd_buf_data = &mut *cmd_buf_data_guard;
412
413 let encoder = &mut cmd_buf_data.encoder;
414
415 encoder.close_if_open().map_pass_err(pass_scope)?;
419 let raw_encoder = encoder
420 .open_pass(base.label.as_deref())
421 .map_pass_err(pass_scope)?;
422
423 let mut state = State {
424 binder: Binder::new(),
425 pipeline: None,
426 scope: device.new_usage_scope(),
427 debug_scope_depth: 0,
428
429 snatch_guard: device.snatchable_lock.read(),
430
431 device,
432 raw_encoder,
433 tracker: &mut cmd_buf_data.trackers,
434 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
435 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
436 tlas_actions: &mut cmd_buf_data.tlas_actions,
437
438 temp_offsets: Vec::new(),
439 dynamic_offset_count: 0,
440 string_offset: 0,
441 active_query: None,
442
443 push_constants: Vec::new(),
444
445 intermediate_trackers: Tracker::new(),
446
447 pending_discard_init_fixups: SurfacesInDiscardState::new(),
448 };
449
450 let indices = &state.device.tracker_indices;
451 state.tracker.buffers.set_size(indices.buffers.size());
452 state.tracker.textures.set_size(indices.textures.size());
453
454 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
455 if let Some(tw) = pass.timestamp_writes.take() {
456 tw.query_set
457 .same_device_as(cmd_buf.as_ref())
458 .map_pass_err(pass_scope)?;
459
460 let query_set = state.tracker.query_sets.insert_single(tw.query_set);
461
462 let range = if let (Some(index_a), Some(index_b)) =
465 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
466 {
467 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
468 } else {
469 tw.beginning_of_pass_write_index
470 .or(tw.end_of_pass_write_index)
471 .map(|i| i..i + 1)
472 };
473 if let Some(range) = range {
476 unsafe {
477 state.raw_encoder.reset_queries(query_set.raw(), range);
478 }
479 }
480
481 Some(hal::PassTimestampWrites {
482 query_set: query_set.raw(),
483 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
484 end_of_pass_write_index: tw.end_of_pass_write_index,
485 })
486 } else {
487 None
488 };
489
490 let hal_desc = hal::ComputePassDescriptor {
491 label: hal_label(base.label.as_deref(), device.instance_flags),
492 timestamp_writes,
493 };
494
495 unsafe {
496 state.raw_encoder.begin_compute_pass(&hal_desc);
497 }
498
499 for command in base.commands {
500 match command {
501 ArcComputeCommand::SetBindGroup {
502 index,
503 num_dynamic_offsets,
504 bind_group,
505 } => {
506 let scope = PassErrorScope::SetBindGroup;
507 set_bind_group(
508 &mut state,
509 cmd_buf,
510 &base.dynamic_offsets,
511 index,
512 num_dynamic_offsets,
513 bind_group,
514 )
515 .map_pass_err(scope)?;
516 }
517 ArcComputeCommand::SetPipeline(pipeline) => {
518 let scope = PassErrorScope::SetPipelineCompute;
519 set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
520 }
521 ArcComputeCommand::SetPushConstant {
522 offset,
523 size_bytes,
524 values_offset,
525 } => {
526 let scope = PassErrorScope::SetPushConstant;
527 set_push_constant(
528 &mut state,
529 &base.push_constant_data,
530 offset,
531 size_bytes,
532 values_offset,
533 )
534 .map_pass_err(scope)?;
535 }
536 ArcComputeCommand::Dispatch(groups) => {
537 let scope = PassErrorScope::Dispatch { indirect: false };
538 dispatch(&mut state, groups).map_pass_err(scope)?;
539 }
540 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
541 let scope = PassErrorScope::Dispatch { indirect: true };
542 dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
543 }
544 ArcComputeCommand::PushDebugGroup { color: _, len } => {
545 push_debug_group(&mut state, &base.string_data, len);
546 }
547 ArcComputeCommand::PopDebugGroup => {
548 let scope = PassErrorScope::PopDebugGroup;
549 pop_debug_group(&mut state).map_pass_err(scope)?;
550 }
551 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
552 insert_debug_marker(&mut state, &base.string_data, len);
553 }
554 ArcComputeCommand::WriteTimestamp {
555 query_set,
556 query_index,
557 } => {
558 let scope = PassErrorScope::WriteTimestamp;
559 write_timestamp(&mut state, cmd_buf, query_set, query_index)
560 .map_pass_err(scope)?;
561 }
562 ArcComputeCommand::BeginPipelineStatisticsQuery {
563 query_set,
564 query_index,
565 } => {
566 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
567 validate_and_begin_pipeline_statistics_query(
568 query_set,
569 state.raw_encoder,
570 &mut state.tracker.query_sets,
571 cmd_buf,
572 query_index,
573 None,
574 &mut state.active_query,
575 )
576 .map_pass_err(scope)?;
577 }
578 ArcComputeCommand::EndPipelineStatisticsQuery => {
579 let scope = PassErrorScope::EndPipelineStatisticsQuery;
580 end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
581 .map_pass_err(scope)?;
582 }
583 }
584 }
585
586 unsafe {
587 state.raw_encoder.end_compute_pass();
588 }
589
590 let State {
591 snatch_guard,
592 tracker,
593 intermediate_trackers,
594 pending_discard_init_fixups,
595 ..
596 } = state;
597
598 encoder.close().map_pass_err(pass_scope)?;
600
601 let transit = encoder
605 .open_pass(Some("(wgpu internal) Pre Pass"))
606 .map_pass_err(pass_scope)?;
607 fixup_discarded_surfaces(
608 pending_discard_init_fixups.into_iter(),
609 transit,
610 &mut tracker.textures,
611 device,
612 &snatch_guard,
613 );
614 CommandBuffer::insert_barriers_from_tracker(
615 transit,
616 tracker,
617 &intermediate_trackers,
618 &snatch_guard,
619 );
620 encoder.close_and_swap().map_pass_err(pass_scope)?;
622 cmd_buf_data_guard.mark_successful();
623
624 Ok(())
625 }
626}
627
628fn set_bind_group(
629 state: &mut State,
630 cmd_buf: &CommandBuffer,
631 dynamic_offsets: &[DynamicOffset],
632 index: u32,
633 num_dynamic_offsets: usize,
634 bind_group: Option<Arc<BindGroup>>,
635) -> Result<(), ComputePassErrorInner> {
636 let max_bind_groups = state.device.limits.max_bind_groups;
637 if index >= max_bind_groups {
638 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
639 index,
640 max: max_bind_groups,
641 });
642 }
643
644 state.temp_offsets.clear();
645 state.temp_offsets.extend_from_slice(
646 &dynamic_offsets
647 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
648 );
649 state.dynamic_offset_count += num_dynamic_offsets;
650
651 if bind_group.is_none() {
652 return Ok(());
654 }
655
656 let bind_group = bind_group.unwrap();
657 let bind_group = state.tracker.bind_groups.insert_single(bind_group);
658
659 bind_group.same_device_as(cmd_buf)?;
660
661 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
662
663 state
664 .buffer_memory_init_actions
665 .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
666 action
667 .buffer
668 .initialization_status
669 .read()
670 .check_action(action)
671 }));
672
673 for action in bind_group.used_texture_ranges.iter() {
674 state
675 .pending_discard_init_fixups
676 .extend(state.texture_memory_actions.register_init_action(action));
677 }
678
679 let used_resource = bind_group
680 .used
681 .acceleration_structures
682 .into_iter()
683 .map(|tlas| TlasAction {
684 tlas: tlas.clone(),
685 kind: crate::ray_tracing::TlasActionKind::Use,
686 });
687
688 state.tlas_actions.extend(used_resource);
689
690 let pipeline_layout = state.binder.pipeline_layout.clone();
691 let entries = state
692 .binder
693 .assign_group(index as usize, bind_group, &state.temp_offsets);
694 if !entries.is_empty() && pipeline_layout.is_some() {
695 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
696 for (i, e) in entries.iter().enumerate() {
697 if let Some(group) = e.group.as_ref() {
698 let raw_bg = group.try_raw(&state.snatch_guard)?;
699 unsafe {
700 state.raw_encoder.set_bind_group(
701 pipeline_layout,
702 index + i as u32,
703 Some(raw_bg),
704 &e.dynamic_offsets,
705 );
706 }
707 }
708 }
709 }
710 Ok(())
711}
712
713fn set_pipeline(
714 state: &mut State,
715 cmd_buf: &CommandBuffer,
716 pipeline: Arc<ComputePipeline>,
717) -> Result<(), ComputePassErrorInner> {
718 pipeline.same_device_as(cmd_buf)?;
719
720 state.pipeline = Some(pipeline.clone());
721
722 let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
723
724 unsafe {
725 state.raw_encoder.set_compute_pipeline(pipeline.raw());
726 }
727
728 if state.binder.pipeline_layout.is_none()
730 || !state
731 .binder
732 .pipeline_layout
733 .as_ref()
734 .unwrap()
735 .is_equal(&pipeline.layout)
736 {
737 let (start_index, entries) = state
738 .binder
739 .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
740 if !entries.is_empty() {
741 for (i, e) in entries.iter().enumerate() {
742 if let Some(group) = e.group.as_ref() {
743 let raw_bg = group.try_raw(&state.snatch_guard)?;
744 unsafe {
745 state.raw_encoder.set_bind_group(
746 pipeline.layout.raw(),
747 start_index as u32 + i as u32,
748 Some(raw_bg),
749 &e.dynamic_offsets,
750 );
751 }
752 }
753 }
754 }
755
756 state.push_constants.clear();
758 if let Some(push_constant_range) =
760 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
761 pcr.stages
762 .contains(wgt::ShaderStages::COMPUTE)
763 .then_some(pcr.range.clone())
764 })
765 {
766 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
768 state.push_constants.extend(core::iter::repeat(0).take(len));
769 }
770
771 let non_overlapping =
773 super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
774 for range in non_overlapping {
775 let offset = range.range.start;
776 let size_bytes = range.range.end - offset;
777 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
778 state.raw_encoder.set_push_constants(
779 pipeline.layout.raw(),
780 wgt::ShaderStages::COMPUTE,
781 clear_offset,
782 clear_data,
783 );
784 });
785 }
786 }
787 Ok(())
788}
789
790fn set_push_constant(
791 state: &mut State,
792 push_constant_data: &[u32],
793 offset: u32,
794 size_bytes: u32,
795 values_offset: u32,
796) -> Result<(), ComputePassErrorInner> {
797 let end_offset_bytes = offset + size_bytes;
798 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
799 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
800
801 let pipeline_layout = state
802 .binder
803 .pipeline_layout
804 .as_ref()
805 .ok_or(ComputePassErrorInner::Dispatch(
807 DispatchError::MissingPipeline,
808 ))?;
809
810 pipeline_layout.validate_push_constant_ranges(
811 wgt::ShaderStages::COMPUTE,
812 offset,
813 end_offset_bytes,
814 )?;
815
816 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
817 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
818 state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
819
820 unsafe {
821 state.raw_encoder.set_push_constants(
822 pipeline_layout.raw(),
823 wgt::ShaderStages::COMPUTE,
824 offset,
825 data_slice,
826 );
827 }
828 Ok(())
829}
830
831fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
832 state.is_ready()?;
833
834 state.flush_states(None)?;
835
836 let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
837
838 if groups[0] > groups_size_limit
839 || groups[1] > groups_size_limit
840 || groups[2] > groups_size_limit
841 {
842 return Err(ComputePassErrorInner::Dispatch(
843 DispatchError::InvalidGroupSize {
844 current: groups,
845 limit: groups_size_limit,
846 },
847 ));
848 }
849
850 unsafe {
851 state.raw_encoder.dispatch(groups);
852 }
853 Ok(())
854}
855
856fn dispatch_indirect(
857 state: &mut State,
858 cmd_buf: &CommandBuffer,
859 buffer: Arc<Buffer>,
860 offset: u64,
861) -> Result<(), ComputePassErrorInner> {
862 buffer.same_device_as(cmd_buf)?;
863
864 state.is_ready()?;
865
866 state
867 .device
868 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
869
870 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
871
872 if offset % 4 != 0 {
873 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
874 }
875
876 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
877 if end_offset > buffer.size {
878 return Err(ComputePassErrorInner::IndirectBufferOverrun {
879 offset,
880 end_offset,
881 buffer_size: buffer.size,
882 });
883 }
884
885 let stride = 3 * 4; state
887 .buffer_memory_init_actions
888 .extend(buffer.initialization_status.read().create_action(
889 &buffer,
890 offset..(offset + stride),
891 MemoryInitKind::NeedsInitializedMemory,
892 ));
893
894 #[cfg(feature = "indirect-validation")]
895 {
896 let params = state.device.indirect_validation.as_ref().unwrap().params(
897 &state.device.limits,
898 offset,
899 buffer.size,
900 );
901
902 unsafe {
903 state.raw_encoder.set_compute_pipeline(params.pipeline);
904 }
905
906 unsafe {
907 state.raw_encoder.set_push_constants(
908 params.pipeline_layout,
909 wgt::ShaderStages::COMPUTE,
910 0,
911 &[params.offset_remainder as u32 / 4],
912 );
913 }
914
915 unsafe {
916 state.raw_encoder.set_bind_group(
917 params.pipeline_layout,
918 0,
919 Some(params.dst_bind_group),
920 &[],
921 );
922 }
923 unsafe {
924 state.raw_encoder.set_bind_group(
925 params.pipeline_layout,
926 1,
927 Some(
928 buffer
929 .raw_indirect_validation_bind_group
930 .get(&state.snatch_guard)
931 .unwrap()
932 .as_ref(),
933 ),
934 &[params.aligned_offset as u32],
935 );
936 }
937
938 let src_transition = state
939 .intermediate_trackers
940 .buffers
941 .set_single(&buffer, hal::BufferUses::STORAGE_READ_ONLY);
942 let src_barrier =
943 src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
944 unsafe {
945 state.raw_encoder.transition_buffers(src_barrier.as_slice());
946 }
947
948 unsafe {
949 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
950 buffer: params.dst_buffer,
951 usage: hal::StateTransition {
952 from: hal::BufferUses::INDIRECT,
953 to: hal::BufferUses::STORAGE_READ_WRITE,
954 },
955 }]);
956 }
957
958 unsafe {
959 state.raw_encoder.dispatch([1, 1, 1]);
960 }
961
962 {
964 let pipeline = state.pipeline.as_ref().unwrap();
965
966 unsafe {
967 state.raw_encoder.set_compute_pipeline(pipeline.raw());
968 }
969
970 if !state.push_constants.is_empty() {
971 unsafe {
972 state.raw_encoder.set_push_constants(
973 pipeline.layout.raw(),
974 wgt::ShaderStages::COMPUTE,
975 0,
976 &state.push_constants,
977 );
978 }
979 }
980
981 for (i, e) in state.binder.list_valid() {
982 let group = e.group.as_ref().unwrap();
983 let raw_bg = group.try_raw(&state.snatch_guard)?;
984 unsafe {
985 state.raw_encoder.set_bind_group(
986 pipeline.layout.raw(),
987 i as u32,
988 Some(raw_bg),
989 &e.dynamic_offsets,
990 );
991 }
992 }
993 }
994
995 unsafe {
996 state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
997 buffer: params.dst_buffer,
998 usage: hal::StateTransition {
999 from: hal::BufferUses::STORAGE_READ_WRITE,
1000 to: hal::BufferUses::INDIRECT,
1001 },
1002 }]);
1003 }
1004
1005 state.flush_states(None)?;
1006 unsafe {
1007 state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
1008 }
1009 };
1010 #[cfg(not(feature = "indirect-validation"))]
1011 {
1012 state
1013 .scope
1014 .buffers
1015 .merge_single(&buffer, hal::BufferUses::INDIRECT)?;
1016
1017 use crate::resource::Trackable;
1018 state.flush_states(Some(buffer.tracker_index()))?;
1019
1020 let buf_raw = buffer.try_raw(&state.snatch_guard)?;
1021 unsafe {
1022 state.raw_encoder.dispatch_indirect(buf_raw, offset);
1023 }
1024 }
1025
1026 Ok(())
1027}
1028
1029fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) {
1030 state.debug_scope_depth += 1;
1031 if !state
1032 .device
1033 .instance_flags
1034 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1035 {
1036 let label =
1037 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1038 unsafe {
1039 state.raw_encoder.begin_debug_marker(label);
1040 }
1041 }
1042 state.string_offset += len;
1043}
1044
1045fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> {
1046 if state.debug_scope_depth == 0 {
1047 return Err(ComputePassErrorInner::InvalidPopDebugGroup);
1048 }
1049 state.debug_scope_depth -= 1;
1050 if !state
1051 .device
1052 .instance_flags
1053 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1054 {
1055 unsafe {
1056 state.raw_encoder.end_debug_marker();
1057 }
1058 }
1059 Ok(())
1060}
1061
1062fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) {
1063 if !state
1064 .device
1065 .instance_flags
1066 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1067 {
1068 let label =
1069 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1070 unsafe { state.raw_encoder.insert_debug_marker(label) }
1071 }
1072 state.string_offset += len;
1073}
1074
1075fn write_timestamp(
1076 state: &mut State,
1077 cmd_buf: &CommandBuffer,
1078 query_set: Arc<resource::QuerySet>,
1079 query_index: u32,
1080) -> Result<(), ComputePassErrorInner> {
1081 query_set.same_device_as(cmd_buf)?;
1082
1083 state
1084 .device
1085 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
1086
1087 let query_set = state.tracker.query_sets.insert_single(query_set);
1088
1089 query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
1090 Ok(())
1091}
1092
1093impl Global {
1095 pub fn compute_pass_set_bind_group(
1096 &self,
1097 pass: &mut ComputePass,
1098 index: u32,
1099 bind_group_id: Option<id::BindGroupId>,
1100 offsets: &[DynamicOffset],
1101 ) -> Result<(), ComputePassError> {
1102 let scope = PassErrorScope::SetBindGroup;
1103 let base = pass
1104 .base
1105 .as_mut()
1106 .ok_or(ComputePassErrorInner::PassEnded)
1107 .map_pass_err(scope)?; let redundant = pass.current_bind_groups.set_and_check_redundant(
1110 bind_group_id,
1111 index,
1112 &mut base.dynamic_offsets,
1113 offsets,
1114 );
1115
1116 if redundant {
1117 return Ok(());
1118 }
1119
1120 let mut bind_group = None;
1121 if bind_group_id.is_some() {
1122 let bind_group_id = bind_group_id.unwrap();
1123
1124 let hub = &self.hub;
1125 let bg = hub
1126 .bind_groups
1127 .get(bind_group_id)
1128 .get()
1129 .map_pass_err(scope)?;
1130 bind_group = Some(bg);
1131 }
1132
1133 base.commands.push(ArcComputeCommand::SetBindGroup {
1134 index,
1135 num_dynamic_offsets: offsets.len(),
1136 bind_group,
1137 });
1138
1139 Ok(())
1140 }
1141
1142 pub fn compute_pass_set_pipeline(
1143 &self,
1144 pass: &mut ComputePass,
1145 pipeline_id: id::ComputePipelineId,
1146 ) -> Result<(), ComputePassError> {
1147 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1148
1149 let scope = PassErrorScope::SetPipelineCompute;
1150
1151 let base = pass.base_mut(scope)?;
1152 if redundant {
1153 return Ok(());
1155 }
1156
1157 let hub = &self.hub;
1158 let pipeline = hub
1159 .compute_pipelines
1160 .get(pipeline_id)
1161 .get()
1162 .map_pass_err(scope)?;
1163
1164 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1165
1166 Ok(())
1167 }
1168
1169 pub fn compute_pass_set_push_constants(
1170 &self,
1171 pass: &mut ComputePass,
1172 offset: u32,
1173 data: &[u8],
1174 ) -> Result<(), ComputePassError> {
1175 let scope = PassErrorScope::SetPushConstant;
1176 let base = pass.base_mut(scope)?;
1177
1178 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1179 return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
1180 }
1181
1182 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1183 return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
1184 }
1185 let value_offset = base
1186 .push_constant_data
1187 .len()
1188 .try_into()
1189 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1190 .map_pass_err(scope)?;
1191
1192 base.push_constant_data.extend(
1193 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1194 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1195 );
1196
1197 base.commands.push(ArcComputeCommand::SetPushConstant {
1198 offset,
1199 size_bytes: data.len() as u32,
1200 values_offset: value_offset,
1201 });
1202
1203 Ok(())
1204 }
1205
1206 pub fn compute_pass_dispatch_workgroups(
1207 &self,
1208 pass: &mut ComputePass,
1209 groups_x: u32,
1210 groups_y: u32,
1211 groups_z: u32,
1212 ) -> Result<(), ComputePassError> {
1213 let scope = PassErrorScope::Dispatch { indirect: false };
1214
1215 let base = pass.base_mut(scope)?;
1216 base.commands
1217 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1218
1219 Ok(())
1220 }
1221
1222 pub fn compute_pass_dispatch_workgroups_indirect(
1223 &self,
1224 pass: &mut ComputePass,
1225 buffer_id: id::BufferId,
1226 offset: BufferAddress,
1227 ) -> Result<(), ComputePassError> {
1228 let hub = &self.hub;
1229 let scope = PassErrorScope::Dispatch { indirect: true };
1230 let base = pass.base_mut(scope)?;
1231
1232 let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
1233
1234 base.commands
1235 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1236
1237 Ok(())
1238 }
1239
1240 pub fn compute_pass_push_debug_group(
1241 &self,
1242 pass: &mut ComputePass,
1243 label: &str,
1244 color: u32,
1245 ) -> Result<(), ComputePassError> {
1246 let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
1247
1248 let bytes = label.as_bytes();
1249 base.string_data.extend_from_slice(bytes);
1250
1251 base.commands.push(ArcComputeCommand::PushDebugGroup {
1252 color,
1253 len: bytes.len(),
1254 });
1255
1256 Ok(())
1257 }
1258
1259 pub fn compute_pass_pop_debug_group(
1260 &self,
1261 pass: &mut ComputePass,
1262 ) -> Result<(), ComputePassError> {
1263 let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
1264
1265 base.commands.push(ArcComputeCommand::PopDebugGroup);
1266
1267 Ok(())
1268 }
1269
1270 pub fn compute_pass_insert_debug_marker(
1271 &self,
1272 pass: &mut ComputePass,
1273 label: &str,
1274 color: u32,
1275 ) -> Result<(), ComputePassError> {
1276 let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
1277
1278 let bytes = label.as_bytes();
1279 base.string_data.extend_from_slice(bytes);
1280
1281 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1282 color,
1283 len: bytes.len(),
1284 });
1285
1286 Ok(())
1287 }
1288
1289 pub fn compute_pass_write_timestamp(
1290 &self,
1291 pass: &mut ComputePass,
1292 query_set_id: id::QuerySetId,
1293 query_index: u32,
1294 ) -> Result<(), ComputePassError> {
1295 let scope = PassErrorScope::WriteTimestamp;
1296 let base = pass.base_mut(scope)?;
1297
1298 let hub = &self.hub;
1299 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1300
1301 base.commands.push(ArcComputeCommand::WriteTimestamp {
1302 query_set,
1303 query_index,
1304 });
1305
1306 Ok(())
1307 }
1308
1309 pub fn compute_pass_begin_pipeline_statistics_query(
1310 &self,
1311 pass: &mut ComputePass,
1312 query_set_id: id::QuerySetId,
1313 query_index: u32,
1314 ) -> Result<(), ComputePassError> {
1315 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1316 let base = pass.base_mut(scope)?;
1317
1318 let hub = &self.hub;
1319 let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1320
1321 base.commands
1322 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1323 query_set,
1324 query_index,
1325 });
1326
1327 Ok(())
1328 }
1329
1330 pub fn compute_pass_end_pipeline_statistics_query(
1331 &self,
1332 pass: &mut ComputePass,
1333 ) -> Result<(), ComputePassError> {
1334 let scope = PassErrorScope::EndPipelineStatisticsQuery;
1335 let base = pass.base_mut(scope)?;
1336 base.commands
1337 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1338
1339 Ok(())
1340 }
1341}