1use std::sync::Arc;
2
3use crate::{
4 binding_model::BindGroup,
5 id,
6 pipeline::ComputePipeline,
7 resource::{Buffer, QuerySet},
8};
9
10#[derive(Clone, Copy, Debug)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub enum ComputeCommand {
13 SetBindGroup {
14 index: u32,
15 num_dynamic_offsets: usize,
16 bind_group_id: Option<id::BindGroupId>,
17 },
18
19 SetPipeline(id::ComputePipelineId),
20
21 SetPushConstant {
23 offset: u32,
26
27 size_bytes: u32,
29
30 values_offset: u32,
36 },
37
38 Dispatch([u32; 3]),
39
40 DispatchIndirect {
41 buffer_id: id::BufferId,
42 offset: wgt::BufferAddress,
43 },
44
45 PushDebugGroup {
46 color: u32,
47 len: usize,
48 },
49
50 PopDebugGroup,
51
52 InsertDebugMarker {
53 color: u32,
54 len: usize,
55 },
56
57 WriteTimestamp {
58 query_set_id: id::QuerySetId,
59 query_index: u32,
60 },
61
62 BeginPipelineStatisticsQuery {
63 query_set_id: id::QuerySetId,
64 query_index: u32,
65 },
66
67 EndPipelineStatisticsQuery,
68}
69
70impl ComputeCommand {
71 #[cfg(any(feature = "serde", feature = "replay"))]
73 pub fn resolve_compute_command_ids(
74 hub: &crate::hub::Hub,
75 commands: &[ComputeCommand],
76 ) -> Result<Vec<ArcComputeCommand>, super::ComputePassError> {
77 use super::{ComputePassError, PassErrorScope};
78
79 let buffers_guard = hub.buffers.read();
80 let bind_group_guard = hub.bind_groups.read();
81 let query_set_guard = hub.query_sets.read();
82 let pipelines_guard = hub.compute_pipelines.read();
83
84 let resolved_commands: Vec<ArcComputeCommand> = commands
85 .iter()
86 .map(|c| -> Result<ArcComputeCommand, ComputePassError> {
87 Ok(match *c {
88 ComputeCommand::SetBindGroup {
89 index,
90 num_dynamic_offsets,
91 bind_group_id,
92 } => {
93 if bind_group_id.is_none() {
94 return Ok(ArcComputeCommand::SetBindGroup {
95 index,
96 num_dynamic_offsets,
97 bind_group: None,
98 });
99 }
100
101 let bind_group_id = bind_group_id.unwrap();
102 let bg = bind_group_guard.get(bind_group_id).get().map_err(|e| {
103 ComputePassError {
104 scope: PassErrorScope::SetBindGroup,
105 inner: e.into(),
106 }
107 })?;
108
109 ArcComputeCommand::SetBindGroup {
110 index,
111 num_dynamic_offsets,
112 bind_group: Some(bg),
113 }
114 }
115 ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
116 pipelines_guard
117 .get(pipeline_id)
118 .get()
119 .map_err(|e| ComputePassError {
120 scope: PassErrorScope::SetPipelineCompute,
121 inner: e.into(),
122 })?,
123 ),
124
125 ComputeCommand::SetPushConstant {
126 offset,
127 size_bytes,
128 values_offset,
129 } => ArcComputeCommand::SetPushConstant {
130 offset,
131 size_bytes,
132 values_offset,
133 },
134
135 ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
136
137 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
138 ArcComputeCommand::DispatchIndirect {
139 buffer: buffers_guard.get(buffer_id).get().map_err(|e| {
140 ComputePassError {
141 scope: PassErrorScope::Dispatch { indirect: true },
142 inner: e.into(),
143 }
144 })?,
145 offset,
146 }
147 }
148
149 ComputeCommand::PushDebugGroup { color, len } => {
150 ArcComputeCommand::PushDebugGroup { color, len }
151 }
152
153 ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
154
155 ComputeCommand::InsertDebugMarker { color, len } => {
156 ArcComputeCommand::InsertDebugMarker { color, len }
157 }
158
159 ComputeCommand::WriteTimestamp {
160 query_set_id,
161 query_index,
162 } => ArcComputeCommand::WriteTimestamp {
163 query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
164 ComputePassError {
165 scope: PassErrorScope::WriteTimestamp,
166 inner: e.into(),
167 }
168 })?,
169 query_index,
170 },
171
172 ComputeCommand::BeginPipelineStatisticsQuery {
173 query_set_id,
174 query_index,
175 } => ArcComputeCommand::BeginPipelineStatisticsQuery {
176 query_set: query_set_guard.get(query_set_id).get().map_err(|e| {
177 ComputePassError {
178 scope: PassErrorScope::BeginPipelineStatisticsQuery,
179 inner: e.into(),
180 }
181 })?,
182 query_index,
183 },
184
185 ComputeCommand::EndPipelineStatisticsQuery => {
186 ArcComputeCommand::EndPipelineStatisticsQuery
187 }
188 })
189 })
190 .collect::<Result<Vec<_>, ComputePassError>>()?;
191 Ok(resolved_commands)
192 }
193}
194
195#[derive(Clone, Debug)]
197pub enum ArcComputeCommand {
198 SetBindGroup {
199 index: u32,
200 num_dynamic_offsets: usize,
201 bind_group: Option<Arc<BindGroup>>,
202 },
203
204 SetPipeline(Arc<ComputePipeline>),
205
206 SetPushConstant {
208 offset: u32,
211
212 size_bytes: u32,
214
215 values_offset: u32,
221 },
222
223 Dispatch([u32; 3]),
224
225 DispatchIndirect {
226 buffer: Arc<Buffer>,
227 offset: wgt::BufferAddress,
228 },
229
230 PushDebugGroup {
231 #[cfg_attr(target_os = "emscripten", allow(dead_code))]
232 color: u32,
233 len: usize,
234 },
235
236 PopDebugGroup,
237
238 InsertDebugMarker {
239 #[cfg_attr(target_os = "emscripten", allow(dead_code))]
240 color: u32,
241 len: usize,
242 },
243
244 WriteTimestamp {
245 query_set: Arc<QuerySet>,
246 query_index: u32,
247 },
248
249 BeginPipelineStatisticsQuery {
250 query_set: Arc<QuerySet>,
251 query_index: u32,
252 },
253
254 EndPipelineStatisticsQuery,
255}