use crate::polynomial::DensePolynomial;
use std::{
iter::Sum,
ops::{Add, Div, Mul, Sub},
prelude::v1::*,
};
use zkp_primefield::FieldElement;
#[derive(Clone)]
#[cfg_attr(feature = "std", derive(Debug))]
pub enum RationalExpression {
X,
Constant(FieldElement),
Trace(usize, isize),
Polynomial(DensePolynomial, Box<RationalExpression>),
Add(Box<RationalExpression>, Box<RationalExpression>),
Neg(Box<RationalExpression>),
Mul(Box<RationalExpression>, Box<RationalExpression>),
Inv(Box<RationalExpression>),
Exp(Box<RationalExpression>, usize),
}
impl RationalExpression {
pub fn neg(&self) -> Self {
Self::Neg(Box::new(self.clone()))
}
pub fn inv(&self) -> Self {
Self::Inv(Box::new(self.clone()))
}
pub fn pow(&self, exponent: usize) -> Self {
Self::Exp(Box::new(self.clone()), exponent)
}
}
impl From<i32> for RationalExpression {
fn from(value: i32) -> Self {
Self::Constant(value.into())
}
}
impl From<&FieldElement> for RationalExpression {
fn from(value: &FieldElement) -> Self {
Self::Constant(value.clone())
}
}
impl Add for RationalExpression {
type Output = Self;
fn add(self, other: Self) -> Self {
Self::Add(Box::new(self), Box::new(other))
}
}
#[allow(clippy::suspicious_arithmetic_impl)]
impl Sub for RationalExpression {
type Output = Self;
fn sub(self, other: Self) -> Self {
self + other.neg()
}
}
impl Mul for RationalExpression {
type Output = Self;
fn mul(self, other: Self) -> Self {
Self::Mul(Box::new(self), Box::new(other))
}
}
#[allow(clippy::suspicious_arithmetic_impl)]
impl Div for RationalExpression {
type Output = Self;
fn div(self, other: Self) -> Self {
self * other.inv()
}
}
impl Sum<RationalExpression> for RationalExpression {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = RationalExpression>,
{
if let Some(expr) = iter.next() {
iter.fold(expr, |a, b| a + b)
} else {
0.into()
}
}
}
impl RationalExpression {
pub fn degree(&self, trace_degree: usize) -> (usize, usize) {
self.degree_impl(1, trace_degree)
}
pub fn trace_degree(&self) -> (usize, usize) {
self.degree_impl(0, 1)
}
fn degree_impl(&self, x_degree: usize, trace_degree: usize) -> (usize, usize) {
use RationalExpression::*;
match self {
X => (x_degree, 0),
Constant(_) => (0, 0),
Trace(..) => (trace_degree, 0),
Polynomial(p, a) => {
let (n, d) = a.degree_impl(x_degree, trace_degree);
(p.degree() * n, p.degree() * d)
}
Add(a, b) => {
let (an, ad) = a.degree_impl(x_degree, trace_degree);
let (bn, bd) = b.degree_impl(x_degree, trace_degree);
assert!(ad == 0);
assert!(bd == 0);
(std::cmp::max(an, bn), 0)
}
Neg(a) => a.degree_impl(x_degree, trace_degree),
Mul(a, b) => {
let (an, ad) = a.degree_impl(x_degree, trace_degree);
let (bn, bd) = b.degree_impl(x_degree, trace_degree);
(an + bn, ad + bd)
}
Inv(a) => {
let (n, d) = a.degree_impl(x_degree, trace_degree);
(d, n)
}
Exp(a, e) => {
let (n, d) = a.degree_impl(x_degree, trace_degree);
(e * n, e * d)
}
}
}
pub fn evaluate(
&self,
x: &FieldElement,
trace: &dyn Fn(usize, isize) -> FieldElement,
) -> FieldElement {
use RationalExpression::*;
match self {
X => x.clone(),
Constant(c) => c.clone(),
&Trace(i, j) => trace(i, j),
Polynomial(p, a) => p.evaluate(&a.evaluate(x, trace)),
Add(a, b) => a.evaluate(x, trace) + b.evaluate(x, trace),
Neg(a) => -&a.evaluate(x, trace),
Mul(a, b) => a.evaluate(x, trace) * b.evaluate(x, trace),
Inv(a) => a.evaluate(x, trace).inv().expect("divided by zero"),
Exp(a, e) => a.evaluate(x, trace).pow(*e),
}
}
}