use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Datum, GenericListArray, Scalar};
use arrow_buffer::BooleanBuffer;
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::utils::string_utils::string_array_to_vec;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;
use crate::utils::make_scalar_function;
use std::any::Any;
use std::sync::{Arc, OnceLock};
make_udf_expr_and_func!(ArrayHas,
array_has,
haystack_array element, "returns true, if the element appears in the first array, otherwise false.", array_has_udf );
make_udf_expr_and_func!(ArrayHasAll,
array_has_all,
haystack_array needle_array, "returns true if each element of the second array appears in the first array; otherwise, it returns false.", array_has_all_udf );
make_udf_expr_and_func!(ArrayHasAny,
array_has_any,
haystack_array needle_array, "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", array_has_any_udf );
#[derive(Debug)]
pub struct ArrayHas {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHas {
fn default() -> Self {
Self::new()
}
}
impl ArrayHas {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec![
String::from("list_has"),
String::from("array_contains"),
String::from("list_contains"),
],
}
}
}
impl ScalarUDFImpl for ArrayHas {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match &args[1] {
ColumnarValue::Array(array_needle) => {
let haystack = args[0].to_owned().into_array(array_needle.len())?;
let array = array_has_inner_for_array(&haystack, array_needle)?;
Ok(ColumnarValue::Array(array))
}
ColumnarValue::Scalar(scalar_needle) => {
if scalar_needle.is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
}
let haystack = args[0].to_owned().into_array(1)?;
let needle = scalar_needle.to_array_of_size(1)?;
let needle = Scalar::new(needle);
let array = array_has_inner_for_scalar(&haystack, &needle)?;
if let ColumnarValue::Scalar(_) = &args[0] {
let scalar_value = ScalarValue::try_from_array(&array, 0)?;
Ok(ColumnarValue::Scalar(scalar_value))
} else {
Ok(ColumnarValue::Array(array))
}
}
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_array_has_doc())
}
}
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_array_has_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ARRAY)
.with_description(
"Returns true if the array contains the element.",
)
.with_syntax_example("array_has(array, element)")
.with_sql_example(
r#"```sql
> select array_has([1, 2, 3], 2);
+-----------------------------+
| array_has(List([1,2,3]), 2) |
+-----------------------------+
| true |
+-----------------------------+
```"#,
)
.with_argument(
"array",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.with_argument(
"element",
"Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.build()
.unwrap()
})
}
fn array_has_inner_for_scalar(
haystack: &ArrayRef,
needle: &dyn Datum,
) -> Result<ArrayRef> {
match haystack.data_type() {
DataType::List(_) => array_has_dispatch_for_scalar::<i32>(haystack, needle),
DataType::LargeList(_) => array_has_dispatch_for_scalar::<i64>(haystack, needle),
_ => exec_err!(
"array_has does not support type '{:?}'.",
haystack.data_type()
),
}
}
fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
match haystack.data_type() {
DataType::List(_) => array_has_dispatch_for_array::<i32>(haystack, needle),
DataType::LargeList(_) => array_has_dispatch_for_array::<i64>(haystack, needle),
_ => exec_err!(
"array_has does not support type '{:?}'.",
haystack.data_type()
),
}
}
fn array_has_dispatch_for_array<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
let mut boolean_builder = BooleanArray::builder(haystack.len());
for (i, arr) in haystack.iter().enumerate() {
if arr.is_none() || needle.is_null(i) {
boolean_builder.append_null();
continue;
}
let arr = arr.unwrap();
let is_nested = arr.data_type().is_nested();
let needle_row = Scalar::new(needle.slice(i, 1));
let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?;
let is_contained = eq_array.true_count() > 0;
boolean_builder.append_value(is_contained)
}
Ok(Arc::new(boolean_builder.finish()))
}
fn array_has_dispatch_for_scalar<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &dyn Datum,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
let values = haystack.values();
let is_nested = values.data_type().is_nested();
let offsets = haystack.value_offsets();
if values.len() == 0 {
return Ok(Arc::new(BooleanArray::new(
BooleanBuffer::new_unset(haystack.len()),
None,
)));
}
let eq_array = compare_with_eq(values, needle, is_nested)?;
let mut final_contained = vec![None; haystack.len()];
for (i, offset) in offsets.windows(2).enumerate() {
let start = offset[0].to_usize().unwrap();
let end = offset[1].to_usize().unwrap();
let length = end - start;
if length == 0 {
continue;
}
let sliced_array = eq_array.slice(start, length);
if sliced_array.null_count() != length {
final_contained[i] = Some(sliced_array.true_count() > 0);
}
}
Ok(Arc::new(BooleanArray::from(final_contained)))
}
fn array_has_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(_) => {
array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
}
DataType::LargeList(_) => {
array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
}
_ => exec_err!(
"array_has does not support type '{:?}'.",
args[0].data_type()
),
}
}
fn array_has_any_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(_) => {
array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
}
DataType::LargeList(_) => {
array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
}
_ => exec_err!(
"array_has does not support type '{:?}'.",
args[0].data_type()
),
}
}
#[derive(Debug)]
pub struct ArrayHasAll {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHasAll {
fn default() -> Self {
Self::new()
}
}
impl ArrayHasAll {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
aliases: vec![String::from("list_has_all")],
}
}
}
impl ScalarUDFImpl for ArrayHasAll {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has_all"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_has_all_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_array_has_all_doc())
}
}
fn get_array_has_all_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ARRAY)
.with_description(
"Returns true if all elements of sub-array exist in array.",
)
.with_syntax_example("array_has_all(array, sub-array)")
.with_sql_example(
r#"```sql
> select array_has_all([1, 2, 3, 4], [2, 3]);
+--------------------------------------------+
| array_has_all(List([1,2,3,4]), List([2,3])) |
+--------------------------------------------+
| true |
+--------------------------------------------+
```"#,
)
.with_argument(
"array",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.with_argument(
"sub-array",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.build()
.unwrap()
})
}
#[derive(Debug)]
pub struct ArrayHasAny {
signature: Signature,
aliases: Vec<String>,
}
impl Default for ArrayHasAny {
fn default() -> Self {
Self::new()
}
}
impl ArrayHasAny {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
aliases: vec![String::from("list_has_any")],
}
}
}
impl ScalarUDFImpl for ArrayHasAny {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_has_any"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_has_any_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_array_has_any_doc())
}
}
fn get_array_has_any_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ARRAY)
.with_description(
"Returns true if any elements exist in both arrays.",
)
.with_syntax_example("array_has_any(array, sub-array)")
.with_sql_example(
r#"```sql
> select array_has_any([1, 2, 3], [3, 4]);
+------------------------------------------+
| array_has_any(List([1,2,3]), List([3,4])) |
+------------------------------------------+
| true |
+------------------------------------------+
```"#,
)
.with_argument(
"array",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.with_argument(
"sub-array",
"Array expression. Can be a constant, column, or function, and any combination of array operators.",
)
.build()
.unwrap()
})
}
#[derive(Debug, PartialEq, Clone, Copy)]
enum ComparisonType {
All,
Any,
}
fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
let needle = as_generic_list_array::<O>(needle)?;
match needle.data_type() {
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
array_has_all_and_any_string_internal::<O>(haystack, needle, comparison_type)
}
_ => general_array_has_for_all_and_any::<O>(haystack, needle, comparison_type),
}
}
fn array_has_all_and_any_string_internal<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
needle: &GenericListArray<O>,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(array.len());
for (arr, sub_arr) in array.iter().zip(needle.iter()) {
match (arr, sub_arr) {
(Some(arr), Some(sub_arr)) => {
let haystack_array = string_array_to_vec(&arr);
let needle_array = string_array_to_vec(&sub_arr);
boolean_builder.append_value(array_has_string_kernel(
haystack_array,
needle_array,
comparison_type,
));
}
(_, _) => {
boolean_builder.append_null();
}
}
}
Ok(Arc::new(boolean_builder.finish()))
}
fn array_has_string_kernel(
haystack: Vec<Option<&str>>,
needle: Vec<Option<&str>>,
comparison_type: ComparisonType,
) -> bool {
match comparison_type {
ComparisonType::All => needle
.iter()
.dedup()
.all(|x| haystack.iter().dedup().any(|y| y == x)),
ComparisonType::Any => needle
.iter()
.dedup()
.any(|x| haystack.iter().dedup().any(|y| y == x)),
}
}
fn general_array_has_for_all_and_any<O: OffsetSizeTrait>(
haystack: &GenericListArray<O>,
needle: &GenericListArray<O>,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(haystack.len());
let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?;
for (arr, sub_arr) in haystack.iter().zip(needle.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
boolean_builder.append_value(general_array_has_all_and_any_kernel(
arr_values,
sub_arr_values,
comparison_type,
));
} else {
boolean_builder.append_null();
}
}
Ok(Arc::new(boolean_builder.finish()))
}
fn general_array_has_all_and_any_kernel(
haystack_rows: Rows,
needle_rows: Rows,
comparison_type: ComparisonType,
) -> bool {
match comparison_type {
ComparisonType::All => needle_rows.iter().all(|needle_row| {
haystack_rows
.iter()
.any(|haystack_row| haystack_row == needle_row)
}),
ComparisonType::Any => needle_rows.iter().any(|needle_row| {
haystack_rows
.iter()
.any(|haystack_row| haystack_row == needle_row)
}),
}
}