use std::any::Any;
use std::collections::VecDeque;
use std::sync::Arc;
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
use arrow_schema::{DataType, Field, Fields};
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Expr, ReversedUDAF, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays;
use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
use datafusion_physical_expr_common::sort_expr::{
limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering,
PhysicalSortExpr,
};
make_udaf_expr_and_func!(
NthValueAgg,
nth_value,
"Returns the nth value in a group of values.",
nth_value_udaf
);
#[derive(Debug)]
pub struct NthValueAgg {
signature: Signature,
}
impl NthValueAgg {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl Default for NthValueAgg {
fn default() -> Self {
Self::new()
}
}
impl AggregateUDFImpl for NthValueAgg {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"nth_value"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let n = match acc_args.input_exprs[1] {
Expr::Literal(ScalarValue::Int64(Some(value))) => {
if acc_args.is_reversed {
Ok(-value)
} else {
Ok(value)
}
}
_ => not_impl_err!(
"{} not supported for n: {}",
self.name(),
&acc_args.input_exprs[1]
),
}?;
let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema(
acc_args.sort_exprs,
acc_args.dfschema,
)?;
let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;
NthValueAccumulator::try_new(
n,
&acc_args.input_types[0],
&ordering_dtypes,
ordering_req,
)
.map(|acc| Box::new(acc) as _)
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_types[0].clone(), true),
false,
)];
let orderings = args.ordering_fields.to_vec();
if !orderings.is_empty() {
fields.push(Field::new_list(
format_state_name(self.name(), "nth_value_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
));
}
Ok(fields)
}
fn aliases(&self) -> &[String] {
&[]
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Reversed(nth_value_udaf())
}
}
#[derive(Debug)]
pub struct NthValueAccumulator {
n: i64,
values: VecDeque<ScalarValue>,
ordering_values: VecDeque<Vec<ScalarValue>>,
datatypes: Vec<DataType>,
ordering_req: LexOrdering,
}
impl NthValueAccumulator {
pub fn try_new(
n: i64,
datatype: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
) -> Result<Self> {
if n == 0 {
return internal_err!("Nth value indices are 1 based. 0 is invalid index");
}
let mut datatypes = vec![datatype.clone()];
datatypes.extend(ordering_dtypes.iter().cloned());
Ok(Self {
n,
values: VecDeque::new(),
ordering_values: VecDeque::new(),
datatypes,
ordering_req,
})
}
}
impl Accumulator for NthValueAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let n_required = self.n.unsigned_abs() as usize;
let from_start = self.n > 0;
if from_start {
let n_remaining = n_required.saturating_sub(self.values.len());
self.append_new_data(values, Some(n_remaining))?;
} else {
self.append_new_data(values, None)?;
let start_offset = self.values.len().saturating_sub(n_required);
if start_offset > 0 {
self.values.drain(0..start_offset);
self.ordering_values.drain(0..start_offset);
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
let array_agg_values = &states[0];
let n_required = self.n.unsigned_abs() as usize;
if self.ordering_req.is_empty() {
let array_agg_res =
ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
for v in array_agg_res.into_iter() {
self.values.extend(v);
if self.values.len() > n_required {
break;
}
}
} else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() {
let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![];
let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![];
partition_values.push(self.values.clone());
partition_ordering_values.push(self.ordering_values.clone());
let array_agg_res =
ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
for v in array_agg_res.into_iter() {
partition_values.push(v.into());
}
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
let ordering_values = orderings.into_iter().map(|partition_ordering_rows| {
partition_ordering_rows.into_iter().map(|ordering_row| {
if let ScalarValue::Struct(s) = ordering_row {
let mut ordering_columns_per_row = vec![];
for column in s.columns() {
let sv = ScalarValue::try_from_array(column, 0)?;
ordering_columns_per_row.push(sv);
}
Ok(ordering_columns_per_row)
} else {
exec_err!(
"Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
ordering_row.data_type()
)
}
}).collect::<Result<Vec<_>>>()
}).collect::<Result<Vec<_>>>()?;
for ordering_values in ordering_values.into_iter() {
partition_ordering_values.push(ordering_values.into());
}
let sort_options = self
.ordering_req
.iter()
.map(|sort_expr| sort_expr.options)
.collect::<Vec<_>>();
let (new_values, new_orderings) = merge_ordered_arrays(
&mut partition_values,
&mut partition_ordering_values,
&sort_options,
)?;
self.values = new_values.into();
self.ordering_values = new_orderings.into();
} else {
return exec_err!("Expects to receive a list array");
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.evaluate_values()];
if !self.ordering_req.is_empty() {
result.push(self.evaluate_orderings()?);
}
Ok(result)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let n_required = self.n.unsigned_abs() as usize;
let from_start = self.n > 0;
let nth_value_idx = if from_start {
let forward_idx = n_required - 1;
(forward_idx < self.values.len()).then_some(forward_idx)
} else {
self.values.len().checked_sub(n_required)
};
if let Some(idx) = nth_value_idx {
Ok(self.values[idx].clone())
} else {
ScalarValue::try_from(self.datatypes[0].clone())
}
}
fn size(&self) -> usize {
let mut total = std::mem::size_of_val(self)
+ ScalarValue::size_of_vec_deque(&self.values)
- std::mem::size_of_val(&self.values);
total +=
std::mem::size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
for row in &self.ordering_values {
total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row);
}
total += std::mem::size_of::<DataType>() * self.datatypes.capacity();
for dtype in &self.datatypes {
total += dtype.size() - std::mem::size_of_val(dtype);
}
total += std::mem::size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
total
}
}
impl NthValueAccumulator {
fn evaluate_orderings(&self) -> Result<ScalarValue> {
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
let struct_field = Fields::from(fields.clone());
let mut column_wise_ordering_values = vec![];
let num_columns = fields.len();
for i in 0..num_columns {
let column_values = self
.ordering_values
.iter()
.map(|x| x[i].clone())
.collect::<Vec<_>>();
let array = if column_values.is_empty() {
new_empty_array(fields[i].data_type())
} else {
ScalarValue::iter_to_array(column_values.into_iter())?
};
column_wise_ordering_values.push(array);
}
let ordering_array = StructArray::try_new(
struct_field.clone(),
column_wise_ordering_values,
None,
)?;
Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable(
Arc::new(ordering_array),
))))
}
fn evaluate_values(&self) -> ScalarValue {
let mut values_cloned = self.values.clone();
let values_slice = values_cloned.make_contiguous();
ScalarValue::List(ScalarValue::new_list_nullable(
values_slice,
&self.datatypes[0],
))
}
fn append_new_data(
&mut self,
values: &[ArrayRef],
fetch: Option<usize>,
) -> Result<()> {
let n_row = values[0].len();
let n_to_add = if let Some(fetch) = fetch {
std::cmp::min(fetch, n_row)
} else {
n_row
};
for index in 0..n_to_add {
let row = get_row_at_idx(values, index)?;
self.values.push_back(row[0].clone());
self.ordering_values.push_back(row[2..].to_vec());
}
Ok(())
}
}