use std::sync::Arc;
use super::{Interval, IntervalBound};
use crate::{
expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr},
PhysicalExpr,
};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Operator;
const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000;
const MDN_NS_MASK: i128 = 0xFFFF_FFFF_FFFF_FFFF;
const DT_MS_MASK: i64 = 0xFFFF_FFFF;
pub fn check_support(expr: &Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> bool {
let expr_any = expr.as_any();
if let Some(binary_expr) = expr_any.downcast_ref::<BinaryExpr>() {
is_operator_supported(binary_expr.op())
&& check_support(binary_expr.left(), schema)
&& check_support(binary_expr.right(), schema)
} else if let Some(column) = expr_any.downcast_ref::<Column>() {
if let Ok(field) = schema.field_with_name(column.name()) {
is_datatype_supported(field.data_type())
} else {
return false;
}
} else if let Some(literal) = expr_any.downcast_ref::<Literal>() {
if let Ok(dt) = literal.data_type(schema) {
is_datatype_supported(&dt)
} else {
return false;
}
} else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
check_support(cast.expr(), schema)
} else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
check_support(negative.arg(), schema)
} else {
false
}
}
pub fn get_inverse_op(op: Operator) -> Operator {
match op {
Operator::Plus => Operator::Minus,
Operator::Minus => Operator::Plus,
_ => unreachable!(),
}
}
pub fn is_operator_supported(op: &Operator) -> bool {
matches!(
op,
&Operator::Plus
| &Operator::Minus
| &Operator::And
| &Operator::Gt
| &Operator::GtEq
| &Operator::Lt
| &Operator::LtEq
| &Operator::Eq
)
}
pub fn is_datatype_supported(data_type: &DataType) -> bool {
matches!(
data_type,
&DataType::Int64
| &DataType::Int32
| &DataType::Int16
| &DataType::Int8
| &DataType::UInt64
| &DataType::UInt32
| &DataType::UInt16
| &DataType::UInt8
| &DataType::Float64
| &DataType::Float32
)
}
pub fn convert_interval_type_to_duration(interval: &Interval) -> Option<Interval> {
if let (Some(lower), Some(upper)) = (
convert_interval_bound_to_duration(&interval.lower),
convert_interval_bound_to_duration(&interval.upper),
) {
Some(Interval::new(lower, upper))
} else {
None
}
}
fn convert_interval_bound_to_duration(
interval_bound: &IntervalBound,
) -> Option<IntervalBound> {
match interval_bound.value {
ScalarValue::IntervalMonthDayNano(Some(mdn)) => {
interval_mdn_to_duration_ns(&mdn).ok().map(|duration| {
IntervalBound::new(
ScalarValue::DurationNanosecond(Some(duration)),
interval_bound.open,
)
})
}
ScalarValue::IntervalDayTime(Some(dt)) => {
interval_dt_to_duration_ms(&dt).ok().map(|duration| {
IntervalBound::new(
ScalarValue::DurationMillisecond(Some(duration)),
interval_bound.open,
)
})
}
_ => None,
}
}
pub fn convert_duration_type_to_interval(interval: &Interval) -> Option<Interval> {
if let (Some(lower), Some(upper)) = (
convert_duration_bound_to_interval(&interval.lower),
convert_duration_bound_to_interval(&interval.upper),
) {
Some(Interval::new(lower, upper))
} else {
None
}
}
fn convert_duration_bound_to_interval(
interval_bound: &IntervalBound,
) -> Option<IntervalBound> {
match interval_bound.value {
ScalarValue::DurationNanosecond(Some(duration)) => Some(IntervalBound::new(
ScalarValue::new_interval_mdn(0, 0, duration),
interval_bound.open,
)),
ScalarValue::DurationMillisecond(Some(duration)) => Some(IntervalBound::new(
ScalarValue::new_interval_dt(0, duration as i32),
interval_bound.open,
)),
_ => None,
}
}
fn interval_mdn_to_duration_ns(mdn: &i128) -> Result<i64> {
let months = mdn >> 96;
let days = (mdn & MDN_DAY_MASK) >> 64;
let nanoseconds = mdn & MDN_NS_MASK;
if months == 0 && days == 0 {
nanoseconds.try_into().map_err(|_| {
DataFusionError::Internal("Resulting duration exceeds i64::MAX".to_string())
})
} else {
Err(DataFusionError::Internal(
"The interval cannot have a non-zero month or day value for duration convertibility"
.to_string(),
))
}
}
fn interval_dt_to_duration_ms(dt: &i64) -> Result<i64> {
let days = dt >> 32;
let milliseconds = dt & DT_MS_MASK;
if days == 0 {
Ok(milliseconds)
} else {
Err(DataFusionError::Internal(
"The interval cannot have a non-zero day value for duration convertibility"
.to_string(),
))
}
}