1use crate::ast::{
20 CallStyle, Extension, ExtensionFunction, ExtensionOutputValue, ExtensionValue, Literal, Name,
21 RepresentableExtensionValue, Type, Value, ValueKind,
22};
23use crate::entities::SchemaType;
24use crate::evaluator;
25use miette::Diagnostic;
26use std::str::FromStr;
27use std::sync::Arc;
28use thiserror::Error;
29
30const NUM_DIGITS: u32 = 4;
32
33#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
36struct Decimal {
37 value: i64,
38}
39
40#[allow(clippy::expect_used, clippy::unwrap_used)]
42mod constants {
43 use super::EXTENSION_NAME;
44 use crate::ast::Name;
45 use regex::Regex;
46
47 lazy_static::lazy_static! {
49 pub static ref DECIMAL_FROM_STR_NAME : Name = Name::parse_unqualified_name(EXTENSION_NAME).expect("should be a valid identifier");
50 pub static ref LESS_THAN : Name = Name::parse_unqualified_name("lessThan").expect("should be a valid identifier");
51 pub static ref LESS_THAN_OR_EQUAL : Name = Name::parse_unqualified_name("lessThanOrEqual").expect("should be a valid identifier");
52 pub static ref GREATER_THAN : Name = Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier");
53 pub static ref GREATER_THAN_OR_EQUAL : Name = Name::parse_unqualified_name("greaterThanOrEqual").expect("should be a valid identifier");
54 }
55
56 lazy_static::lazy_static! {
59 pub static ref DECIMAL_REGEX : Regex = Regex::new(r"^(-?\d+)\.(\d+)$").unwrap();
60 }
61}
62
63const ADVICE_MSG: &str = "maybe you forgot to apply the `decimal` constructor?";
66
67#[derive(Debug, Diagnostic, Error)]
71enum Error {
72 #[error("`{0}` is not a well-formed decimal value")]
74 FailedParse(String),
75
76 #[error("too many digits after the decimal in `{0}`")]
78 #[diagnostic(help("at most {NUM_DIGITS} digits are supported"))]
79 TooManyDigits(String),
80
81 #[error("overflow when converting to decimal")]
83 Overflow,
84}
85
86fn checked_mul_pow(x: i64, y: u32) -> Result<i64, Error> {
88 if let Some(z) = i64::checked_pow(10, y) {
89 if let Some(w) = i64::checked_mul(x, z) {
90 return Ok(w);
91 }
92 };
93 Err(Error::Overflow)
94}
95
96impl Decimal {
97 fn typename() -> Name {
99 constants::DECIMAL_FROM_STR_NAME.clone()
100 }
101
102 fn from_str(str: impl AsRef<str>) -> Result<Self, Error> {
111 if !constants::DECIMAL_REGEX.is_match(str.as_ref()) {
113 return Err(Error::FailedParse(str.as_ref().to_owned()));
114 }
115
116 let caps = constants::DECIMAL_REGEX
120 .captures(str.as_ref())
121 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?;
122 let l = caps
123 .get(1)
124 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
125 .as_str();
126 let r = caps
127 .get(2)
128 .ok_or_else(|| Error::FailedParse(str.as_ref().to_owned()))?
129 .as_str();
130
131 let l = i64::from_str(l).map_err(|_| Error::Overflow)?;
133 let l = checked_mul_pow(l, NUM_DIGITS)?;
134
135 let len: u32 = r.len().try_into().map_err(|_| Error::Overflow)?;
137 if NUM_DIGITS < len {
138 return Err(Error::TooManyDigits(str.as_ref().to_string()));
139 }
140 let r = i64::from_str(r).map_err(|_| Error::Overflow)?;
141 let r = checked_mul_pow(r, NUM_DIGITS - len)?;
142
143 if l >= 0 {
145 l.checked_add(r)
146 } else {
147 l.checked_sub(r)
148 }
149 .map(|value| Self { value })
150 .ok_or(Error::Overflow)
151 }
152}
153
154impl std::fmt::Display for Decimal {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(
157 f,
158 "{}.{}",
159 self.value / i64::pow(10, NUM_DIGITS),
160 (self.value % i64::pow(10, NUM_DIGITS)).abs()
161 )
162 }
163}
164
165impl ExtensionValue for Decimal {
166 fn typename(&self) -> Name {
167 Self::typename()
168 }
169 fn supports_operator_overloading(&self) -> bool {
170 false
171 }
172}
173
174const EXTENSION_NAME: &str = "decimal";
175
176fn extension_err(msg: impl Into<String>, advice: Option<String>) -> evaluator::EvaluationError {
177 evaluator::EvaluationError::failed_extension_function_application(
178 constants::DECIMAL_FROM_STR_NAME.clone(),
179 msg.into(),
180 None,
181 advice.map(Into::into), )
183}
184
185fn decimal_from_str(arg: &Value) -> evaluator::Result<ExtensionOutputValue> {
188 let str = arg.get_as_string()?;
189 let decimal =
190 Decimal::from_str(str.as_str()).map_err(|e| extension_err(e.to_string(), None))?;
191 let arg_source_loc = arg.source_loc().cloned();
192 let e = RepresentableExtensionValue::new(
193 Arc::new(decimal),
194 constants::DECIMAL_FROM_STR_NAME.clone(),
195 vec![arg.clone().into()],
196 );
197 Ok(Value {
198 value: ValueKind::ExtensionValue(Arc::new(e)),
199 loc: arg_source_loc, }
201 .into())
202}
203
204fn as_decimal(v: &Value) -> Result<&Decimal, evaluator::EvaluationError> {
206 match &v.value {
207 ValueKind::ExtensionValue(ev) if ev.typename() == Decimal::typename() => {
208 #[allow(clippy::expect_used)]
210 let d = ev
211 .value()
212 .as_any()
213 .downcast_ref::<Decimal>()
214 .expect("already typechecked, so this downcast should succeed");
215 Ok(d)
216 }
217 ValueKind::Lit(Literal::String(_)) => {
218 Err(evaluator::EvaluationError::type_error_with_advice_single(
219 Type::Extension {
220 name: Decimal::typename(),
221 },
222 v,
223 ADVICE_MSG.into(),
224 ))
225 }
226 _ => Err(evaluator::EvaluationError::type_error_single(
227 Type::Extension {
228 name: Decimal::typename(),
229 },
230 v,
231 )),
232 }
233}
234
235fn decimal_lt(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
238 let left = as_decimal(left)?;
239 let right = as_decimal(right)?;
240 Ok(Value::from(left < right).into())
241}
242
243fn decimal_le(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
246 let left = as_decimal(left)?;
247 let right = as_decimal(right)?;
248 Ok(Value::from(left <= right).into())
249}
250
251fn decimal_gt(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
254 let left = as_decimal(left)?;
255 let right = as_decimal(right)?;
256 Ok(Value::from(left > right).into())
257}
258
259fn decimal_ge(left: &Value, right: &Value) -> evaluator::Result<ExtensionOutputValue> {
262 let left = as_decimal(left)?;
263 let right = as_decimal(right)?;
264 Ok(Value::from(left >= right).into())
265}
266
267pub fn extension() -> Extension {
269 let decimal_type = SchemaType::Extension {
270 name: Decimal::typename(),
271 };
272 Extension::new(
273 constants::DECIMAL_FROM_STR_NAME.clone(),
274 vec![
275 ExtensionFunction::unary(
276 constants::DECIMAL_FROM_STR_NAME.clone(),
277 CallStyle::FunctionStyle,
278 Box::new(decimal_from_str),
279 decimal_type.clone(),
280 SchemaType::String,
281 ),
282 ExtensionFunction::binary(
283 constants::LESS_THAN.clone(),
284 CallStyle::MethodStyle,
285 Box::new(decimal_lt),
286 SchemaType::Bool,
287 (decimal_type.clone(), decimal_type.clone()),
288 ),
289 ExtensionFunction::binary(
290 constants::LESS_THAN_OR_EQUAL.clone(),
291 CallStyle::MethodStyle,
292 Box::new(decimal_le),
293 SchemaType::Bool,
294 (decimal_type.clone(), decimal_type.clone()),
295 ),
296 ExtensionFunction::binary(
297 constants::GREATER_THAN.clone(),
298 CallStyle::MethodStyle,
299 Box::new(decimal_gt),
300 SchemaType::Bool,
301 (decimal_type.clone(), decimal_type.clone()),
302 ),
303 ExtensionFunction::binary(
304 constants::GREATER_THAN_OR_EQUAL.clone(),
305 CallStyle::MethodStyle,
306 Box::new(decimal_ge),
307 SchemaType::Bool,
308 (decimal_type.clone(), decimal_type),
309 ),
310 ],
311 std::iter::empty(),
312 )
313}
314
315#[cfg(test)]
316#[allow(clippy::panic)]
318mod tests {
319 use super::*;
320 use crate::ast::{Expr, Type, Value};
321 use crate::evaluator::test::{basic_entities, basic_request};
322 use crate::evaluator::{evaluation_errors, EvaluationError, Evaluator};
323 use crate::extensions::Extensions;
324 use crate::parser::parse_expr;
325 use cool_asserts::assert_matches;
326 use nonempty::nonempty;
327
328 #[track_caller] fn assert_decimal_err<T: std::fmt::Debug>(res: evaluator::Result<T>) {
331 assert_matches!(res, Err(evaluator::EvaluationError::FailedExtensionFunctionExecution(evaluation_errors::ExtensionFunctionExecutionError {
332 extension_name,
333 msg,
334 ..
335 })) => {
336 println!("{msg}");
337 assert_eq!(
338 extension_name,
339 Name::parse_unqualified_name("decimal")
340 .expect("should be a valid identifier")
341 )
342 });
343 }
344
345 #[track_caller] fn assert_decimal_valid(res: evaluator::Result<Value>) {
348 assert_matches!(res, Ok(Value { value: ValueKind::ExtensionValue(ev), .. }) => {
349 assert_eq!(ev.typename(), Decimal::typename());
350 });
351 }
352
353 #[test]
355 fn constructors() {
356 let ext = extension();
357 assert!(ext
358 .get_func(
359 &Name::parse_unqualified_name("decimal").expect("should be a valid identifier")
360 )
361 .expect("function should exist")
362 .is_constructor());
363 assert!(!ext
364 .get_func(
365 &Name::parse_unqualified_name("lessThan").expect("should be a valid identifier")
366 )
367 .expect("function should exist")
368 .is_constructor());
369 assert!(!ext
370 .get_func(
371 &Name::parse_unqualified_name("lessThanOrEqual")
372 .expect("should be a valid identifier")
373 )
374 .expect("function should exist")
375 .is_constructor());
376 assert!(!ext
377 .get_func(
378 &Name::parse_unqualified_name("greaterThan").expect("should be a valid identifier")
379 )
380 .expect("function should exist")
381 .is_constructor());
382 assert!(!ext
383 .get_func(
384 &Name::parse_unqualified_name("greaterThanOrEqual")
385 .expect("should be a valid identifier")
386 )
387 .expect("function should exist")
388 .is_constructor(),);
389 }
390
391 #[test]
392 fn decimal_creation() {
393 let ext_array = [extension()];
394 let exts = Extensions::specific_extensions(&ext_array).unwrap();
395 let request = basic_request();
396 let entities = basic_entities();
397 let eval = Evaluator::new(request, &entities, &exts);
398
399 assert_decimal_valid(
401 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0")"#).expect("parsing error")),
402 );
403 assert_decimal_valid(
404 eval.interpret_inline_policy(&parse_expr(r#"decimal("-1.0")"#).expect("parsing error")),
405 );
406 assert_decimal_valid(
407 eval.interpret_inline_policy(
408 &parse_expr(r#"decimal("123.456")"#).expect("parsing error"),
409 ),
410 );
411 assert_decimal_valid(
412 eval.interpret_inline_policy(
413 &parse_expr(r#"decimal("0.1234")"#).expect("parsing error"),
414 ),
415 );
416 assert_decimal_valid(
417 eval.interpret_inline_policy(
418 &parse_expr(r#"decimal("-0.0123")"#).expect("parsing error"),
419 ),
420 );
421 assert_decimal_valid(
422 eval.interpret_inline_policy(&parse_expr(r#"decimal("55.1")"#).expect("parsing error")),
423 );
424 assert_decimal_valid(eval.interpret_inline_policy(
425 &parse_expr(r#"decimal("-922337203685477.5808")"#).expect("parsing error"),
426 ));
427
428 assert_decimal_valid(
430 eval.interpret_inline_policy(
431 &parse_expr(r#"decimal("00.000")"#).expect("parsing error"),
432 ),
433 );
434
435 assert_decimal_err(
437 eval.interpret_inline_policy(&parse_expr(r#"decimal("1234")"#).expect("parsing error")),
438 );
439 assert_decimal_err(
440 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.0.")"#).expect("parsing error")),
441 );
442 assert_decimal_err(
443 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.")"#).expect("parsing error")),
444 );
445 assert_decimal_err(
446 eval.interpret_inline_policy(&parse_expr(r#"decimal(".1")"#).expect("parsing error")),
447 );
448 assert_decimal_err(
449 eval.interpret_inline_policy(&parse_expr(r#"decimal("1.a")"#).expect("parsing error")),
450 );
451 assert_decimal_err(
452 eval.interpret_inline_policy(&parse_expr(r#"decimal("-.")"#).expect("parsing error")),
453 );
454
455 assert_decimal_err(eval.interpret_inline_policy(
457 &parse_expr(r#"decimal("1000000000000000.0")"#).expect("parsing error"),
458 ));
459 assert_decimal_err(eval.interpret_inline_policy(
460 &parse_expr(r#"decimal("922337203685477.5808")"#).expect("parsing error"),
461 ));
462 assert_decimal_err(eval.interpret_inline_policy(
463 &parse_expr(r#"decimal("-922337203685477.5809")"#).expect("parsing error"),
464 ));
465 assert_decimal_err(eval.interpret_inline_policy(
466 &parse_expr(r#"decimal("-922337203685478.0")"#).expect("parsing error"),
467 ));
468
469 assert_decimal_err(
471 eval.interpret_inline_policy(
472 &parse_expr(r#"decimal("0.12345")"#).expect("parsing error"),
473 ),
474 );
475
476 assert_decimal_err(
478 eval.interpret_inline_policy(
479 &parse_expr(r#"decimal("0.00000")"#).expect("parsing error"),
480 ),
481 );
482
483 parse_expr(r#" "1.0".decimal() "#).expect_err("should fail");
485 }
486
487 #[test]
488 fn decimal_equality() {
489 let ext_array = [extension()];
490 let exts = Extensions::specific_extensions(&ext_array).unwrap();
491 let request = basic_request();
492 let entities = basic_entities();
493 let eval = Evaluator::new(request, &entities, &exts);
494
495 let a = parse_expr(r#"decimal("123.0")"#).expect("parsing error");
496 let b = parse_expr(r#"decimal("123.0000")"#).expect("parsing error");
497 let c = parse_expr(r#"decimal("0123.0")"#).expect("parsing error");
498 let d = parse_expr(r#"decimal("123.456")"#).expect("parsing error");
499 let e = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
500 let f = parse_expr(r#"decimal("0.0")"#).expect("parsing error");
501 let g = parse_expr(r#"decimal("-0.0")"#).expect("parsing error");
502
503 assert_eq!(
505 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), a.clone())),
506 Ok(Value::from(true))
507 );
508 assert_eq!(
509 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), b.clone())),
510 Ok(Value::from(true))
511 );
512 assert_eq!(
513 eval.interpret_inline_policy(&Expr::is_eq(b.clone(), c.clone())),
514 Ok(Value::from(true))
515 );
516 assert_eq!(
517 eval.interpret_inline_policy(&Expr::is_eq(c, a.clone())),
518 Ok(Value::from(true))
519 );
520
521 assert_eq!(
523 eval.interpret_inline_policy(&Expr::is_eq(b, d.clone())),
524 Ok(Value::from(false))
525 );
526 assert_eq!(
527 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), e.clone())),
528 Ok(Value::from(false))
529 );
530 assert_eq!(
531 eval.interpret_inline_policy(&Expr::is_eq(d, e)),
532 Ok(Value::from(false))
533 );
534
535 assert_eq!(
537 eval.interpret_inline_policy(&Expr::is_eq(f, g)),
538 Ok(Value::from(true))
539 );
540
541 assert_eq!(
543 eval.interpret_inline_policy(&Expr::is_eq(a.clone(), Expr::val("123.0"))),
544 Ok(Value::from(false))
545 );
546 assert_eq!(
547 eval.interpret_inline_policy(&Expr::is_eq(a, Expr::val(1))),
548 Ok(Value::from(false))
549 );
550 }
551
552 fn decimal_ops_helper(op: &str, tests: Vec<((Expr, Expr), bool)>) {
553 let ext_array = [extension()];
554 let exts = Extensions::specific_extensions(&ext_array).unwrap();
555 let request = basic_request();
556 let entities = basic_entities();
557 let eval = Evaluator::new(request, &entities, &exts);
558
559 for ((l, r), res) in tests {
560 assert_eq!(
561 eval.interpret_inline_policy(&Expr::call_extension_fn(
562 Name::parse_unqualified_name(op).expect("should be a valid identifier"),
563 vec![l, r]
564 )),
565 Ok(Value::from(res))
566 );
567 }
568 }
569
570 #[test]
571 fn decimal_ops() {
572 let a = parse_expr(r#"decimal("1.23")"#).expect("parsing error");
573 let b = parse_expr(r#"decimal("1.24")"#).expect("parsing error");
574 let c = parse_expr(r#"decimal("123.45")"#).expect("parsing error");
575 let d = parse_expr(r#"decimal("-1.23")"#).expect("parsing error");
576 let e = parse_expr(r#"decimal("-1.24")"#).expect("parsing error");
577
578 let tests = vec![
580 ((a.clone(), b.clone()), true), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
586 decimal_ops_helper("lessThan", tests);
587
588 let tests = vec![
590 ((a.clone(), b.clone()), true), ((a.clone(), a.clone()), true), ((c.clone(), a.clone()), false), ((d.clone(), a.clone()), true), ((d.clone(), e.clone()), false), ];
596 decimal_ops_helper("lessThanOrEqual", tests);
597
598 let tests = vec![
600 ((a.clone(), b.clone()), false), ((a.clone(), a.clone()), false), ((c.clone(), a.clone()), true), ((d.clone(), a.clone()), false), ((d.clone(), e.clone()), true), ];
606 decimal_ops_helper("greaterThan", tests);
607
608 let tests = vec![
610 ((a.clone(), b), false), ((a.clone(), a.clone()), true), ((c, a.clone()), true), ((d.clone(), a), false), ((d, e), true), ];
616 decimal_ops_helper("greaterThanOrEqual", tests);
617
618 let ext_array = [extension()];
621 let exts = Extensions::specific_extensions(&ext_array).unwrap();
622 let request = basic_request();
623 let entities = basic_entities();
624 let eval = Evaluator::new(request, &entities, &exts);
625
626 assert_matches!(
627 eval.interpret_inline_policy(
628 &parse_expr(r#"decimal("1.23") < decimal("1.24")"#).expect("parsing error")
629 ),
630 Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
631 assert_eq!(expected, nonempty![Type::Long]);
632 assert_eq!(actual, Type::Extension {
633 name: Name::parse_unqualified_name("decimal")
634 .expect("should be a valid identifier")
635 });
636 assert_eq!(advice, Some("Only types long support comparison".into()));
637 }
638 );
639 assert_matches!(
640 eval.interpret_inline_policy(
641 &parse_expr(r#"decimal("-1.23").lessThan("1.23")"#).expect("parsing error")
642 ),
643 Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => {
644 assert_eq!(expected, nonempty![Type::Extension {
645 name: Name::parse_unqualified_name("decimal")
646 .expect("should be a valid identifier")
647 }]);
648 assert_eq!(actual, Type::String);
649 assert_matches!(advice, Some(a) => assert_eq!(a, ADVICE_MSG));
650 }
651 );
652 parse_expr(r#"lessThan(decimal("-1.23"), decimal("1.23"))"#).expect_err("should fail");
654 }
655
656 fn check_round_trip(s: &str) {
657 let d = Decimal::from_str(s).expect("should be a valid decimal");
658 assert_eq!(s, d.to_string());
659 }
660
661 #[test]
662 fn decimal_display() {
663 check_round_trip("123.0");
665 check_round_trip("1.2300");
666 check_round_trip("123.4560");
667 check_round_trip("-123.4560");
668 check_round_trip("0.0");
669 }
670}