datafusion_functions_nested/
distance.rsuse crate::utils::{downcast_arg, make_scalar_function};
use arrow_array::{
Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
};
use arrow_schema::DataType;
use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List};
use core::any::type_name;
use datafusion_common::cast::{
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
as_int64_array,
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
make_udf_expr_and_func!(
ArrayDistance,
array_distance,
array,
"returns the Euclidean distance between two numeric arrays.",
array_distance_udf
);
#[derive(Debug)]
pub(super) struct ArrayDistance {
signature: Signature,
aliases: Vec<String>,
}
impl ArrayDistance {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec!["list_distance".to_string()],
}
}
}
impl ScalarUDFImpl for ArrayDistance {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_distance"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64),
_ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
}
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let mut result = Vec::new();
for arg_type in arg_types {
match arg_type {
List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)),
_ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
}
}
Ok(result)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_distance_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_array_distance_doc())
}
}
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_array_distance_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ARRAY)
.with_description(
"Returns the Euclidean distance between two input arrays of equal length.",
)
.with_syntax_example("array_distance(array1, array2)")
.with_sql_example(
r#"```sql
> select array_distance([1, 2], [1, 4]);
+------------------------------------+
| array_distance(List([1,2], [1,4])) |
+------------------------------------+
| 2.0 |
+------------------------------------+
```"#,
)
.with_argument(
"array1",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.with_argument(
"array2",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.build()
.unwrap()
})
}
pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
match (&args[0].data_type(), &args[1].data_type()) {
(List(_), List(_)) => general_array_distance::<i32>(args),
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
(array_type1, array_type2) => {
exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'")
}
}
}
fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
let result = list_array1
.iter()
.zip(list_array2.iter())
.map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
.collect::<Result<Float64Array>>()?;
Ok(Arc::new(result) as ArrayRef)
}
fn compute_array_distance(
arr1: Option<ArrayRef>,
arr2: Option<ArrayRef>,
) -> Result<Option<f64>> {
let value1 = match arr1 {
Some(arr) => arr,
None => return Ok(None),
};
let value2 = match arr2 {
Some(arr) => arr,
None => return Ok(None),
};
let mut value1 = value1;
let mut value2 = value2;
loop {
match value1.data_type() {
List(_) => {
if downcast_arg!(value1, ListArray).null_count() > 0 {
return Ok(None);
}
value1 = downcast_arg!(value1, ListArray).value(0);
}
LargeList(_) => {
if downcast_arg!(value1, LargeListArray).null_count() > 0 {
return Ok(None);
}
value1 = downcast_arg!(value1, LargeListArray).value(0);
}
_ => break,
}
match value2.data_type() {
List(_) => {
if downcast_arg!(value2, ListArray).null_count() > 0 {
return Ok(None);
}
value2 = downcast_arg!(value2, ListArray).value(0);
}
LargeList(_) => {
if downcast_arg!(value2, LargeListArray).null_count() > 0 {
return Ok(None);
}
value2 = downcast_arg!(value2, LargeListArray).value(0);
}
_ => break,
}
}
if value1.null_count() != 0 || value2.null_count() != 0 {
return Ok(None);
}
let values1 = convert_to_f64_array(&value1)?;
let values2 = convert_to_f64_array(&value2)?;
if values1.len() != values2.len() {
return exec_err!("Both arrays must have the same length");
}
let sum_squares: f64 = values1
.iter()
.zip(values2.iter())
.map(|(v1, v2)| {
let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
diff * diff
})
.sum();
Ok(Some(sum_squares.sqrt()))
}
fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
match array.data_type() {
Float64 => Ok(as_float64_array(array)?.clone()),
DataType::Float32 => {
let array = as_float32_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
DataType::Int64 => {
let array = as_int64_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
DataType::Int32 => {
let array = as_int32_array(array)?;
let converted: Float64Array =
array.iter().map(|v| v.map(|v| v as f64)).collect();
Ok(converted)
}
_ => exec_err!("Unsupported array type for conversion to Float64Array"),
}
}