use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
use arrow::array::{
downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
PrimitiveBuilder,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::{
array::{ArrayRef, AsArray},
datatypes::{
DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
Float64Type,
},
};
use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};
use datafusion_common::{
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Documentation, Signature, Volatility,
};
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::Hashable;
use datafusion_macros::user_doc;
make_udaf_expr_and_func!(
Median,
median,
expression,
"Computes the median of a set of numbers",
median_udaf
);
#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the median value in the specified column.",
syntax_example = "median(expression)",
sql_example = r#"```sql
> SELECT median(column_name) FROM table_name;
+----------------------+
| median(column_name) |
+----------------------+
| 45.5 |
+----------------------+
```"#,
standard_argument(name = "expression", prefix = "The")
)]
pub struct Median {
signature: Signature,
}
impl Debug for Median {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_struct("Median")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}
impl Default for Median {
fn default() -> Self {
Self::new()
}
}
impl Median {
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for Median {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"median"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let field = Field::new_list_field(args.input_types[0].clone(), true);
let state_name = if args.is_distinct {
"distinct_median"
} else {
"median"
};
Ok(vec![Field::new(
format_state_name(args.name, state_name),
DataType::List(Arc::new(field)),
true,
)])
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
macro_rules! helper {
($t:ty, $dt:expr) => {
if acc_args.is_distinct {
Ok(Box::new(DistinctMedianAccumulator::<$t> {
data_type: $dt.clone(),
distinct_values: HashSet::new(),
}))
} else {
Ok(Box::new(MedianAccumulator::<$t> {
data_type: $dt.clone(),
all_values: vec![],
}))
}
};
}
let dt = acc_args.exprs[0].data_type(acc_args.schema)?;
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianAccumulator not supported for {} with {}",
acc_args.name,
dt,
))),
}
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct
}
fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let num_args = args.exprs.len();
if num_args != 1 {
return internal_err!(
"median should only have 1 arg, but found num args:{}",
args.exprs.len()
);
}
let dt = args.exprs[0].data_type(args.schema)?;
macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
};
}
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianGroupsAccumulator not supported for {} with {}",
args.name,
dt,
))),
}
}
fn aliases(&self) -> &[String] {
&[]
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
struct MedianAccumulator<T: ArrowNumericType> {
data_type: DataType,
all_values: Vec<T::Native>,
}
impl<T: ArrowNumericType> Debug for MedianAccumulator<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "MedianAccumulator({})", self.data_type)
}
}
impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let all_values = self
.all_values
.iter()
.map(|x| ScalarValue::new_primitive::<T>(Some(*x), &self.data_type))
.collect::<Result<Vec<_>>>()?;
let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
Ok(vec![ScalarValue::List(arr)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<T>();
self.all_values.reserve(values.len() - values.null_count());
self.all_values.extend(values.iter().flatten());
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let array = states[0].as_list::<i32>();
for v in array.iter().flatten() {
self.update_batch(&[v])?
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.all_values);
let median = calculate_median::<T>(d);
ScalarValue::new_primitive::<T>(median, &self.data_type)
}
fn size(&self) -> usize {
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
}
}
#[derive(Debug)]
struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
data_type: DataType,
group_values: Vec<Vec<T::Native>>,
}
impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
pub fn new(data_type: DataType) -> Self {
Self {
data_type,
group_values: Vec::new(),
}
}
}
impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<T>();
self.group_values.resize(total_num_groups, Vec::new());
accumulate(
group_indices,
values,
opt_filter,
|group_index, new_value| {
self.group_values[group_index].push(new_value);
},
);
Ok(())
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let input_group_values = values[0].as_list::<i32>();
self.group_values.resize(total_num_groups, Vec::new());
group_indices
.iter()
.zip(input_group_values.iter())
.for_each(|(&group_index, values_opt)| {
if let Some(values) = values_opt {
let values = values.as_primitive::<T>();
self.group_values[group_index].extend(values.values().iter());
}
});
Ok(())
}
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
offsets.push(0);
let mut cur_len = 0_i32;
for group_value in &emit_group_values {
cur_len += group_value.len() as i32;
offsets.push(cur_len);
}
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
let flatten_group_values =
emit_group_values.into_iter().flatten().collect::<Vec<_>>();
let group_values_array =
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
.with_data_type(self.data_type.clone());
let result_list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
offsets,
Arc::new(group_values_array),
None,
);
Ok(vec![Arc::new(result_list_array)])
}
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let emit_group_values = emit_to.take_needed(&mut self.group_values);
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
for values in emit_group_values {
let median = calculate_median::<T>(values);
evaluate_result_builder.append_option(median);
}
Ok(Arc::new(evaluate_result_builder.finish()))
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
let input_array = values[0].as_primitive::<T>();
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
.with_data_type(self.data_type.clone());
let offset_end = i32::try_from(input_array.len()).map_err(|e| {
internal_datafusion_err!(
"cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
)
})?;
let offsets = (0..=offset_end).collect::<Vec<_>>();
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
let nulls = filtered_null_mask(opt_filter, input_array);
let converted_list_array = ListArray::new(
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
offsets,
Arc::new(values),
nulls,
);
Ok(vec![Arc::new(converted_list_array)])
}
fn supports_convert_to_state(&self) -> bool {
true
}
fn size(&self) -> usize {
self.group_values
.iter()
.map(|values| values.capacity() * size_of::<T>())
.sum::<usize>()
+ self.group_values.capacity() * size_of::<Vec<T>>()
}
}
struct DistinctMedianAccumulator<T: ArrowNumericType> {
data_type: DataType,
distinct_values: HashSet<Hashable<T::Native>>,
}
impl<T: ArrowNumericType> Debug for DistinctMedianAccumulator<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "DistinctMedianAccumulator({})", self.data_type)
}
}
impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let all_values = self
.distinct_values
.iter()
.map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
.collect::<Result<Vec<_>>>()?;
let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
Ok(vec![ScalarValue::List(arr)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let array = values[0].as_primitive::<T>();
match array.nulls().filter(|x| x.null_count() > 0) {
Some(n) => {
for idx in n.valid_indices() {
self.distinct_values.insert(Hashable(array.value(idx)));
}
}
None => array.values().iter().for_each(|x| {
self.distinct_values.insert(Hashable(*x));
}),
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let array = states[0].as_list::<i32>();
for v in array.iter().flatten() {
self.update_batch(&[v])?
}
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.distinct_values)
.into_iter()
.map(|v| v.0)
.collect::<Vec<_>>();
let median = calculate_median::<T>(d);
ScalarValue::new_primitive::<T>(median, &self.data_type)
}
fn size(&self) -> usize {
size_of_val(self) + self.distinct_values.capacity() * size_of::<T::Native>()
}
}
fn slice_max<T>(array: &[T::Native]) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: PartialOrd, {
debug_assert!(!array.is_empty());
*array
.iter()
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
.unwrap()
}
fn calculate_median<T: ArrowNumericType>(
mut values: Vec<T::Native>,
) -> Option<T::Native> {
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
let len = values.len();
if len == 0 {
None
} else if len % 2 == 0 {
let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
let left_max = slice_max::<T>(low);
let median = left_max
.add_wrapping(*high)
.div_wrapping(T::Native::usize_as(2));
Some(median)
} else {
let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
Some(*median)
}
}