quil_rs/expression/
mod.rs

1// Copyright 2021 Rigetti Computing
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{
16    hash::hash_f64,
17    imag,
18    instruction::MemoryReference,
19    parser::{lex, parse_expression, ParseError},
20    program::{disallow_leftover, ParseProgramError},
21    quil::Quil,
22    real,
23};
24use internment::ArcIntern;
25use lexical::{format, to_string_with_options, WriteFloatOptions};
26use nom_locate::LocatedSpan;
27use num_complex::Complex64;
28use once_cell::sync::Lazy;
29use std::{
30    collections::HashMap,
31    f64::consts::PI,
32    fmt,
33    hash::{Hash, Hasher},
34    num::NonZeroI32,
35    ops::{
36        Add, AddAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
37    },
38    str::FromStr,
39};
40
41#[cfg(test)]
42use proptest_derive::Arbitrary;
43
44mod simplification;
45
46/// The different possible types of errors that could occur during expression evaluation.
47#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
48pub enum EvaluationError {
49    #[error("There wasn't enough information to completely evaluate the expression.")]
50    Incomplete,
51    #[error("The operation expected a real number but received a complex one.")]
52    NumberNotReal,
53    #[error("The operation expected a number but received a different type of expression.")]
54    NotANumber,
55}
56
57/// The type of Quil expressions.
58///
59/// Quil expressions take advantage of *structural sharing*; if a Quil expression contains the same
60/// subexpression twice, such as `x + y` in `(x + y) * (x + y)`, the two children of the `*` node
61/// will be the *same pointer*.  This is implemented through *interning*, also known as
62/// *hash-consing*; the recursive references to child nodes are done via [`ArcIntern<Expression>`]s.
63/// Creating an [`ArcIntern`] from an `Expression` will always return the same pointer for the same
64/// expression.
65///
66/// The structural sharing means that equality, cloning, and hashing on Quil expressions are all
67/// very cheap, as they do not need to be recursive: equality of [`ArcIntern`]s is a single-pointer
68/// comparison, cloning of [`ArcIntern`]s is a pointer copy and an atomic increment, and hashing of
69/// [`ArcIntern`]s hashes a single pointer.  It is also very cheap to key [`HashMap`]s by an
70/// [`ArcIntern<Expression>`], which can allow for cheap memoization of operations on `Expression`s.
71///
72/// The structural sharing also means that Quil expressions are fundamentally *immutable*; it is
73/// *impossible* to get an owned or `&mut` reference to the child `Expression`s of any `Expression`,
74/// as the use of interning means there may be multiple references to that `Expression` at any time.
75///
76/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
77/// NaNs, not unequal, in contravention of the IEEE 754 spec.
78#[derive(Clone, Debug)]
79pub enum Expression {
80    Address(MemoryReference),
81    FunctionCall(FunctionCallExpression),
82    Infix(InfixExpression),
83    Number(Complex64),
84    PiConstant,
85    Prefix(PrefixExpression),
86    Variable(String),
87}
88
89/// The type of function call Quil expressions, e.g. `sin(e)`.
90///
91/// Quil expressions take advantage of *structural sharing*, which is why the `expression` here is
92/// wrapped in an [`ArcIntern`]; for more details, see the documentation for [`Expression`].
93///
94/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
95/// NaNs, not unequal, in contravention of the IEEE 754 spec.
96#[derive(Clone, Debug, PartialEq, Eq, Hash)]
97pub struct FunctionCallExpression {
98    pub function: ExpressionFunction,
99    pub expression: ArcIntern<Expression>,
100}
101
102impl FunctionCallExpression {
103    pub fn new(function: ExpressionFunction, expression: ArcIntern<Expression>) -> Self {
104        Self {
105            function,
106            expression,
107        }
108    }
109}
110
111/// The type of infix Quil expressions, e.g. `e1 + e2`.
112///
113/// Quil expressions take advantage of *structural sharing*, which is why the `left` and `right`
114/// expressions here are wrapped in [`ArcIntern`]s; for more details, see the documentation for
115/// [`Expression`].
116///
117/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
118/// NaNs, not unequal, in contravention of the IEEE 754 spec.
119#[derive(Clone, Debug, PartialEq, Eq, Hash)]
120pub struct InfixExpression {
121    pub left: ArcIntern<Expression>,
122    pub operator: InfixOperator,
123    pub right: ArcIntern<Expression>,
124}
125
126impl InfixExpression {
127    pub fn new(
128        left: ArcIntern<Expression>,
129        operator: InfixOperator,
130        right: ArcIntern<Expression>,
131    ) -> Self {
132        Self {
133            left,
134            operator,
135            right,
136        }
137    }
138}
139
140/// The type of prefix Quil expressions, e.g. `-e`.
141///
142/// Quil expressions take advantage of *structural sharing*, which is why the `expression` here is
143/// wrapped in an [`ArcIntern`]; for more details, see the documentation for [`Expression`].
144///
145/// Note that when comparing Quil expressions, any embedded NaNs are treated as *equal* to other
146/// NaNs, not unequal, in contravention of the IEEE 754 spec.
147#[derive(Clone, Debug, PartialEq, Eq, Hash)]
148pub struct PrefixExpression {
149    pub operator: PrefixOperator,
150    pub expression: ArcIntern<Expression>,
151}
152
153impl PrefixExpression {
154    pub fn new(operator: PrefixOperator, expression: ArcIntern<Expression>) -> Self {
155        Self {
156            operator,
157            expression,
158        }
159    }
160}
161
162impl PartialEq for Expression {
163    // Implemented by hand since we can't derive with f64s hidden inside.
164    fn eq(&self, other: &Self) -> bool {
165        match (self, other) {
166            (Self::Address(left), Self::Address(right)) => left == right,
167            (Self::Infix(left), Self::Infix(right)) => left == right,
168            (Self::Number(left), Self::Number(right)) => {
169                (left.re == right.re || left.re.is_nan() && right.re.is_nan())
170                    && (left.im == right.im || left.im.is_nan() && right.im.is_nan())
171            }
172            (Self::Prefix(left), Self::Prefix(right)) => left == right,
173            (Self::FunctionCall(left), Self::FunctionCall(right)) => left == right,
174            (Self::Variable(left), Self::Variable(right)) => left == right,
175            (Self::PiConstant, Self::PiConstant) => true,
176            _ => false,
177        }
178    }
179}
180
181// Implemented by hand since we can't derive with f64s hidden inside.
182impl Eq for Expression {}
183
184impl Hash for Expression {
185    // Implemented by hand since we can't derive with f64s hidden inside.
186    fn hash<H: Hasher>(&self, state: &mut H) {
187        match self {
188            Self::Address(m) => {
189                "Address".hash(state);
190                m.hash(state);
191            }
192            Self::FunctionCall(FunctionCallExpression {
193                function,
194                expression,
195            }) => {
196                "FunctionCall".hash(state);
197                function.hash(state);
198                expression.hash(state);
199            }
200            Self::Infix(InfixExpression {
201                left,
202                operator,
203                right,
204            }) => {
205                "Infix".hash(state);
206                operator.hash(state);
207                left.hash(state);
208                right.hash(state);
209            }
210            Self::Number(n) => {
211                "Number".hash(state);
212                // Skip zero values (akin to `format_complex`).
213                if n.re.abs() > 0f64 {
214                    hash_f64(n.re, state)
215                }
216                if n.im.abs() > 0f64 {
217                    hash_f64(n.im, state)
218                }
219            }
220            Self::PiConstant => {
221                "PiConstant".hash(state);
222            }
223            Self::Prefix(p) => {
224                "Prefix".hash(state);
225                p.operator.hash(state);
226                p.expression.hash(state);
227            }
228            Self::Variable(v) => {
229                "Variable".hash(state);
230                v.hash(state);
231            }
232        }
233    }
234}
235
236macro_rules! impl_expr_op {
237    ($name:ident, $name_assign:ident, $function:ident, $function_assign:ident, $operator:ident) => {
238        impl $name for Expression {
239            type Output = Self;
240            fn $function(self, other: Self) -> Self {
241                Self::Infix(InfixExpression {
242                    left: ArcIntern::new(self),
243                    operator: InfixOperator::$operator,
244                    right: ArcIntern::new(other),
245                })
246            }
247        }
248
249        impl $name<ArcIntern<Expression>> for Expression {
250            type Output = Self;
251            fn $function(self, other: ArcIntern<Expression>) -> Self {
252                Self::Infix(InfixExpression {
253                    left: ArcIntern::new(self),
254                    operator: InfixOperator::$operator,
255                    right: other,
256                })
257            }
258        }
259
260        impl $name<Expression> for ArcIntern<Expression> {
261            type Output = Expression;
262            fn $function(self, other: Expression) -> Expression {
263                Expression::Infix(InfixExpression {
264                    left: self,
265                    operator: InfixOperator::$operator,
266                    right: ArcIntern::new(other),
267                })
268            }
269        }
270
271        impl $name_assign for Expression {
272            fn $function_assign(&mut self, other: Self) {
273                // Move out of self to avoid potentially cloning a large value
274                let temp = ::std::mem::replace(self, Self::PiConstant);
275                *self = temp.$function(other);
276            }
277        }
278
279        impl $name_assign<ArcIntern<Expression>> for Expression {
280            fn $function_assign(&mut self, other: ArcIntern<Expression>) {
281                // Move out of self to avoid potentially cloning a large value
282                let temp = ::std::mem::replace(self, Self::PiConstant);
283                *self = temp.$function(other);
284            }
285        }
286    };
287}
288
289impl_expr_op!(BitXor, BitXorAssign, bitxor, bitxor_assign, Caret);
290impl_expr_op!(Add, AddAssign, add, add_assign, Plus);
291impl_expr_op!(Sub, SubAssign, sub, sub_assign, Minus);
292impl_expr_op!(Mul, MulAssign, mul, mul_assign, Star);
293impl_expr_op!(Div, DivAssign, div, div_assign, Slash);
294
295impl Neg for Expression {
296    type Output = Self;
297
298    fn neg(self) -> Self {
299        Expression::Prefix(PrefixExpression {
300            operator: PrefixOperator::Minus,
301            expression: ArcIntern::new(self),
302        })
303    }
304}
305
306/// Compute the result of an infix expression where both operands are complex.
307#[inline]
308pub(crate) fn calculate_infix(
309    left: Complex64,
310    operator: InfixOperator,
311    right: Complex64,
312) -> Complex64 {
313    use InfixOperator::*;
314    match operator {
315        Caret => left.powc(right),
316        Plus => left + right,
317        Minus => left - right,
318        Slash => left / right,
319        Star => left * right,
320    }
321}
322
323/// Compute the result of a Quil-defined expression function where the operand is complex.
324#[inline]
325pub(crate) fn calculate_function(function: ExpressionFunction, argument: Complex64) -> Complex64 {
326    use ExpressionFunction::*;
327    match function {
328        Sine => argument.sin(),
329        Cis => argument.cos() + imag!(1f64) * argument.sin(),
330        Cosine => argument.cos(),
331        Exponent => argument.exp(),
332        SquareRoot => argument.sqrt(),
333    }
334}
335
336/// Is this a small floating point number?
337#[inline(always)]
338fn is_small(x: f64) -> bool {
339    x.abs() < 1e-16
340}
341
342impl Expression {
343    /// Simplify the expression as much as possible, in-place.
344    ///
345    /// # Example
346    ///
347    /// ```rust
348    /// use quil_rs::expression::Expression;
349    /// use std::str::FromStr;
350    /// use num_complex::Complex64;
351    ///
352    /// let mut expression = Expression::from_str("cos(2 * pi) + 2").unwrap();
353    /// expression.simplify();
354    ///
355    /// assert_eq!(expression, Expression::Number(Complex64::from(3.0)));
356    /// ```
357    pub fn simplify(&mut self) {
358        match self {
359            Expression::Address(_) | Expression::Number(_) | Expression::Variable(_) => {}
360            Expression::PiConstant => {
361                *self = Expression::Number(Complex64::from(PI));
362            }
363            _ => *self = simplification::run(self),
364        }
365    }
366
367    /// Consume the expression, simplifying it as much as possible.
368    ///
369    /// # Example
370    ///
371    /// ```rust
372    /// use quil_rs::expression::Expression;
373    /// use std::str::FromStr;
374    /// use num_complex::Complex64;
375    ///
376    /// let simplified = Expression::from_str("cos(2 * pi) + 2").unwrap().into_simplified();
377    ///
378    /// assert_eq!(simplified, Expression::Number(Complex64::from(3.0)));
379    /// ```
380    pub fn into_simplified(mut self) -> Self {
381        self.simplify();
382        self
383    }
384
385    /// Evaluate an expression, expecting that it may be fully reduced to a single complex number.
386    /// If it cannot be reduced to a complex number, return an error.
387    ///
388    /// # Example
389    ///
390    /// ```rust
391    /// use quil_rs::expression::Expression;
392    /// use std::str::FromStr;
393    /// use std::collections::HashMap;
394    /// use num_complex::Complex64;
395    ///
396    /// let expression = Expression::from_str("%beta + theta[0]").unwrap();
397    ///
398    /// let mut variables = HashMap::with_capacity(1);
399    /// variables.insert(String::from("beta"), Complex64::from(1.0));
400    ///
401    /// let mut memory_references = HashMap::with_capacity(1);
402    /// memory_references.insert("theta", vec![2.0]);
403    ///
404    /// let evaluated = expression.evaluate(&variables, &memory_references).unwrap();
405    ///
406    /// assert_eq!(evaluated, Complex64::from(3.0))
407    /// ```
408    pub fn evaluate(
409        &self,
410        variables: &HashMap<String, Complex64>,
411        memory_references: &HashMap<&str, Vec<f64>>,
412    ) -> Result<Complex64, EvaluationError> {
413        use Expression::*;
414
415        match self {
416            FunctionCall(FunctionCallExpression {
417                function,
418                expression,
419            }) => {
420                let evaluated = expression.evaluate(variables, memory_references)?;
421                Ok(calculate_function(*function, evaluated))
422            }
423            Infix(InfixExpression {
424                left,
425                operator,
426                right,
427            }) => {
428                let left_evaluated = left.evaluate(variables, memory_references)?;
429                let right_evaluated = right.evaluate(variables, memory_references)?;
430                Ok(calculate_infix(left_evaluated, *operator, right_evaluated))
431            }
432            Prefix(PrefixExpression {
433                operator,
434                expression,
435            }) => {
436                use PrefixOperator::*;
437                let value = expression.evaluate(variables, memory_references)?;
438                if matches!(operator, Minus) {
439                    Ok(-value)
440                } else {
441                    Ok(value)
442                }
443            }
444            Variable(identifier) => match variables.get(identifier.as_str()) {
445                Some(value) => Ok(*value),
446                None => Err(EvaluationError::Incomplete),
447            },
448            Address(memory_reference) => memory_references
449                .get(memory_reference.name.as_str())
450                .and_then(|values| {
451                    let value = values.get(memory_reference.index as usize)?;
452                    Some(real!(*value))
453                })
454                .ok_or(EvaluationError::Incomplete),
455            PiConstant => Ok(real!(PI)),
456            Number(number) => Ok(*number),
457        }
458    }
459
460    /// Substitute an expression in the place of each matching variable.
461    ///
462    /// # Example
463    ///
464    /// ```rust
465    /// use quil_rs::expression::Expression;
466    /// use std::str::FromStr;
467    /// use std::collections::HashMap;
468    /// use num_complex::Complex64;
469    ///
470    /// let expression = Expression::from_str("%x + %y").unwrap();
471    ///
472    /// let mut variables = HashMap::with_capacity(1);
473    /// variables.insert(String::from("x"), Expression::Number(Complex64::from(1.0)));
474    ///
475    /// let evaluated = expression.substitute_variables(&variables);
476    ///
477    /// assert_eq!(evaluated, Expression::from_str("1.0 + %y").unwrap())
478    /// ```
479    pub fn substitute_variables(&self, variable_values: &HashMap<String, Expression>) -> Self {
480        use Expression::*;
481
482        match self {
483            FunctionCall(FunctionCallExpression {
484                function,
485                expression,
486            }) => Expression::FunctionCall(FunctionCallExpression {
487                function: *function,
488                expression: expression.substitute_variables(variable_values).into(),
489            }),
490            Infix(InfixExpression {
491                left,
492                operator,
493                right,
494            }) => {
495                let left = left.substitute_variables(variable_values).into();
496                let right = right.substitute_variables(variable_values).into();
497                Infix(InfixExpression {
498                    left,
499                    operator: *operator,
500                    right,
501                })
502            }
503            Prefix(PrefixExpression {
504                operator,
505                expression,
506            }) => Prefix(PrefixExpression {
507                operator: *operator,
508                expression: expression.substitute_variables(variable_values).into(),
509            }),
510            Variable(identifier) => match variable_values.get(identifier.as_str()) {
511                Some(value) => value.clone(),
512                None => Variable(identifier.clone()),
513            },
514            other => other.clone(),
515        }
516    }
517
518    /// If this is a number with imaginary part "equal to" zero (of _small_ absolute value), return
519    /// that number. Otherwise, error with an evaluation error of a descriptive type.
520    pub fn to_real(&self) -> Result<f64, EvaluationError> {
521        match self {
522            Expression::PiConstant => Ok(PI),
523            Expression::Number(x) if is_small(x.im) => Ok(x.re),
524            Expression::Number(_) => Err(EvaluationError::NumberNotReal),
525            _ => Err(EvaluationError::NotANumber),
526        }
527    }
528}
529
530impl FromStr for Expression {
531    type Err = ParseProgramError<Self>;
532
533    fn from_str(s: &str) -> Result<Self, Self::Err> {
534        let input = LocatedSpan::new(s);
535        let tokens = lex(input)?;
536        disallow_leftover(parse_expression(&tokens).map_err(ParseError::from_nom_internal_err))
537    }
538}
539
540static FORMAT_REAL_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
541    WriteFloatOptions::builder()
542        .negative_exponent_break(NonZeroI32::new(-5))
543        .positive_exponent_break(NonZeroI32::new(15))
544        .trim_floats(true)
545        .build()
546        .expect("options are valid")
547});
548
549static FORMAT_IMAGINARY_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
550    WriteFloatOptions::builder()
551        .negative_exponent_break(NonZeroI32::new(-5))
552        .positive_exponent_break(NonZeroI32::new(15))
553        .trim_floats(false) // Per the quil spec, the imaginary part of a complex number is always a floating point number
554        .build()
555        .expect("options are valid")
556});
557
558/// Format a num_complex::Complex64 value in a way that omits the real or imaginary part when
559/// reasonable. That is:
560///
561/// - When imaginary is set but real is 0, show only imaginary
562/// - When imaginary is 0, show real only
563/// - When both are non-zero, show with the correct operator in between
564#[inline(always)]
565pub(crate) fn format_complex(value: &Complex64) -> String {
566    const FORMAT: u128 = format::STANDARD;
567    if value.re == 0f64 && value.im == 0f64 {
568        "0".to_owned()
569    } else if value.im == 0f64 {
570        to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS)
571    } else if value.re == 0f64 {
572        to_string_with_options::<_, FORMAT>(value.im, &FORMAT_IMAGINARY_OPTIONS) + "i"
573    } else {
574        let mut out = to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS);
575        if value.im > 0f64 {
576            out.push('+')
577        }
578        out.push_str(&to_string_with_options::<_, FORMAT>(
579            value.im,
580            &FORMAT_IMAGINARY_OPTIONS,
581        ));
582        out.push('i');
583        out
584    }
585}
586
587impl Quil for Expression {
588    fn write(
589        &self,
590        f: &mut impl std::fmt::Write,
591        fall_back_to_debug: bool,
592    ) -> Result<(), crate::quil::ToQuilError> {
593        use Expression::*;
594        match self {
595            Address(memory_reference) => memory_reference.write(f, fall_back_to_debug),
596            FunctionCall(FunctionCallExpression {
597                function,
598                expression,
599            }) => {
600                write!(f, "{function}(")?;
601                expression.write(f, fall_back_to_debug)?;
602                write!(f, ")")?;
603                Ok(())
604            }
605            Infix(InfixExpression {
606                left,
607                operator,
608                right,
609            }) => {
610                format_inner_expression(f, fall_back_to_debug, left)?;
611                write!(f, "{}", operator)?;
612                format_inner_expression(f, fall_back_to_debug, right)
613            }
614            Number(value) => write!(f, "{}", format_complex(value)).map_err(Into::into),
615            PiConstant => write!(f, "pi").map_err(Into::into),
616            Prefix(PrefixExpression {
617                operator,
618                expression,
619            }) => {
620                write!(f, "{}", operator)?;
621                format_inner_expression(f, fall_back_to_debug, expression)
622            }
623            Variable(identifier) => write!(f, "%{}", identifier).map_err(Into::into),
624        }
625    }
626}
627
628/// Utility function to wrap infix expressions that are part of an expression in parentheses, so
629/// that correct precedence rules are enforced.
630fn format_inner_expression(
631    f: &mut impl std::fmt::Write,
632    fall_back_to_debug: bool,
633    expression: &Expression,
634) -> crate::quil::ToQuilResult<()> {
635    match expression {
636        Expression::Infix(InfixExpression {
637            left,
638            operator,
639            right,
640        }) => {
641            write!(f, "(")?;
642            format_inner_expression(f, fall_back_to_debug, left)?;
643            write!(f, "{operator}")?;
644            format_inner_expression(f, fall_back_to_debug, right)?;
645            write!(f, ")")?;
646            Ok(())
647        }
648        _ => expression.write(f, fall_back_to_debug),
649    }
650}
651
652#[cfg(test)]
653mod test {
654    use crate::{
655        expression::{
656            Expression, InfixExpression, InfixOperator, PrefixExpression, PrefixOperator,
657        },
658        quil::Quil,
659        real,
660    };
661
662    use internment::ArcIntern;
663
664    #[test]
665    fn formats_nested_expression() {
666        let expression = Expression::Infix(InfixExpression {
667            left: ArcIntern::new(Expression::Prefix(PrefixExpression {
668                operator: PrefixOperator::Minus,
669                expression: ArcIntern::new(Expression::Number(real!(3f64))),
670            })),
671            operator: InfixOperator::Star,
672            right: ArcIntern::new(Expression::Infix(InfixExpression {
673                left: ArcIntern::new(Expression::PiConstant),
674                operator: InfixOperator::Slash,
675                right: ArcIntern::new(Expression::Number(real!(2f64))),
676            })),
677        });
678
679        assert_eq!(expression.to_quil_or_debug(), "-3*(pi/2)");
680    }
681}
682
683/// A function defined within Quil syntax.
684#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
685#[cfg_attr(test, derive(Arbitrary))]
686pub enum ExpressionFunction {
687    Cis,
688    Cosine,
689    Exponent,
690    Sine,
691    SquareRoot,
692}
693
694impl fmt::Display for ExpressionFunction {
695    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
696        use ExpressionFunction::*;
697        write!(
698            f,
699            "{}",
700            match self {
701                Cis => "cis",
702                Cosine => "cos",
703                Exponent => "exp",
704                Sine => "sin",
705                SquareRoot => "sqrt",
706            }
707        )
708    }
709}
710
711#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
712#[cfg_attr(test, derive(Arbitrary))]
713pub enum PrefixOperator {
714    Plus,
715    Minus,
716}
717
718impl fmt::Display for PrefixOperator {
719    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
720        use PrefixOperator::*;
721        write!(
722            f,
723            "{}",
724            match self {
725                // NOTE: prefix Plus does nothing but cause parsing issues
726                Plus => "",
727                Minus => "-",
728            }
729        )
730    }
731}
732
733#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
734#[cfg_attr(test, derive(Arbitrary))]
735pub enum InfixOperator {
736    Caret,
737    Plus,
738    Minus,
739    Slash,
740    Star,
741}
742
743impl fmt::Display for InfixOperator {
744    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
745        use InfixOperator::*;
746        write!(
747            f,
748            "{}",
749            match self {
750                Caret => "^",
751                Plus => "+",
752                // NOTE: spaces included to distinguish from hyphenated identifiers
753                Minus => " - ",
754                Slash => "/",
755                Star => "*",
756            }
757        )
758    }
759}
760
761/// Convenience constructors for creating [`ArcIntern<Expression>`]s out of other
762/// [`ArcIntern<Expression>`]s.
763pub mod interned {
764    use super::*;
765
766    macro_rules! atoms {
767        ($($func:ident: $ctor:ident$(($typ:ty))?),+ $(,)?) => {
768            $(
769                #[doc = concat!(
770                    "A wrapper around [`Expression::",
771                    stringify!($ctor),
772                    "`] that returns an [`ArcIntern<Expression>`]."
773                )]
774                #[inline(always)]
775                pub fn $func($(value: $typ)?) -> ArcIntern<Expression> {
776                    // Using `std::convert::identity` lets us refer to `$typ` and thus make the
777                    // constructor argument optional
778                    ArcIntern::new(Expression::$ctor$((std::convert::identity::<$typ>(value)))?)
779                }
780            )+
781        };
782    }
783
784    macro_rules! expression_wrappers {
785        ($($func:ident: $atom:ident($ctor:ident { $($field:ident: $field_ty:ty),*$(,)? })),+ $(,)?) => {
786            paste::paste! { $(
787                #[doc = concat!(
788                    "A wrapper around [`Expression::", stringify!([<$func:camel>]), "`] ",
789                    "that takes the contents of the inner expression type as arguments directly ",
790                    "and returns an [`ArcIntern<Expression>`].",
791                    "\n\n",
792                    "See also [`", stringify!($atom), "`].",
793                )]
794                #[inline(always)]
795                pub fn $func($($field: $field_ty),*) -> ArcIntern<Expression> {
796                    $atom($ctor { $($field),* })
797                }
798            )+ }
799        };
800    }
801
802    macro_rules! function_wrappers {
803        ($($func:ident: $ctor:ident),+ $(,)?) => {
804            $(
805                #[doc = concat!(
806                    "Create an <code>[ArcIntern]&lt;[Expression]&gt;</code> representing ",
807                    "`", stringify!($func), "(expression)`.",
808                    "\n\n",
809                    "A wrapper around [`Expression::FunctionCall`] with ",
810                    "[`ExpressionFunction::", stringify!($ctor), "`].",
811                )]
812                #[inline(always)]
813                pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
814                    function_call(ExpressionFunction::$ctor, expression)
815                }
816            )+
817        };
818    }
819
820    macro_rules! infix_wrappers {
821        ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
822            $(
823                #[doc = concat!(
824                    "Create an <code>[ArcIntern]&lt;[Expression]&gt;</code> representing ",
825                    "`left ", stringify!($op), " right`.",
826                    "\n\n",
827                    "A wrapper around [`Expression::Infix`] with ",
828                    "[`InfixOperator::", stringify!($ctor), "`].",
829                )]
830                #[inline(always)]
831                pub fn $func(
832                    left: ArcIntern<Expression>,
833                    right: ArcIntern<Expression>,
834                ) -> ArcIntern<Expression> {
835                    infix(left, InfixOperator::$ctor, right)
836                }
837            )+
838        };
839    }
840
841    macro_rules! prefix_wrappers {
842        ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
843            $(
844                #[doc = concat!(
845                    "Create an <code>[ArcIntern]&lt;[Expression]&gt;</code> representing ",
846                    "`", stringify!($op), "expression`.",
847                    "\n\n",
848                    "A wrapper around [`Expression::Prefix`] with ",
849                    "[`PrefixOperator::", stringify!($ctor), "`].",
850                )]
851                #[inline(always)]
852                pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
853                    prefix(PrefixOperator::$ctor, expression)
854                }
855            )+
856        };
857    }
858
859    atoms! {
860        function_call_expr: FunctionCall(FunctionCallExpression),
861        infix_expr: Infix(InfixExpression),
862        pi: PiConstant,
863        number: Number(Complex64),
864        prefix_expr: Prefix(PrefixExpression),
865        variable: Variable(String),
866    }
867
868    expression_wrappers! {
869        function_call: function_call_expr(FunctionCallExpression {
870            function: ExpressionFunction,
871            expression: ArcIntern<Expression>,
872        }),
873
874        infix: infix_expr(InfixExpression {
875            left: ArcIntern<Expression>,
876            operator: InfixOperator,
877            right: ArcIntern<Expression>,
878        }),
879
880        prefix: prefix_expr(PrefixExpression {
881            operator: PrefixOperator,
882            expression: ArcIntern<Expression>,
883        }),
884    }
885
886    function_wrappers! {
887        cis: Cis,
888        cos: Cosine,
889        exp: Exponent,
890        sin: Sine,
891        sqrt: SquareRoot,
892    }
893
894    infix_wrappers! {
895        add: Plus (+),
896        sub: Minus (-),
897        mul: Star (*),
898        div: Slash (/),
899        pow: Caret (^),
900    }
901
902    prefix_wrappers! {
903        unary_plus: Plus (+),
904        neg: Minus (-),
905    }
906}
907
908#[cfg(test)]
909// This lint should be re-enabled once this proptest issue is resolved
910// https://github.com/proptest-rs/proptest/issues/364
911#[allow(clippy::arc_with_non_send_sync)]
912mod tests {
913    use super::*;
914    use crate::reserved::ReservedToken;
915    use proptest::prelude::*;
916    use std::collections::hash_map::DefaultHasher;
917    use std::collections::HashSet;
918
919    /// Hash value helper: turn a hashable thing into a u64.
920    #[inline]
921    fn hash_to_u64<T: Hash>(t: &T) -> u64 {
922        let mut s = DefaultHasher::new();
923        t.hash(&mut s);
924        s.finish()
925    }
926
927    #[test]
928    fn simplify_and_evaluate() {
929        use Expression::*;
930
931        let one = real!(1.0);
932        let empty_variables = HashMap::new();
933
934        let mut variables = HashMap::new();
935        variables.insert("foo".to_owned(), real!(10f64));
936        variables.insert("bar".to_owned(), real!(100f64));
937
938        let empty_memory = HashMap::new();
939
940        let mut memory_references = HashMap::new();
941        memory_references.insert("theta", vec![1.0, 2.0]);
942        memory_references.insert("beta", vec![3.0, 4.0]);
943
944        struct TestCase<'a> {
945            expression: Expression,
946            variables: &'a HashMap<String, Complex64>,
947            memory_references: &'a HashMap<&'a str, Vec<f64>>,
948            simplified: Expression,
949            evaluated: Result<Complex64, EvaluationError>,
950        }
951
952        let cases: Vec<TestCase> = vec![
953            TestCase {
954                expression: Number(one),
955                variables: &empty_variables,
956                memory_references: &empty_memory,
957                simplified: Number(one),
958                evaluated: Ok(one),
959            },
960            TestCase {
961                expression: Expression::Prefix(PrefixExpression {
962                    operator: PrefixOperator::Minus,
963                    expression: ArcIntern::new(Number(real!(1f64))),
964                }),
965                variables: &empty_variables,
966                memory_references: &empty_memory,
967                simplified: Number(real!(-1f64)),
968                evaluated: Ok(real!(-1f64)),
969            },
970            TestCase {
971                expression: Expression::Variable("foo".to_owned()),
972                variables: &variables,
973                memory_references: &empty_memory,
974                simplified: Expression::Variable("foo".to_owned()),
975                evaluated: Ok(real!(10f64)),
976            },
977            TestCase {
978                expression: Expression::from_str("%foo + %bar").unwrap(),
979                variables: &variables,
980                memory_references: &empty_memory,
981                simplified: Expression::from_str("%foo + %bar").unwrap(),
982                evaluated: Ok(real!(110f64)),
983            },
984            TestCase {
985                expression: Expression::FunctionCall(FunctionCallExpression {
986                    function: ExpressionFunction::Sine,
987                    expression: ArcIntern::new(Expression::Number(real!(PI / 2f64))),
988                }),
989                variables: &variables,
990                memory_references: &empty_memory,
991                simplified: Number(real!(1f64)),
992                evaluated: Ok(real!(1f64)),
993            },
994            TestCase {
995                expression: Expression::from_str("theta[1] * beta[0]").unwrap(),
996                variables: &empty_variables,
997                memory_references: &memory_references,
998                simplified: Expression::from_str("theta[1] * beta[0]").unwrap(),
999                evaluated: Ok(real!(6.0)),
1000            },
1001        ];
1002
1003        for mut case in cases {
1004            let evaluated = case
1005                .expression
1006                .evaluate(case.variables, case.memory_references);
1007            assert_eq!(evaluated, case.evaluated);
1008
1009            case.expression.simplify();
1010            assert_eq!(case.expression, case.simplified);
1011        }
1012    }
1013
1014    /// Parenthesized version of [`Expression::to_string()`]
1015    fn parenthesized(expression: &Expression) -> String {
1016        use Expression::*;
1017        match expression {
1018            Address(memory_reference) => memory_reference.to_quil_or_debug(),
1019            FunctionCall(FunctionCallExpression {
1020                function,
1021                expression,
1022            }) => format!("({function}({}))", parenthesized(expression)),
1023            Infix(InfixExpression {
1024                left,
1025                operator,
1026                right,
1027            }) => format!(
1028                "({}{}{})",
1029                parenthesized(left),
1030                operator,
1031                parenthesized(right)
1032            ),
1033            Number(value) => format!("({})", format_complex(value)),
1034            PiConstant => "pi".to_string(),
1035            Prefix(PrefixExpression {
1036                operator,
1037                expression,
1038            }) => format!("({}{})", operator, parenthesized(expression)),
1039            Variable(identifier) => format!("(%{identifier})"),
1040        }
1041    }
1042
1043    // Better behaved than the auto-derived version for names
1044    fn arb_name() -> impl Strategy<Value = String> {
1045        r"[a-z][a-zA-Z0-9]{1,10}".prop_filter("Exclude reserved tokens", |t| {
1046            ReservedToken::from_str(t).is_err() && !t.to_lowercase().starts_with("nan")
1047        })
1048    }
1049
1050    // Better behaved than the auto-derived version re: names & indices
1051    fn arb_memory_reference() -> impl Strategy<Value = MemoryReference> {
1052        (arb_name(), (u64::MIN..u32::MAX as u64))
1053            .prop_map(|(name, index)| MemoryReference { name, index })
1054    }
1055
1056    // Better behaved than the auto-derived version via arbitrary floats
1057    fn arb_complex64() -> impl Strategy<Value = Complex64> {
1058        let tau = std::f64::consts::TAU;
1059        ((-tau..tau), (-tau..tau)).prop_map(|(re, im)| Complex64 { re, im })
1060    }
1061
1062    /// Filter an Expression to not be constantly zero.
1063    fn nonzero(strat: impl Strategy<Value = Expression>) -> impl Strategy<Value = Expression> {
1064        strat.prop_filter("Exclude constantly-zero expressions", |expr| {
1065            expr.clone().into_simplified() != Expression::Number(Complex64::new(0.0, 0.0))
1066        })
1067    }
1068
1069    /// Generate an arbitrary Expression for a property test.
1070    /// See https://docs.rs/proptest/1.0.0/proptest/prelude/trait.Strategy.html#method.prop_recursive
1071    fn arb_expr() -> impl Strategy<Value = Expression> {
1072        use Expression::*;
1073        let leaf = prop_oneof![
1074            arb_memory_reference().prop_map(Address),
1075            arb_complex64().prop_map(Number),
1076            Just(PiConstant),
1077            arb_name().prop_map(Variable),
1078        ];
1079        leaf.prop_recursive(
1080            4,  // No more than 4 branch levels deep
1081            64, // Target around 64 total nodes
1082            16, // Each "collection" is up to 16 elements
1083            |expr| {
1084                let inner = expr.clone();
1085                prop_oneof![
1086                    (any::<ExpressionFunction>(), expr.clone()).prop_map(|(function, e)| {
1087                        Expression::FunctionCall(FunctionCallExpression {
1088                            function,
1089                            expression: ArcIntern::new(e),
1090                        })
1091                    }),
1092                    (expr.clone(), any::<InfixOperator>())
1093                        .prop_flat_map(move |(left, operator)| (
1094                            Just(left),
1095                            Just(operator),
1096                            // Avoid division by 0 so that we can reliably assert equality
1097                            if let InfixOperator::Slash = operator {
1098                                nonzero(inner.clone()).boxed()
1099                            } else {
1100                                inner.clone().boxed()
1101                            }
1102                        ))
1103                        .prop_map(|(l, operator, r)| Infix(InfixExpression {
1104                            left: ArcIntern::new(l),
1105                            operator,
1106                            right: ArcIntern::new(r)
1107                        })),
1108                    expr.prop_map(|e| Prefix(PrefixExpression {
1109                        operator: PrefixOperator::Minus,
1110                        expression: ArcIntern::new(e)
1111                    }))
1112                ]
1113            },
1114        )
1115    }
1116
1117    proptest! {
1118
1119        #[test]
1120        fn eq(a in any::<f64>(), b in any::<f64>()) {
1121            let first = Expression::Infix (InfixExpression {
1122                left: ArcIntern::new(Expression::Number(real!(a))),
1123                operator: InfixOperator::Plus,
1124                right: ArcIntern::new(Expression::Number(real!(b))),
1125            } );
1126            let differing = Expression::Number(real!(a + b));
1127            prop_assert_eq!(&first, &first);
1128            prop_assert_ne!(&first, &differing);
1129        }
1130
1131        #[test]
1132        fn hash(a in any::<f64>(), b in any::<f64>()) {
1133            let first = Expression::Infix (InfixExpression {
1134                left: ArcIntern::new(Expression::Number(real!(a))),
1135                operator: InfixOperator::Plus,
1136                right: ArcIntern::new(Expression::Number(real!(b))),
1137            });
1138            let matching = first.clone();
1139            let differing = Expression::Number(real!(a + b));
1140            let mut set = HashSet::new();
1141            set.insert(first);
1142            assert!(set.contains(&matching));
1143            assert!(!set.contains(&differing))
1144        }
1145
1146        #[test]
1147        fn eq_iff_hash_eq(x in arb_expr(), y in arb_expr()) {
1148            prop_assert_eq!(x == y, hash_to_u64(&x) == hash_to_u64(&y));
1149        }
1150
1151        #[test]
1152        fn reals_are_real(x in any::<f64>()) {
1153            prop_assert_eq!(Expression::Number(real!(x)).to_real(), Ok(x))
1154        }
1155
1156        #[test]
1157        fn some_nums_are_real(re in any::<f64>(), im in any::<f64>()) {
1158            let result = Expression::Number(Complex64{re, im}).to_real();
1159            if is_small(im) {
1160                prop_assert_eq!(result, Ok(re))
1161            } else {
1162                prop_assert_eq!(result, Err(EvaluationError::NumberNotReal))
1163            }
1164        }
1165
1166        #[test]
1167        fn no_other_exps_are_real(expr in arb_expr().prop_filter("Not numbers", |e| !matches!(e, Expression::Number(_) | Expression::PiConstant))) {
1168            prop_assert_eq!(expr.to_real(), Err(EvaluationError::NotANumber))
1169        }
1170
1171        #[test]
1172        fn complexes_are_parseable_as_expressions(value in arb_complex64()) {
1173            let parsed = Expression::from_str(&format_complex(&value));
1174            assert!(parsed.is_ok());
1175            let simple = parsed.unwrap().into_simplified();
1176            assert_eq!(Expression::Number(value), simple);
1177        }
1178
1179        #[test]
1180        fn exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
1181            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1182            prop_assert_eq!(left ^ right, expected);
1183        }
1184
1185        #[test]
1186        fn in_place_exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
1187            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1188            let mut x = left;
1189            x ^= right;
1190            prop_assert_eq!(x, expected);
1191        }
1192
1193        #[test]
1194        fn addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
1195            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1196            prop_assert_eq!(left + right, expected);
1197        }
1198
1199        #[test]
1200        fn in_place_addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
1201            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1202            let mut x = left;
1203            x += right;
1204            prop_assert_eq!(x, expected);
1205        }
1206
1207        #[test]
1208        fn subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
1209            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1210            prop_assert_eq!(left - right, expected);
1211        }
1212
1213        #[test]
1214        fn in_place_subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
1215            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1216            let mut x = left;
1217            x -= right;
1218            prop_assert_eq!(x, expected);
1219        }
1220
1221        #[test]
1222        fn multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
1223            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1224            prop_assert_eq!(left * right, expected);
1225        }
1226
1227        #[test]
1228        fn in_place_multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
1229            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1230            let mut x = left;
1231            x *= right;
1232            prop_assert_eq!(x, expected);
1233        }
1234
1235
1236        // Avoid division by 0 so that we can reliably assert equality
1237        #[test]
1238        fn division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1239            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1240            prop_assert_eq!(left / right, expected);
1241        }
1242
1243        // Avoid division by 0 so that we can reliably assert equality
1244        #[test]
1245        fn in_place_division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1246            let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1247            let mut x = left;
1248            x /= right;
1249            prop_assert_eq!(x, expected);
1250        }
1251
1252        // Redundant clone: clippy does not correctly introspect the prop_assert_eq! macro
1253        #[allow(clippy::redundant_clone)]
1254        #[test]
1255        fn round_trip(e in arb_expr()) {
1256            let simple_e = e.clone().into_simplified();
1257            let s = parenthesized(&e);
1258            let p = Expression::from_str(&s);
1259            prop_assert!(p.is_ok());
1260            let p = p.unwrap();
1261            let simple_p = p.clone().into_simplified();
1262
1263            prop_assert_eq!(
1264                &simple_p,
1265                &simple_e,
1266                "Simplified expressions should be equal:\nparenthesized {p} ({p:?}) extracted from {s} simplified to {simple_p}\nvs original {e} ({e:?}) simplified to {simple_e}",
1267                p=p.to_quil_or_debug(),
1268                s=s,
1269                e=e.to_quil_or_debug(),
1270                simple_p=simple_p.to_quil_or_debug(),
1271                simple_e=simple_e.to_quil_or_debug()
1272            );
1273        }
1274
1275    }
1276
1277    /// Assert that certain selected expressions are parsed and re-written to string
1278    /// in exactly the same way.
1279    #[test]
1280    fn specific_round_trip_tests() {
1281        for input in &[
1282            "-1*(phases[0]+phases[1])",
1283            "(-1*(phases[0]+phases[1]))+(-1*(phases[0]+phases[1]))",
1284        ] {
1285            let parsed = Expression::from_str(input);
1286            let parsed = parsed.unwrap();
1287            let restring = parsed.to_quil_or_debug();
1288            assert_eq!(input, &restring);
1289        }
1290    }
1291
1292    #[test]
1293    fn test_nan_is_equal() {
1294        let left = Expression::Number(f64::NAN.into());
1295        let right = left.clone();
1296        assert_eq!(left, right);
1297    }
1298
1299    #[test]
1300    fn specific_simplification_tests() {
1301        for (input, expected) in [
1302            ("pi", Expression::Number(PI.into())),
1303            ("pi/2", Expression::Number((PI / 2.0).into())),
1304            ("pi * pi", Expression::Number((PI.powi(2)).into())),
1305            ("1.0/(1.0-1.0)", Expression::Number(f64::NAN.into())),
1306            (
1307                "(a[0]*2*pi)/6.283185307179586",
1308                Expression::Address(MemoryReference {
1309                    name: String::from("a"),
1310                    index: 0,
1311                }),
1312            ),
1313        ] {
1314            assert_eq!(
1315                Expression::from_str(input).unwrap().into_simplified(),
1316                expected
1317            )
1318        }
1319    }
1320
1321    #[test]
1322    fn specific_to_real_tests() {
1323        for (input, expected) in [
1324            (Expression::PiConstant, Ok(PI)),
1325            (Expression::Number(Complex64 { re: 1.0, im: 0.0 }), Ok(1.0)),
1326            (
1327                Expression::Number(Complex64 { re: 1.0, im: 1.0 }),
1328                Err(EvaluationError::NumberNotReal),
1329            ),
1330            (
1331                Expression::Variable("Not a number".into()),
1332                Err(EvaluationError::NotANumber),
1333            ),
1334        ] {
1335            assert_eq!(input.to_real(), expected)
1336        }
1337    }
1338
1339    #[test]
1340    fn specific_format_complex_tests() {
1341        for (x, s) in &[
1342            (Complex64::new(0.0, 0.0), "0"),
1343            (Complex64::new(-0.0, 0.0), "0"),
1344            (Complex64::new(-0.0, -0.0), "0"),
1345            (Complex64::new(0.0, 1.0), "1.0i"),
1346            (Complex64::new(1.0, -1.0), "1-1.0i"),
1347            (Complex64::new(1.234, 0.0), "1.234"),
1348            (Complex64::new(0.0, 1.234), "1.234i"),
1349            (Complex64::new(-1.234, 0.0), "-1.234"),
1350            (Complex64::new(0.0, -1.234), "-1.234i"),
1351            (Complex64::new(1.234, 5.678), "1.234+5.678i"),
1352            (Complex64::new(-1.234, 5.678), "-1.234+5.678i"),
1353            (Complex64::new(1.234, -5.678), "1.234-5.678i"),
1354            (Complex64::new(-1.234, -5.678), "-1.234-5.678i"),
1355            (Complex64::new(1e100, 2e-100), "1e100+2.0e-100i"),
1356        ] {
1357            assert_eq!(format_complex(x), *s);
1358        }
1359    }
1360}