use ahash::RandomState;
use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
use arrow::{
array::{ArrayRef, AsArray},
compute,
datatypes::{
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
},
};
use arrow::{
array::{Array, BooleanArray, Int64Array, PrimitiveArray},
buffer::BooleanBuffer,
};
use datafusion_common::{
downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::{
aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
},
binary_map::OutputType,
};
make_udaf_expr_and_func!(
Count,
count,
expr,
"Count the number of non-null values in the column",
count_udaf
);
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
count_udaf(),
vec![expr],
true,
None,
None,
None,
))
}
pub struct Count {
signature: Signature,
}
impl Debug for Count {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Count")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}
impl Default for Count {
fn default() -> Self {
Self::new()
}
}
impl Count {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
Volatility::Immutable,
),
}
}
}
impl AggregateUDFImpl for Count {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"count"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
if args.is_distinct {
Ok(vec![Field::new_list(
format_state_name(args.name, "count distinct"),
Field::new("item", args.input_types[0].clone(), true),
false,
)])
} else {
Ok(vec![Field::new(
format_state_name(args.name, "count"),
DataType::Int64,
true,
)])
}
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if !acc_args.is_distinct {
return Ok(Box::new(CountAccumulator::new()));
}
if acc_args.input_exprs.len() > 1 {
return not_impl_err!("COUNT DISTINCT with multiple arguments");
}
let data_type = &acc_args.input_types[0];
Ok(match data_type {
DataType::Int8 => Box::new(
PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
),
DataType::Int16 => Box::new(
PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
),
DataType::Int32 => Box::new(
PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
),
DataType::Int64 => Box::new(
PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
),
DataType::UInt8 => Box::new(
PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
),
DataType::UInt16 => Box::new(
PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
),
DataType::UInt32 => Box::new(
PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
),
DataType::UInt64 => Box::new(
PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
),
DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
Decimal128Type,
>::new(data_type)),
DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
Decimal256Type,
>::new(data_type)),
DataType::Date32 => Box::new(
PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
),
DataType::Date64 => Box::new(
PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
),
DataType::Time32(TimeUnit::Millisecond) => Box::new(
PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
data_type,
),
),
DataType::Time32(TimeUnit::Second) => Box::new(
PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
),
DataType::Time64(TimeUnit::Microsecond) => Box::new(
PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
data_type,
),
),
DataType::Time64(TimeUnit::Nanosecond) => Box::new(
PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
),
DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
data_type,
),
),
DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
data_type,
),
),
DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
data_type,
),
),
DataType::Timestamp(TimeUnit::Second, _) => Box::new(
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
),
DataType::Float16 => {
Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
}
DataType::Float32 => {
Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
}
DataType::Float64 => {
Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
}
DataType::Utf8 => {
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
}
DataType::Utf8View => {
Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View))
}
DataType::LargeUtf8 => {
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
}
DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
OutputType::Binary,
)),
DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new(
OutputType::BinaryView,
)),
DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
OutputType::Binary,
)),
_ => Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: data_type.clone(),
}),
})
}
fn aliases(&self) -> &[String] {
&[]
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
if args.is_distinct {
return false;
}
args.input_exprs.len() == 1
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(CountGroupsAccumulator::new()))
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
}
#[derive(Debug)]
struct CountAccumulator {
count: i64,
}
impl CountAccumulator {
pub fn new() -> Self {
Self { count: 0 }
}
}
impl Accumulator for CountAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Int64(Some(self.count))])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count += (array.len() - null_count_for_multiple_cols(values)) as i64;
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64;
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = downcast_value!(states[0], Int64Array);
let delta = &arrow::compute::sum(counts);
if let Some(d) = delta {
self.count += *d;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.count)))
}
fn supports_retract_batch(&self) -> bool {
true
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}
#[derive(Debug)]
struct CountGroupsAccumulator {
counts: Vec<i64>,
}
impl CountGroupsAccumulator {
pub fn new() -> Self {
Self { counts: vec![] }
}
}
impl GroupsAccumulator for CountGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = &values[0];
self.counts.resize(total_num_groups, 0);
accumulate_indices(
group_indices,
values.logical_nulls().as_ref(),
opt_filter,
|group_index| {
self.counts[group_index] += 1;
},
);
Ok(())
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let partial_counts = values[0].as_primitive::<Int64Type>();
assert_eq!(partial_counts.null_count(), 0);
let partial_counts = partial_counts.values();
self.counts.resize(total_num_groups, 0);
match opt_filter {
Some(filter) => filter
.iter()
.zip(group_indices.iter())
.zip(partial_counts.iter())
.for_each(|((filter_value, &group_index), partial_count)| {
if let Some(true) = filter_value {
self.counts[group_index] += partial_count;
}
}),
None => group_indices.iter().zip(partial_counts.iter()).for_each(
|(&group_index, partial_count)| {
self.counts[group_index] += partial_count;
},
),
}
Ok(())
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let counts = emit_to.take_needed(&mut self.counts);
let nulls = None;
let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
Ok(Arc::new(array))
}
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let counts = emit_to.take_needed(&mut self.counts);
let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); Ok(vec![Arc::new(counts) as ArrayRef])
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = &values[0];
let state_array = match (values.logical_nulls(), opt_filter) {
(None, None) => {
Arc::new(Int64Array::from_value(1, values.len()))
}
(Some(nulls), None) => {
let nulls = BooleanArray::new(nulls.into_inner(), None);
compute::cast(&nulls, &DataType::Int64)?
}
(None, Some(filter)) => {
let (filter_values, filter_nulls) = filter.clone().into_parts();
let state_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};
let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
(Some(nulls), Some(filter)) => {
let (filter_values, filter_nulls) = filter.clone().into_parts();
let filter_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};
let state_buf = &filter_buf & nulls.inner();
let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
};
Ok(vec![state_array])
}
fn supports_convert_to_state(&self) -> bool {
true
}
fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<usize>()
}
}
fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
if values.len() > 1 {
let result_bool_buf: Option<BooleanBuffer> = values
.iter()
.map(|a| a.logical_nulls())
.fold(None, |acc, b| match (acc, b) {
(Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
(Some(acc), None) => Some(acc),
(None, Some(b)) => Some(b.into_inner()),
_ => None,
});
result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
} else {
values[0]
.logical_nulls()
.map_or(0, |nulls| nulls.null_count())
}
}
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<ScalarValue, RandomState>,
state_data_type: DataType,
}
impl DistinctCountAccumulator {
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
.next()
.map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
.unwrap_or(0)
+ std::mem::size_of::<DataType>()
}
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
.map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
.sum::<usize>()
+ std::mem::size_of::<DataType>()
}
}
impl Accumulator for DistinctCountAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr =
ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type);
Ok(vec![ScalarValue::List(arr)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
if arr.data_type() == &DataType::Null {
return Ok(());
}
(0..arr.len()).try_for_each(|index| {
if !arr.is_null(index) {
let scalar = ScalarValue::try_from_array(arr, index)?;
self.values.insert(scalar);
}
Ok(())
})
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let array = &states[0];
let list_array = array.as_list::<i32>();
for inner_array in list_array.iter() {
let Some(inner_array) = inner_array else {
return internal_err!(
"Intermediate results of COUNT DISTINCT should always be non null"
);
};
self.update_batch(&[inner_array])?;
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}
fn size(&self) -> usize {
match &self.state_data_type {
DataType::Boolean | DataType::Null => self.fixed_size(),
d if d.is_primitive() => self.fixed_size(),
_ => self.full_size(),
}
}
}