datafusion_physical_expr_common/
datum.rsuse arrow::array::BooleanArray;
use arrow::array::{make_comparator, ArrayRef, Datum};
use arrow::buffer::NullBuffer;
use arrow::compute::SortOptions;
use arrow::error::ArrowError;
use datafusion_common::DataFusionError;
use datafusion_common::{arrow_datafusion_err, internal_err};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::operator::Operator;
use std::sync::Arc;
pub fn apply(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue> {
match (&lhs, &rhs) {
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
}
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
),
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
),
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
Ok(ColumnarValue::Scalar(scalar))
}
}
}
pub fn apply_cmp(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
) -> Result<ColumnarValue> {
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
}
pub fn apply_cmp_for_nested(
op: Operator,
lhs: &ColumnarValue,
rhs: &ColumnarValue,
) -> Result<ColumnarValue> {
if matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::Gt
| Operator::LtEq
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom
) {
apply(lhs, rhs, |l, r| {
Ok(Arc::new(compare_op_for_nested(op, l, r)?))
})
} else {
internal_err!("invalid operator for nested")
}
}
pub fn compare_with_eq(
lhs: &dyn Datum,
rhs: &dyn Datum,
is_nested: bool,
) -> Result<BooleanArray> {
if is_nested {
compare_op_for_nested(Operator::Eq, lhs, rhs)
} else {
arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
}
}
pub fn compare_op_for_nested(
op: Operator,
lhs: &dyn Datum,
rhs: &dyn Datum,
) -> Result<BooleanArray> {
let (l, is_l_scalar) = lhs.get();
let (r, is_r_scalar) = rhs.get();
let l_len = l.len();
let r_len = r.len();
if l_len != r_len && !is_l_scalar && !is_r_scalar {
return internal_err!("len mismatch");
}
let len = match is_l_scalar {
true => r_len,
false => l_len,
};
if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
&& (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
{
return Ok(BooleanArray::new_null(len));
}
let cmp = make_comparator(l, r, SortOptions::default())?;
let cmp_with_op = |i, j| match op {
Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
Operator::Lt => cmp(i, j).is_lt(),
Operator::Gt => cmp(i, j).is_gt(),
Operator::LtEq => !cmp(i, j).is_gt(),
Operator::GtEq => !cmp(i, j).is_lt(),
Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
_ => unreachable!("unexpected operator found"),
};
let values = match (is_l_scalar, is_r_scalar) {
(false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
(true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
(false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
(true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
};
if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
Ok(BooleanArray::new(values, None))
} else {
let nulls = NullBuffer::union(l.nulls(), r.nulls());
Ok(BooleanArray::new(values, nulls))
}
}