quil_rs/program/
memory.rs

1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use crate::expression::{Expression, FunctionCallExpression, InfixExpression, PrefixExpression};
18use crate::instruction::{
19    Arithmetic, ArithmeticOperand, BinaryLogic, BinaryOperand, CallResolutionError, Capture,
20    CircuitDefinition, Comparison, ComparisonOperand, Convert, Delay, Exchange, ExternSignatureMap,
21    Gate, GateDefinition, GateSpecification, Instruction, JumpUnless, JumpWhen, Load,
22    MeasureCalibrationDefinition, Measurement, MemoryReference, Move, Pulse, RawCapture,
23    SetFrequency, SetPhase, SetScale, Sharing, ShiftFrequency, ShiftPhase, Store, UnaryLogic,
24    Vector, WaveformInvocation,
25};
26
27#[derive(Clone, Debug, Hash, PartialEq)]
28pub struct MemoryRegion {
29    pub size: Vector,
30    pub sharing: Option<Sharing>,
31}
32
33impl MemoryRegion {
34    pub fn new(size: Vector, sharing: Option<Sharing>) -> Self {
35        Self { size, sharing }
36    }
37}
38
39impl Eq for MemoryRegion {}
40
41#[derive(Clone, Debug)]
42pub struct MemoryAccess {
43    pub regions: HashSet<String>,
44    pub access_type: MemoryAccessType,
45}
46
47#[derive(Clone, Debug, Default, PartialEq)]
48pub struct MemoryAccesses {
49    pub captures: HashSet<String>,
50    pub reads: HashSet<String>,
51    pub writes: HashSet<String>,
52}
53
54/// Express a mode of memory access.
55#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
56pub enum MemoryAccessType {
57    /// Read from a memory location
58    Read,
59
60    /// Write to a memory location using classical instructions
61    Write,
62
63    /// Write to a memory location using readout (`CAPTURE` and `RAW-CAPTURE` instructions)
64    Capture,
65}
66
67macro_rules! merge_sets {
68    ($left:expr, $right:expr) => {
69        $left.union(&$right).cloned().collect::<HashSet<String>>()
70    };
71}
72
73/// Build a HashSet<String> from a Vec<&str> by cloning
74macro_rules! set_from_reference_vec {
75    ($vec:expr) => {
76        $vec.into_iter()
77            .map(|el| el.clone())
78            .collect::<HashSet<String>>()
79    };
80}
81
82/// Build a HashSet<String> from an Option<&MemoryReference>
83macro_rules! set_from_optional_memory_reference {
84    ($reference:expr) => {
85        set_from_reference_vec![$reference.map_or_else(Vec::new, |reference| vec![&reference.name])]
86    };
87}
88
89/// Build a HashSet<&String> from a Vec<&MemoryReference>
90macro_rules! set_from_memory_references {
91    ($references:expr) => {
92        set_from_reference_vec![$references.iter().map(|reference| &reference.name)]
93    };
94}
95
96#[derive(thiserror::Error, Debug, PartialEq, Clone)]
97pub enum MemoryAccessesError {
98    #[error(transparent)]
99    CallResolution(#[from] CallResolutionError),
100}
101
102pub type MemoryAccessesResult = Result<MemoryAccesses, MemoryAccessesError>;
103
104impl Instruction {
105    /// Return all memory accesses by the instruction - in expressions, captures, and memory manipulation.
106    ///
107    /// This will fail if the program contains [`Instruction::Call`] instructions that cannot
108    /// be resolved against a signature in the provided [`ExternSignatureMap`] (either because
109    /// they call functions that don't appear in the map or because the types of the parameters
110    /// are wrong).
111    pub fn get_memory_accesses(
112        &self,
113        extern_signature_map: &ExternSignatureMap,
114    ) -> MemoryAccessesResult {
115        Ok(match self {
116            Instruction::Convert(Convert {
117                source,
118                destination,
119            }) => MemoryAccesses {
120                reads: set_from_memory_references![[source]],
121                writes: set_from_memory_references![[destination]],
122                ..Default::default()
123            },
124            Instruction::Call(call) => call.get_memory_accesses(extern_signature_map)?,
125            Instruction::Comparison(Comparison {
126                destination,
127                lhs,
128                rhs,
129                operator: _,
130            }) => {
131                let mut reads = HashSet::from([lhs.name.clone()]);
132                let writes = HashSet::from([destination.name.clone()]);
133                if let ComparisonOperand::MemoryReference(mem) = &rhs {
134                    reads.insert(mem.name.clone());
135                }
136
137                MemoryAccesses {
138                    reads,
139                    writes,
140                    ..Default::default()
141                }
142            }
143            Instruction::BinaryLogic(BinaryLogic {
144                destination,
145                source,
146                operator: _,
147            }) => {
148                let mut reads = HashSet::new();
149                let mut writes = HashSet::new();
150                reads.insert(destination.name.clone());
151                writes.insert(destination.name.clone());
152                if let BinaryOperand::MemoryReference(mem) = &source {
153                    reads.insert(mem.name.clone());
154                }
155
156                MemoryAccesses {
157                    reads,
158                    writes,
159                    ..Default::default()
160                }
161            }
162            Instruction::UnaryLogic(UnaryLogic { operand, .. }) => MemoryAccesses {
163                reads: HashSet::from([operand.name.clone()]),
164                writes: HashSet::from([operand.name.clone()]),
165                ..Default::default()
166            },
167            Instruction::Arithmetic(Arithmetic {
168                destination,
169                source,
170                ..
171            }) => MemoryAccesses {
172                writes: HashSet::from([destination.name.clone()]),
173                reads: set_from_optional_memory_reference![source.get_memory_reference()],
174                ..Default::default()
175            },
176            Instruction::Move(Move {
177                destination,
178                source,
179            }) => MemoryAccesses {
180                writes: set_from_memory_references![[destination]],
181                reads: set_from_optional_memory_reference![source.get_memory_reference()],
182                ..Default::default()
183            },
184            Instruction::CalibrationDefinition(definition) => {
185                let references: Vec<&MemoryReference> = definition
186                    .identifier
187                    .parameters
188                    .iter()
189                    .flat_map(|expr| expr.get_memory_references())
190                    .collect();
191                MemoryAccesses {
192                    reads: set_from_memory_references![references],
193                    ..Default::default()
194                }
195            }
196            Instruction::Capture(Capture {
197                memory_reference,
198                waveform,
199                ..
200            }) => MemoryAccesses {
201                captures: set_from_memory_references!([memory_reference]),
202                reads: set_from_memory_references!(waveform.get_memory_references()),
203                ..Default::default()
204            },
205            Instruction::CircuitDefinition(CircuitDefinition { instructions, .. })
206            | Instruction::MeasureCalibrationDefinition(MeasureCalibrationDefinition {
207                instructions,
208                ..
209            }) => instructions.iter().try_fold(
210                Default::default(),
211                |acc: MemoryAccesses, el| -> MemoryAccessesResult {
212                    let el_accesses = el.get_memory_accesses(extern_signature_map)?;
213                    Ok(MemoryAccesses {
214                        reads: merge_sets!(acc.reads, el_accesses.reads),
215                        writes: merge_sets!(acc.writes, el_accesses.writes),
216                        captures: merge_sets!(acc.captures, el_accesses.captures),
217                    })
218                },
219            )?,
220            Instruction::Delay(Delay { duration, .. }) => MemoryAccesses {
221                reads: set_from_memory_references!(duration.get_memory_references()),
222                ..Default::default()
223            },
224            Instruction::Exchange(Exchange { left, right }) => MemoryAccesses {
225                reads: set_from_memory_references![[left, right]],
226                writes: set_from_memory_references![[left, right]],
227                ..Default::default()
228            },
229            Instruction::Gate(Gate { parameters, .. }) => MemoryAccesses {
230                reads: set_from_memory_references!(parameters
231                    .iter()
232                    .flat_map(|param| param.get_memory_references())
233                    .collect::<Vec<&MemoryReference>>()),
234                ..Default::default()
235            },
236            Instruction::GateDefinition(GateDefinition { specification, .. }) => {
237                if let GateSpecification::Matrix(matrix) = specification {
238                    let references = matrix
239                        .iter()
240                        .flat_map(|row| row.iter().flat_map(|cell| cell.get_memory_references()))
241                        .collect::<Vec<&MemoryReference>>();
242                    MemoryAccesses {
243                        reads: set_from_memory_references!(references),
244                        ..Default::default()
245                    }
246                } else {
247                    Default::default()
248                }
249            }
250            Instruction::JumpWhen(JumpWhen {
251                target: _,
252                condition,
253            })
254            | Instruction::JumpUnless(JumpUnless {
255                target: _,
256                condition,
257            }) => MemoryAccesses {
258                reads: set_from_memory_references!([condition]),
259                ..Default::default()
260            },
261            Instruction::Load(Load {
262                destination,
263                source,
264                offset,
265            }) => MemoryAccesses {
266                writes: set_from_memory_references![[destination]],
267                reads: set_from_reference_vec![vec![source, &offset.name]],
268                ..Default::default()
269            },
270            Instruction::Measurement(Measurement { target, .. }) => MemoryAccesses {
271                captures: set_from_optional_memory_reference!(target.as_ref()),
272                ..Default::default()
273            },
274            Instruction::Pulse(Pulse { waveform, .. }) => MemoryAccesses {
275                reads: set_from_memory_references![waveform.get_memory_references()],
276                ..Default::default()
277            },
278            Instruction::RawCapture(RawCapture {
279                duration,
280                memory_reference,
281                ..
282            }) => MemoryAccesses {
283                reads: set_from_memory_references![duration.get_memory_references()],
284                captures: set_from_memory_references![[memory_reference]],
285                ..Default::default()
286            },
287            Instruction::SetPhase(SetPhase { phase: expr, .. })
288            | Instruction::SetScale(SetScale { scale: expr, .. })
289            | Instruction::ShiftPhase(ShiftPhase { phase: expr, .. }) => MemoryAccesses {
290                reads: set_from_memory_references!(expr.get_memory_references()),
291                ..Default::default()
292            },
293            Instruction::SetFrequency(SetFrequency { frequency, .. })
294            | Instruction::ShiftFrequency(ShiftFrequency { frequency, .. }) => MemoryAccesses {
295                reads: set_from_memory_references!(frequency.get_memory_references()),
296                ..Default::default()
297            },
298            Instruction::Store(Store {
299                destination,
300                offset,
301                source,
302            }) => {
303                let mut reads = vec![&offset.name];
304                if let Some(source) = source.get_memory_reference() {
305                    reads.push(&source.name);
306                }
307                MemoryAccesses {
308                    reads: set_from_reference_vec![reads],
309                    writes: set_from_reference_vec![vec![destination]],
310                    ..Default::default()
311                }
312            }
313            Instruction::Declaration(_)
314            | Instruction::Fence(_)
315            | Instruction::FrameDefinition(_)
316            | Instruction::Halt
317            | Instruction::Wait
318            | Instruction::Include(_)
319            | Instruction::Jump(_)
320            | Instruction::Label(_)
321            | Instruction::Nop
322            | Instruction::Pragma(_)
323            | Instruction::Reset(_)
324            | Instruction::SwapPhases(_)
325            | Instruction::WaveformDefinition(_) => Default::default(),
326        })
327    }
328}
329
330impl ArithmeticOperand {
331    pub fn get_memory_reference(&self) -> Option<&MemoryReference> {
332        match self {
333            ArithmeticOperand::LiteralInteger(_) => None,
334            ArithmeticOperand::LiteralReal(_) => None,
335            ArithmeticOperand::MemoryReference(reference) => Some(reference),
336        }
337    }
338}
339
340impl Expression {
341    /// Return, if any, the memory references contained within this Expression.
342    pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
343        match self {
344            Expression::Address(reference) => vec![reference],
345            Expression::FunctionCall(FunctionCallExpression { expression, .. }) => {
346                expression.get_memory_references()
347            }
348            Expression::Infix(InfixExpression { left, right, .. }) => {
349                let mut result = left.get_memory_references();
350                result.extend(right.get_memory_references());
351                result
352            }
353            Expression::Number(_) => vec![],
354            Expression::PiConstant => vec![],
355            Expression::Prefix(PrefixExpression { expression, .. }) => {
356                expression.get_memory_references()
357            }
358            Expression::Variable(_) => vec![],
359        }
360    }
361}
362
363impl WaveformInvocation {
364    /// Return, if any, the memory references contained within this WaveformInvocation.
365    pub fn get_memory_references(&self) -> Vec<&MemoryReference> {
366        self.parameters
367            .values()
368            .flat_map(Expression::get_memory_references)
369            .collect()
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use rstest::rstest;
376
377    use crate::expression::Expression;
378    use crate::instruction::{
379        ArithmeticOperand, Convert, Exchange, ExternSignatureMap, FrameIdentifier, Instruction,
380        MemoryReference, Qubit, SetFrequency, ShiftFrequency, Store,
381    };
382    use crate::program::MemoryAccesses;
383    use std::collections::HashSet;
384
385    #[rstest]
386    #[case(
387        Instruction::Store(Store {
388            destination: "destination".to_string(),
389            offset: MemoryReference {
390                name: "offset".to_string(),
391                index: Default::default()
392            },
393            source: ArithmeticOperand::MemoryReference(MemoryReference {
394                name: "source".to_string(),
395                index: Default::default()
396            }),
397        }),
398        MemoryAccesses {
399            captures: HashSet::new(),
400            reads: ["source", "offset"].iter().cloned().map(String::from).collect(),
401            writes: ["destination"].iter().cloned().map(String::from).collect(),
402        }
403    )]
404    #[case(
405        Instruction::Convert(Convert {
406            destination: MemoryReference {
407                name: "destination".to_string(),
408                index: Default::default()
409            },
410            source: MemoryReference {
411                name: "source".to_string(),
412                index: Default::default()
413            },
414        }),
415        MemoryAccesses {
416            captures: HashSet::new(),
417            reads: ["source"].iter().cloned().map(String::from).collect(),
418            writes: ["destination"].iter().cloned().map(String::from).collect(),
419        }
420    )]
421    #[case(
422        Instruction::Exchange(Exchange {
423            left: MemoryReference {
424                name: "left".to_string(),
425                index: Default::default()
426            },
427            right: MemoryReference {
428                name: "right".to_string(),
429                index: Default::default()
430            },
431        }),
432        MemoryAccesses {
433            captures: HashSet::new(),
434            reads: ["left", "right"].iter().cloned().map(String::from).collect(),
435            writes: ["left", "right"].iter().cloned().map(String::from).collect(),
436        }
437    )]
438    #[case(
439        Instruction::SetFrequency(SetFrequency {
440            frequency: Expression::Address(MemoryReference {
441                name: "frequency".to_string(),
442                index: Default::default()
443            }),
444            frame: FrameIdentifier {
445                name: "frame".to_string(),
446                qubits: vec![Qubit::Fixed(0)]
447            }
448        }),
449        MemoryAccesses {
450            captures: HashSet::new(),
451            reads: ["frequency"].iter().cloned().map(String::from).collect(),
452            writes: HashSet::new(),
453        }
454    )]
455    #[case(
456        Instruction::ShiftFrequency(ShiftFrequency {
457            frequency: Expression::Address(MemoryReference {
458                name: "frequency".to_string(),
459                index: Default::default()
460            }),
461            frame: FrameIdentifier {
462                name: "frame".to_string(),
463                qubits: vec![Qubit::Fixed(0)]
464            }
465        }),
466        MemoryAccesses {
467            captures: HashSet::new(),
468            reads: ["frequency"].iter().cloned().map(String::from).collect(),
469            writes: HashSet::new(),
470        }
471    )]
472    fn test_instruction_accesses(
473        #[case] instruction: Instruction,
474        #[case] expected: MemoryAccesses,
475    ) {
476        let memory_accesses = instruction
477            .get_memory_accesses(&ExternSignatureMap::default())
478            .expect("must be able to get memory accesses");
479        assert_eq!(memory_accesses.captures, expected.captures);
480        assert_eq!(memory_accesses.reads, expected.reads);
481        assert_eq!(memory_accesses.writes, expected.writes);
482    }
483}