wgpu_core/command/
compute_command.rs

1use std::sync::Arc;
2
3use crate::{
4    binding_model::BindGroup,
5    id,
6    pipeline::ComputePipeline,
7    resource::{Buffer, QuerySet},
8};
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub enum ComputeCommand {
13    SetBindGroup {
14        index: u32,
15        num_dynamic_offsets: usize,
16        bind_group_id: Option<id::BindGroupId>,
17    },
18
19    SetPipeline(id::ComputePipelineId),
20
21    /// Set a range of push constants to values stored in `push_constant_data`.
22    SetPushConstant {
23        /// The byte offset within the push constant storage to write to. This
24        /// must be a multiple of four.
25        offset: u32,
26
27        /// The number of bytes to write. This must be a multiple of four.
28        size_bytes: u32,
29
30        /// Index in `push_constant_data` of the start of the data
31        /// to be written.
32        ///
33        /// Note: this is not a byte offset like `offset`. Rather, it is the
34        /// index of the first `u32` element in `push_constant_data` to read.
35        values_offset: u32,
36    },
37
38    Dispatch([u32; 3]),
39
40    DispatchIndirect {
41        buffer_id: id::BufferId,
42        offset: wgt::BufferAddress,
43    },
44
45    PushDebugGroup {
46        color: u32,
47        len: usize,
48    },
49
50    PopDebugGroup,
51
52    InsertDebugMarker {
53        color: u32,
54        len: usize,
55    },
56
57    WriteTimestamp {
58        query_set_id: id::QuerySetId,
59        query_index: u32,
60    },
61
62    BeginPipelineStatisticsQuery {
63        query_set_id: id::QuerySetId,
64        query_index: u32,
65    },
66
67    EndPipelineStatisticsQuery,
68}
69
70impl ComputeCommand {
71    /// Resolves all ids in a list of commands into the corresponding resource Arc.
72    #[cfg(any(feature = "serde", feature = "replay"))]
73    pub fn resolve_compute_command_ids(
74        hub: &crate::hub::Hub,
75        commands: &[ComputeCommand],
76    ) -> Result<Vec<ArcComputeCommand>, super::ComputePassError> {
77        use super::{ComputePassError, PassErrorScope};
78
79        let buffers_guard = hub.buffers.read();
80        let bind_group_guard = hub.bind_groups.read();
81        let query_set_guard = hub.query_sets.read();
82        let pipelines_guard = hub.compute_pipelines.read();
83
84        let resolved_commands: Vec<ArcComputeCommand> = commands
85            .iter()
86            .map(|c| -> Result<ArcComputeCommand, ComputePassError> {
87                Ok(match *c {
88                    ComputeCommand::SetBindGroup {
89                        index,
90                        num_dynamic_offsets,
91                        bind_group_id,
92                    } => {
93                        if bind_group_id.is_none() {
94                            return Ok(ArcComputeCommand::SetBindGroup {
95                                index,
96                                num_dynamic_offsets,
97                                bind_group: None,
98                            });
99                        }
100
101                        let bind_group_id = bind_group_id.unwrap();
102                        let bg = bind_group_guard.get(bind_group_id).get().map_err(|e| {
103                            ComputePassError {
104                                scope: PassErrorScope::SetBindGroup,
105                                inner: e.into(),
106                            }
107                        })?;
108
109                        ArcComputeCommand::SetBindGroup {
110                            index,
111                            num_dynamic_offsets,
112                            bind_group: Some(bg),
113                        }
114                    }
115                    ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
116                        pipelines_guard
117                            .get(pipeline_id)
118                            .get()
119                            .map_err(|e| ComputePassError {
120                                scope: PassErrorScope::SetPipelineCompute,
121                                inner: e.into(),
122                            })?,
123                    ),
124
125                    ComputeCommand::SetPushConstant {
126                        offset,
127                        size_bytes,
128                        values_offset,
129                    } => ArcComputeCommand::SetPushConstant {
130                        offset,
131                        size_bytes,
132                        values_offset,
133                    },
134
135                    ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
136
137                    ComputeCommand::DispatchIndirect { buffer_id, offset } => {
138                        ArcComputeCommand::DispatchIndirect {
139                            buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
140                                ComputePassError {
141                                    scope: PassErrorScope::Dispatch { indirect: true },
142                                    inner: e.into(),
143                                }
144                            })?,
145                            offset,
146                        }
147                    }
148
149                    ComputeCommand::PushDebugGroup { color, len } => {
150                        ArcComputeCommand::PushDebugGroup { color, len }
151                    }
152
153                    ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
154
155                    ComputeCommand::InsertDebugMarker { color, len } => {
156                        ArcComputeCommand::InsertDebugMarker { color, len }
157                    }
158
159                    ComputeCommand::WriteTimestamp {
160                        query_set_id,
161                        query_index,
162                    } => ArcComputeCommand::WriteTimestamp {
163                        query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
164                            ComputePassError {
165                                scope: PassErrorScope::WriteTimestamp,
166                                inner: e.into(),
167                            }
168                        })?,
169                        query_index,
170                    },
171
172                    ComputeCommand::BeginPipelineStatisticsQuery {
173                        query_set_id,
174                        query_index,
175                    } => ArcComputeCommand::BeginPipelineStatisticsQuery {
176                        query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
177                            ComputePassError {
178                                scope: PassErrorScope::BeginPipelineStatisticsQuery,
179                                inner: e.into(),
180                            }
181                        })?,
182                        query_index,
183                    },
184
185                    ComputeCommand::EndPipelineStatisticsQuery => {
186                        ArcComputeCommand::EndPipelineStatisticsQuery
187                    }
188                })
189            })
190            .collect::<Result<Vec<_>, ComputePassError>>()?;
191        Ok(resolved_commands)
192    }
193}
194
195/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
196#[derive(Clone, Debug)]
197pub enum ArcComputeCommand {
198    SetBindGroup {
199        index: u32,
200        num_dynamic_offsets: usize,
201        bind_group: Option<Arc<BindGroup>>,
202    },
203
204    SetPipeline(Arc<ComputePipeline>),
205
206    /// Set a range of push constants to values stored in `push_constant_data`.
207    SetPushConstant {
208        /// The byte offset within the push constant storage to write to. This
209        /// must be a multiple of four.
210        offset: u32,
211
212        /// The number of bytes to write. This must be a multiple of four.
213        size_bytes: u32,
214
215        /// Index in `push_constant_data` of the start of the data
216        /// to be written.
217        ///
218        /// Note: this is not a byte offset like `offset`. Rather, it is the
219        /// index of the first `u32` element in `push_constant_data` to read.
220        values_offset: u32,
221    },
222
223    Dispatch([u32; 3]),
224
225    DispatchIndirect {
226        buffer: Arc<Buffer>,
227        offset: wgt::BufferAddress,
228    },
229
230    PushDebugGroup {
231        #[cfg_attr(target_os = "emscripten", allow(dead_code))]
232        color: u32,
233        len: usize,
234    },
235
236    PopDebugGroup,
237
238    InsertDebugMarker {
239        #[cfg_attr(target_os = "emscripten", allow(dead_code))]
240        color: u32,
241        len: usize,
242    },
243
244    WriteTimestamp {
245        query_set: Arc<QuerySet>,
246        query_index: u32,
247    },
248
249    BeginPipelineStatisticsQuery {
250        query_set: Arc<QuerySet>,
251        query_index: u32,
252    },
253
254    EndPipelineStatisticsQuery,
255}