quil_rs/expression/
mod.rs

1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{
16    hash::hash_f64,
17    imag,
18    instruction::MemoryReference,
19    parser::{lex, parse_expression, ParseError},
20    program::{disallow_leftover, ParseProgramError},
21    quil::Quil,
22    real,
23};
24use lexical::{format, to_string_with_options, WriteFloatOptions};
25use nom_locate::LocatedSpan;
26use num_complex::Complex64;
27use once_cell::sync::Lazy;
28use std::{
29    collections::HashMap,
30    f64::consts::PI,
31    fmt,
32    hash::{Hash, Hasher},
33    num::NonZeroI32,
34    ops::{Add, AddAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
35    str::FromStr,
36};
37
38#[cfg(test)]
39use proptest_derive::Arbitrary;
40
41mod simplification;
42
43/// The different possible types of errors that could occur during expression evaluation.
44#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
45pub enum EvaluationError {
46    #[error("There wasn't enough information to completely evaluate the expression.")]
47    Incomplete,
48    #[error("The operation expected a real number but received a complex one.")]
49    NumberNotReal,
50    #[error("The operation expected a number but received a different type of expression.")]
51    NotANumber,
52}
53
54/// The type of Quil expressions.
55///
56/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
57/// NaNs, not unequal, in contravention of the IEEE 754 spec.
58#[derive(Clone, Debug)]
59pub enum Expression {
60    Address(MemoryReference),
61    FunctionCall(FunctionCallExpression),
62    Infix(InfixExpression),
63    Number(Complex64),
64    PiConstant,
65    Prefix(PrefixExpression),
66    Variable(String),
67}
68
69/// The type of function call Quil expressions, e.g. `sin(e)`.
70///
71/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
72/// NaNs, not unequal, in contravention of the IEEE 754 spec.
73#[derive(Clone, Debug, PartialEq, Eq, Hash)]
74pub struct FunctionCallExpression {
75    pub function: ExpressionFunction,
76    pub expression: Box<Expression>,
77}
78
79impl FunctionCallExpression {
80    pub fn new(function: ExpressionFunction, expression: Box<Expression>) -> Self {
81        Self {
82            function,
83            expression,
84        }
85    }
86}
87
88/// The type of infix Quil expressions, e.g. `e1 + e2`.
89///
90/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
91/// NaNs, not unequal, in contravention of the IEEE 754 spec.
92#[derive(Clone, Debug, PartialEq, Eq, Hash)]
93pub struct InfixExpression {
94    pub left: Box<Expression>,
95    pub operator: InfixOperator,
96    pub right: Box<Expression>,
97}
98
99impl InfixExpression {
100    pub fn new(left: Box<Expression>, operator: InfixOperator, right: Box<Expression>) -> Self {
101        Self {
102            left,
103            operator,
104            right,
105        }
106    }
107}
108
109/// The type of prefix Quil expressions, e.g. `-e`.
110///
111/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
112/// NaNs, not unequal, in contravention of the IEEE 754 spec.
113#[derive(Clone, Debug, PartialEq, Eq, Hash)]
114pub struct PrefixExpression {
115    pub operator: PrefixOperator,
116    pub expression: Box<Expression>,
117}
118
119impl PrefixExpression {
120    pub fn new(operator: PrefixOperator, expression: Box<Expression>) -> Self {
121        Self {
122            operator,
123            expression,
124        }
125    }
126}
127
128impl PartialEq for Expression {
129    // Implemented by hand since we can't derive with f64s hidden inside.
130    fn eq(&self, other: &Self) -> bool {
131        match (self, other) {
132            (Self::Address(left), Self::Address(right)) => left == right,
133            (Self::Infix(left), Self::Infix(right)) => left == right,
134            (Self::Number(left), Self::Number(right)) => {
135                (left.re == right.re || left.re.is_nan() && right.re.is_nan())
136                    && (left.im == right.im || left.im.is_nan() && right.im.is_nan())
137            }
138            (Self::Prefix(left), Self::Prefix(right)) => left == right,
139            (Self::FunctionCall(left), Self::FunctionCall(right)) => left == right,
140            (Self::Variable(left), Self::Variable(right)) => left == right,
141            (Self::PiConstant, Self::PiConstant) => true,
142            _ => false,
143        }
144    }
145}
146
147// Implemented by hand since we can't derive with f64s hidden inside.
148impl Eq for Expression {}
149
150impl Hash for Expression {
151    // Implemented by hand since we can't derive with f64s hidden inside.
152    fn hash<H: Hasher>(&self, state: &mut H) {
153        match self {
154            Self::Address(m) => {
155                "Address".hash(state);
156                m.hash(state);
157            }
158            Self::FunctionCall(FunctionCallExpression {
159                function,
160                expression,
161            }) => {
162                "FunctionCall".hash(state);
163                function.hash(state);
164                expression.hash(state);
165            }
166            Self::Infix(InfixExpression {
167                left,
168                operator,
169                right,
170            }) => {
171                "Infix".hash(state);
172                operator.hash(state);
173                left.hash(state);
174                right.hash(state);
175            }
176            Self::Number(n) => {
177                "Number".hash(state);
178                // Skip zero values (akin to `format_complex`).
179                if n.re.abs() > 0f64 {
180                    hash_f64(n.re, state)
181                }
182                if n.im.abs() > 0f64 {
183                    hash_f64(n.im, state)
184                }
185            }
186            Self::PiConstant => {
187                "PiConstant".hash(state);
188            }
189            Self::Prefix(p) => {
190                "Prefix".hash(state);
191                p.operator.hash(state);
192                p.expression.hash(state);
193            }
194            Self::Variable(v) => {
195                "Variable".hash(state);
196                v.hash(state);
197            }
198        }
199    }
200}
201
202macro_rules! impl_expr_op {
203    ($name:ident, $name_assign:ident, $function:ident, $function_assign:ident, $operator:ident) => {
204        impl $name for Expression {
205            type Output = Self;
206            fn $function(self, other: Self) -> Self {
207                Expression::Infix(InfixExpression {
208                    left: Box::new(self),
209                    operator: InfixOperator::$operator,
210                    right: Box::new(other),
211                })
212            }
213        }
214        impl $name_assign for Expression {
215            fn $function_assign(&mut self, other: Self) {
216                // Move out of self to avoid potentially cloning a large value
217                let temp = ::std::mem::replace(self, Self::PiConstant);
218                *self = temp.$function(other);
219            }
220        }
221    };
222}
223
224impl_expr_op!(BitXor, BitXorAssign, bitxor, bitxor_assign, Caret);
225impl_expr_op!(Add, AddAssign, add, add_assign, Plus);
226impl_expr_op!(Sub, SubAssign, sub, sub_assign, Minus);
227impl_expr_op!(Mul, MulAssign, mul, mul_assign, Star);
228impl_expr_op!(Div, DivAssign, div, div_assign, Slash);
229
230/// Compute the result of an infix expression where both operands are complex.
231fn calculate_infix(left: &Complex64, operator: &InfixOperator, right: &Complex64) -> Complex64 {
232    use InfixOperator::*;
233    match operator {
234        Caret => left.powc(*right),
235        Plus => left + right,
236        Minus => left - right,
237        Slash => left / right,
238        Star => left * right,
239    }
240}
241
242/// Compute the result of a Quil-defined expression function where the operand is complex.
243fn calculate_function(function: &ExpressionFunction, argument: &Complex64) -> Complex64 {
244    use ExpressionFunction::*;
245    match function {
246        Sine => argument.sin(),
247        Cis => argument.cos() + imag!(1f64) * argument.sin(),
248        Cosine => argument.cos(),
249        Exponent => argument.exp(),
250        SquareRoot => argument.sqrt(),
251    }
252}
253
254/// Is this a small floating point number?
255#[inline(always)]
256fn is_small(x: f64) -> bool {
257    x.abs() < 1e-16
258}
259
260impl Expression {
261    /// Simplify the expression as much as possible, in-place.
262    ///
263    /// # Example
264    ///
265    /// ```rust
266    /// use quil_rs::expression::Expression;
267    /// use std::str::FromStr;
268    /// use num_complex::Complex64;
269    ///
270    /// let mut expression = Expression::from_str("cos(2 * pi) + 2").unwrap();
271    /// expression.simplify();
272    ///
273    /// assert_eq!(expression, Expression::Number(Complex64::from(3.0)));
274    /// ```
275    pub fn simplify(&mut self) {
276        match self {
277            Expression::Address(_) | Expression::Number(_) | Expression::Variable(_) => {}
278            Expression::PiConstant => {
279                *self = Expression::Number(Complex64::from(PI));
280            }
281            _ => *self = simplification::run(self),
282        }
283    }
284
285    /// Consume the expression, simplifying it as much as possible.
286    ///
287    /// # Example
288    ///
289    /// ```rust
290    /// use quil_rs::expression::Expression;
291    /// use std::str::FromStr;
292    /// use num_complex::Complex64;
293    ///
294    /// let simplified = Expression::from_str("cos(2 * pi) + 2").unwrap().into_simplified();
295    ///
296    /// assert_eq!(simplified, Expression::Number(Complex64::from(3.0)));
297    /// ```
298    pub fn into_simplified(mut self) -> Self {
299        self.simplify();
300        self
301    }
302
303    /// Evaluate an expression, expecting that it may be fully reduced to a single complex number.
304    /// If it cannot be reduced to a complex number, return an error.
305    ///
306    /// # Example
307    ///
308    /// ```rust
309    /// use quil_rs::expression::Expression;
310    /// use std::str::FromStr;
311    /// use std::collections::HashMap;
312    /// use num_complex::Complex64;
313    ///
314    /// let expression = Expression::from_str("%beta + theta[0]").unwrap();
315    ///
316    /// let mut variables = HashMap::with_capacity(1);
317    /// variables.insert(String::from("beta"), Complex64::from(1.0));
318    ///
319    /// let mut memory_references = HashMap::with_capacity(1);
320    /// memory_references.insert("theta", vec![2.0]);
321    ///
322    /// let evaluated = expression.evaluate(&variables, &memory_references).unwrap();
323    ///
324    /// assert_eq!(evaluated, Complex64::from(3.0))
325    /// ```
326    pub fn evaluate(
327        &self,
328        variables: &HashMap<String, Complex64>,
329        memory_references: &HashMap<&str, Vec<f64>>,
330    ) -> Result<Complex64, EvaluationError> {
331        use Expression::*;
332
333        match self {
334            FunctionCall(FunctionCallExpression {
335                function,
336                expression,
337            }) => {
338                let evaluated = expression.evaluate(variables, memory_references)?;
339                Ok(calculate_function(function, &evaluated))
340            }
341            Infix(InfixExpression {
342                left,
343                operator,
344                right,
345            }) => {
346                let left_evaluated = left.evaluate(variables, memory_references)?;
347                let right_evaluated = right.evaluate(variables, memory_references)?;
348                Ok(calculate_infix(&left_evaluated, operator, &right_evaluated))
349            }
350            Prefix(PrefixExpression {
351                operator,
352                expression,
353            }) => {
354                use PrefixOperator::*;
355                let value = expression.evaluate(variables, memory_references)?;
356                if matches!(operator, Minus) {
357                    Ok(-value)
358                } else {
359                    Ok(value)
360                }
361            }
362            Variable(identifier) => match variables.get(identifier.as_str()) {
363                Some(value) => Ok(*value),
364                None => Err(EvaluationError::Incomplete),
365            },
366            Address(memory_reference) => memory_references
367                .get(memory_reference.name.as_str())
368                .and_then(|values| {
369                    let value = values.get(memory_reference.index as usize)?;
370                    Some(real!(*value))
371                })
372                .ok_or(EvaluationError::Incomplete),
373            PiConstant => Ok(real!(PI)),
374            Number(number) => Ok(*number),
375        }
376    }
377
378    /// Substitute an expression in the place of each matching variable.
379    /// Consumes the expression and returns a new one.
380    ///
381    /// # Example
382    ///
383    /// ```rust
384    /// use quil_rs::expression::Expression;
385    /// use std::str::FromStr;
386    /// use std::collections::HashMap;
387    /// use num_complex::Complex64;
388    ///
389    /// let expression = Expression::from_str("%x + %y").unwrap();
390    ///
391    /// let mut variables = HashMap::with_capacity(1);
392    /// variables.insert(String::from("x"), Expression::Number(Complex64::from(1.0)));
393    ///
394    /// let evaluated = expression.substitute_variables(&variables);
395    ///
396    /// assert_eq!(evaluated, Expression::from_str("1.0 + %y").unwrap())
397    /// ```
398    pub fn substitute_variables(self, variable_values: &HashMap<String, Expression>) -> Self {
399        use Expression::*;
400
401        match self {
402            FunctionCall(FunctionCallExpression {
403                function,
404                expression,
405            }) => Expression::FunctionCall(FunctionCallExpression {
406                function,
407                expression: expression.substitute_variables(variable_values).into(),
408            }),
409            Infix(InfixExpression {
410                left,
411                operator,
412                right,
413            }) => {
414                let left = left.substitute_variables(variable_values).into();
415                let right = right.substitute_variables(variable_values).into();
416                Infix(InfixExpression {
417                    left,
418                    operator,
419                    right,
420                })
421            }
422            Prefix(PrefixExpression {
423                operator,
424                expression,
425            }) => Prefix(PrefixExpression {
426                operator,
427                expression: expression.substitute_variables(variable_values).into(),
428            }),
429            Variable(identifier) => match variable_values.get(identifier.as_str()) {
430                Some(value) => value.clone(),
431                None => Variable(identifier),
432            },
433            other => other,
434        }
435    }
436
437    /// If this is a number with imaginary part "equal to" zero (of _small_ absolute value), return
438    /// that number. Otherwise, error with an evaluation error of a descriptive type.
439    pub fn to_real(&self) -> Result<f64, EvaluationError> {
440        match self {
441            Expression::PiConstant => Ok(PI),
442            Expression::Number(x) if is_small(x.im) => Ok(x.re),
443            Expression::Number(_) => Err(EvaluationError::NumberNotReal),
444            _ => Err(EvaluationError::NotANumber),
445        }
446    }
447}
448
449impl FromStr for Expression {
450    type Err = ParseProgramError<Self>;
451
452    fn from_str(s: &str) -> Result<Self, Self::Err> {
453        let input = LocatedSpan::new(s);
454        let tokens = lex(input)?;
455        disallow_leftover(parse_expression(&tokens).map_err(ParseError::from_nom_internal_err))
456    }
457}
458
459static FORMAT_REAL_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
460    WriteFloatOptions::builder()
461        .negative_exponent_break(NonZeroI32::new(-5))
462        .positive_exponent_break(NonZeroI32::new(15))
463        .trim_floats(true)
464        .build()
465        .expect("options are valid")
466});
467
468static FORMAT_IMAGINARY_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
469    WriteFloatOptions::builder()
470        .negative_exponent_break(NonZeroI32::new(-5))
471        .positive_exponent_break(NonZeroI32::new(15))
472        .trim_floats(false) // Per the quil spec, the imaginary part of a complex number is always a floating point number
473        .build()
474        .expect("options are valid")
475});
476
477/// Format a num_complex::Complex64 value in a way that omits the real or imaginary part when
478/// reasonable. That is:
479///
480/// - When imaginary is set but real is 0, show only imaginary
481/// - When imaginary is 0, show real only
482/// - When both are non-zero, show with the correct operator in between
483#[inline(always)]
484pub(crate) fn format_complex(value: &Complex64) -> String {
485    const FORMAT: u128 = format::STANDARD;
486    if value.re == 0f64 && value.im == 0f64 {
487        "0".to_owned()
488    } else if value.im == 0f64 {
489        to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS)
490    } else if value.re == 0f64 {
491        to_string_with_options::<_, FORMAT>(value.im, &FORMAT_IMAGINARY_OPTIONS) + "i"
492    } else {
493        let mut out = to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS);
494        if value.im > 0f64 {
495            out.push('+')
496        }
497        out.push_str(&to_string_with_options::<_, FORMAT>(
498            value.im,
499            &FORMAT_IMAGINARY_OPTIONS,
500        ));
501        out.push('i');
502        out
503    }
504}
505
506impl Quil for Expression {
507    fn write(
508        &self,
509        f: &mut impl std::fmt::Write,
510        fall_back_to_debug: bool,
511    ) -> Result<(), crate::quil::ToQuilError> {
512        use Expression::*;
513        match self {
514            Address(memory_reference) => memory_reference.write(f, fall_back_to_debug),
515            FunctionCall(FunctionCallExpression {
516                function,
517                expression,
518            }) => {
519                write!(f, "{function}(")?;
520                expression.write(f, fall_back_to_debug)?;
521                write!(f, ")")?;
522                Ok(())
523            }
524            Infix(InfixExpression {
525                left,
526                operator,
527                right,
528            }) => {
529                format_inner_expression(f, fall_back_to_debug, left)?;
530                write!(f, "{}", operator)?;
531                format_inner_expression(f, fall_back_to_debug, right)
532            }
533            Number(value) => write!(f, "{}", format_complex(value)).map_err(Into::into),
534            PiConstant => write!(f, "pi").map_err(Into::into),
535            Prefix(PrefixExpression {
536                operator,
537                expression,
538            }) => {
539                write!(f, "{}", operator)?;
540                format_inner_expression(f, fall_back_to_debug, expression)
541            }
542            Variable(identifier) => write!(f, "%{}", identifier).map_err(Into::into),
543        }
544    }
545}
546
547/// Utility function to wrap infix expressions that are part of an expression in parentheses, so
548/// that correct precedence rules are enforced.
549fn format_inner_expression(
550    f: &mut impl std::fmt::Write,
551    fall_back_to_debug: bool,
552    expression: &Expression,
553) -> crate::quil::ToQuilResult<()> {
554    match expression {
555        Expression::Infix(InfixExpression {
556            left,
557            operator,
558            right,
559        }) => {
560            write!(f, "(")?;
561            format_inner_expression(f, fall_back_to_debug, left)?;
562            write!(f, "{operator}")?;
563            format_inner_expression(f, fall_back_to_debug, right)?;
564            write!(f, ")")?;
565            Ok(())
566        }
567        _ => expression.write(f, fall_back_to_debug),
568    }
569}
570
571#[cfg(test)]
572mod test {
573    use crate::{
574        expression::{
575            Expression, InfixExpression, InfixOperator, PrefixExpression, PrefixOperator,
576        },
577        quil::Quil,
578        real,
579    };
580
581    #[test]
582    fn formats_nested_expression() {
583        let expression = Expression::Infix(InfixExpression {
584            left: Box::new(Expression::Prefix(PrefixExpression {
585                operator: PrefixOperator::Minus,
586                expression: Box::new(Expression::Number(real!(3f64))),
587            })),
588            operator: InfixOperator::Star,
589            right: Box::new(Expression::Infix(InfixExpression {
590                left: Box::new(Expression::PiConstant),
591                operator: InfixOperator::Slash,
592                right: Box::new(Expression::Number(real!(2f64))),
593            })),
594        });
595
596        assert_eq!(expression.to_quil_or_debug(), "-3*(pi/2)");
597    }
598}
599
600/// A function defined within Quil syntax.
601#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
602#[cfg_attr(test, derive(Arbitrary))]
603pub enum ExpressionFunction {
604    Cis,
605    Cosine,
606    Exponent,
607    Sine,
608    SquareRoot,
609}
610
611impl fmt::Display for ExpressionFunction {
612    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
613        use ExpressionFunction::*;
614        write!(
615            f,
616            "{}",
617            match self {
618                Cis => "cis",
619                Cosine => "cos",
620                Exponent => "exp",
621                Sine => "sin",
622                SquareRoot => "sqrt",
623            }
624        )
625    }
626}
627
628#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
629#[cfg_attr(test, derive(Arbitrary))]
630pub enum PrefixOperator {
631    Plus,
632    Minus,
633}
634
635impl fmt::Display for PrefixOperator {
636    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
637        use PrefixOperator::*;
638        write!(
639            f,
640            "{}",
641            match self {
642                // NOTE: prefix Plus does nothing but cause parsing issues
643                Plus => "",
644                Minus => "-",
645            }
646        )
647    }
648}
649
650#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
651#[cfg_attr(test, derive(Arbitrary))]
652pub enum InfixOperator {
653    Caret,
654    Plus,
655    Minus,
656    Slash,
657    Star,
658}
659
660impl fmt::Display for InfixOperator {
661    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
662        use InfixOperator::*;
663        write!(
664            f,
665            "{}",
666            match self {
667                Caret => "^",
668                Plus => "+",
669                // NOTE: spaces included to distinguish from hyphenated identifiers
670                Minus => " - ",
671                Slash => "/",
672                Star => "*",
673            }
674        )
675    }
676}
677
678#[cfg(test)]
679// This lint should be re-enabled once this proptest issue is resolved
680// https://github.com/proptest-rs/proptest/issues/364
681#[allow(clippy::arc_with_non_send_sync)]
682mod tests {
683    use super::*;
684    use crate::reserved::ReservedToken;
685    use proptest::prelude::*;
686    use std::collections::hash_map::DefaultHasher;
687    use std::collections::HashSet;
688
689    /// Hash value helper: turn a hashable thing into a u64.
690    #[inline]
691    fn hash_to_u64<T: Hash>(t: &T) -> u64 {
692        let mut s = DefaultHasher::new();
693        t.hash(&mut s);
694        s.finish()
695    }
696
697    #[test]
698    fn simplify_and_evaluate() {
699        use Expression::*;
700
701        let one = real!(1.0);
702        let empty_variables = HashMap::new();
703
704        let mut variables = HashMap::new();
705        variables.insert("foo".to_owned(), real!(10f64));
706        variables.insert("bar".to_owned(), real!(100f64));
707
708        let empty_memory = HashMap::new();
709
710        let mut memory_references = HashMap::new();
711        memory_references.insert("theta", vec![1.0, 2.0]);
712        memory_references.insert("beta", vec![3.0, 4.0]);
713
714        struct TestCase<'a> {
715            expression: Expression,
716            variables: &'a HashMap<String, Complex64>,
717            memory_references: &'a HashMap<&'a str, Vec<f64>>,
718            simplified: Expression,
719            evaluated: Result<Complex64, EvaluationError>,
720        }
721
722        let cases: Vec<TestCase> = vec![
723            TestCase {
724                expression: Number(one),
725                variables: &empty_variables,
726                memory_references: &empty_memory,
727                simplified: Number(one),
728                evaluated: Ok(one),
729            },
730            TestCase {
731                expression: Expression::Prefix(PrefixExpression {
732                    operator: PrefixOperator::Minus,
733                    expression: Box::new(Number(real!(1f64))),
734                }),
735                variables: &empty_variables,
736                memory_references: &empty_memory,
737                simplified: Number(real!(-1f64)),
738                evaluated: Ok(real!(-1f64)),
739            },
740            TestCase {
741                expression: Expression::Variable("foo".to_owned()),
742                variables: &variables,
743                memory_references: &empty_memory,
744                simplified: Expression::Variable("foo".to_owned()),
745                evaluated: Ok(real!(10f64)),
746            },
747            TestCase {
748                expression: Expression::from_str("%foo + %bar").unwrap(),
749                variables: &variables,
750                memory_references: &empty_memory,
751                simplified: Expression::from_str("%foo + %bar").unwrap(),
752                evaluated: Ok(real!(110f64)),
753            },
754            TestCase {
755                expression: Expression::FunctionCall(FunctionCallExpression {
756                    function: ExpressionFunction::Sine,
757                    expression: Box::new(Expression::Number(real!(PI / 2f64))),
758                }),
759                variables: &variables,
760                memory_references: &empty_memory,
761                simplified: Number(real!(1f64)),
762                evaluated: Ok(real!(1f64)),
763            },
764            TestCase {
765                expression: Expression::from_str("theta[1] * beta[0]").unwrap(),
766                variables: &empty_variables,
767                memory_references: &memory_references,
768                simplified: Expression::from_str("theta[1] * beta[0]").unwrap(),
769                evaluated: Ok(real!(6.0)),
770            },
771        ];
772
773        for mut case in cases {
774            let evaluated = case
775                .expression
776                .evaluate(case.variables, case.memory_references);
777            assert_eq!(evaluated, case.evaluated);
778
779            case.expression.simplify();
780            assert_eq!(case.expression, case.simplified);
781        }
782    }
783
784    /// Parenthesized version of [`Expression::to_string()`]
785    fn parenthesized(expression: &Expression) -> String {
786        use Expression::*;
787        match expression {
788            Address(memory_reference) => memory_reference.to_quil_or_debug(),
789            FunctionCall(FunctionCallExpression {
790                function,
791                expression,
792            }) => format!("({function}({}))", parenthesized(expression)),
793            Infix(InfixExpression {
794                left,
795                operator,
796                right,
797            }) => format!(
798                "({}{}{})",
799                parenthesized(left),
800                operator,
801                parenthesized(right)
802            ),
803            Number(value) => format!("({})", format_complex(value)),
804            PiConstant => "pi".to_string(),
805            Prefix(PrefixExpression {
806                operator,
807                expression,
808            }) => format!("({}{})", operator, parenthesized(expression)),
809            Variable(identifier) => format!("(%{identifier})"),
810        }
811    }
812
813    // Better behaved than the auto-derived version for names
814    fn arb_name() -> impl Strategy<Value = String> {
815        r"[a-z][a-zA-Z0-9]{1,10}".prop_filter("Exclude reserved tokens", |t| {
816            ReservedToken::from_str(t).is_err() && !t.to_lowercase().starts_with("nan")
817        })
818    }
819
820    // Better behaved than the auto-derived version re: names & indices
821    fn arb_memory_reference() -> impl Strategy<Value = MemoryReference> {
822        (arb_name(), (u64::MIN..u32::MAX as u64))
823            .prop_map(|(name, index)| MemoryReference { name, index })
824    }
825
826    // Better behaved than the auto-derived version via arbitrary floats
827    fn arb_complex64() -> impl Strategy<Value = Complex64> {
828        let tau = std::f64::consts::TAU;
829        ((-tau..tau), (-tau..tau)).prop_map(|(re, im)| Complex64 { re, im })
830    }
831
832    /// Filter an Expression to not be constantly zero.
833    fn nonzero(strat: impl Strategy<Value = Expression>) -> impl Strategy<Value = Expression> {
834        strat.prop_filter("Exclude constantly-zero expressions", |expr| {
835            expr.clone().into_simplified() != Expression::Number(Complex64::new(0.0, 0.0))
836        })
837    }
838
839    /// Generate an arbitrary Expression for a property test.
840    /// See https://docs.rs/proptest/1.0.0/proptest/prelude/trait.Strategy.html#method.prop_recursive
841    fn arb_expr() -> impl Strategy<Value = Expression> {
842        use Expression::*;
843        let leaf = prop_oneof![
844            arb_memory_reference().prop_map(Address),
845            arb_complex64().prop_map(Number),
846            Just(PiConstant),
847            arb_name().prop_map(Variable),
848        ];
849        leaf.prop_recursive(
850            4,  // No more than 4 branch levels deep
851            64, // Target around 64 total nodes
852            16, // Each "collection" is up to 16 elements
853            |expr| {
854                let inner = expr.clone();
855                prop_oneof![
856                    (any::<ExpressionFunction>(), expr.clone()).prop_map(|(function, e)| {
857                        Expression::FunctionCall(FunctionCallExpression {
858                            function,
859                            expression: Box::new(e),
860                        })
861                    }),
862                    (expr.clone(), any::<InfixOperator>())
863                        .prop_flat_map(move |(left, operator)| (
864                            Just(left),
865                            Just(operator),
866                            // Avoid division by 0 so that we can reliably assert equality
867                            if let InfixOperator::Slash = operator {
868                                nonzero(inner.clone()).boxed()
869                            } else {
870                                inner.clone().boxed()
871                            }
872                        ))
873                        .prop_map(|(l, operator, r)| Infix(InfixExpression {
874                            left: Box::new(l),
875                            operator,
876                            right: Box::new(r)
877                        })),
878                    expr.prop_map(|e| Prefix(PrefixExpression {
879                        operator: PrefixOperator::Minus,
880                        expression: Box::new(e)
881                    }))
882                ]
883            },
884        )
885    }
886
887    proptest! {
888
889        #[test]
890        fn eq(a in any::<f64>(), b in any::<f64>()) {
891            let first = Expression::Infix (InfixExpression {
892                left: Box::new(Expression::Number(real!(a))),
893                operator: InfixOperator::Plus,
894                right: Box::new(Expression::Number(real!(b))),
895            } );
896            let differing = Expression::Number(real!(a + b));
897            prop_assert_eq!(&first, &first);
898            prop_assert_ne!(&first, &differing);
899        }
900
901        #[test]
902        fn hash(a in any::<f64>(), b in any::<f64>()) {
903            let first = Expression::Infix (InfixExpression {
904                left: Box::new(Expression::Number(real!(a))),
905                operator: InfixOperator::Plus,
906                right: Box::new(Expression::Number(real!(b))),
907            });
908            let matching = first.clone();
909            let differing = Expression::Number(real!(a + b));
910            let mut set = HashSet::new();
911            set.insert(first);
912            assert!(set.contains(&matching));
913            assert!(!set.contains(&differing))
914        }
915
916        #[test]
917        fn eq_iff_hash_eq(x in arb_expr(), y in arb_expr()) {
918            prop_assert_eq!(x == y, hash_to_u64(&x) == hash_to_u64(&y));
919        }
920
921        #[test]
922        fn reals_are_real(x in any::<f64>()) {
923            prop_assert_eq!(Expression::Number(real!(x)).to_real(), Ok(x))
924        }
925
926        #[test]
927        fn some_nums_are_real(re in any::<f64>(), im in any::<f64>()) {
928            let result = Expression::Number(Complex64{re, im}).to_real();
929            if is_small(im) {
930                prop_assert_eq!(result, Ok(re))
931            } else {
932                prop_assert_eq!(result, Err(EvaluationError::NumberNotReal))
933            }
934        }
935
936        #[test]
937        fn no_other_exps_are_real(expr in arb_expr().prop_filter("Not numbers", |e| !matches!(e, Expression::Number(_) | Expression::PiConstant))) {
938            prop_assert_eq!(expr.to_real(), Err(EvaluationError::NotANumber))
939        }
940
941        #[test]
942        fn complexes_are_parseable_as_expressions(value in arb_complex64()) {
943            let parsed = Expression::from_str(&format_complex(&value));
944            assert!(parsed.is_ok());
945            let simple = parsed.unwrap().into_simplified();
946            assert_eq!(Expression::Number(value), simple);
947        }
948
949        #[test]
950        fn exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
951            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Caret, right: Box::new(right.clone()) } );
952            prop_assert_eq!(left ^ right, expected);
953        }
954
955        #[test]
956        fn in_place_exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
957            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Caret, right: Box::new(right.clone()) } );
958            let mut x = left;
959            x ^= right;
960            prop_assert_eq!(x, expected);
961        }
962
963        #[test]
964        fn addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
965            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Plus, right: Box::new(right.clone()) } );
966            prop_assert_eq!(left + right, expected);
967        }
968
969        #[test]
970        fn in_place_addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
971            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Plus, right: Box::new(right.clone()) } );
972            let mut x = left;
973            x += right;
974            prop_assert_eq!(x, expected);
975        }
976
977        #[test]
978        fn subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
979            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Minus, right: Box::new(right.clone()) } );
980            prop_assert_eq!(left - right, expected);
981        }
982
983        #[test]
984        fn in_place_subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
985            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Minus, right: Box::new(right.clone()) } );
986            let mut x = left;
987            x -= right;
988            prop_assert_eq!(x, expected);
989        }
990
991        #[test]
992        fn multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
993            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Star, right: Box::new(right.clone()) } );
994            prop_assert_eq!(left * right, expected);
995        }
996
997        #[test]
998        fn in_place_multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
999            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Star, right: Box::new(right.clone()) } );
1000            let mut x = left;
1001            x *= right;
1002            prop_assert_eq!(x, expected);
1003        }
1004
1005
1006        // Avoid division by 0 so that we can reliably assert equality
1007        #[test]
1008        fn division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1009            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Slash, right: Box::new(right.clone()) } );
1010            prop_assert_eq!(left / right, expected);
1011        }
1012
1013        // Avoid division by 0 so that we can reliably assert equality
1014        #[test]
1015        fn in_place_division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1016            let expected = Expression::Infix (InfixExpression { left: Box::new(left.clone()), operator: InfixOperator::Slash, right: Box::new(right.clone()) } );
1017            let mut x = left;
1018            x /= right;
1019            prop_assert_eq!(x, expected);
1020        }
1021
1022        // Redundant clone: clippy does not correctly introspect the prop_assert_eq! macro
1023        #[allow(clippy::redundant_clone)]
1024        #[test]
1025        fn round_trip(e in arb_expr()) {
1026            let simple_e = e.clone().into_simplified();
1027            let s = parenthesized(&e);
1028            let p = Expression::from_str(&s);
1029            prop_assert!(p.is_ok());
1030            let p = p.unwrap();
1031            let simple_p = p.clone().into_simplified();
1032
1033            prop_assert_eq!(
1034                simple_p.clone(),
1035                simple_e.clone(),
1036                "Simplified expressions should be equal:\nparenthesized {p} ({p:?}) extracted from {s} simplified to {simple_p}\nvs original {e} ({e:?}) simplified to {simple_e}",
1037                p=p.to_quil_or_debug(),
1038                s=s,
1039                e=e.to_quil_or_debug(),
1040                simple_p=simple_p.to_quil_or_debug(),
1041                simple_e=simple_e.to_quil_or_debug()
1042            );
1043        }
1044
1045    }
1046
1047    /// Assert that certain selected expressions are parsed and re-written to string
1048    /// in exactly the same way.
1049    #[test]
1050    fn specific_round_trip_tests() {
1051        for input in &[
1052            "-1*(phases[0]+phases[1])",
1053            "(-1*(phases[0]+phases[1]))+(-1*(phases[0]+phases[1]))",
1054        ] {
1055            let parsed = Expression::from_str(input);
1056            let parsed = parsed.unwrap();
1057            let restring = parsed.to_quil_or_debug();
1058            assert_eq!(input, &restring);
1059        }
1060    }
1061
1062    #[test]
1063    fn test_nan_is_equal() {
1064        let left = Expression::Number(f64::NAN.into());
1065        let right = left.clone();
1066        assert_eq!(left, right);
1067    }
1068
1069    #[test]
1070    fn specific_simplification_tests() {
1071        for (input, expected) in [
1072            ("pi", Expression::Number(PI.into())),
1073            ("pi/2", Expression::Number((PI / 2.0).into())),
1074            ("pi * pi", Expression::Number((PI.powi(2)).into())),
1075            ("1.0/(1.0-1.0)", Expression::Number(f64::NAN.into())),
1076            (
1077                "(a[0]*2*pi)/6.283185307179586",
1078                Expression::Address(MemoryReference {
1079                    name: String::from("a"),
1080                    index: 0,
1081                }),
1082            ),
1083        ] {
1084            assert_eq!(
1085                Expression::from_str(input).unwrap().into_simplified(),
1086                expected
1087            )
1088        }
1089    }
1090
1091    #[test]
1092    fn specific_to_real_tests() {
1093        for (input, expected) in [
1094            (Expression::PiConstant, Ok(PI)),
1095            (Expression::Number(Complex64 { re: 1.0, im: 0.0 }), Ok(1.0)),
1096            (
1097                Expression::Number(Complex64 { re: 1.0, im: 1.0 }),
1098                Err(EvaluationError::NumberNotReal),
1099            ),
1100            (
1101                Expression::Variable("Not a number".into()),
1102                Err(EvaluationError::NotANumber),
1103            ),
1104        ] {
1105            assert_eq!(input.to_real(), expected)
1106        }
1107    }
1108
1109    #[test]
1110    fn specific_format_complex_tests() {
1111        for (x, s) in &[
1112            (Complex64::new(0.0, 0.0), "0"),
1113            (Complex64::new(-0.0, 0.0), "0"),
1114            (Complex64::new(-0.0, -0.0), "0"),
1115            (Complex64::new(0.0, 1.0), "1.0i"),
1116            (Complex64::new(1.0, -1.0), "1-1.0i"),
1117            (Complex64::new(1.234, 0.0), "1.234"),
1118            (Complex64::new(0.0, 1.234), "1.234i"),
1119            (Complex64::new(-1.234, 0.0), "-1.234"),
1120            (Complex64::new(0.0, -1.234), "-1.234i"),
1121            (Complex64::new(1.234, 5.678), "1.234+5.678i"),
1122            (Complex64::new(-1.234, 5.678), "-1.234+5.678i"),
1123            (Complex64::new(1.234, -5.678), "1.234-5.678i"),
1124            (Complex64::new(-1.234, -5.678), "-1.234-5.678i"),
1125            (Complex64::new(1e100, 2e-100), "1e100+2.0e-100i"),
1126        ] {
1127            assert_eq!(format_complex(x), *s);
1128        }
1129    }
1130}