datafusion_functions_aggregate/
nth_value.rsuse std::any::Any;
use std::collections::VecDeque;
use std::mem::{size_of, size_of_val};
use std::sync::{Arc, OnceLock};
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
use arrow_schema::{DataType, Field, Fields};
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF,
Signature, SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
use datafusion_functions_aggregate_common::utils::ordering_fields;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
create_func!(NthValueAgg, nth_value_udaf);
pub fn nth_value(
expr: datafusion_expr::Expr,
n: i64,
order_by: Vec<SortExpr>,
) -> datafusion_expr::Expr {
let args = vec![expr, lit(n)];
if !order_by.is_empty() {
nth_value_udaf()
.call(args)
.order_by(order_by)
.build()
.unwrap()
} else {
nth_value_udaf().call(args)
}
}
#[derive(Debug)]
pub struct NthValueAgg {
signature: Signature,
}
impl NthValueAgg {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl Default for NthValueAgg {
fn default() -> Self {
Self::new()
}
}
impl AggregateUDFImpl for NthValueAgg {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"nth_value"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let n = match acc_args.exprs[1]
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value())
{
Some(ScalarValue::Int64(Some(value))) => {
if acc_args.is_reversed {
-*value
} else {
*value
}
}
_ => {
return not_impl_err!(
"{} not supported for n: {}",
self.name(),
&acc_args.exprs[1]
)
}
};
let ordering_dtypes = acc_args
.ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
NthValueAccumulator::try_new(
n,
&data_type,
&ordering_dtypes,
LexOrdering::from_ref(acc_args.ordering_req),
)
.map(|acc| Box::new(acc) as _)
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_types[0].clone(), true),
false,
)];
let orderings = args.ordering_fields.to_vec();
if !orderings.is_empty() {
fields.push(Field::new_list(
format_state_name(self.name(), "nth_value_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
));
}
Ok(fields)
}
fn aliases(&self) -> &[String] {
&[]
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Reversed(nth_value_udaf())
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_nth_value_doc())
}
}
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_nth_value_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Returns the nth value in a group of values.",
)
.with_syntax_example("nth_value(expression, n ORDER BY expression)")
.with_sql_example(r#"```sql
> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept
FROM employee;
+---------+--------+-------------------------+
| dept_id | salary | second_salary_by_dept |
+---------+--------+-------------------------+
| 1 | 30000 | NULL |
| 1 | 40000 | 40000 |
| 1 | 50000 | 40000 |
| 2 | 35000 | NULL |
| 2 | 45000 | 45000 |
+---------+--------+-------------------------+
```"#)
.with_argument("expression", "The column or expression to retrieve the nth value from.")
.with_argument("n", "The position (nth) of the value to retrieve, based on the ordering.")
.build()
.unwrap()
})
}
#[derive(Debug)]
pub struct NthValueAccumulator {
n: i64,
values: VecDeque<ScalarValue>,
ordering_values: VecDeque<Vec<ScalarValue>>,
datatypes: Vec<DataType>,
ordering_req: LexOrdering,
}
impl NthValueAccumulator {
pub fn try_new(
n: i64,
datatype: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
) -> Result<Self> {
if n == 0 {
return internal_err!("Nth value indices are 1 based. 0 is invalid index");
}
let mut datatypes = vec![datatype.clone()];
datatypes.extend(ordering_dtypes.iter().cloned());
Ok(Self {
n,
values: VecDeque::new(),
ordering_values: VecDeque::new(),
datatypes,
ordering_req,
})
}
}
impl Accumulator for NthValueAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let n_required = self.n.unsigned_abs() as usize;
let from_start = self.n > 0;
if from_start {
let n_remaining = n_required.saturating_sub(self.values.len());
self.append_new_data(values, Some(n_remaining))?;
} else {
self.append_new_data(values, None)?;
let start_offset = self.values.len().saturating_sub(n_required);
if start_offset > 0 {
self.values.drain(0..start_offset);
self.ordering_values.drain(0..start_offset);
}
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
let array_agg_values = &states[0];
let n_required = self.n.unsigned_abs() as usize;
if self.ordering_req.is_empty() {
let array_agg_res =
ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
for v in array_agg_res.into_iter() {
self.values.extend(v);
if self.values.len() > n_required {
break;
}
}
} else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() {
let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![];
let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![];
partition_values.push(self.values.clone());
partition_ordering_values.push(self.ordering_values.clone());
let array_agg_res =
ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
for v in array_agg_res.into_iter() {
partition_values.push(v.into());
}
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
let ordering_values = orderings.into_iter().map(|partition_ordering_rows| {
partition_ordering_rows.into_iter().map(|ordering_row| {
if let ScalarValue::Struct(s) = ordering_row {
let mut ordering_columns_per_row = vec![];
for column in s.columns() {
let sv = ScalarValue::try_from_array(column, 0)?;
ordering_columns_per_row.push(sv);
}
Ok(ordering_columns_per_row)
} else {
exec_err!(
"Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
ordering_row.data_type()
)
}
}).collect::<Result<Vec<_>>>()
}).collect::<Result<Vec<_>>>()?;
for ordering_values in ordering_values.into_iter() {
partition_ordering_values.push(ordering_values.into());
}
let sort_options = self
.ordering_req
.iter()
.map(|sort_expr| sort_expr.options)
.collect::<Vec<_>>();
let (new_values, new_orderings) = merge_ordered_arrays(
&mut partition_values,
&mut partition_ordering_values,
&sort_options,
)?;
self.values = new_values.into();
self.ordering_values = new_orderings.into();
} else {
return exec_err!("Expects to receive a list array");
}
Ok(())
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.evaluate_values()];
if !self.ordering_req.is_empty() {
result.push(self.evaluate_orderings()?);
}
Ok(result)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let n_required = self.n.unsigned_abs() as usize;
let from_start = self.n > 0;
let nth_value_idx = if from_start {
let forward_idx = n_required - 1;
(forward_idx < self.values.len()).then_some(forward_idx)
} else {
self.values.len().checked_sub(n_required)
};
if let Some(idx) = nth_value_idx {
Ok(self.values[idx].clone())
} else {
ScalarValue::try_from(self.datatypes[0].clone())
}
}
fn size(&self) -> usize {
let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values)
- size_of_val(&self.values);
total += size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
for row in &self.ordering_values {
total += ScalarValue::size_of_vec(row) - size_of_val(row);
}
total += size_of::<DataType>() * self.datatypes.capacity();
for dtype in &self.datatypes {
total += dtype.size() - size_of_val(dtype);
}
total += size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
total
}
}
impl NthValueAccumulator {
fn evaluate_orderings(&self) -> Result<ScalarValue> {
let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]);
let struct_field = Fields::from(fields.clone());
let mut column_wise_ordering_values = vec![];
let num_columns = fields.len();
for i in 0..num_columns {
let column_values = self
.ordering_values
.iter()
.map(|x| x[i].clone())
.collect::<Vec<_>>();
let array = if column_values.is_empty() {
new_empty_array(fields[i].data_type())
} else {
ScalarValue::iter_to_array(column_values.into_iter())?
};
column_wise_ordering_values.push(array);
}
let ordering_array =
StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable(
Arc::new(ordering_array),
))))
}
fn evaluate_values(&self) -> ScalarValue {
let mut values_cloned = self.values.clone();
let values_slice = values_cloned.make_contiguous();
ScalarValue::List(ScalarValue::new_list_nullable(
values_slice,
&self.datatypes[0],
))
}
fn append_new_data(
&mut self,
values: &[ArrayRef],
fetch: Option<usize>,
) -> Result<()> {
let n_row = values[0].len();
let n_to_add = if let Some(fetch) = fetch {
std::cmp::min(fetch, n_row)
} else {
n_row
};
for index in 0..n_to_add {
let row = get_row_at_idx(values, index)?;
self.values.push_back(row[0].clone());
self.ordering_values.push_back(row[2..].to_vec());
}
Ok(())
}
}