use regex::Regex;
use crate::ast::{
CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, ExtensionValue,
ExtensionValueWithArgs, Name, StaticallyTyped, Type, Value,
};
use crate::entities::SchemaType;
use crate::evaluator;
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
const NUM_DIGITS: u32 = 4;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
struct Decimal {
value: i64,
}
#[allow(clippy::expect_used)]
mod names {
use super::{Name, EXTENSION_NAME};
lazy_static::lazy_static! {
pub static ref DECIMAL_FROM_STR_NAME : Name = Name::parse_unqualified_name(EXTENSION_NAME).expect("should be a valid identifier");
pub static ref LESS_THAN : Name = Name::parse_unqualified_name("lessThan").expect("should be a valid identifier");
pub static ref LESS_THAN_OR_EQUAL : Name = Name::parse_unqualified_name("lessThanOrEqual").expect("should be a valid identifier");
pub static ref GREATER_THAN : Name = Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier");
pub static ref GREATER_THAN_OR_EQUAL : Name = Name::parse_unqualified_name("greaterThanOrEqual").expect("should be a valid identifier");
}
}
#[derive(Debug, Error)]
enum Error {
#[error("input string is not a well-formed decimal value: {0}")]
FailedParse(String),
#[error("too many digits after the decimal, we support at most {NUM_DIGITS}: {0}")]
TooManyDigits(String),
#[error("overflow when converting to decimal")]
Overflow,
}
fn checked_mul_pow(x: i64, y: u32) -> Result<i64, Error> {
if let Some(z) = i64::checked_pow(10, y) {
if let Some(w) = i64::checked_mul(x, z) {
return Ok(w);
}
};
Err(Error::Overflow)
}
impl Decimal {
fn typename() -> Name {
names::DECIMAL_FROM_STR_NAME.clone()
}
fn from_str(str: impl AsRef<str>) -> Result<Self, Error> {
#[allow(clippy::unwrap_used)]
let re = Regex::new(r"^(-?\d+)\.(\d+)$").unwrap();
if !re.is_match(str.as_ref()) {
return Err(Error::FailedParse(str.as_ref().to_owned()));
}
let caps = re
.captures(str.as_ref())
.ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?;
let l = caps
.get(1)
.ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
.as_str();
let r = caps
.get(2)
.ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
.as_str();
let l = i64::from_str(l).map_err(|_| Error::Overflow)?;
let l = checked_mul_pow(l, NUM_DIGITS)?;
let len: u32 = r.len().try_into().map_err(|_| Error::Overflow)?;
if NUM_DIGITS < len {
return Err(Error::TooManyDigits(str.as_ref().to_string()));
}
let r = i64::from_str(r).map_err(|_| Error::Overflow)?;
let r = checked_mul_pow(r, NUM_DIGITS - len)?;
if l >= 0 {
l.checked_add(r)
} else {
l.checked_sub(r)
}
.map(|value| Self { value })
.ok_or(Error::Overflow)
}
}
impl std::fmt::Display for Decimal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}.{}",
self.value / i64::pow(10, NUM_DIGITS),
(self.value % i64::pow(10, NUM_DIGITS)).abs()
)
}
}
impl ExtensionValue for Decimal {
fn typename(&self) -> Name {
Self::typename()
}
}
const EXTENSION_NAME: &str = "decimal";
fn extension_err(msg: impl Into<String>) -> evaluator::EvaluationError {
evaluator::EvaluationError::FailedExtensionFunctionApplication {
extension_name: names::DECIMAL_FROM_STR_NAME.clone(),
msg: msg.into(),
}
}
fn decimal_from_str(arg: Value) -> evaluator::Result<ExtensionOutputValue> {
let str = arg.get_as_string()?;
let decimal = Decimal::from_str(str.as_str()).map_err(|e| extension_err(e.to_string()))?;
let function_name = names::DECIMAL_FROM_STR_NAME.clone();
let e = ExtensionValueWithArgs::new(Arc::new(decimal), vec![arg.into()], function_name);
Ok(Value::ExtensionValue(Arc::new(e)).into())
}
fn as_decimal(v: &Value) -> Result<&Decimal, evaluator::EvaluationError> {
match v {
Value::ExtensionValue(ev) if ev.typename() == Decimal::typename() => {
#[allow(clippy::expect_used)]
let d = ev
.value()
.as_any()
.downcast_ref::<Decimal>()
.expect("already typechecked, so this downcast should succeed");
Ok(d)
}
_ => Err(evaluator::EvaluationError::TypeError {
expected: vec![Type::Extension {
name: Decimal::typename(),
}],
actual: v.type_of(),
}),
}
}
fn decimal_lt(left: Value, right: Value) -> evaluator::Result<ExtensionOutputValue> {
let left = as_decimal(&left)?;
let right = as_decimal(&right)?;
Ok(Value::Lit((left < right).into()).into())
}
fn decimal_le(left: Value, right: Value) -> evaluator::Result<ExtensionOutputValue> {
let left = as_decimal(&left)?;
let right = as_decimal(&right)?;
Ok(Value::Lit((left <= right).into()).into())
}
fn decimal_gt(left: Value, right: Value) -> evaluator::Result<ExtensionOutputValue> {
let left = as_decimal(&left)?;
let right = as_decimal(&right)?;
Ok(Value::Lit((left > right).into()).into())
}
fn decimal_ge(left: Value, right: Value) -> evaluator::Result<ExtensionOutputValue> {
let left = as_decimal(&left)?;
let right = as_decimal(&right)?;
Ok(Value::Lit((left >= right).into()).into())
}
pub fn extension() -> Extension {
let decimal_type = SchemaType::Extension {
name: Decimal::typename(),
};
Extension::new(
names::DECIMAL_FROM_STR_NAME.clone(),
vec![
ExtensionFunction::unary(
names::DECIMAL_FROM_STR_NAME.clone(),
CallStyle::FunctionStyle,
Box::new(decimal_from_str),
decimal_type.clone(),
Some(SchemaType::String),
),
ExtensionFunction::binary(
names::LESS_THAN.clone(),
CallStyle::MethodStyle,
Box::new(decimal_lt),
SchemaType::Bool,
(Some(decimal_type.clone()), Some(decimal_type.clone())),
),
ExtensionFunction::binary(
names::LESS_THAN_OR_EQUAL.clone(),
CallStyle::MethodStyle,
Box::new(decimal_le),
SchemaType::Bool,
(Some(decimal_type.clone()), Some(decimal_type.clone())),
),
ExtensionFunction::binary(
names::GREATER_THAN.clone(),
CallStyle::MethodStyle,
Box::new(decimal_gt),
SchemaType::Bool,
(Some(decimal_type.clone()), Some(decimal_type.clone())),
),
ExtensionFunction::binary(
names::GREATER_THAN_OR_EQUAL.clone(),
CallStyle::MethodStyle,
Box::new(decimal_ge),
SchemaType::Bool,
(Some(decimal_type.clone()), Some(decimal_type)),
),
],
)
}
#[cfg(test)]
#[allow(clippy::panic)]
mod tests {
use super::*;
use crate::ast::{Expr, Type, Value};
use crate::evaluator::test::{basic_entities, basic_request};
use crate::evaluator::Evaluator;
use crate::extensions::Extensions;
use crate::parser::parse_expr;
fn assert_decimal_err<T>(res: evaluator::Result<T>) {
match res {
Err(evaluator::EvaluationError::FailedExtensionFunctionApplication {
extension_name,
msg,
}) => {
println!("{msg}");
assert_eq!(
extension_name,
Name::parse_unqualified_name("decimal").expect("should be a valid identifier")
)
}
Err(e) => panic!("Expected an decimal ExtensionErr, got {:?}", e),
Ok(_) => panic!("Expected an decimal ExtensionErr, got Ok"),
}
}
fn assert_decimal_valid(res: evaluator::Result<Value>) {
match res {
Ok(Value::ExtensionValue(ev)) => {
assert_eq!(ev.typename(), Decimal::typename())
}
Ok(v) => panic!("Expected decimal ExtensionValue, got {:?}", v),
Err(e) => panic!("Expected Ok, got Err: {:?}", e),
}
}
#[test]
fn constructors() {
let ext = extension();
assert!(ext
.get_func(
&Name::parse_unqualified_name("decimal").expect("should be a valid identifier")
)
.expect("function should exist")
.is_constructor());
assert!(!ext
.get_func(
&Name::parse_unqualified_name("lessThan").expect("should be a valid identifier")
)
.expect("function should exist")
.is_constructor());
assert!(!ext
.get_func(
&Name::parse_unqualified_name("lessThanOrEqual")
.expect("should be a valid identifier")
)
.expect("function should exist")
.is_constructor());
assert!(!ext
.get_func(
&Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier")
)
.expect("function should exist")
.is_constructor());
assert!(!ext
.get_func(
&Name::parse_unqualified_name("greaterThanOrEqual")
.expect("should be a valid identifier")
)
.expect("function should exist")
.is_constructor(),);
}
#[test]
fn decimal_creation() {
let ext_array = [extension()];
let exts = Extensions::specific_extensions(&ext_array);
let request = basic_request();
let entities = basic_entities();
let eval = Evaluator::new(&request, &entities, &exts).unwrap();
assert_decimal_valid(
eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0")"#).expect("parsing error")),
);
assert_decimal_valid(
eval.interpret_inline_policy(&parse_expr(r#"decimal("-1.0")"#).expect("parsing error")),
);
assert_decimal_valid(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("123.456")"#).expect("parsing error"),
),
);
assert_decimal_valid(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("0.1234")"#).expect("parsing error"),
),
);
assert_decimal_valid(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("-0.0123")"#).expect("parsing error"),
),
);
assert_decimal_valid(
eval.interpret_inline_policy(&parse_expr(r#"decimal("55.1")"#).expect("parsing error")),
);
assert_decimal_valid(eval.interpret_inline_policy(
&parse_expr(r#"decimal("-922337203685477.5808")"#).expect("parsing error"),
));
assert_decimal_valid(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("00.000")"#).expect("parsing error"),
),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal("1234")"#).expect("parsing error")),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0.")"#).expect("parsing error")),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal("1.")"#).expect("parsing error")),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal(".1")"#).expect("parsing error")),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal("1.a")"#).expect("parsing error")),
);
assert_decimal_err(
eval.interpret_inline_policy(&parse_expr(r#"decimal("-.")"#).expect("parsing error")),
);
assert_decimal_err(eval.interpret_inline_policy(
&parse_expr(r#"decimal("1000000000000000.0")"#).expect("parsing error"),
));
assert_decimal_err(eval.interpret_inline_policy(
&parse_expr(r#"decimal("922337203685477.5808")"#).expect("parsing error"),
));
assert_decimal_err(eval.interpret_inline_policy(
&parse_expr(r#"decimal("-922337203685477.5809")"#).expect("parsing error"),
));
assert_decimal_err(eval.interpret_inline_policy(
&parse_expr(r#"decimal("-922337203685478.0")"#).expect("parsing error"),
));
assert_decimal_err(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("0.12345")"#).expect("parsing error"),
),
);
assert_decimal_err(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("0.00000")"#).expect("parsing error"),
),
);
parse_expr(r#" "1.0".decimal() "#).expect_err("should fail");
}
#[test]
fn decimal_equality() {
let ext_array = [extension()];
let exts = Extensions::specific_extensions(&ext_array);
let request = basic_request();
let entities = basic_entities();
let eval = Evaluator::new(&request, &entities, &exts).unwrap();
let a = parse_expr(r#"decimal("123.0")"#).expect("parsing error");
let b = parse_expr(r#"decimal("123.0000")"#).expect("parsing error");
let c = parse_expr(r#"decimal("0123.0")"#).expect("parsing error");
let d = parse_expr(r#"decimal("123.456")"#).expect("parsing error");
let e = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
let f = parse_expr(r#"decimal("0.0")"#).expect("parsing error");
let g = parse_expr(r#"decimal("-0.0")"#).expect("parsing error");
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(a.clone(), a.clone())),
Ok(Value::from(true))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(a.clone(), b.clone())),
Ok(Value::from(true))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(b.clone(), c.clone())),
Ok(Value::from(true))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(c, a.clone())),
Ok(Value::from(true))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(b, d.clone())),
Ok(Value::from(false))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(a.clone(), e.clone())),
Ok(Value::from(false))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(d, e)),
Ok(Value::from(false))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(f, g)),
Ok(Value::from(true))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(a.clone(), Expr::val("123.0"))),
Ok(Value::from(false))
);
assert_eq!(
eval.interpret_inline_policy(&Expr::is_eq(a, Expr::val(1))),
Ok(Value::from(false))
);
}
fn decimal_ops_helper(op: &str, tests: Vec<((Expr, Expr), bool)>) {
let ext_array = [extension()];
let exts = Extensions::specific_extensions(&ext_array);
let request = basic_request();
let entities = basic_entities();
let eval = Evaluator::new(&request, &entities, &exts).unwrap();
for ((l, r), res) in tests {
assert_eq!(
eval.interpret_inline_policy(&Expr::call_extension_fn(
Name::parse_unqualified_name(op).expect("should be a valid identifier"),
vec![l, r]
)),
Ok(Value::from(res))
);
}
}
#[test]
fn decimal_ops() {
let a = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
let b = parse_expr(r#"decimal("1.24")"#).expect("parsing error");
let c = parse_expr(r#"decimal("123.45")"#).expect("parsing error");
let d = parse_expr(r#"decimal("-1.23")"#).expect("parsing error");
let e = parse_expr(r#"decimal("-1.24")"#).expect("parsing error");
let tests = vec![
((a.clone(), b.clone()), true), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
decimal_ops_helper("lessThan", tests);
let tests = vec![
((a.clone(), b.clone()), true), ((a.clone(), a.clone()), true), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
decimal_ops_helper("lessThanOrEqual", tests);
let tests = vec![
((a.clone(), b.clone()), false), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), true), ((d.clone(), a.clone()), false), ((d.clone(), e.clone()), true), ];
decimal_ops_helper("greaterThan", tests);
let tests = vec![
((a.clone(), b), false), ((a.clone(), a.clone()), true), ((c, a.clone()), true), ((d.clone(), a), false), ((d, e), true), ];
decimal_ops_helper("greaterThanOrEqual", tests);
let ext_array = [extension()];
let exts = Extensions::specific_extensions(&ext_array);
let request = basic_request();
let entities = basic_entities();
let eval = Evaluator::new(&request, &entities, &exts).unwrap();
assert_eq!(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("1.23") < decimal("1.24")"#).expect("parsing error")
),
Err(evaluator::EvaluationError::TypeError {
expected: vec![Type::Long],
actual: Type::Extension {
name: Name::parse_unqualified_name("decimal")
.expect("should be a valid identifier")
},
})
);
assert_eq!(
eval.interpret_inline_policy(
&parse_expr(r#"decimal("-1.23").lessThan("1.23")"#).expect("parsing error")
),
Err(evaluator::EvaluationError::TypeError {
expected: vec![Type::Extension {
name: Name::parse_unqualified_name("decimal")
.expect("should be a valid identifier")
}],
actual: Type::String,
})
);
parse_expr(r#"lessThan(decimal("-1.23"), decimal("1.23"))"#).expect_err("should fail");
}
fn check_round_trip(s: &str) {
let d = Decimal::from_str(s).expect("should be a valid decimal");
assert_eq!(s, d.to_string());
}
#[test]
fn decimal_display() {
check_round_trip("123.0");
check_round_trip("1.2300");
check_round_trip("123.4560");
check_round_trip("-123.4560");
check_round_trip("0.0");
}
}