quil_rs/instruction/
gate.rs

1use crate::{
2    expression::Expression,
3    imag,
4    instruction::{write_expression_parameter_string, write_parameter_string, write_qubits, Qubit},
5    quil::{write_join_quil, Quil, INDENT},
6    real,
7    validation::identifier::{
8        validate_identifier, validate_user_identifier, IdentifierValidationError,
9    },
10};
11use ndarray::{array, linalg::kron, Array2};
12use num_complex::Complex64;
13use once_cell::sync::Lazy;
14use std::{
15    cmp::Ordering,
16    collections::{HashMap, HashSet},
17};
18
19/// A struct encapsulating all the properties of a Quil Quantum Gate.
20#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21pub struct Gate {
22    pub name: String,
23    pub parameters: Vec<Expression>,
24    pub qubits: Vec<Qubit>,
25    pub modifiers: Vec<GateModifier>,
26}
27
28/// An enum of all the possible modifiers on a quil [`Gate`]
29#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
30pub enum GateModifier {
31    /// The `CONTROLLED` modifier makes the gate take an extra [`Qubit`] parameter as a control
32    /// qubit.
33    Controlled,
34    /// The `DAGGER` modifier does a complex-conjugate transpose on the [`Gate`].
35    Dagger,
36    /// The `FORKED` modifier allows an alternate set of parameters to be used based on the state
37    /// of a qubit.
38    Forked,
39}
40
41#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
42pub enum GateError {
43    #[error("invalid name: {0}")]
44    InvalidIdentifier(#[from] IdentifierValidationError),
45
46    #[error("a gate must operate on 1 or more qubits")]
47    EmptyQubits,
48
49    #[error("expected {expected} parameters, but got {actual}")]
50    ForkedParameterLength { expected: usize, actual: usize },
51
52    #[error("expected the number of Pauli term arguments, {actual}, to match the length of the Pauli word, {expected}")]
53    PauliTermArgumentLength { expected: usize, actual: usize },
54
55    #[error("the Pauli term arguments {mismatches:?}, are not in the defined argument list: {expected_arguments:?}")]
56    PauliSumArgumentMismatch {
57        mismatches: Vec<String>,
58        expected_arguments: Vec<String>,
59    },
60
61    #[error("unknown gate `{name}` to turn into {} matrix ",  if *parameterized { "parameterized" } else { "constant" })]
62    UndefinedGate { name: String, parameterized: bool },
63
64    #[error("expected {expected} parameters, was given {actual}")]
65    MatrixArgumentLength { expected: usize, actual: usize },
66
67    #[error(
68        "cannot produce a matrix for a gate `{name}` with non-constant parameters {parameters:?}"
69    )]
70    MatrixNonConstantParams {
71        name: String,
72        parameters: Vec<Expression>,
73    },
74
75    #[error("cannot produce a matrix for gate `{name}` with variable qubit {qubit}", qubit=.qubit.to_quil_or_debug())]
76    MatrixVariableQubit { name: String, qubit: Qubit },
77
78    #[error("forked gate `{name}` has an odd number of parameters: {parameters:?}")]
79    ForkedGateOddNumParams {
80        name: String,
81        parameters: Vec<Expression>,
82    },
83
84    #[error("cannot produce a matrix for a gate `{name}` with unresolved qubit placeholders")]
85    UnresolvedQubitPlaceholder { name: String },
86}
87
88/// Matrix version of a gate.
89pub type Matrix = Array2<Complex64>;
90
91impl Gate {
92    /// Build a new gate
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the given name isn't a valid Quil identifier or if no qubits are given.
97    pub fn new(
98        name: &str,
99        parameters: Vec<Expression>,
100        qubits: Vec<Qubit>,
101        modifiers: Vec<GateModifier>,
102    ) -> Result<Self, GateError> {
103        if qubits.is_empty() {
104            return Err(GateError::EmptyQubits);
105        }
106
107        validate_identifier(name).map_err(GateError::InvalidIdentifier)?;
108
109        Ok(Self {
110            name: name.to_string(),
111            parameters,
112            qubits,
113            modifiers,
114        })
115    }
116
117    /// Apply a DAGGER modifier to the gate
118    pub fn dagger(mut self) -> Self {
119        self.modifiers.insert(0, GateModifier::Dagger);
120        self
121    }
122
123    /// Apply a CONTROLLED modifier to the gate
124    pub fn controlled(mut self, control_qubit: Qubit) -> Self {
125        self.qubits.insert(0, control_qubit);
126        self.modifiers.insert(0, GateModifier::Controlled);
127        self
128    }
129
130    /// Apply a FORKED modifier to the gate
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the number of provided alternate parameters don't
135    /// equal the number of existing parameters.
136    pub fn forked(
137        mut self,
138        fork_qubit: Qubit,
139        alt_params: Vec<Expression>,
140    ) -> Result<Self, GateError> {
141        if alt_params.len() != self.parameters.len() {
142            return Err(GateError::ForkedParameterLength {
143                expected: self.parameters.len(),
144                actual: alt_params.len(),
145            });
146        }
147        self.modifiers.insert(0, GateModifier::Forked);
148        self.qubits.insert(0, fork_qubit);
149        self.parameters.extend(alt_params);
150        Ok(self)
151    }
152
153    /// Lift a Gate to the full `n_qubits`-qubit Hilbert space.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if any of the parameters of this gate are non-constant, if any of the
158    /// qubits are variable, if the name of this gate is unknown, or if there are an unexpected
159    /// number of parameters.
160    pub fn to_unitary(&mut self, n_qubits: u64) -> Result<Matrix, GateError> {
161        let qubits = self
162            .qubits
163            .iter()
164            .map(|q| match q {
165                Qubit::Variable(_) => Err(GateError::MatrixVariableQubit {
166                    name: self.name.clone(),
167                    qubit: q.clone(),
168                }),
169                Qubit::Placeholder(_) => Err(GateError::UnresolvedQubitPlaceholder {
170                    name: self.name.clone(),
171                }),
172                Qubit::Fixed(i) => Ok(*i),
173            })
174            .collect::<Result<Vec<_>, _>>()?;
175        Ok(lifted_gate_matrix(&gate_matrix(self)?, &qubits, n_qubits))
176    }
177}
178
179/// Lift a unitary matrix to act on the specified qubits in a full `n_qubits`-qubit Hilbert
180/// space.
181///
182/// For 1-qubit gates, this is easy and can be achieved with appropriate kronning of identity
183/// matrices. For 2-qubit gates acting on adjacent qubit indices, it is also easy. However, for a
184/// multiqubit gate acting on non-adjactent qubit indices, we must first apply a permutation matrix
185/// to make the qubits adjacent and then apply the inverse permutation.
186fn lifted_gate_matrix(matrix: &Matrix, qubits: &[u64], n_qubits: u64) -> Matrix {
187    let (perm, start) = permutation_arbitrary(qubits, n_qubits);
188    let v = qubit_adjacent_lifted_gate(start, matrix, n_qubits);
189    perm.t().mapv(|c| c.conj()).dot(&v.dot(&perm))
190}
191
192/// Recursively handle a gate, with all modifiers.
193///
194/// The main source of complexity is in handling FORKED gates. Given a gate with modifiers, such as
195/// `FORKED CONTROLLED FORKED RX(a,b,c,d) 0 1 2 3`, we get a tree, as in
196///
197/// ```text
198///
199///               FORKED CONTROLLED FORKED RX(a,b,c,d) 0 1 2 3
200///                 /                                      \
201///    CONTROLLED FORKED RX(a,b) 1 2 3       CONTROLLED FORKED RX(c,d) 1 2 3
202///                |                                        |
203///         FORKED RX(a,b) 2 3                      FORKED RX(c,d) 2 3
204///          /          \                            /          \
205///      RX(a) 3      RX(b) 3                    RX(c) 3      RX(d) 3
206/// ```
207fn gate_matrix(gate: &mut Gate) -> Result<Matrix, GateError> {
208    static ZERO: Lazy<Matrix> =
209        Lazy::new(|| array![[real!(1.0), real!(0.0)], [real!(0.0), real!(0.0)]]);
210    static ONE: Lazy<Matrix> =
211        Lazy::new(|| array![[real!(0.0), real!(0.0)], [real!(0.0), real!(1.0)]]);
212    if let Some(modifier) = gate.modifiers.pop() {
213        match modifier {
214            GateModifier::Controlled => {
215                gate.qubits = gate.qubits[1..].to_vec();
216                let matrix = gate_matrix(gate)?;
217                Ok(kron(&ZERO, &Array2::eye(matrix.shape()[0])) + kron(&ONE, &matrix))
218            }
219            GateModifier::Dagger => gate_matrix(gate).map(|g| g.t().mapv(|c| c.conj())),
220            GateModifier::Forked => {
221                let param_index = gate.parameters.len();
222                if param_index & 1 != 0 {
223                    Err(GateError::ForkedGateOddNumParams {
224                        name: gate.name.clone(),
225                        parameters: gate.parameters.clone(),
226                    })
227                } else {
228                    // Some mutability dancing to keep the borrow checker happy
229                    gate.qubits = gate.qubits[1..].to_vec();
230                    let (p0, p1) = gate.parameters[..].split_at(param_index / 2);
231                    let mut child0 = gate.clone();
232                    child0.parameters = p0.to_vec();
233                    let mat0 = gate_matrix(&mut child0)?;
234                    gate.parameters = p1.to_vec();
235                    let mat1 = gate_matrix(gate)?;
236                    Ok(kron(&ZERO, &mat0) + kron(&ONE, &mat1))
237                }
238            }
239        }
240    } else if gate.parameters.is_empty() {
241        CONSTANT_GATE_MATRICES
242            .get(&gate.name)
243            .cloned()
244            .ok_or_else(|| GateError::UndefinedGate {
245                name: gate.name.clone(),
246                parameterized: false,
247            })
248    } else {
249        match gate.parameters.len() {
250            1 => {
251                if let Expression::Number(x) = gate.parameters[0].clone().into_simplified() {
252                    PARAMETERIZED_GATE_MATRICES
253                        .get(&gate.name)
254                        .map(|f| f(x))
255                        .ok_or_else(|| GateError::UndefinedGate {
256                            name: gate.name.clone(),
257                            parameterized: true,
258                        })
259                } else {
260                    Err(GateError::MatrixNonConstantParams {
261                        name: gate.name.clone(),
262                        parameters: gate.parameters.clone(),
263                    })
264                }
265            }
266            actual => Err(GateError::MatrixArgumentLength {
267                expected: 1,
268                actual,
269            }),
270        }
271    }
272}
273
274/// Generate the permutation matrix that permutes an arbitrary number of single-particle Hilbert
275/// spaces into adjacent positions.
276///
277///
278/// Transposes the qubit indices in the order they are passed to a contiguous region in the
279/// complete Hilbert space, in increasing qubit index order (preserving the order they are passed
280/// in).
281///
282/// Gates are usually defined as `GATE 0 1 2`, with such an argument ordering dictating the layout
283/// of the matrix corresponding to GATE. If such an instruction is given, actual qubits (0, 1, 2)
284/// need to be swapped into the positions (2, 1, 0), because the lifting operation taking the 8 x 8
285/// matrix of GATE is done in the little-endian (reverse) addressed qubit space.
286///
287/// For example, suppose I have a Quil command CCNOT 20 15 10. The median of the qubit indices is
288/// 15 - hence, we permute qubits [20, 15, 10] into the final map [16, 15, 14] to minimize the
289/// number of swaps needed, and so we can directly operate with the final CCNOT, when lifted from
290/// indices [16, 15, 14] to the complete Hilbert space.
291///
292/// Notes: assumes qubit indices are unique.
293///
294/// Done in preparation for arbitrary gate application on adjacent qubits.
295fn permutation_arbitrary(qubit_inds: &[u64], n_qubits: u64) -> (Matrix, u64) {
296    // Begin construction of permutation
297    let mut perm = Array2::eye(2usize.pow(n_qubits as u32));
298    // First, sort the list and find the median.
299    let mut sorted_inds = qubit_inds.to_vec();
300    sorted_inds.sort();
301    let med_i = qubit_inds.len() / 2;
302    let med = sorted_inds[med_i];
303    // The starting position of all specified Hilbert spaces begins at the qubit at (median -
304    // med_i)
305    let start = med - med_i as u64;
306    if qubit_inds.len() > 1 {
307        // Array of final indices the arguments are mapped to, from high index to low index, left to
308        // right ordering
309        let final_map = (start..start + qubit_inds.len() as u64)
310            .rev()
311            .collect::<Vec<_>>();
312
313        // Note that the lifting operation takes a k-qubit gate operating on the qubits i+k-1, i+k-2,
314        // ... i (left to right). two_swap_helper can be used to build the permutation matrix by
315        // filling out the final map by sweeping over the qubit_inds from left to right and back again,
316        // swapping qubits into position. we loop over the qubit_inds until the final mapping matches
317        // the argument.
318        let mut qubit_arr = (0..n_qubits).collect::<Vec<_>>(); // Current qubit indexing
319
320        let mut made_it = false;
321        let mut right = true;
322        while !made_it {
323            let array = if right {
324                (0..qubit_inds.len()).collect::<Vec<_>>()
325            } else {
326                (0..qubit_inds.len()).rev().collect()
327            };
328
329            for i in array {
330                let j = qubit_arr
331                    .iter()
332                    .position(|&q| q == qubit_inds[i])
333                    .expect("These arrays cover the same range.");
334                let pmod = two_swap_helper(j as u64, final_map[i], n_qubits, &mut qubit_arr);
335                perm = pmod.dot(&perm);
336                if (final_map[final_map.len() - 1]..final_map[0] + 1)
337                    .rev()
338                    .zip(qubit_inds)
339                    .all(|(f, &q)| qubit_arr[f as usize] == q)
340                {
341                    made_it = true;
342                    break;
343                }
344            }
345            right = !right;
346        }
347    }
348    (perm, start)
349}
350
351/// Generate the permutation matrix that permutes two single-particle Hilbert spaces into adjacent
352/// positions.
353///
354/// ALWAYS swaps j TO k. Recall that Hilbert spaces are ordered in decreasing qubit index order.
355/// Hence, j > k implies that j is to the left of k.
356///
357/// End results:
358///     j == k: nothing happens
359///     j > k: Swap j right to k, until j at ind (k) and k at ind (k+1).
360///     j < k: Swap j left to k, until j at ind (k) and k at ind (k-1).
361///
362/// Done in preparation for arbitrary 2-qubit gate application on ADJACENT qubits.
363fn two_swap_helper(j: u64, k: u64, n_qubits: u64, qubit_map: &mut [u64]) -> Matrix {
364    let mut perm = Array2::eye(2usize.pow(n_qubits as u32));
365    let swap = CONSTANT_GATE_MATRICES
366        .get("SWAP")
367        .expect("Key should exist by design.");
368    match Ord::cmp(&j, &k) {
369        Ordering::Equal => {}
370        Ordering::Greater => {
371            // swap j right to k, until j at ind (k) and k at ind (k+1)
372            for i in (k + 1..=j).rev() {
373                perm = qubit_adjacent_lifted_gate(i - 1, swap, n_qubits).dot(&perm);
374                qubit_map.swap(i as usize, (i - 1) as usize);
375            }
376        }
377        Ordering::Less => {
378            // swap j left to k, until j at ind (k) and k at ind (k-1)
379            for i in j..k {
380                perm = qubit_adjacent_lifted_gate(i, swap, n_qubits).dot(&perm);
381                qubit_map.swap(i as usize, (i + 1) as usize);
382            }
383        }
384    }
385    perm
386}
387
388/// Lifts input k-qubit gate on adjacent qubits starting from qubit i to complete Hilbert space of
389/// dimension 2 ** `num_qubits`.
390///
391/// Ex: 1-qubit gate, lifts from qubit i
392/// Ex: 2-qubit gate, lifts from qubits (i+1, i)
393/// Ex: 3-qubit gate, lifts from qubits (i+2, i+1, i), operating in that order
394///
395/// In general, this takes a k-qubit gate (2D matrix 2^k x 2^k) and lifts it to the complete
396/// Hilbert space of dim 2^num_qubits, as defined by the right-to-left tensor product (1) in
397/// arXiv:1608.03355.
398///
399/// Developer note: Quil and the QVM like qubits to be ordered such that qubit 0 is on the right.
400/// Therefore, in `qubit_adjacent_lifted_gate`, `lifted_pauli`, and `lifted_state_operator`, we
401/// build up the lifted matrix by performing the kronecker product from right to left.
402///
403/// Note that while the qubits are addressed in decreasing order, starting with num_qubit - 1 on
404/// the left and ending with qubit 0 on the right (in a little-endian fashion), gates are still
405/// lifted to apply on qubits in increasing index (right-to-left) order.
406fn qubit_adjacent_lifted_gate(i: u64, matrix: &Matrix, n_qubits: u64) -> Matrix {
407    let bottom_matrix = Array2::eye(2usize.pow(i as u32));
408    let gate_size = (matrix.shape()[0] as f64).log2().floor() as u64;
409    let top_qubits = n_qubits - i - gate_size;
410    let top_matrix = Array2::eye(2usize.pow(top_qubits as u32));
411    kron(&top_matrix, &kron(matrix, &bottom_matrix))
412}
413
414/// Gates matrices that don't use any parameters.
415///
416/// https://github.com/quil-lang/quil/blob/master/spec/Quil.md#standard-gates
417static CONSTANT_GATE_MATRICES: Lazy<HashMap<String, Matrix>> = Lazy::new(|| {
418    let _0 = real!(0.0);
419    let _1 = real!(1.0);
420    let _i = imag!(1.0);
421    let _1_sqrt_2 = real!(std::f64::consts::FRAC_1_SQRT_2);
422    HashMap::from([
423        ("I".to_string(), Array2::eye(2)),
424        ("X".to_string(), array![[_0, _1], [_1, _0]]),
425        ("Y".to_string(), array![[_0, -_i], [_i, _0]]),
426        ("Z".to_string(), array![[_1, _0], [_0, -_1]]),
427        ("H".to_string(), array![[_1, _1], [_1, -_1]] * _1_sqrt_2),
428        (
429            "CNOT".to_string(),
430            array![
431                [_1, _0, _0, _0],
432                [_0, _1, _0, _0],
433                [_0, _0, _0, _1],
434                [_0, _0, _1, _0]
435            ],
436        ),
437        (
438            "CCNOT".to_string(),
439            array![
440                [_1, _0, _0, _0, _0, _0, _0, _0],
441                [_0, _1, _0, _0, _0, _0, _0, _0],
442                [_0, _0, _1, _0, _0, _0, _0, _0],
443                [_0, _0, _0, _1, _0, _0, _0, _0],
444                [_0, _0, _0, _0, _1, _0, _0, _0],
445                [_0, _0, _0, _0, _0, _1, _0, _0],
446                [_0, _0, _0, _0, _0, _0, _0, _1],
447                [_0, _0, _0, _0, _0, _0, _1, _0],
448            ],
449        ),
450        ("S".to_string(), array![[_1, _0], [_0, _i]]),
451        (
452            "T".to_string(),
453            array![[_1, _0], [_0, Complex64::cis(std::f64::consts::FRAC_PI_4)]],
454        ),
455        ("CZ".to_string(), {
456            let mut cz = Array2::eye(4);
457            cz[[3, 3]] = -_1;
458            cz
459        }),
460        (
461            "SWAP".to_string(),
462            array![
463                [_1, _0, _0, _0],
464                [_0, _0, _1, _0],
465                [_0, _1, _0, _0],
466                [_0, _0, _0, _1],
467            ],
468        ),
469        (
470            "CSWAP".to_string(),
471            array![
472                [_1, _0, _0, _0, _0, _0, _0, _0],
473                [_0, _1, _0, _0, _0, _0, _0, _0],
474                [_0, _0, _1, _0, _0, _0, _0, _0],
475                [_0, _0, _0, _1, _0, _0, _0, _0],
476                [_0, _0, _0, _0, _1, _0, _0, _0],
477                [_0, _0, _0, _0, _0, _0, _1, _0],
478                [_0, _0, _0, _0, _0, _1, _0, _0],
479                [_0, _0, _0, _0, _0, _0, _0, _1],
480            ],
481        ),
482        (
483            "ISWAP".to_string(),
484            array![
485                [_1, _0, _0, _0],
486                [_0, _0, _i, _0],
487                [_0, _i, _0, _0],
488                [_0, _0, _0, _1],
489            ],
490        ),
491    ])
492});
493
494type ParameterizedMatrix = fn(Complex64) -> Matrix;
495
496/// Gates matrices that use parameters.
497///
498/// https://github.com/quil-lang/quil/blob/master/spec/Quil.md#standard-gates
499static PARAMETERIZED_GATE_MATRICES: Lazy<HashMap<String, ParameterizedMatrix>> = Lazy::new(|| {
500    // Unfortunately, Complex::cis takes a _float_ argument.
501    HashMap::from([
502        (
503            "RX".to_string(),
504            (|theta: Complex64| {
505                let _i = imag!(1.0);
506                let t = theta / 2.0;
507                array![[t.cos(), -_i * t.sin()], [-_i * t.sin(), t.cos()]]
508            }) as ParameterizedMatrix,
509        ),
510        (
511            "RY".to_string(),
512            (|theta: Complex64| {
513                let t = theta / 2.0;
514                array![[t.cos(), -t.sin()], [t.sin(), t.cos()]]
515            }) as ParameterizedMatrix,
516        ),
517        (
518            "RZ".to_string(),
519            (|theta: Complex64| {
520                let t = theta / 2.0;
521                array![[t.cos(), -t.sin()], [t.sin(), t.cos()]]
522            }) as ParameterizedMatrix,
523        ),
524        (
525            "PHASE".to_string(),
526            (|alpha: Complex64| {
527                let mut p = Array2::eye(2);
528                p[[1, 1]] = alpha.cos() + imag!(1.0) * alpha.sin();
529                p
530            }) as ParameterizedMatrix,
531        ),
532        (
533            "CPHASE00".to_string(),
534            (|alpha: Complex64| {
535                let mut p = Array2::eye(4);
536                p[[0, 0]] = alpha.cos() + imag!(1.0) * alpha.sin();
537                p
538            }) as ParameterizedMatrix,
539        ),
540        (
541            "CPHASE01".to_string(),
542            (|alpha: Complex64| {
543                let mut p = Array2::eye(4);
544                p[[1, 1]] = alpha.cos() + imag!(1.0) * alpha.sin();
545                p
546            }) as ParameterizedMatrix,
547        ),
548        (
549            "CPHASE10".to_string(),
550            (|alpha: Complex64| {
551                let mut p = Array2::eye(4);
552                p[[2, 2]] = alpha.cos() + imag!(1.0) * alpha.sin();
553                p
554            }) as ParameterizedMatrix,
555        ),
556        (
557            "CPHASE".to_string(),
558            (|alpha: Complex64| {
559                let mut p = Array2::eye(4);
560                p[[3, 3]] = alpha.cos() + imag!(1.0) * alpha.sin();
561                p
562            }) as ParameterizedMatrix,
563        ),
564        (
565            "PSWAP".to_string(),
566            (|theta: Complex64| {
567                let (_0, _1, _c) = (real!(0.0), real!(1.0), theta.cos() + theta);
568                array![
569                    [_1, _0, _0, _0],
570                    [_0, _0, _c, _0],
571                    [_0, _c, _0, _0],
572                    [_0, _0, _0, _1],
573                ]
574            }) as ParameterizedMatrix,
575        ),
576    ])
577});
578
579impl Quil for Gate {
580    fn write(
581        &self,
582        f: &mut impl std::fmt::Write,
583        fall_back_to_debug: bool,
584    ) -> crate::quil::ToQuilResult<()> {
585        for modifier in &self.modifiers {
586            modifier.write(f, fall_back_to_debug)?;
587            write!(f, " ")?;
588        }
589
590        write!(f, "{}", self.name)?;
591        write_expression_parameter_string(f, fall_back_to_debug, &self.parameters)?;
592        write_qubits(f, fall_back_to_debug, &self.qubits)
593    }
594}
595
596impl Quil for GateModifier {
597    fn write(
598        &self,
599        f: &mut impl std::fmt::Write,
600        _fall_back_to_debug: bool,
601    ) -> crate::quil::ToQuilResult<()> {
602        match self {
603            Self::Controlled => write!(f, "CONTROLLED"),
604            Self::Dagger => write!(f, "DAGGER"),
605            Self::Forked => write!(f, "FORKED"),
606        }
607        .map_err(Into::into)
608    }
609}
610
611#[cfg(test)]
612mod test_gate_into_matrix {
613    use super::{
614        lifted_gate_matrix, permutation_arbitrary, qubit_adjacent_lifted_gate, two_swap_helper,
615        Expression::Number, Gate, GateModifier::*, Matrix, ParameterizedMatrix, Qubit::Fixed,
616        CONSTANT_GATE_MATRICES, PARAMETERIZED_GATE_MATRICES,
617    };
618    use crate::{imag, real};
619    use approx::assert_abs_diff_eq;
620    use ndarray::{array, linalg::kron, Array2};
621    use num_complex::Complex64;
622    use once_cell::sync::Lazy;
623    use rstest::rstest;
624
625    static _0: Complex64 = real!(0.0);
626    static _1: Complex64 = real!(1.0);
627    static _I: Complex64 = imag!(1.0);
628    static PI: Complex64 = real!(std::f64::consts::PI);
629    static PI_4: Complex64 = real!(std::f64::consts::FRAC_PI_4);
630    static SWAP: Lazy<Matrix> = Lazy::new(|| CONSTANT_GATE_MATRICES.get("SWAP").cloned().unwrap());
631    static X: Lazy<Matrix> = Lazy::new(|| array![[_0, _1], [_1, _0]]);
632    static P0: Lazy<Matrix> = Lazy::new(|| array![[_1, _0], [_0, _0]]);
633    static P1: Lazy<Matrix> = Lazy::new(|| array![[_0, _0], [_0, _1]]);
634    static CNOT: Lazy<Matrix> = Lazy::new(|| CONSTANT_GATE_MATRICES.get("CNOT").cloned().unwrap());
635    static ISWAP: Lazy<Matrix> =
636        Lazy::new(|| CONSTANT_GATE_MATRICES.get("ISWAP").cloned().unwrap());
637    static H: Lazy<Matrix> = Lazy::new(|| CONSTANT_GATE_MATRICES.get("H").cloned().unwrap());
638    static RZ: Lazy<ParameterizedMatrix> =
639        Lazy::new(|| PARAMETERIZED_GATE_MATRICES.get("RZ").cloned().unwrap());
640    static CCNOT: Lazy<Matrix> =
641        Lazy::new(|| CONSTANT_GATE_MATRICES.get("CCNOT").cloned().unwrap());
642    static CZ: Lazy<Matrix> = Lazy::new(|| CONSTANT_GATE_MATRICES.get("CZ").cloned().unwrap());
643
644    #[rstest]
645    #[case(0, 2, &SWAP)]
646    #[case(0, 3, &kron(&Array2::eye(2), &SWAP))]
647    #[case(0, 4, &kron(&Array2::eye(4), &SWAP))]
648    #[case(1, 3, &kron(&SWAP, &Array2::eye(2)))]
649    #[case(1, 4, &kron(&Array2::eye(2), &kron(&SWAP, &Array2::eye(2))))]
650    #[case(2, 4, &kron(&Array2::eye(1), &kron(&SWAP, &Array2::eye(4))))]
651    #[case(8, 10, &kron(&Array2::eye(1), &kron(&SWAP, &Array2::eye(2usize.pow(8)))))]
652    fn test_qubit_adjacent_lifted_gate(
653        #[case] i: u64,
654        #[case] n_qubits: u64,
655        #[case] expected: &Matrix,
656    ) {
657        let result = qubit_adjacent_lifted_gate(i, &SWAP, n_qubits);
658        assert_abs_diff_eq!(result, expected);
659    }
660
661    // test cases via pyquil.simulation.tools.two_swap_helper
662    #[rstest]
663    #[case(0, 0, 2, &mut[0, 1], &[0, 1], &Array2::eye(4))]
664    #[case(0, 1, 2, &mut[0, 1], &[1, 0], &array![[_1, _0, _0, _0],
665                                                 [_0, _0, _1, _0],
666                                                 [_0, _1, _0, _0],
667                                                 [_0, _0, _0, _1]])]
668    #[case(0, 1, 2, &mut[1, 0], &[0, 1], &array![[_1, _0, _0, _0],
669                                                 [_0, _0, _1, _0],
670                                                 [_0, _1, _0, _0],
671                                                 [_0, _0, _0, _1]])]
672    #[case(1, 0, 2, &mut[0, 1], &[1, 0], &array![[_1, _0, _0, _0],
673                                                 [_0, _0, _1, _0],
674                                                 [_0, _1, _0, _0],
675                                                 [_0, _0, _0, _1]])]
676    #[case(1, 0, 2, &mut[1, 0], &[0, 1], &array![[_1, _0, _0, _0],
677                                                 [_0, _0, _1, _0],
678                                                 [_0, _1, _0, _0],
679                                                 [_0, _0, _0, _1]])]
680    #[case(0, 1, 3, &mut[0, 1, 2], &[1, 0, 2], &array![[_1, _0, _0, _0, _0, _0, _0, _0],
681                                                       [_0, _0, _1, _0, _0, _0, _0, _0],
682                                                       [_0, _1, _0, _0, _0, _0, _0, _0],
683                                                       [_0, _0, _0, _1, _0, _0, _0, _0],
684                                                       [_0, _0, _0, _0, _1, _0, _0, _0],
685                                                       [_0, _0, _0, _0, _0, _0, _1, _0],
686                                                       [_0, _0, _0, _0, _0, _1, _0, _0],
687                                                       [_0, _0, _0, _0, _0, _0, _0, _1]])]
688
689    fn test_two_swap_helper(
690        #[case] j: u64,
691        #[case] k: u64,
692        #[case] n_qubits: u64,
693        #[case] qubit_map: &mut [u64],
694        #[case] expected_qubit_map: &[u64],
695        #[case] expected_matrix: &Matrix,
696    ) {
697        let result = two_swap_helper(j, k, n_qubits, qubit_map);
698        assert_eq!(qubit_map, expected_qubit_map);
699        assert_abs_diff_eq!(result, expected_matrix);
700    }
701
702    // test cases via pyquil.simulation.tools.permutation_arbitrary
703    #[rstest]
704    #[case(&[0], 1, 0, &Array2::eye(2))]
705    #[case(&[0, 1], 2, 0, &array![[_1, _0, _0, _0],
706                                  [_0, _0, _1, _0],
707                                  [_0, _1, _0, _0],
708                                  [_0, _0, _0, _1]])]
709    #[case(&[1, 0], 2, 0, &Array2::eye(4))]
710    #[case(&[0, 2], 3, 1, &array![[_1, _0, _0, _0, _0, _0, _0, _0],
711                                  [_0, _0, _1, _0, _0, _0, _0, _0],
712                                  [_0, _0, _0, _0, _1, _0, _0, _0],
713                                  [_0, _0, _0, _0, _0, _0, _1, _0],
714                                  [_0, _1, _0, _0, _0, _0, _0, _0],
715                                  [_0, _0, _0, _1, _0, _0, _0, _0],
716                                  [_0, _0, _0, _0, _0, _1, _0, _0],
717                                  [_0, _0, _0, _0, _0, _0, _0, _1]])]
718    #[case(&[1, 2], 3, 1, &array![[_1, _0, _0, _0, _0, _0, _0, _0],
719                                  [_0, _1, _0, _0, _0, _0, _0, _0],
720                                  [_0, _0, _0, _0, _1, _0, _0, _0],
721                                  [_0, _0, _0, _0, _0, _1, _0, _0],
722                                  [_0, _0, _1, _0, _0, _0, _0, _0],
723                                  [_0, _0, _0, _1, _0, _0, _0, _0],
724                                  [_0, _0, _0, _0, _0, _0, _1, _0],
725                                  [_0, _0, _0, _0, _0, _0, _0, _1]])]
726    #[case(&[0, 1, 2], 3, 0, &array![[_1, _0, _0, _0, _0, _0, _0, _0],
727                                     [_0, _0, _0, _0, _1, _0, _0, _0],
728                                     [_0, _0, _1, _0, _0, _0, _0, _0],
729                                     [_0, _0, _0, _0, _0, _0, _1, _0],
730                                     [_0, _1, _0, _0, _0, _0, _0, _0],
731                                     [_0, _0, _0, _0, _0, _1, _0, _0],
732                                     [_0, _0, _0, _1, _0, _0, _0, _0],
733                                     [_0, _0, _0, _0, _0, _0, _0, _1]])]
734    fn test_permutation_arbitrary(
735        #[case] qubit_inds: &[u64],
736        #[case] n_qubits: u64,
737        #[case] expected_start: u64,
738        #[case] expected_matrix: &Matrix,
739    ) {
740        let (result_matrix, result_start) = permutation_arbitrary(qubit_inds, n_qubits);
741        assert_abs_diff_eq!(result_matrix, expected_matrix);
742        assert_eq!(result_start, expected_start);
743    }
744
745    #[rstest]
746    #[case(&CNOT, &mut [1, 0], 2, &(kron(&P0, &Array2::eye(2)) + kron(&P1, &X)))]
747    #[case(&CNOT, &mut [0, 1], 2, &(kron(&Array2::eye(2), &P0) + kron(&X, &P1)))]
748    #[case(&CNOT, &mut [2, 1], 3, &(kron(&CNOT, &Array2::eye(2))))]
749    #[case(&ISWAP, &mut [0, 1], 3, &kron(&Array2::eye(2), &ISWAP))]
750    #[case(&ISWAP, &mut [1, 0], 3, &kron(&Array2::eye(2), &ISWAP))]
751    #[case(&ISWAP, &mut [1, 2], 4, &kron(&Array2::eye(2), &kron(&ISWAP, &Array2::eye(2))))]
752    #[case(&ISWAP, &mut [3, 2], 4, &kron(&ISWAP, &Array2::eye(4)))]
753    #[case(&ISWAP, &mut [2, 3], 4, &kron(&ISWAP, &Array2::eye(4)))]
754    #[case(&H, &mut [0], 4, &kron(&Array2::eye(8), &H))]
755    #[case(&H, &mut [1], 4, &kron(&Array2::eye(4), &kron(&H, &Array2::eye(2))))]
756    #[case(&H, &mut [2], 4, &kron(&Array2::eye(2), &kron(&H, &Array2::eye(4))))]
757    #[case(&H, &mut [3], 4, &kron(&H, &Array2::eye(8)))]
758    #[case(&H, &mut [0], 5, &kron(&Array2::eye(16), &H))]
759    #[case(&H, &mut [1], 5, &kron(&Array2::eye(8), &kron(&H, &Array2::eye(2))))]
760    #[case(&H, &mut [2], 5, &kron(&Array2::eye(4), &kron(&H, &Array2::eye(4))))]
761    #[case(&H, &mut [3], 5, &kron(&Array2::eye(2), &kron(&H, &Array2::eye(8))))]
762    #[case(&H, &mut [4], 5, &kron(&H, &Array2::eye(16)))]
763    fn test_lifted_gate_matrix(
764        #[case] matrix: &Matrix,
765        #[case] indices: &mut [u64],
766        #[case] n_qubits: u64,
767        #[case] expected: &Matrix,
768    ) {
769        assert_abs_diff_eq!(lifted_gate_matrix(matrix, indices, n_qubits), expected);
770    }
771
772    #[rstest]
773    #[case(&mut Gate::new("H", vec![], vec![Fixed(0)], vec![]).unwrap(), 4, &kron(&Array2::eye(8), &H))]
774    #[case(&mut Gate::new("RZ", vec![Number(PI_4)], vec![Fixed(0)], vec![Dagger]).unwrap(), 1, &RZ(-PI_4))]
775    #[case(&mut Gate::new("X", vec![], vec![Fixed(0)], vec![Dagger]).unwrap().controlled(Fixed(1)), 2, &CNOT)]
776    #[case(
777        &mut Gate::new("X", vec![], vec![Fixed(0)], vec![]).unwrap().dagger().controlled(Fixed(1)).dagger().dagger().controlled(Fixed(2)),
778        3,
779        &CCNOT
780    )]
781    #[case(
782        &mut Gate::new("PHASE", vec![Number(_0)], vec![Fixed(1)], vec![]).unwrap().forked(Fixed(0), vec![Number(PI)]).unwrap(),
783        2,
784        &lifted_gate_matrix(&CZ, &[0, 1], 2)
785    )]
786    fn test_to_unitary(#[case] gate: &mut Gate, #[case] n_qubits: u64, #[case] expected: &Matrix) {
787        let result = gate.to_unitary(n_qubits);
788        assert!(result.is_ok());
789        assert_abs_diff_eq!(result.as_ref().unwrap(), expected);
790    }
791}
792
793#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, strum::Display, strum::EnumString)]
794#[strum(serialize_all = "UPPERCASE")]
795pub enum PauliGate {
796    I,
797    X,
798    Y,
799    Z,
800}
801
802#[derive(Clone, Debug, PartialEq, Eq, Hash)]
803pub struct PauliTerm {
804    pub arguments: Vec<(PauliGate, String)>,
805    pub expression: Expression,
806}
807
808impl PauliTerm {
809    pub fn new(arguments: Vec<(PauliGate, String)>, expression: Expression) -> Self {
810        Self {
811            arguments,
812            expression,
813        }
814    }
815
816    pub(crate) fn word(&self) -> impl Iterator<Item = &PauliGate> {
817        self.arguments.iter().map(|(gate, _)| gate)
818    }
819
820    pub(crate) fn arguments(&self) -> impl Iterator<Item = &String> {
821        self.arguments.iter().map(|(_, argument)| argument)
822    }
823}
824
825#[derive(Clone, Debug, PartialEq, Eq, Hash)]
826pub struct PauliSum {
827    pub arguments: Vec<String>,
828    pub terms: Vec<PauliTerm>,
829}
830
831impl PauliSum {
832    pub fn new(arguments: Vec<String>, terms: Vec<PauliTerm>) -> Result<Self, GateError> {
833        let diff = terms
834            .iter()
835            .flat_map(|t| t.arguments())
836            .collect::<HashSet<_>>()
837            .difference(&arguments.iter().collect::<HashSet<_>>())
838            .copied()
839            .collect::<Vec<_>>();
840
841        if !diff.is_empty() {
842            return Err(GateError::PauliSumArgumentMismatch {
843                mismatches: diff.into_iter().cloned().collect(),
844                expected_arguments: arguments,
845            });
846        }
847
848        Ok(Self { arguments, terms })
849    }
850}
851
852/// An enum representing a the specification of a [`GateDefinition`] for a given [`GateType`]
853#[derive(Clone, Debug, PartialEq, Eq, Hash)]
854pub enum GateSpecification {
855    /// A matrix of [`Expression`]s representing a unitary operation for a [`GateType::Matrix`].
856    Matrix(Vec<Vec<Expression>>),
857    /// A vector of integers that defines the permutation used for a [`GateType::Permutation`]
858    Permutation(Vec<u64>),
859    /// A Hermitian operator specified as a Pauli sum, a sum of combinations of Pauli operators,
860    /// used for a [`GateType::PauliSum`]
861    PauliSum(PauliSum),
862}
863
864impl Quil for GateSpecification {
865    fn write(
866        &self,
867        f: &mut impl std::fmt::Write,
868        fall_back_to_debug: bool,
869    ) -> crate::quil::ToQuilResult<()> {
870        match self {
871            GateSpecification::Matrix(matrix) => {
872                for row in matrix {
873                    write!(f, "{INDENT}")?;
874                    write_join_quil(f, fall_back_to_debug, row.iter(), ", ", "")?;
875                    writeln!(f)?;
876                }
877            }
878            GateSpecification::Permutation(permutation) => {
879                write!(f, "{INDENT}")?;
880                if let Some(i) = permutation.first() {
881                    write!(f, "{i}")?;
882                }
883                for i in permutation.iter().skip(1) {
884                    write!(f, ", {i}")?;
885                }
886                writeln!(f)?;
887            }
888            GateSpecification::PauliSum(pauli_sum) => {
889                for term in &pauli_sum.terms {
890                    write!(f, "{INDENT}")?;
891                    for word in term.word() {
892                        write!(f, "{word}")?;
893                    }
894                    write!(f, "(")?;
895                    term.expression.write(f, fall_back_to_debug)?;
896                    write!(f, ")")?;
897                    for argument in term.arguments() {
898                        write!(f, " {argument}")?;
899                    }
900                    writeln!(f)?;
901                }
902            }
903        }
904        Ok(())
905    }
906}
907
908/// A struct encapsulating a quil Gate Definition
909#[derive(Clone, Debug, PartialEq, Eq, Hash)]
910pub struct GateDefinition {
911    pub name: String,
912    pub parameters: Vec<String>,
913    pub specification: GateSpecification,
914}
915
916impl GateDefinition {
917    pub fn new(
918        name: String,
919        parameters: Vec<String>,
920        specification: GateSpecification,
921    ) -> Result<Self, GateError> {
922        validate_user_identifier(&name)?;
923        Ok(Self {
924            name,
925            parameters,
926            specification,
927        })
928    }
929}
930
931impl Quil for GateDefinition {
932    fn write(
933        &self,
934        f: &mut impl std::fmt::Write,
935        fall_back_to_debug: bool,
936    ) -> crate::quil::ToQuilResult<()> {
937        write!(f, "DEFGATE {}", self.name,)?;
938        write_parameter_string(f, &self.parameters)?;
939        match &self.specification {
940            GateSpecification::Matrix(_) => writeln!(f, " AS MATRIX:")?,
941            GateSpecification::Permutation(_) => writeln!(f, " AS PERMUTATION:")?,
942            GateSpecification::PauliSum(sum) => {
943                for arg in &sum.arguments {
944                    write!(f, " {arg}")?;
945                }
946                writeln!(f, " AS PAULI-SUM:")?
947            }
948        }
949        self.specification.write(f, fall_back_to_debug)?;
950        Ok(())
951    }
952}
953
954#[cfg(test)]
955mod test_gate_definition {
956    use super::{GateDefinition, GateSpecification, PauliGate, PauliSum, PauliTerm};
957    use crate::expression::{
958        Expression, ExpressionFunction, FunctionCallExpression, InfixExpression, InfixOperator,
959        PrefixExpression, PrefixOperator,
960    };
961    use crate::quil::Quil;
962    use crate::{imag, real};
963    use insta::assert_snapshot;
964    use internment::ArcIntern;
965    use rstest::rstest;
966
967    #[rstest]
968    #[case(
969        "Permutation GateDefinition",
970        GateDefinition{
971            name: "PermGate".to_string(),
972            parameters: vec![],
973            specification: GateSpecification::Permutation(vec![0, 1, 2, 3, 4, 5, 7, 6]),
974
975        }
976    )]
977    #[case(
978        "Parameterized GateDefinition",
979        GateDefinition{
980            name: "ParamGate".to_string(),
981            parameters: vec!["theta".to_string()],
982            specification: GateSpecification::Matrix(vec![
983                vec![
984                    Expression::FunctionCall(FunctionCallExpression {
985                        function: crate::expression::ExpressionFunction::Cosine,
986                        expression: ArcIntern::new(Expression::Infix(InfixExpression {
987                            left: ArcIntern::new(Expression::Variable("theta".to_string())),
988                            operator: InfixOperator::Slash,
989                            right: ArcIntern::new(Expression::Number(real!(2.0))),
990                        })),
991                    }),
992                    Expression::Infix(InfixExpression {
993                        left: ArcIntern::new(Expression::Prefix(PrefixExpression {
994                            operator: PrefixOperator::Minus,
995                            expression: ArcIntern::new(Expression::Number(imag!(1f64)))
996                        })),
997                        operator: InfixOperator::Star,
998                        right: ArcIntern::new(Expression::FunctionCall(FunctionCallExpression {
999                            function: ExpressionFunction::Sine,
1000                            expression: ArcIntern::new(Expression::Infix(InfixExpression {
1001                                left: ArcIntern::new(Expression::Variable("theta".to_string())),
1002                                operator: InfixOperator::Slash,
1003                                right: ArcIntern::new(Expression::Number(real!(2.0))),
1004                            })),
1005                        })),
1006                    })
1007                ],
1008                vec![
1009                    Expression::Infix(InfixExpression {
1010                        left: ArcIntern::new(Expression::Prefix(PrefixExpression {
1011                            operator: PrefixOperator::Minus,
1012                            expression: ArcIntern::new(Expression::Number(imag!(1f64)))
1013                        })),
1014                        operator: InfixOperator::Star,
1015                        right: ArcIntern::new(Expression::FunctionCall(FunctionCallExpression {
1016                            function: ExpressionFunction::Sine,
1017                            expression: ArcIntern::new(Expression::Infix(InfixExpression {
1018                                left: ArcIntern::new(Expression::Variable("theta".to_string())),
1019                                operator: InfixOperator::Slash,
1020                                right: ArcIntern::new(Expression::Number(real!(2.0))),
1021                            })),
1022                        })),
1023                    }),
1024                    Expression::FunctionCall(FunctionCallExpression {
1025                        function: crate::expression::ExpressionFunction::Cosine,
1026                        expression: ArcIntern::new(Expression::Infix(InfixExpression {
1027                            left: ArcIntern::new(Expression::Variable("theta".to_string())),
1028                            operator: InfixOperator::Slash,
1029                            right: ArcIntern::new(Expression::Number(real!(2.0))),
1030                        })),
1031                    }),
1032                ],
1033            ]),
1034
1035        }
1036    )]
1037    #[case(
1038        "Pauli Sum GateDefinition",
1039        GateDefinition{
1040            name: "PauliSumGate".to_string(),
1041            parameters: vec!["theta".to_string()],
1042            specification: GateSpecification::PauliSum(PauliSum{arguments: vec!["p".to_string(), "q".to_string()], terms: vec![
1043                PauliTerm {
1044                    arguments: vec![(PauliGate::Z, "p".to_string()), (PauliGate::Z, "q".to_string())],
1045                    expression: Expression::Infix(InfixExpression {
1046                        left: ArcIntern::new(Expression::Prefix(PrefixExpression {
1047                            operator: PrefixOperator::Minus,
1048                            expression: ArcIntern::new(Expression::Variable("theta".to_string()))
1049                        })),
1050                        operator: InfixOperator::Slash,
1051                        right: ArcIntern::new(Expression::Number(real!(4.0)))
1052                    }),
1053                },
1054                PauliTerm {
1055                    arguments: vec![(PauliGate::Y, "p".to_string())],
1056                    expression: Expression::Infix(InfixExpression {
1057                        left: ArcIntern::new(Expression::Variable("theta".to_string())),
1058                        operator: InfixOperator::Slash,
1059                        right: ArcIntern::new(Expression::Number(real!(4.0)))
1060                    }),
1061                },
1062                PauliTerm {
1063                    arguments: vec![(PauliGate::X, "q".to_string())],
1064                    expression: Expression::Infix(InfixExpression {
1065                        left: ArcIntern::new(Expression::Variable("theta".to_string())),
1066                        operator: InfixOperator::Slash,
1067                        right: ArcIntern::new(Expression::Number(real!(4.0)))
1068                    }),
1069                },
1070            ]})
1071        }
1072    )]
1073    fn test_display(#[case] description: &str, #[case] gate_def: GateDefinition) {
1074        insta::with_settings!({
1075            snapshot_suffix => description,
1076        }, {
1077            assert_snapshot!(gate_def.to_quil_or_debug())
1078        })
1079    }
1080}
1081
1082/// The type of a [`GateDefinition`]
1083#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1084pub enum GateType {
1085    Matrix,
1086    Permutation,
1087    PauliSum,
1088}
1089
1090impl Quil for GateType {
1091    fn write(
1092        &self,
1093        f: &mut impl std::fmt::Write,
1094        _fall_back_to_debug: bool,
1095    ) -> crate::quil::ToQuilResult<()> {
1096        match self {
1097            Self::Matrix => write!(f, "MATRIX"),
1098            Self::Permutation => write!(f, "PERMUTATION"),
1099            Self::PauliSum => write!(f, "PAULI-SUM"),
1100        }
1101        .map_err(Into::into)
1102    }
1103}