use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow::row::{RowConverter, SortField};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;
use crate::utils::check_datatypes;
use std::any::Any;
use std::sync::Arc;
make_udf_expr_and_func!(ArrayHas,
array_has,
haystack_array element, "returns true, if the element appears in the first array, otherwise false.", array_has_udf );
make_udf_expr_and_func!(ArrayHasAll,
array_has_all,
haystack_array needle_array, "returns true if each element of the second array appears in the first array; otherwise, it returns false.", array_has_all_udf );
make_udf_expr_and_func!(ArrayHasAny,
array_has_any,
haystack_array needle_array, "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", array_has_any_udf );
#[derive(Debug)]
pub struct ArrayHas {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHas {
fn default() -> Self {
Self::new()
}
}
impl ArrayHas {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec![
String::from("list_has"),
String::from("array_contains"),
String::from("list_contains"),
],
}
}
}
impl ScalarUDFImpl for ArrayHas {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
if args.len() != 2 {
return exec_err!("array_has needs two arguments");
}
let array_type = args[0].data_type();
match array_type {
DataType::List(_) => general_array_has_dispatch::<i32>(
&args[0],
&args[1],
ComparisonType::Single,
)
.map(ColumnarValue::Array),
DataType::LargeList(_) => general_array_has_dispatch::<i64>(
&args[0],
&args[1],
ComparisonType::Single,
)
.map(ColumnarValue::Array),
_ => exec_err!("array_has does not support type '{array_type:?}'."),
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
#[derive(Debug)]
pub struct ArrayHasAll {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHasAll {
fn default() -> Self {
Self::new()
}
}
impl ArrayHasAll {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
aliases: vec![String::from("list_has_all")],
}
}
}
impl ScalarUDFImpl for ArrayHasAll {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has_all"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
if args.len() != 2 {
return exec_err!("array_has_all needs two arguments");
}
let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
.map(ColumnarValue::Array)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
.map(ColumnarValue::Array)
}
_ => exec_err!("array_has_all does not support type '{array_type:?}'."),
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
#[derive(Debug)]
pub struct ArrayHasAny {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHasAny {
fn default() -> Self {
Self::new()
}
}
impl ArrayHasAny {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
aliases: vec![String::from("list_has_any")],
}
}
}
impl ScalarUDFImpl for ArrayHasAny {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has_any"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
if args.len() != 2 {
return exec_err!("array_has_any needs two arguments");
}
let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
.map(ColumnarValue::Array)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
.map(ColumnarValue::Array)
}
_ => exec_err!("array_has_any does not support type '{array_type:?}'."),
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
#[derive(Debug, PartialEq)]
enum ComparisonType {
All,
Any,
Single,
}
fn general_array_has_dispatch<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let array = if comparison_type == ComparisonType::Single {
let arr = as_generic_list_array::<O>(haystack)?;
check_datatypes("array_has", &[arr.values(), needle])?;
arr
} else {
check_datatypes("array_has", &[haystack, needle])?;
as_generic_list_array::<O>(haystack)?
};
let mut boolean_builder = BooleanArray::builder(array.len());
let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let element = Arc::clone(needle);
let sub_array = if comparison_type != ComparisonType::Single {
as_generic_list_array::<O>(needle)?
} else {
array
};
for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
match (arr, sub_arr) {
(Some(arr), Some(sub_arr)) => {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[Arc::clone(&element)])?
};
let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};
if comparison_type == ComparisonType::Any {
res |= res;
}
boolean_builder.append_value(res);
}
(_, _) => {
boolean_builder.append_null();
}
}
}
Ok(Arc::new(boolean_builder.finish()))
}