wgpu_core/command/
compute.rs

1use crate::{
2    binding_model::{
3        BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4    },
5    command::{
6        bind::Binder,
7        compute_command::ArcComputeCommand,
8        end_pipeline_statistics_query,
9        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10        validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
11        BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr, PassErrorScope,
12        PassTimestampWrites, QueryUseError, StateChange,
13    },
14    device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
15    global::Global,
16    hal_label, id,
17    init_tracker::{BufferInitTrackerAction, MemoryInitKind},
18    pipeline::ComputePipeline,
19    resource::{
20        self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
21        MissingBufferUsageError, ParentDevice,
22    },
23    snatch::SnatchGuard,
24    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
25    Label,
26};
27
28use thiserror::Error;
29use wgt::{BufferAddress, DynamicOffset};
30
31use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
32use crate::ray_tracing::TlasAction;
33use std::{fmt, mem::size_of, str, sync::Arc};
34
35pub struct ComputePass {
36    /// All pass data & records is stored here.
37    ///
38    /// If this is `None`, the pass is in the 'ended' state and can no longer be used.
39    /// Any attempt to record more commands will result in a validation error.
40    base: Option<BasePass<ArcComputeCommand>>,
41
42    /// Parent command buffer that this pass records commands into.
43    ///
44    /// If it is none, this pass is invalid and any operation on it will return an error.
45    parent: Option<Arc<CommandBuffer>>,
46
47    timestamp_writes: Option<ArcPassTimestampWrites>,
48
49    // Resource binding dedupe state.
50    current_bind_groups: BindGroupStateChange,
51    current_pipeline: StateChange<id::ComputePipelineId>,
52}
53
54impl ComputePass {
55    /// If the parent command buffer is invalid, the returned pass will be invalid.
56    fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
57        let ArcComputePassDescriptor {
58            label,
59            timestamp_writes,
60        } = desc;
61
62        Self {
63            base: Some(BasePass::new(label)),
64            parent,
65            timestamp_writes,
66
67            current_bind_groups: BindGroupStateChange::new(),
68            current_pipeline: StateChange::new(),
69        }
70    }
71
72    #[inline]
73    pub fn label(&self) -> Option<&str> {
74        self.base.as_ref().and_then(|base| base.label.as_deref())
75    }
76
77    fn base_mut<'a>(
78        &'a mut self,
79        scope: PassErrorScope,
80    ) -> Result<&'a mut BasePass<ArcComputeCommand>, ComputePassError> {
81        self.base
82            .as_mut()
83            .ok_or(ComputePassErrorInner::PassEnded)
84            .map_pass_err(scope)
85    }
86}
87
88impl fmt::Debug for ComputePass {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self.parent {
91            Some(ref cmd_buf) => write!(f, "ComputePass {{ parent: {} }}", cmd_buf.error_ident()),
92            None => write!(f, "ComputePass {{ parent: None }}"),
93        }
94    }
95}
96
97#[derive(Clone, Debug, Default)]
98pub struct ComputePassDescriptor<'a> {
99    pub label: Label<'a>,
100    /// Defines where and when timestamp values will be written for this pass.
101    pub timestamp_writes: Option<&'a PassTimestampWrites>,
102}
103
104struct ArcComputePassDescriptor<'a> {
105    pub label: &'a Label<'a>,
106    /// Defines where and when timestamp values will be written for this pass.
107    pub timestamp_writes: Option<ArcPassTimestampWrites>,
108}
109
110#[derive(Clone, Debug, Error)]
111#[non_exhaustive]
112pub enum DispatchError {
113    #[error("Compute pipeline must be set")]
114    MissingPipeline,
115    #[error(transparent)]
116    IncompatibleBindGroup(#[from] Box<BinderError>),
117    #[error(
118        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
119    )]
120    InvalidGroupSize { current: [u32; 3], limit: u32 },
121    #[error(transparent)]
122    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
123}
124
125/// Error encountered when performing a compute pass.
126#[derive(Clone, Debug, Error)]
127pub enum ComputePassErrorInner {
128    #[error(transparent)]
129    Device(#[from] DeviceError),
130    #[error(transparent)]
131    Encoder(#[from] CommandEncoderError),
132    #[error("Parent encoder is invalid")]
133    InvalidParentEncoder,
134    #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
135    BindGroupIndexOutOfRange { index: u32, max: u32 },
136    #[error(transparent)]
137    DestroyedResource(#[from] DestroyedResourceError),
138    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
139    UnalignedIndirectBufferOffset(BufferAddress),
140    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
141    IndirectBufferOverrun {
142        offset: u64,
143        end_offset: u64,
144        buffer_size: u64,
145    },
146    #[error(transparent)]
147    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
148    #[error(transparent)]
149    MissingBufferUsage(#[from] MissingBufferUsageError),
150    #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
151    InvalidPopDebugGroup,
152    #[error(transparent)]
153    Dispatch(#[from] DispatchError),
154    #[error(transparent)]
155    Bind(#[from] BindError),
156    #[error(transparent)]
157    PushConstants(#[from] PushConstantUploadError),
158    #[error("Push constant offset must be aligned to 4 bytes")]
159    PushConstantOffsetAlignment,
160    #[error("Push constant size must be aligned to 4 bytes")]
161    PushConstantSizeAlignment,
162    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
163    PushConstantOutOfMemory,
164    #[error(transparent)]
165    QueryUse(#[from] QueryUseError),
166    #[error(transparent)]
167    MissingFeatures(#[from] MissingFeatures),
168    #[error(transparent)]
169    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
170    #[error("The compute pass has already been ended and no further commands can be recorded")]
171    PassEnded,
172    #[error(transparent)]
173    InvalidResource(#[from] InvalidResourceError),
174}
175
176/// Error encountered when performing a compute pass.
177#[derive(Clone, Debug, Error)]
178#[error("{scope}")]
179pub struct ComputePassError {
180    pub scope: PassErrorScope,
181    #[source]
182    pub(super) inner: ComputePassErrorInner,
183}
184
185impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
186where
187    E: Into<ComputePassErrorInner>,
188{
189    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
190        self.map_err(|inner| ComputePassError {
191            scope,
192            inner: inner.into(),
193        })
194    }
195}
196
197struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
198    binder: Binder,
199    pipeline: Option<Arc<ComputePipeline>>,
200    scope: UsageScope<'scope>,
201    debug_scope_depth: u32,
202
203    snatch_guard: SnatchGuard<'snatch_guard>,
204
205    device: &'cmd_buf Arc<Device>,
206
207    raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
208
209    tracker: &'cmd_buf mut Tracker,
210    buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
211    texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
212    tlas_actions: &'cmd_buf mut Vec<TlasAction>,
213
214    temp_offsets: Vec<u32>,
215    dynamic_offset_count: usize,
216    string_offset: usize,
217    active_query: Option<(Arc<resource::QuerySet>, u32)>,
218
219    push_constants: Vec<u32>,
220
221    intermediate_trackers: Tracker,
222
223    /// Immediate texture inits required because of prior discards. Need to
224    /// be inserted before texture reads.
225    pending_discard_init_fixups: SurfacesInDiscardState,
226}
227
228impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
229    State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
230{
231    fn is_ready(&self) -> Result<(), DispatchError> {
232        if let Some(pipeline) = self.pipeline.as_ref() {
233            self.binder.check_compatibility(pipeline.as_ref())?;
234            self.binder.check_late_buffer_bindings()?;
235            Ok(())
236        } else {
237            Err(DispatchError::MissingPipeline)
238        }
239    }
240
241    // `extra_buffer` is there to represent the indirect buffer that is also
242    // part of the usage scope.
243    fn flush_states(
244        &mut self,
245        indirect_buffer: Option<TrackerIndex>,
246    ) -> Result<(), ResourceUsageCompatibilityError> {
247        for bind_group in self.binder.list_active() {
248            unsafe { self.scope.merge_bind_group(&bind_group.used)? };
249            // Note: stateless trackers are not merged: the lifetime reference
250            // is held to the bind group itself.
251        }
252
253        for bind_group in self.binder.list_active() {
254            unsafe {
255                self.intermediate_trackers
256                    .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
257            }
258        }
259
260        // Add the state of the indirect buffer if it hasn't been hit before.
261        unsafe {
262            self.intermediate_trackers
263                .buffers
264                .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
265        }
266
267        CommandBuffer::drain_barriers(
268            self.raw_encoder,
269            &mut self.intermediate_trackers,
270            &self.snatch_guard,
271        );
272        Ok(())
273    }
274}
275
276// Running the compute pass.
277
278impl Global {
279    /// Creates a compute pass.
280    ///
281    /// If creation fails, an invalid pass is returned.
282    /// Any operation on an invalid pass will return an error.
283    ///
284    /// If successful, puts the encoder into the [`Locked`] state.
285    ///
286    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
287    pub fn command_encoder_create_compute_pass(
288        &self,
289        encoder_id: id::CommandEncoderId,
290        desc: &ComputePassDescriptor<'_>,
291    ) -> (ComputePass, Option<CommandEncoderError>) {
292        let hub = &self.hub;
293
294        let mut arc_desc = ArcComputePassDescriptor {
295            label: &desc.label,
296            timestamp_writes: None, // Handle only once we resolved the encoder.
297        };
298
299        let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
300
301        let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
302
303        match cmd_buf.data.lock().lock_encoder() {
304            Ok(_) => {}
305            Err(e) => return make_err(e, arc_desc),
306        };
307
308        arc_desc.timestamp_writes = match desc
309            .timestamp_writes
310            .map(|tw| {
311                Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
312            })
313            .transpose()
314        {
315            Ok(ok) => ok,
316            Err(e) => return make_err(e, arc_desc),
317        };
318
319        (ComputePass::new(Some(cmd_buf), arc_desc), None)
320    }
321
322    /// Note that this differs from [`Self::compute_pass_end`], it will
323    /// create a new pass, replay the commands and end the pass.
324    #[doc(hidden)]
325    #[cfg(any(feature = "serde", feature = "replay"))]
326    pub fn compute_pass_end_with_unresolved_commands(
327        &self,
328        encoder_id: id::CommandEncoderId,
329        base: BasePass<super::ComputeCommand>,
330        timestamp_writes: Option<&PassTimestampWrites>,
331    ) -> Result<(), ComputePassError> {
332        let pass_scope = PassErrorScope::Pass;
333
334        #[cfg(feature = "trace")]
335        {
336            let cmd_buf = self
337                .hub
338                .command_buffers
339                .get(encoder_id.into_command_buffer_id());
340            let mut cmd_buf_data = cmd_buf.data.lock();
341            let cmd_buf_data = cmd_buf_data.get_inner().map_pass_err(pass_scope)?;
342
343            if let Some(ref mut list) = cmd_buf_data.commands {
344                list.push(crate::device::trace::Command::RunComputePass {
345                    base: BasePass {
346                        label: base.label.clone(),
347                        commands: base.commands.clone(),
348                        dynamic_offsets: base.dynamic_offsets.clone(),
349                        string_data: base.string_data.clone(),
350                        push_constant_data: base.push_constant_data.clone(),
351                    },
352                    timestamp_writes: timestamp_writes.cloned(),
353                });
354            }
355        }
356
357        let BasePass {
358            label,
359            commands,
360            dynamic_offsets,
361            string_data,
362            push_constant_data,
363        } = base;
364
365        let (mut compute_pass, encoder_error) = self.command_encoder_create_compute_pass(
366            encoder_id,
367            &ComputePassDescriptor {
368                label: label.as_deref().map(std::borrow::Cow::Borrowed),
369                timestamp_writes,
370            },
371        );
372        if let Some(err) = encoder_error {
373            return Err(ComputePassError {
374                scope: pass_scope,
375                inner: err.into(),
376            });
377        };
378
379        compute_pass.base = Some(BasePass {
380            label,
381            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)?,
382            dynamic_offsets,
383            string_data,
384            push_constant_data,
385        });
386
387        self.compute_pass_end(&mut compute_pass)
388    }
389
390    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
391        profiling::scope!("CommandEncoder::run_compute_pass");
392        let pass_scope = PassErrorScope::Pass;
393
394        let cmd_buf = pass
395            .parent
396            .as_ref()
397            .ok_or(ComputePassErrorInner::InvalidParentEncoder)
398            .map_pass_err(pass_scope)?;
399
400        let base = pass
401            .base
402            .take()
403            .ok_or(ComputePassErrorInner::PassEnded)
404            .map_pass_err(pass_scope)?;
405
406        let device = &cmd_buf.device;
407        device.check_is_valid().map_pass_err(pass_scope)?;
408
409        let mut cmd_buf_data = cmd_buf.data.lock();
410        let mut cmd_buf_data_guard = cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
411        let cmd_buf_data = &mut *cmd_buf_data_guard;
412
413        let encoder = &mut cmd_buf_data.encoder;
414
415        // We automatically keep extending command buffers over time, and because
416        // we want to insert a command buffer _before_ what we're about to record,
417        // we need to make sure to close the previous one.
418        encoder.close_if_open().map_pass_err(pass_scope)?;
419        let raw_encoder = encoder
420            .open_pass(base.label.as_deref())
421            .map_pass_err(pass_scope)?;
422
423        let mut state = State {
424            binder: Binder::new(),
425            pipeline: None,
426            scope: device.new_usage_scope(),
427            debug_scope_depth: 0,
428
429            snatch_guard: device.snatchable_lock.read(),
430
431            device,
432            raw_encoder,
433            tracker: &mut cmd_buf_data.trackers,
434            buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
435            texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
436            tlas_actions: &mut cmd_buf_data.tlas_actions,
437
438            temp_offsets: Vec::new(),
439            dynamic_offset_count: 0,
440            string_offset: 0,
441            active_query: None,
442
443            push_constants: Vec::new(),
444
445            intermediate_trackers: Tracker::new(),
446
447            pending_discard_init_fixups: SurfacesInDiscardState::new(),
448        };
449
450        let indices = &state.device.tracker_indices;
451        state.tracker.buffers.set_size(indices.buffers.size());
452        state.tracker.textures.set_size(indices.textures.size());
453
454        let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
455            if let Some(tw) = pass.timestamp_writes.take() {
456                tw.query_set
457                    .same_device_as(cmd_buf.as_ref())
458                    .map_pass_err(pass_scope)?;
459
460                let query_set = state.tracker.query_sets.insert_single(tw.query_set);
461
462                // Unlike in render passes we can't delay resetting the query sets since
463                // there is no auxiliary pass.
464                let range = if let (Some(index_a), Some(index_b)) =
465                    (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
466                {
467                    Some(index_a.min(index_b)..index_a.max(index_b) + 1)
468                } else {
469                    tw.beginning_of_pass_write_index
470                        .or(tw.end_of_pass_write_index)
471                        .map(|i| i..i + 1)
472                };
473                // Range should always be Some, both values being None should lead to a validation error.
474                // But no point in erroring over that nuance here!
475                if let Some(range) = range {
476                    unsafe {
477                        state.raw_encoder.reset_queries(query_set.raw(), range);
478                    }
479                }
480
481                Some(hal::PassTimestampWrites {
482                    query_set: query_set.raw(),
483                    beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
484                    end_of_pass_write_index: tw.end_of_pass_write_index,
485                })
486            } else {
487                None
488            };
489
490        let hal_desc = hal::ComputePassDescriptor {
491            label: hal_label(base.label.as_deref(), device.instance_flags),
492            timestamp_writes,
493        };
494
495        unsafe {
496            state.raw_encoder.begin_compute_pass(&hal_desc);
497        }
498
499        for command in base.commands {
500            match command {
501                ArcComputeCommand::SetBindGroup {
502                    index,
503                    num_dynamic_offsets,
504                    bind_group,
505                } => {
506                    let scope = PassErrorScope::SetBindGroup;
507                    set_bind_group(
508                        &mut state,
509                        cmd_buf,
510                        &base.dynamic_offsets,
511                        index,
512                        num_dynamic_offsets,
513                        bind_group,
514                    )
515                    .map_pass_err(scope)?;
516                }
517                ArcComputeCommand::SetPipeline(pipeline) => {
518                    let scope = PassErrorScope::SetPipelineCompute;
519                    set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
520                }
521                ArcComputeCommand::SetPushConstant {
522                    offset,
523                    size_bytes,
524                    values_offset,
525                } => {
526                    let scope = PassErrorScope::SetPushConstant;
527                    set_push_constant(
528                        &mut state,
529                        &base.push_constant_data,
530                        offset,
531                        size_bytes,
532                        values_offset,
533                    )
534                    .map_pass_err(scope)?;
535                }
536                ArcComputeCommand::Dispatch(groups) => {
537                    let scope = PassErrorScope::Dispatch { indirect: false };
538                    dispatch(&mut state, groups).map_pass_err(scope)?;
539                }
540                ArcComputeCommand::DispatchIndirect { buffer, offset } => {
541                    let scope = PassErrorScope::Dispatch { indirect: true };
542                    dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
543                }
544                ArcComputeCommand::PushDebugGroup { color: _, len } => {
545                    push_debug_group(&mut state, &base.string_data, len);
546                }
547                ArcComputeCommand::PopDebugGroup => {
548                    let scope = PassErrorScope::PopDebugGroup;
549                    pop_debug_group(&mut state).map_pass_err(scope)?;
550                }
551                ArcComputeCommand::InsertDebugMarker { color: _, len } => {
552                    insert_debug_marker(&mut state, &base.string_data, len);
553                }
554                ArcComputeCommand::WriteTimestamp {
555                    query_set,
556                    query_index,
557                } => {
558                    let scope = PassErrorScope::WriteTimestamp;
559                    write_timestamp(&mut state, cmd_buf, query_set, query_index)
560                        .map_pass_err(scope)?;
561                }
562                ArcComputeCommand::BeginPipelineStatisticsQuery {
563                    query_set,
564                    query_index,
565                } => {
566                    let scope = PassErrorScope::BeginPipelineStatisticsQuery;
567                    validate_and_begin_pipeline_statistics_query(
568                        query_set,
569                        state.raw_encoder,
570                        &mut state.tracker.query_sets,
571                        cmd_buf,
572                        query_index,
573                        None,
574                        &mut state.active_query,
575                    )
576                    .map_pass_err(scope)?;
577                }
578                ArcComputeCommand::EndPipelineStatisticsQuery => {
579                    let scope = PassErrorScope::EndPipelineStatisticsQuery;
580                    end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
581                        .map_pass_err(scope)?;
582                }
583            }
584        }
585
586        unsafe {
587            state.raw_encoder.end_compute_pass();
588        }
589
590        let State {
591            snatch_guard,
592            tracker,
593            intermediate_trackers,
594            pending_discard_init_fixups,
595            ..
596        } = state;
597
598        // Stop the current command buffer.
599        encoder.close().map_pass_err(pass_scope)?;
600
601        // Create a new command buffer, which we will insert _before_ the body of the compute pass.
602        //
603        // Use that buffer to insert barriers and clear discarded images.
604        let transit = encoder
605            .open_pass(Some("(wgpu internal) Pre Pass"))
606            .map_pass_err(pass_scope)?;
607        fixup_discarded_surfaces(
608            pending_discard_init_fixups.into_iter(),
609            transit,
610            &mut tracker.textures,
611            device,
612            &snatch_guard,
613        );
614        CommandBuffer::insert_barriers_from_tracker(
615            transit,
616            tracker,
617            &intermediate_trackers,
618            &snatch_guard,
619        );
620        // Close the command buffer, and swap it with the previous.
621        encoder.close_and_swap().map_pass_err(pass_scope)?;
622        cmd_buf_data_guard.mark_successful();
623
624        Ok(())
625    }
626}
627
628fn set_bind_group(
629    state: &mut State,
630    cmd_buf: &CommandBuffer,
631    dynamic_offsets: &[DynamicOffset],
632    index: u32,
633    num_dynamic_offsets: usize,
634    bind_group: Option<Arc<BindGroup>>,
635) -> Result<(), ComputePassErrorInner> {
636    let max_bind_groups = state.device.limits.max_bind_groups;
637    if index >= max_bind_groups {
638        return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
639            index,
640            max: max_bind_groups,
641        });
642    }
643
644    state.temp_offsets.clear();
645    state.temp_offsets.extend_from_slice(
646        &dynamic_offsets
647            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
648    );
649    state.dynamic_offset_count += num_dynamic_offsets;
650
651    if bind_group.is_none() {
652        // TODO: Handle bind_group None.
653        return Ok(());
654    }
655
656    let bind_group = bind_group.unwrap();
657    let bind_group = state.tracker.bind_groups.insert_single(bind_group);
658
659    bind_group.same_device_as(cmd_buf)?;
660
661    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
662
663    state
664        .buffer_memory_init_actions
665        .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
666            action
667                .buffer
668                .initialization_status
669                .read()
670                .check_action(action)
671        }));
672
673    for action in bind_group.used_texture_ranges.iter() {
674        state
675            .pending_discard_init_fixups
676            .extend(state.texture_memory_actions.register_init_action(action));
677    }
678
679    let used_resource = bind_group
680        .used
681        .acceleration_structures
682        .into_iter()
683        .map(|tlas| TlasAction {
684            tlas: tlas.clone(),
685            kind: crate::ray_tracing::TlasActionKind::Use,
686        });
687
688    state.tlas_actions.extend(used_resource);
689
690    let pipeline_layout = state.binder.pipeline_layout.clone();
691    let entries = state
692        .binder
693        .assign_group(index as usize, bind_group, &state.temp_offsets);
694    if !entries.is_empty() && pipeline_layout.is_some() {
695        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
696        for (i, e) in entries.iter().enumerate() {
697            if let Some(group) = e.group.as_ref() {
698                let raw_bg = group.try_raw(&state.snatch_guard)?;
699                unsafe {
700                    state.raw_encoder.set_bind_group(
701                        pipeline_layout,
702                        index + i as u32,
703                        Some(raw_bg),
704                        &e.dynamic_offsets,
705                    );
706                }
707            }
708        }
709    }
710    Ok(())
711}
712
713fn set_pipeline(
714    state: &mut State,
715    cmd_buf: &CommandBuffer,
716    pipeline: Arc<ComputePipeline>,
717) -> Result<(), ComputePassErrorInner> {
718    pipeline.same_device_as(cmd_buf)?;
719
720    state.pipeline = Some(pipeline.clone());
721
722    let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
723
724    unsafe {
725        state.raw_encoder.set_compute_pipeline(pipeline.raw());
726    }
727
728    // Rebind resources
729    if state.binder.pipeline_layout.is_none()
730        || !state
731            .binder
732            .pipeline_layout
733            .as_ref()
734            .unwrap()
735            .is_equal(&pipeline.layout)
736    {
737        let (start_index, entries) = state
738            .binder
739            .change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
740        if !entries.is_empty() {
741            for (i, e) in entries.iter().enumerate() {
742                if let Some(group) = e.group.as_ref() {
743                    let raw_bg = group.try_raw(&state.snatch_guard)?;
744                    unsafe {
745                        state.raw_encoder.set_bind_group(
746                            pipeline.layout.raw(),
747                            start_index as u32 + i as u32,
748                            Some(raw_bg),
749                            &e.dynamic_offsets,
750                        );
751                    }
752                }
753            }
754        }
755
756        // TODO: integrate this in the code below once we simplify push constants
757        state.push_constants.clear();
758        // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
759        if let Some(push_constant_range) =
760            pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
761                pcr.stages
762                    .contains(wgt::ShaderStages::COMPUTE)
763                    .then_some(pcr.range.clone())
764            })
765        {
766            // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
767            let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
768            state.push_constants.extend(core::iter::repeat(0).take(len));
769        }
770
771        // Clear push constant ranges
772        let non_overlapping =
773            super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
774        for range in non_overlapping {
775            let offset = range.range.start;
776            let size_bytes = range.range.end - offset;
777            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
778                state.raw_encoder.set_push_constants(
779                    pipeline.layout.raw(),
780                    wgt::ShaderStages::COMPUTE,
781                    clear_offset,
782                    clear_data,
783                );
784            });
785        }
786    }
787    Ok(())
788}
789
790fn set_push_constant(
791    state: &mut State,
792    push_constant_data: &[u32],
793    offset: u32,
794    size_bytes: u32,
795    values_offset: u32,
796) -> Result<(), ComputePassErrorInner> {
797    let end_offset_bytes = offset + size_bytes;
798    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
799    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
800
801    let pipeline_layout = state
802        .binder
803        .pipeline_layout
804        .as_ref()
805        // TODO: don't error here, lazily update the push constants using `state.push_constants`
806        .ok_or(ComputePassErrorInner::Dispatch(
807            DispatchError::MissingPipeline,
808        ))?;
809
810    pipeline_layout.validate_push_constant_ranges(
811        wgt::ShaderStages::COMPUTE,
812        offset,
813        end_offset_bytes,
814    )?;
815
816    let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
817    let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
818    state.push_constants[offset_in_elements..][..size_in_elements].copy_from_slice(data_slice);
819
820    unsafe {
821        state.raw_encoder.set_push_constants(
822            pipeline_layout.raw(),
823            wgt::ShaderStages::COMPUTE,
824            offset,
825            data_slice,
826        );
827    }
828    Ok(())
829}
830
831fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
832    state.is_ready()?;
833
834    state.flush_states(None)?;
835
836    let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension;
837
838    if groups[0] > groups_size_limit
839        || groups[1] > groups_size_limit
840        || groups[2] > groups_size_limit
841    {
842        return Err(ComputePassErrorInner::Dispatch(
843            DispatchError::InvalidGroupSize {
844                current: groups,
845                limit: groups_size_limit,
846            },
847        ));
848    }
849
850    unsafe {
851        state.raw_encoder.dispatch(groups);
852    }
853    Ok(())
854}
855
856fn dispatch_indirect(
857    state: &mut State,
858    cmd_buf: &CommandBuffer,
859    buffer: Arc<Buffer>,
860    offset: u64,
861) -> Result<(), ComputePassErrorInner> {
862    buffer.same_device_as(cmd_buf)?;
863
864    state.is_ready()?;
865
866    state
867        .device
868        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
869
870    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
871
872    if offset % 4 != 0 {
873        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
874    }
875
876    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
877    if end_offset > buffer.size {
878        return Err(ComputePassErrorInner::IndirectBufferOverrun {
879            offset,
880            end_offset,
881            buffer_size: buffer.size,
882        });
883    }
884
885    let stride = 3 * 4; // 3 integers, x/y/z group size
886    state
887        .buffer_memory_init_actions
888        .extend(buffer.initialization_status.read().create_action(
889            &buffer,
890            offset..(offset + stride),
891            MemoryInitKind::NeedsInitializedMemory,
892        ));
893
894    #[cfg(feature = "indirect-validation")]
895    {
896        let params = state.device.indirect_validation.as_ref().unwrap().params(
897            &state.device.limits,
898            offset,
899            buffer.size,
900        );
901
902        unsafe {
903            state.raw_encoder.set_compute_pipeline(params.pipeline);
904        }
905
906        unsafe {
907            state.raw_encoder.set_push_constants(
908                params.pipeline_layout,
909                wgt::ShaderStages::COMPUTE,
910                0,
911                &[params.offset_remainder as u32 / 4],
912            );
913        }
914
915        unsafe {
916            state.raw_encoder.set_bind_group(
917                params.pipeline_layout,
918                0,
919                Some(params.dst_bind_group),
920                &[],
921            );
922        }
923        unsafe {
924            state.raw_encoder.set_bind_group(
925                params.pipeline_layout,
926                1,
927                Some(
928                    buffer
929                        .raw_indirect_validation_bind_group
930                        .get(&state.snatch_guard)
931                        .unwrap()
932                        .as_ref(),
933                ),
934                &[params.aligned_offset as u32],
935            );
936        }
937
938        let src_transition = state
939            .intermediate_trackers
940            .buffers
941            .set_single(&buffer, hal::BufferUses::STORAGE_READ_ONLY);
942        let src_barrier =
943            src_transition.map(|transition| transition.into_hal(&buffer, &state.snatch_guard));
944        unsafe {
945            state.raw_encoder.transition_buffers(src_barrier.as_slice());
946        }
947
948        unsafe {
949            state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
950                buffer: params.dst_buffer,
951                usage: hal::StateTransition {
952                    from: hal::BufferUses::INDIRECT,
953                    to: hal::BufferUses::STORAGE_READ_WRITE,
954                },
955            }]);
956        }
957
958        unsafe {
959            state.raw_encoder.dispatch([1, 1, 1]);
960        }
961
962        // reset state
963        {
964            let pipeline = state.pipeline.as_ref().unwrap();
965
966            unsafe {
967                state.raw_encoder.set_compute_pipeline(pipeline.raw());
968            }
969
970            if !state.push_constants.is_empty() {
971                unsafe {
972                    state.raw_encoder.set_push_constants(
973                        pipeline.layout.raw(),
974                        wgt::ShaderStages::COMPUTE,
975                        0,
976                        &state.push_constants,
977                    );
978                }
979            }
980
981            for (i, e) in state.binder.list_valid() {
982                let group = e.group.as_ref().unwrap();
983                let raw_bg = group.try_raw(&state.snatch_guard)?;
984                unsafe {
985                    state.raw_encoder.set_bind_group(
986                        pipeline.layout.raw(),
987                        i as u32,
988                        Some(raw_bg),
989                        &e.dynamic_offsets,
990                    );
991                }
992            }
993        }
994
995        unsafe {
996            state.raw_encoder.transition_buffers(&[hal::BufferBarrier {
997                buffer: params.dst_buffer,
998                usage: hal::StateTransition {
999                    from: hal::BufferUses::STORAGE_READ_WRITE,
1000                    to: hal::BufferUses::INDIRECT,
1001                },
1002            }]);
1003        }
1004
1005        state.flush_states(None)?;
1006        unsafe {
1007            state.raw_encoder.dispatch_indirect(params.dst_buffer, 0);
1008        }
1009    };
1010    #[cfg(not(feature = "indirect-validation"))]
1011    {
1012        state
1013            .scope
1014            .buffers
1015            .merge_single(&buffer, hal::BufferUses::INDIRECT)?;
1016
1017        use crate::resource::Trackable;
1018        state.flush_states(Some(buffer.tracker_index()))?;
1019
1020        let buf_raw = buffer.try_raw(&state.snatch_guard)?;
1021        unsafe {
1022            state.raw_encoder.dispatch_indirect(buf_raw, offset);
1023        }
1024    }
1025
1026    Ok(())
1027}
1028
1029fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) {
1030    state.debug_scope_depth += 1;
1031    if !state
1032        .device
1033        .instance_flags
1034        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1035    {
1036        let label =
1037            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1038        unsafe {
1039            state.raw_encoder.begin_debug_marker(label);
1040        }
1041    }
1042    state.string_offset += len;
1043}
1044
1045fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> {
1046    if state.debug_scope_depth == 0 {
1047        return Err(ComputePassErrorInner::InvalidPopDebugGroup);
1048    }
1049    state.debug_scope_depth -= 1;
1050    if !state
1051        .device
1052        .instance_flags
1053        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1054    {
1055        unsafe {
1056            state.raw_encoder.end_debug_marker();
1057        }
1058    }
1059    Ok(())
1060}
1061
1062fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) {
1063    if !state
1064        .device
1065        .instance_flags
1066        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
1067    {
1068        let label =
1069            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
1070        unsafe { state.raw_encoder.insert_debug_marker(label) }
1071    }
1072    state.string_offset += len;
1073}
1074
1075fn write_timestamp(
1076    state: &mut State,
1077    cmd_buf: &CommandBuffer,
1078    query_set: Arc<resource::QuerySet>,
1079    query_index: u32,
1080) -> Result<(), ComputePassErrorInner> {
1081    query_set.same_device_as(cmd_buf)?;
1082
1083    state
1084        .device
1085        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
1086
1087    let query_set = state.tracker.query_sets.insert_single(query_set);
1088
1089    query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
1090    Ok(())
1091}
1092
1093// Recording a compute pass.
1094impl Global {
1095    pub fn compute_pass_set_bind_group(
1096        &self,
1097        pass: &mut ComputePass,
1098        index: u32,
1099        bind_group_id: Option<id::BindGroupId>,
1100        offsets: &[DynamicOffset],
1101    ) -> Result<(), ComputePassError> {
1102        let scope = PassErrorScope::SetBindGroup;
1103        let base = pass
1104            .base
1105            .as_mut()
1106            .ok_or(ComputePassErrorInner::PassEnded)
1107            .map_pass_err(scope)?; // Can't use base_mut() utility here because of borrow checker.
1108
1109        let redundant = pass.current_bind_groups.set_and_check_redundant(
1110            bind_group_id,
1111            index,
1112            &mut base.dynamic_offsets,
1113            offsets,
1114        );
1115
1116        if redundant {
1117            return Ok(());
1118        }
1119
1120        let mut bind_group = None;
1121        if bind_group_id.is_some() {
1122            let bind_group_id = bind_group_id.unwrap();
1123
1124            let hub = &self.hub;
1125            let bg = hub
1126                .bind_groups
1127                .get(bind_group_id)
1128                .get()
1129                .map_pass_err(scope)?;
1130            bind_group = Some(bg);
1131        }
1132
1133        base.commands.push(ArcComputeCommand::SetBindGroup {
1134            index,
1135            num_dynamic_offsets: offsets.len(),
1136            bind_group,
1137        });
1138
1139        Ok(())
1140    }
1141
1142    pub fn compute_pass_set_pipeline(
1143        &self,
1144        pass: &mut ComputePass,
1145        pipeline_id: id::ComputePipelineId,
1146    ) -> Result<(), ComputePassError> {
1147        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1148
1149        let scope = PassErrorScope::SetPipelineCompute;
1150
1151        let base = pass.base_mut(scope)?;
1152        if redundant {
1153            // Do redundant early-out **after** checking whether the pass is ended or not.
1154            return Ok(());
1155        }
1156
1157        let hub = &self.hub;
1158        let pipeline = hub
1159            .compute_pipelines
1160            .get(pipeline_id)
1161            .get()
1162            .map_pass_err(scope)?;
1163
1164        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1165
1166        Ok(())
1167    }
1168
1169    pub fn compute_pass_set_push_constants(
1170        &self,
1171        pass: &mut ComputePass,
1172        offset: u32,
1173        data: &[u8],
1174    ) -> Result<(), ComputePassError> {
1175        let scope = PassErrorScope::SetPushConstant;
1176        let base = pass.base_mut(scope)?;
1177
1178        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1179            return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
1180        }
1181
1182        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1183            return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
1184        }
1185        let value_offset = base
1186            .push_constant_data
1187            .len()
1188            .try_into()
1189            .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1190            .map_pass_err(scope)?;
1191
1192        base.push_constant_data.extend(
1193            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1194                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1195        );
1196
1197        base.commands.push(ArcComputeCommand::SetPushConstant {
1198            offset,
1199            size_bytes: data.len() as u32,
1200            values_offset: value_offset,
1201        });
1202
1203        Ok(())
1204    }
1205
1206    pub fn compute_pass_dispatch_workgroups(
1207        &self,
1208        pass: &mut ComputePass,
1209        groups_x: u32,
1210        groups_y: u32,
1211        groups_z: u32,
1212    ) -> Result<(), ComputePassError> {
1213        let scope = PassErrorScope::Dispatch { indirect: false };
1214
1215        let base = pass.base_mut(scope)?;
1216        base.commands
1217            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1218
1219        Ok(())
1220    }
1221
1222    pub fn compute_pass_dispatch_workgroups_indirect(
1223        &self,
1224        pass: &mut ComputePass,
1225        buffer_id: id::BufferId,
1226        offset: BufferAddress,
1227    ) -> Result<(), ComputePassError> {
1228        let hub = &self.hub;
1229        let scope = PassErrorScope::Dispatch { indirect: true };
1230        let base = pass.base_mut(scope)?;
1231
1232        let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
1233
1234        base.commands
1235            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1236
1237        Ok(())
1238    }
1239
1240    pub fn compute_pass_push_debug_group(
1241        &self,
1242        pass: &mut ComputePass,
1243        label: &str,
1244        color: u32,
1245    ) -> Result<(), ComputePassError> {
1246        let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
1247
1248        let bytes = label.as_bytes();
1249        base.string_data.extend_from_slice(bytes);
1250
1251        base.commands.push(ArcComputeCommand::PushDebugGroup {
1252            color,
1253            len: bytes.len(),
1254        });
1255
1256        Ok(())
1257    }
1258
1259    pub fn compute_pass_pop_debug_group(
1260        &self,
1261        pass: &mut ComputePass,
1262    ) -> Result<(), ComputePassError> {
1263        let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
1264
1265        base.commands.push(ArcComputeCommand::PopDebugGroup);
1266
1267        Ok(())
1268    }
1269
1270    pub fn compute_pass_insert_debug_marker(
1271        &self,
1272        pass: &mut ComputePass,
1273        label: &str,
1274        color: u32,
1275    ) -> Result<(), ComputePassError> {
1276        let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
1277
1278        let bytes = label.as_bytes();
1279        base.string_data.extend_from_slice(bytes);
1280
1281        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1282            color,
1283            len: bytes.len(),
1284        });
1285
1286        Ok(())
1287    }
1288
1289    pub fn compute_pass_write_timestamp(
1290        &self,
1291        pass: &mut ComputePass,
1292        query_set_id: id::QuerySetId,
1293        query_index: u32,
1294    ) -> Result<(), ComputePassError> {
1295        let scope = PassErrorScope::WriteTimestamp;
1296        let base = pass.base_mut(scope)?;
1297
1298        let hub = &self.hub;
1299        let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1300
1301        base.commands.push(ArcComputeCommand::WriteTimestamp {
1302            query_set,
1303            query_index,
1304        });
1305
1306        Ok(())
1307    }
1308
1309    pub fn compute_pass_begin_pipeline_statistics_query(
1310        &self,
1311        pass: &mut ComputePass,
1312        query_set_id: id::QuerySetId,
1313        query_index: u32,
1314    ) -> Result<(), ComputePassError> {
1315        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1316        let base = pass.base_mut(scope)?;
1317
1318        let hub = &self.hub;
1319        let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
1320
1321        base.commands
1322            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1323                query_set,
1324                query_index,
1325            });
1326
1327        Ok(())
1328    }
1329
1330    pub fn compute_pass_end_pipeline_statistics_query(
1331        &self,
1332        pass: &mut ComputePass,
1333    ) -> Result<(), ComputePassError> {
1334        let scope = PassErrorScope::EndPipelineStatisticsQuery;
1335        let base = pass.base_mut(scope)?;
1336        base.commands
1337            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1338
1339        Ok(())
1340    }
1341}