use crate::TypeSignature;
use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{internal_err, plan_err, Result};
pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
pub static SIGNED_INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
];
pub static UNSIGNED_INTEGERS: &[DataType] = &[
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];
pub static INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];
pub static NUMERICS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];
pub static TIMESTAMPS: &[DataType] = &[
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
];
pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary];
pub static TIMES: &[DataType] = &[
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
];
pub fn check_arg_count(
func_name: &str,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return plan_err!(
"The function {func_name} expects {:?} arguments, but {:?} were provided",
agg_count,
input_types.len()
);
}
}
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return plan_err!(
"The function {func_name} expects {:?} arguments, but {:?} were provided",
types.len(),
input_types.len()
);
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(func_name, input_types, v).is_ok());
if !ok {
return plan_err!(
"The function {func_name} does not accept {:?} function arguments.",
input_types.len()
);
}
}
TypeSignature::VariadicAny => {
if input_types.is_empty() {
return plan_err!(
"The function {func_name} expects at least one argument"
);
}
}
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
}
_ => {
return internal_err!(
"Aggregate functions do not support this {signature:?}"
);
}
}
Ok(())
}
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
DataType::Decimal128(precision, scale) => {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
DataType::Decimal256(precision, scale) => {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
other => plan_err!("SUM does not support type \"{other:?}\""),
}
}
pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("VAR does not support {arg_type:?}")
}
}
pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("COVAR does not support {arg_type:?}")
}
}
pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("CORR does not support {arg_type:?}")
}
}
pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal128(precision, scale) => {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
DataType::Decimal256(precision, scale) => {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal256(new_precision, new_scale))
}
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_return_type(func_name, dict_value_type.as_ref())
}
other => plan_err!("{func_name} does not support {other:?}"),
}
}
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal128(precision, scale) => {
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
DataType::Decimal256(precision, scale) => {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_sum_type(dict_value_type.as_ref())
}
other => plan_err!("AVG does not support {other:?}"),
}
}
pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
match arg_type {
DataType::Dictionary(_, dict_value_type) => {
is_sum_support_arg_type(dict_value_type.as_ref())
}
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
),
}
}
pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
match arg_type {
DataType::Dictionary(_, dict_value_type) => {
is_avg_support_arg_type(dict_value_type.as_ref())
}
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
),
}
}
pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}
pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
arg_type.is_integer()
}
pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<DataType>> {
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
return match &data_type {
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()),
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
func_name,
data_type
)
}
};
}
Ok(vec![coerced_type(func_name, &arg_types[0])?])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_variance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = variance_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);
let data_type = DataType::Decimal128(36, 10);
assert!(variance_return_type(&data_type).is_err());
Ok(())
}
#[test]
fn test_sum_return_data_type() -> Result<()> {
let data_type = DataType::Decimal128(10, 5);
let result_type = sum_return_type(&data_type)?;
assert_eq!(DataType::Decimal128(20, 5), result_type);
let data_type = DataType::Decimal128(36, 10);
let result_type = sum_return_type(&data_type)?;
assert_eq!(DataType::Decimal128(38, 10), result_type);
Ok(())
}
#[test]
fn test_covariance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = covariance_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);
let data_type = DataType::Decimal128(36, 10);
assert!(covariance_return_type(&data_type).is_err());
Ok(())
}
#[test]
fn test_correlation_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = correlation_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);
let data_type = DataType::Decimal128(36, 10);
assert!(correlation_return_type(&data_type).is_err());
Ok(())
}
}