use crate::optimizer::ApplyOrder;
use crate::utils::merge_schema;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{
internal_err, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
use std::cmp::Ordering;
use std::sync::Arc;
#[derive(Default)]
pub struct UnwrapCastInComparison {}
impl UnwrapCastInComparison {
pub fn new() -> Self {
Self::default()
}
}
impl OptimizerRule for UnwrapCastInComparison {
fn try_optimize(
&self,
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let mut schema = merge_schema(plan.inputs());
schema.merge(plan.schema());
let mut expr_rewriter = UnwrapCastExprRewriter {
schema: Arc::new(schema),
};
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;
let inputs: Vec<LogicalPlan> = plan.inputs().into_iter().cloned().collect();
Ok(Some(plan.with_new_exprs(new_exprs, inputs.as_slice())?))
}
fn name(&self) -> &str {
"unwrap_cast_in_comparison"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
}
struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
impl TreeNodeRewriter for UnwrapCastExprRewriter {
type N = Expr;
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match &expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let left = left.as_ref().clone();
let right = right.as_ref().clone();
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
if is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
{
match (&left, &right) {
(
Expr::Literal(left_lit_value),
Expr::TryCast(TryCast { expr, .. })
| Expr::Cast(Cast { expr, .. }),
) => {
let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
try_cast_literal_to_type(left_lit_value, &expr_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(
lit(value),
*op,
expr.as_ref().clone(),
));
}
}
(
Expr::TryCast(TryCast { expr, .. })
| Expr::Cast(Cast { expr, .. }),
Expr::Literal(right_lit_value),
) => {
let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &expr_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(
expr.as_ref().clone(),
*op,
lit(value),
));
}
}
(_, _) => {
}
};
}
Ok(binary_expr(left, *op, right))
}
Expr::InList(InList {
expr: left_expr,
list,
negated,
}) => {
if let Some(
Expr::TryCast(TryCast {
expr: internal_left_expr,
..
})
| Expr::Cast(Cast {
expr: internal_left_expr,
..
}),
) = Some(left_expr.as_ref())
{
let internal_left = internal_left_expr.as_ref().clone();
let internal_left_type = internal_left.get_type(&self.schema);
if internal_left_type.is_err() {
return Ok(expr);
}
let internal_left_type = internal_left_type?;
if !is_support_data_type(&internal_left_type) {
return Ok(expr);
}
let right_exprs = list
.iter()
.map(|right| {
let right_type = right.get_type(&self.schema)?;
if !is_support_data_type(&right_type) {
return internal_err!(
"The type of list expr {} not support",
&right_type
);
}
match right {
Expr::Literal(right_lit_value) => {
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
if let Some(value) = casted_scalar_value {
Ok(lit(value))
} else {
internal_err!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &internal_left_type
)
}
}
other_expr => internal_err!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
),
}
})
.collect::<Result<Vec<_>>>();
match right_exprs {
Ok(right_exprs) => {
Ok(in_list(internal_left, right_exprs, *negated))
}
Err(_) => Ok(expr),
}
} else {
Ok(expr)
}
}
_ => Ok(expr),
}
}
}
fn is_comparison_op(op: &Operator) -> bool {
matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Gt
| Operator::GtEq
| Operator::Lt
| Operator::LtEq
)
}
fn is_support_data_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
| DataType::Timestamp(_, _)
)
}
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<Option<ScalarValue>> {
let lit_data_type = lit_value.data_type();
if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
return Ok(None);
}
if lit_value.is_null() {
return Ok(Some(ScalarValue::try_from(target_type)?));
}
let mul = match target_type {
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return internal_err!("Error target data type {other_type:?}");
}
};
let (target_min, target_max) = match target_type {
DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
MIN_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
),
other_type => {
return internal_err!("Error target data type {other_type:?}");
}
};
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
Some(*v / (lit_scale_mul / mul))
} else {
None
}
}
other_value => {
return internal_err!("Invalid literal value {other_value:?}");
}
};
match lit_value_target_type {
None => Ok(None),
Some(value) => {
if value >= target_min && value <= target_max {
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Second, tz.clone()),
value,
);
ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
value,
);
ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
value,
);
ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
value,
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
other_type => {
return internal_err!("Error target data type {other_type:?}");
}
};
Ok(Some(result_scalar))
} else {
Ok(None)
}
}
}
}
fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option<i64> {
let value = value as i64;
let from_scale = match from {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};
let to_scale = match to {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};
match from_scale.cmp(&to_scale) {
Ordering::Less => value.checked_mul(to_scale / from_scale),
Ordering::Greater => Some(value / (from_scale / to_scale)),
Ordering::Equal => Some(value),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::{DataType, Field};
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr};
use std::collections::HashMap;
use std::sync::Arc;
#[test]
fn test_not_unwrap_cast_comparison() {
let schema = expr_test_schema();
let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2"));
assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
let expr_lt = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
}
#[test]
fn test_unwrap_cast_comparison() {
let schema = expr_test_schema();
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64));
let expected = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64));
let expected = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32));
let expected = col("c2").eq(lit(16i64));
assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
let expected = col("c1").lt(null_i32());
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
let expected = null_i8().lt(lit(12i8));
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}
#[test]
fn test_unwrap_cast_comparison_unsigned() {
let schema = expr_test_schema();
let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
let expected = col("c6").eq(lit(0u32));
assert_eq!(optimize_test(expr_input, &schema), expected);
}
#[test]
fn test_not_unwrap_cast_with_decimal_comparison() {
let schema = expr_test_schema();
let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
let expr_eq =
cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
let expr_eq =
cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
let expr_eq =
cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
}
#[test]
fn test_unwrap_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64));
let expected = col("c3").lt(lit_decimal(1600, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
let expected = col("c3").lt(null_decimal(18, 2));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
let expr_lt =
cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0));
let expected = col("c3").lt(lit_decimal(12300, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt =
cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3));
let expected = col("c3").lt(lit_decimal(123, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt =
cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2));
let expected = col("c1").lt(lit(123i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
#[test]
fn test_not_unwrap_list_cast_lit_comparison() {
let schema = expr_test_schema();
let expr_lt =
cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
let expr_lt = cast(col("c1"), DataType::Float32)
.in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
let expr_lt = cast(col("c1"), DataType::Int64)
.in_list(vec![lit(12i32), lit(99999999999i64)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
vec![
lit_decimal(12, 12, 3),
lit_decimal(12, 12, 3),
lit_decimal(128, 12, 3),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
}
#[test]
fn test_unwrap_list_cast_comparison() {
let schema = expr_test_schema();
let expr_lt =
cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false);
let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt =
cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false);
let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
vec![
lit_decimal(12000, 19, 3),
lit_decimal(24000, 19, 3),
lit_decimal(1280, 19, 3),
lit_decimal(1240, 19, 3),
],
false,
);
let expected = col("c3").in_list(
vec![
lit_decimal(1200, 18, 2),
lit_decimal(2400, 18, 2),
lit_decimal(128, 18, 2),
lit_decimal(124, 18, 2),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
let expr_lt = cast(lit(12i32), DataType::Int64)
.in_list(vec![lit(13i64), lit(12i64)], false);
let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
#[test]
fn aliased() {
let schema = expr_test_schema();
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x");
let expected = col("c1").lt(lit(16i32)).alias("x");
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
#[test]
fn nested() {
let schema = expr_test_schema();
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast(
col("c1"),
DataType::Int64,
)
.gt(lit(32i64)));
let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
#[test]
fn test_not_support_data_type() {
let schema = expr_test_schema();
let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64));
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
let expr_input =
in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false);
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}
#[test]
fn test_unwrap_cast_with_timestamp_nanos() {
let schema = expr_test_schema();
let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type())
.lt(lit_timestamp_nano_utc(1666612093000000000));
let expected =
col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
};
expr.rewrite(&mut expr_rewriter).unwrap()
}
fn expr_test_schema() -> DFSchemaRef {
Arc::new(
DFSchema::new_with_metadata(
vec![
DFField::new_unqualified("c1", DataType::Int32, false),
DFField::new_unqualified("c2", DataType::Int64, false),
DFField::new_unqualified("c3", DataType::Decimal128(18, 2), false),
DFField::new_unqualified("c4", DataType::Decimal128(38, 37), false),
DFField::new_unqualified("c5", DataType::Float32, false),
DFField::new_unqualified("c6", DataType::UInt32, false),
DFField::new_unqualified(
"ts_nano_none",
timestamp_nano_none_type(),
false,
),
DFField::new_unqualified(
"ts_nano_utf",
timestamp_nano_utc_type(),
false,
),
],
HashMap::new(),
)
.unwrap(),
)
}
fn null_i8() -> Expr {
lit(ScalarValue::Int8(None))
}
fn null_i32() -> Expr {
lit(ScalarValue::Int32(None))
}
fn null_i64() -> Expr {
lit(ScalarValue::Int64(None))
}
fn lit_decimal(value: i128, precision: u8, scale: i8) -> Expr {
lit(ScalarValue::Decimal128(Some(value), precision, scale))
}
fn lit_timestamp_nano_none(ts: i64) -> Expr {
lit(ScalarValue::TimestampNanosecond(Some(ts), None))
}
fn lit_timestamp_nano_utc(ts: i64) -> Expr {
let utc = Some("+0:00".into());
lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
}
fn null_decimal(precision: u8, scale: i8) -> Expr {
lit(ScalarValue::Decimal128(None, precision, scale))
}
fn timestamp_nano_none_type() -> DataType {
DataType::Timestamp(TimeUnit::Nanosecond, None)
}
fn timestamp_nano_utc_type() -> DataType {
let utc = Some("+0:00".into());
DataType::Timestamp(TimeUnit::Nanosecond, utc)
}
#[test]
fn test_try_cast_to_type_nulls() {
let scalars = vec![
ScalarValue::Int8(None),
ScalarValue::Int16(None),
ScalarValue::Int32(None),
ScalarValue::Int64(None),
ScalarValue::UInt8(None),
ScalarValue::UInt16(None),
ScalarValue::UInt32(None),
ScalarValue::UInt64(None),
ScalarValue::Decimal128(None, 3, 0),
ScalarValue::Decimal128(None, 8, 2),
];
for s1 in &scalars {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
}
#[test]
fn test_try_cast_to_type_int_in_range() {
let scalars = vec![
ScalarValue::Int8(Some(123)),
ScalarValue::Int16(Some(123)),
ScalarValue::Int32(Some(123)),
ScalarValue::Int64(Some(123)),
ScalarValue::UInt8(Some(123)),
ScalarValue::UInt16(Some(123)),
ScalarValue::UInt32(Some(123)),
ScalarValue::UInt64(Some(123)),
ScalarValue::Decimal128(Some(123), 3, 0),
ScalarValue::Decimal128(Some(12300), 8, 2),
];
for s1 in &scalars {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
let max_i32 = ScalarValue::Int32(Some(i32::MAX));
expect_cast(
max_i32,
DataType::UInt64,
ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
);
let min_i32 = ScalarValue::Int32(Some(i32::MIN));
expect_cast(
min_i32,
DataType::Int64,
ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
);
let max_i64 = ScalarValue::Int64(Some(i64::MAX));
expect_cast(
max_i64,
DataType::UInt64,
ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
);
}
#[test]
fn test_try_cast_to_type_int_out_of_range() {
let min_i32 = ScalarValue::Int32(Some(i32::MIN));
let min_i64 = ScalarValue::Int64(Some(i64::MIN));
let max_i64 = ScalarValue::Int64(Some(i64::MAX));
let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
expect_cast(
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
DataType::Int64,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
DataType::Int64,
ExpectedCast::NoValue,
);
}
#[test]
fn test_try_decimal_cast_in_range() {
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(3, 0),
ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(8, 0),
ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(8, 5),
ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
);
}
#[test]
fn test_try_decimal_cast_out_of_range() {
expect_cast(
ScalarValue::Decimal128(Some(12345), 5, 2),
DataType::Decimal128(3, 0),
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(2, 0),
ExpectedCast::NoValue,
);
}
#[test]
fn test_try_cast_to_type_timestamps() {
for time_unit in [
TimeUnit::Second,
TimeUnit::Millisecond,
TimeUnit::Microsecond,
TimeUnit::Nanosecond,
] {
let utc = Some("+00:00".into());
let (lit_tz_none, lit_tz_utc) = match time_unit {
TimeUnit::Second => (
ScalarValue::TimestampSecond(Some(12345), None),
ScalarValue::TimestampSecond(Some(12345), utc),
),
TimeUnit::Millisecond => (
ScalarValue::TimestampMillisecond(Some(12345), None),
ScalarValue::TimestampMillisecond(Some(12345), utc),
),
TimeUnit::Microsecond => (
ScalarValue::TimestampMicrosecond(Some(12345), None),
ScalarValue::TimestampMicrosecond(Some(12345), utc),
),
TimeUnit::Nanosecond => (
ScalarValue::TimestampNanosecond(Some(12345), None),
ScalarValue::TimestampNanosecond(Some(12345), utc),
),
};
assert_eq!(lit_tz_none, lit_tz_utc);
let dt_tz_none = lit_tz_none.data_type();
let dt_tz_utc = lit_tz_utc.data_type();
expect_cast(
lit_tz_none.clone(),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
lit_tz_none.clone(),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
lit_tz_utc.clone(),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
DataType::Int64,
ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
);
expect_cast(
ScalarValue::Int64(Some(12345)),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
ScalarValue::Int64(Some(12345)),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
DataType::LargeUtf8,
ExpectedCast::NoValue,
);
}
}
#[test]
fn test_try_cast_to_type_unsupported() {
expect_cast(
ScalarValue::Int64(Some(12345)),
DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
ExpectedCast::NoValue,
);
}
#[derive(Debug, Clone)]
enum ExpectedCast {
Value(ScalarValue),
NoValue,
}
fn expect_cast(
literal: ScalarValue,
target_type: DataType,
expected_result: ExpectedCast,
) {
let actual_result = try_cast_literal_to_type(&literal, &target_type);
println!("expect_cast: ");
println!(" {literal:?} --> {target_type:?}");
println!(" expected_result: {expected_result:?}");
println!(" actual_result: {actual_result:?}");
match expected_result {
ExpectedCast::Value(expected_value) => {
let actual_value = actual_result
.expect("Expected success but got error")
.expect("Expected cast value but got None");
assert_eq!(actual_value, expected_value);
let literal_array = literal.to_array_of_size(1);
let expected_array = expected_value.to_array_of_size(1);
let cast_array = cast_with_options(
&literal_array,
&target_type,
&CastOptions::default(),
)
.expect("Expected to be cast array with arrow cast kernel");
assert_eq!(
&expected_array, &cast_array,
"Result of casing {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
);
if let (
DataType::Timestamp(left_unit, left_tz),
DataType::Timestamp(right_unit, right_tz),
) = (actual_value.data_type(), expected_value.data_type())
{
assert_eq!(left_unit, right_unit);
assert_eq!(left_tz, right_tz);
}
}
ExpectedCast::NoValue => {
let actual_value = actual_result.expect("Expected success but got error");
assert!(
actual_value.is_none(),
"Expected no cast value, but got {actual_value:?}"
);
}
}
}
#[test]
fn test_try_cast_literal_to_timestamp() {
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123456), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMillisecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
}
}