1use std::collections::HashSet;
16
17use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression};
18use crate::instruction::{
19 Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, CallResolutionError, Capture,
20 CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, ExternSignatureMap,
21 Gate, GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load,
22 MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture,
23 SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic,
24 Vector, WaveformInvocation,
25};
26
27#[derive(Clone, Debug, Hash, PartialEq)]
28pub struct MemoryRegion {
29 pub size: Vector,
30 pub sharing: Option<Sharing>,
31}
32
33impl MemoryRegion {
34 pub fn new(size: Vector, sharing: Option<Sharing>) -> Self {
35 Self { size, sharing }
36 }
37}
38
39impl Eq for MemoryRegion {}
40
41#[derive(Clone, Debug)]
42pub struct MemoryAccess {
43 pub regions: HashSet<String>,
44 pub access_type: MemoryAccessType,
45}
46
47#[derive(Clone, Debug, Default, PartialEq)]
48pub struct MemoryAccesses {
49 pub captures: HashSet<String>,
50 pub reads: HashSet<String>,
51 pub writes: HashSet<String>,
52}
53
54#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
56pub enum MemoryAccessType {
57 Read,
59
60 Write,
62
63 Capture,
65}
66
67macro_rules! merge_sets {
68 ($left:expr, $right:expr) => {
69 $left.union(&$right).cloned().collect::<HashSet<String>>()
70 };
71}
72
73macro_rules! set_from_reference_vec {
75 ($vec:expr) => {
76 $vec.into_iter()
77 .map(|el| el.clone())
78 .collect::<HashSet<String>>()
79 };
80}
81
82macro_rules! set_from_optional_memory_reference {
84 ($reference:expr) => {
85 set_from_reference_vec![$reference.map_or_else(Vec::new, |reference| vec![&reference.name])]
86 };
87}
88
89macro_rules! set_from_memory_references {
91 ($references:expr) => {
92 set_from_reference_vec![$references.iter().map(|reference| &reference.name)]
93 };
94}
95
96#[derive(thiserror::Error, Debug, PartialEq, Clone)]
97pub enum MemoryAccessesError {
98 #[error(transparent)]
99 CallResolution(#[from] CallResolutionError),
100}
101
102pub type MemoryAccessesResult = Result<MemoryAccesses, MemoryAccessesError>;
103
104impl Instruction {
105 pub fn get_memory_accesses(
112 &self,
113 extern_signature_map: &ExternSignatureMap,
114 ) -> MemoryAccessesResult {
115 Ok(match self {
116 Instruction::Convert(Convert {
117 source,
118 destination,
119 }) => MemoryAccesses {
120 reads: set_from_memory_references![[source]],
121 writes: set_from_memory_references![[destination]],
122 ..Default::default()
123 },
124 Instruction::Call(call) => call.get_memory_accesses(extern_signature_map)?,
125 Instruction::Comparison(Comparison {
126 destination,
127 lhs,
128 rhs,
129 operator: _,
130 }) => {
131 let mut reads = HashSet::from([lhs.name.clone()]);
132 let writes = HashSet::from([destination.name.clone()]);
133 if let ComparisonOperand::MemoryReference(mem) = &rhs {
134 reads.insert(mem.name.clone());
135 }
136
137 MemoryAccesses {
138 reads,
139 writes,
140 ..Default::default()
141 }
142 }
143 Instruction::BinaryLogic(BinaryLogic {
144 destination,
145 source,
146 operator: _,
147 }) => {
148 let mut reads = HashSet::new();
149 let mut writes = HashSet::new();
150 reads.insert(destination.name.clone());
151 writes.insert(destination.name.clone());
152 if let BinaryOperand::MemoryReference(mem) = &source {
153 reads.insert(mem.name.clone());
154 }
155
156 MemoryAccesses {
157 reads,
158 writes,
159 ..Default::default()
160 }
161 }
162 Instruction::UnaryLogic(UnaryLogic { operand, .. }) => MemoryAccesses {
163 reads: HashSet::from([operand.name.clone()]),
164 writes: HashSet::from([operand.name.clone()]),
165 ..Default::default()
166 },
167 Instruction::Arithmetic(Arithmetic {
168 destination,
169 source,
170 ..
171 }) => MemoryAccesses {
172 writes: HashSet::from([destination.name.clone()]),
173 reads: set_from_optional_memory_reference![source.get_memory_reference()],
174 ..Default::default()
175 },
176 Instruction::Move(Move {
177 destination,
178 source,
179 }) => MemoryAccesses {
180 writes: set_from_memory_references![[destination]],
181 reads: set_from_optional_memory_reference![source.get_memory_reference()],
182 ..Default::default()
183 },
184 Instruction::CalibrationDefinition(definition) => {
185 let references: Vec<&MemoryReference> = definition
186 .identifier
187 .parameters
188 .iter()
189 .flat_map(|expr| expr.get_memory_references())
190 .collect();
191 MemoryAccesses {
192 reads: set_from_memory_references![references],
193 ..Default::default()
194 }
195 }
196 Instruction::Capture(Capture {
197 memory_reference,
198 waveform,
199 ..
200 }) => MemoryAccesses {
201 captures: set_from_memory_references!([memory_reference]),
202 reads: set_from_memory_references!(waveform.get_memory_references()),
203 ..Default::default()
204 },
205 Instruction::CircuitDefinition(CircuitDefinition { instructions, .. })
206 | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition {
207 instructions,
208 ..
209 }) => instructions.iter().try_fold(
210 Default::default(),
211 |acc: MemoryAccesses, el| -> MemoryAccessesResult {
212 let el_accesses = el.get_memory_accesses(extern_signature_map)?;
213 Ok(MemoryAccesses {
214 reads: merge_sets!(acc.reads, el_accesses.reads),
215 writes: merge_sets!(acc.writes, el_accesses.writes),
216 captures: merge_sets!(acc.captures, el_accesses.captures),
217 })
218 },
219 )?,
220 Instruction::Delay(Delay { duration, .. }) => MemoryAccesses {
221 reads: set_from_memory_references!(duration.get_memory_references()),
222 ..Default::default()
223 },
224 Instruction::Exchange(Exchange { left, right }) => MemoryAccesses {
225 reads: set_from_memory_references![[left, right]],
226 writes: set_from_memory_references![[left, right]],
227 ..Default::default()
228 },
229 Instruction::Gate(Gate { parameters, .. }) => MemoryAccesses {
230 reads: set_from_memory_references!(parameters
231 .iter()
232 .flat_map(|param| param.get_memory_references())
233 .collect::<Vec<&MemoryReference>>()),
234 ..Default::default()
235 },
236 Instruction::GateDefinition(GateDefinition { specification, .. }) => {
237 if let GateSpecification::Matrix(matrix) = specification {
238 let references = matrix
239 .iter()
240 .flat_map(|row| row.iter().flat_map(|cell| cell.get_memory_references()))
241 .collect::<Vec<&MemoryReference>>();
242 MemoryAccesses {
243 reads: set_from_memory_references!(references),
244 ..Default::default()
245 }
246 } else {
247 Default::default()
248 }
249 }
250 Instruction::JumpWhen(JumpWhen {
251 target: _,
252 condition,
253 })
254 | Instruction::JumpUnless(JumpUnless {
255 target: _,
256 condition,
257 }) => MemoryAccesses {
258 reads: set_from_memory_references!([condition]),
259 ..Default::default()
260 },
261 Instruction::Load(Load {
262 destination,
263 source,
264 offset,
265 }) => MemoryAccesses {
266 writes: set_from_memory_references![[destination]],
267 reads: set_from_reference_vec![vec![source, &offset.name]],
268 ..Default::default()
269 },
270 Instruction::Measurement(Measurement { target, .. }) => MemoryAccesses {
271 captures: set_from_optional_memory_reference!(target.as_ref()),
272 ..Default::default()
273 },
274 Instruction::Pulse(Pulse { waveform, .. }) => MemoryAccesses {
275 reads: set_from_memory_references![waveform.get_memory_references()],
276 ..Default::default()
277 },
278 Instruction::RawCapture(RawCapture {
279 duration,
280 memory_reference,
281 ..
282 }) => MemoryAccesses {
283 reads: set_from_memory_references![duration.get_memory_references()],
284 captures: set_from_memory_references![[memory_reference]],
285 ..Default::default()
286 },
287 Instruction::SetPhase(SetPhase { phase: expr, .. })
288 | Instruction::SetScale(SetScale { scale: expr, .. })
289 | Instruction::ShiftPhase(ShiftPhase { phase: expr, .. }) => MemoryAccesses {
290 reads: set_from_memory_references!(expr.get_memory_references()),
291 ..Default::default()
292 },
293 Instruction::SetFrequency(SetFrequency { frequency, .. })
294 | Instruction::ShiftFrequency(ShiftFrequency { frequency, .. }) => MemoryAccesses {
295 reads: set_from_memory_references!(frequency.get_memory_references()),
296 ..Default::default()
297 },
298 Instruction::Store(Store {
299 destination,
300 offset,
301 source,
302 }) => {
303 let mut reads = vec![&offset.name];
304 if let Some(source) = source.get_memory_reference() {
305 reads.push(&source.name);
306 }
307 MemoryAccesses {
308 reads: set_from_reference_vec![reads],
309 writes: set_from_reference_vec![vec![destination]],
310 ..Default::default()
311 }
312 }
313 Instruction::Declaration(_)
314 | Instruction::Fence(_)
315 | Instruction::FrameDefinition(_)
316 | Instruction::Halt
317 | Instruction::Wait
318 | Instruction::Include(_)
319 | Instruction::Jump(_)
320 | Instruction::Label(_)
321 | Instruction::Nop
322 | Instruction::Pragma(_)
323 | Instruction::Reset(_)
324 | Instruction::SwapPhases(_)
325 | Instruction::WaveformDefinition(_) => Default::default(),
326 })
327 }
328}
329
330impl ArithmeticOperand {
331 pub fn get_memory_reference(&self) -> Option<&MemoryReference> {
332 match self {
333 ArithmeticOperand::LiteralInteger(_) => None,
334 ArithmeticOperand::LiteralReal(_) => None,
335 ArithmeticOperand::MemoryReference(reference) => Some(reference),
336 }
337 }
338}
339
340impl Expression {
341 pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
343 match self {
344 Expression::Address(reference) => vec![reference],
345 Expression::FunctionCall(FunctionCallExpression { expression, .. }) => {
346 expression.get_memory_references()
347 }
348 Expression::Infix(InfixExpression { left, right, .. }) => {
349 let mut result = left.get_memory_references();
350 result.extend(right.get_memory_references());
351 result
352 }
353 Expression::Number(_) => vec![],
354 Expression::PiConstant => vec![],
355 Expression::Prefix(PrefixExpression { expression, .. }) => {
356 expression.get_memory_references()
357 }
358 Expression::Variable(_) => vec![],
359 }
360 }
361}
362
363impl WaveformInvocation {
364 pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
366 self.parameters
367 .values()
368 .flat_map(Expression::get_memory_references)
369 .collect()
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use rstest::rstest;
376
377 use crate::expression::Expression;
378 use crate::instruction::{
379 ArithmeticOperand, Convert, Exchange, ExternSignatureMap, FrameIdentifier, Instruction,
380 MemoryReference, Qubit, SetFrequency, ShiftFrequency, Store,
381 };
382 use crate::program::MemoryAccesses;
383 use std::collections::HashSet;
384
385 #[rstest]
386 #[case(
387 Instruction::Store(Store {
388 destination: "destination".to_string(),
389 offset: MemoryReference {
390 name: "offset".to_string(),
391 index: Default::default()
392 },
393 source: ArithmeticOperand::MemoryReference(MemoryReference {
394 name: "source".to_string(),
395 index: Default::default()
396 }),
397 }),
398 MemoryAccesses {
399 captures: HashSet::new(),
400 reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
401 writes: ["destination"].iter().cloned().map(String::from).collect(),
402 }
403 )]
404 #[case(
405 Instruction::Convert(Convert {
406 destination: MemoryReference {
407 name: "destination".to_string(),
408 index: Default::default()
409 },
410 source: MemoryReference {
411 name: "source".to_string(),
412 index: Default::default()
413 },
414 }),
415 MemoryAccesses {
416 captures: HashSet::new(),
417 reads: ["source"].iter().cloned().map(String::from).collect(),
418 writes: ["destination"].iter().cloned().map(String::from).collect(),
419 }
420 )]
421 #[case(
422 Instruction::Exchange(Exchange {
423 left: MemoryReference {
424 name: "left".to_string(),
425 index: Default::default()
426 },
427 right: MemoryReference {
428 name: "right".to_string(),
429 index: Default::default()
430 },
431 }),
432 MemoryAccesses {
433 captures: HashSet::new(),
434 reads: ["left", "right"].iter().cloned().map(String::from).collect(),
435 writes: ["left", "right"].iter().cloned().map(String::from).collect(),
436 }
437 )]
438 #[case(
439 Instruction::SetFrequency(SetFrequency {
440 frequency: Expression::Address(MemoryReference {
441 name: "frequency".to_string(),
442 index: Default::default()
443 }),
444 frame: FrameIdentifier {
445 name: "frame".to_string(),
446 qubits: vec![Qubit::Fixed(0)]
447 }
448 }),
449 MemoryAccesses {
450 captures: HashSet::new(),
451 reads: ["frequency"].iter().cloned().map(String::from).collect(),
452 writes: HashSet::new(),
453 }
454 )]
455 #[case(
456 Instruction::ShiftFrequency(ShiftFrequency {
457 frequency: Expression::Address(MemoryReference {
458 name: "frequency".to_string(),
459 index: Default::default()
460 }),
461 frame: FrameIdentifier {
462 name: "frame".to_string(),
463 qubits: vec![Qubit::Fixed(0)]
464 }
465 }),
466 MemoryAccesses {
467 captures: HashSet::new(),
468 reads: ["frequency"].iter().cloned().map(String::from).collect(),
469 writes: HashSet::new(),
470 }
471 )]
472 fn test_instruction_accesses(
473 #[case] instruction: Instruction,
474 #[case] expected: MemoryAccesses,
475 ) {
476 let memory_accesses = instruction
477 .get_memory_accesses(&ExternSignatureMap::default())
478 .expect("must be able to get memory accesses");
479 assert_eq!(memory_accesses.captures, expected.captures);
480 assert_eq!(memory_accesses.reads, expected.reads);
481 assert_eq!(memory_accesses.writes, expected.writes);
482 }
483}