use crate::{Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{plan_err, DataFusionError, Result};
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
if current_types.is_empty() {
return Ok(vec![]);
}
let valid_types = get_valid_types(&signature.type_signature, current_types)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}
plan_err!(
"Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}
fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.collect()]
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
"The function expected {} arguments but received {}",
number,
current_types.len()
);
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(t, current_types).ok())
.flatten()
.collect::<Vec<_>>(),
};
Ok(valid_types)
}
fn maybe_data_types(
valid_types: &[DataType],
current_types: &[DataType],
) -> Option<Vec<DataType>> {
if valid_types.len() != current_types.len() {
return None;
}
let mut new_type = Vec::with_capacity(valid_types.len());
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = ¤t_types[i];
if current_type == valid_type {
new_type.push(current_type.clone())
} else {
if can_coerce_from(valid_type, current_type) {
new_type.push(valid_type.clone())
} else {
return None;
}
}
}
Some(new_type)
}
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
if type_into == type_from {
return true;
}
match type_into {
Int8 => matches!(type_from, Null | Int8),
Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
UInt8 => matches!(type_from, Null | UInt8),
UInt16 => matches!(type_from, Null | UInt8 | UInt16),
UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
),
Float64 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
| Decimal128(_, _)
),
Timestamp(TimeUnit::Nanosecond, _) => {
matches!(
type_from,
Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8
)
}
Interval(_) => {
matches!(type_from, Utf8 | LargeUtf8)
}
Utf8 | LargeUtf8 => true,
Null => can_cast_types(type_from, type_into),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
#[test]
fn test_maybe_data_types() {
let cases = vec![
(
vec![DataType::UInt8, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt8, DataType::UInt16]),
),
(
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt16, DataType::UInt16]),
),
(vec![], vec![], Some(vec![])),
(
vec![DataType::Boolean, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
None,
),
(
vec![DataType::Boolean, DataType::UInt32],
vec![DataType::Boolean, DataType::UInt16],
Some(vec![DataType::Boolean, DataType::UInt32]),
),
];
for case in cases {
assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
}
}
#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature =
TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
let invalid_types = get_valid_types(
&signature,
&[DataType::Int32, DataType::Int32, DataType::Int32],
)?;
assert_eq!(invalid_types.len(), 0);
let args = vec![DataType::Int32, DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);
let args = vec![DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);
Ok(())
}
}