use std::any::Any;
use std::sync::Arc;
use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr};
use arrow::array::ArrayRef;
use arrow::compute;
use arrow::compute::{lexsort_to_indices, SortColumn};
use arrow::datatypes::{DataType, Field};
use arrow_array::cast::AsArray;
use arrow_array::{Array, BooleanArray};
use arrow_schema::SortOptions;
use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
#[derive(Debug)]
pub struct FirstValue {
name: String,
input_data_type: DataType,
order_by_data_types: Vec<DataType>,
expr: Arc<dyn PhysicalExpr>,
ordering_req: LexOrdering,
}
impl FirstValue {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
ordering_req: LexOrdering,
order_by_data_types: Vec<DataType>,
) -> Self {
Self {
name: name.into(),
input_data_type,
order_by_data_types,
expr,
ordering_req,
}
}
}
impl AggregateExpr for FirstValue {
fn as_any(&self) -> &dyn Any {
self
}
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.input_data_type.clone(), true))
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(FirstValueAccumulator::try_new(
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
)?))
}
fn state_fields(&self) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
format_state_name(&self.name, "first_value"),
self.input_data_type.clone(),
true,
)];
fields.extend(ordering_fields(
&self.ordering_req,
&self.order_by_data_types,
));
fields.push(Field::new(
format_state_name(&self.name, "is_set"),
DataType::Boolean,
true,
));
Ok(fields)
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
if self.ordering_req.is_empty() {
None
} else {
Some(&self.ordering_req)
}
}
fn name(&self) -> &str {
&self.name
}
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
let name = if self.name.starts_with("FIRST") {
format!("LAST{}", &self.name[5..])
} else {
format!("LAST_VALUE({})", self.expr)
};
Some(Arc::new(LastValue::new(
self.expr.clone(),
name,
self.input_data_type.clone(),
self.ordering_req.clone(),
self.order_by_data_types.clone(),
)))
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(FirstValueAccumulator::try_new(
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
)?))
}
}
impl PartialEq<dyn Any> for FirstValue {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.input_data_type == x.input_data_type
&& self.order_by_data_types == x.order_by_data_types
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
}
#[derive(Debug)]
struct FirstValueAccumulator {
first: ScalarValue,
is_set: bool,
orderings: Vec<ScalarValue>,
ordering_req: LexOrdering,
}
impl FirstValueAccumulator {
pub fn try_new(
data_type: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
) -> Result<Self> {
let orderings = ordering_dtypes
.iter()
.map(ScalarValue::try_from)
.collect::<Result<Vec<_>>>()?;
ScalarValue::try_from(data_type).map(|value| Self {
first: value,
is_set: false,
orderings,
ordering_req,
})
}
fn update_with_new_row(&mut self, row: &[ScalarValue]) {
self.first = row[0].clone();
self.orderings = row[1..].to_vec();
self.is_set = true;
}
}
impl Accumulator for FirstValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.first.clone()];
result.extend(self.orderings.iter().cloned());
result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if !values[0].is_empty() && !self.is_set {
let row = get_row_at_idx(values, 0)?;
self.update_with_new_row(&row);
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let is_set_idx = states.len() - 1;
let flags = states[is_set_idx].as_boolean();
let filtered_states = filter_states_according_to_is_set(states, flags)?;
let sort_cols =
convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
let ordered_states = if sort_cols.is_empty() {
filtered_states
} else {
let indices = lexsort_to_indices(&sort_cols, None)?;
get_arrayref_at_indices(&filtered_states, &indices)?
};
if !ordered_states[0].is_empty() {
let first_row = get_row_at_idx(&ordered_states, 0)?;
let first_ordering = &first_row[1..];
let sort_options = get_sort_options(&self.ordering_req);
if !self.is_set
|| compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt()
{
self.update_with_new_row(&first_row);
}
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(self.first.clone())
}
fn size(&self) -> usize {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.first)
+ self.first.size()
+ ScalarValue::size_of_vec(&self.orderings)
- std::mem::size_of_val(&self.orderings)
}
}
#[derive(Debug)]
pub struct LastValue {
name: String,
input_data_type: DataType,
order_by_data_types: Vec<DataType>,
expr: Arc<dyn PhysicalExpr>,
ordering_req: LexOrdering,
}
impl LastValue {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
ordering_req: LexOrdering,
order_by_data_types: Vec<DataType>,
) -> Self {
Self {
name: name.into(),
input_data_type,
order_by_data_types,
expr,
ordering_req,
}
}
}
impl AggregateExpr for LastValue {
fn as_any(&self) -> &dyn Any {
self
}
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.input_data_type.clone(), true))
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(LastValueAccumulator::try_new(
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
)?))
}
fn state_fields(&self) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
format_state_name(&self.name, "last_value"),
self.input_data_type.clone(),
true,
)];
fields.extend(ordering_fields(
&self.ordering_req,
&self.order_by_data_types,
));
fields.push(Field::new(
format_state_name(&self.name, "is_set"),
DataType::Boolean,
true,
));
Ok(fields)
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
if self.ordering_req.is_empty() {
None
} else {
Some(&self.ordering_req)
}
}
fn name(&self) -> &str {
&self.name
}
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
let name = if self.name.starts_with("LAST") {
format!("FIRST{}", &self.name[4..])
} else {
format!("FIRST_VALUE({})", self.expr)
};
Some(Arc::new(FirstValue::new(
self.expr.clone(),
name,
self.input_data_type.clone(),
self.ordering_req.clone(),
self.order_by_data_types.clone(),
)))
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(LastValueAccumulator::try_new(
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
)?))
}
}
impl PartialEq<dyn Any> for LastValue {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.input_data_type == x.input_data_type
&& self.order_by_data_types == x.order_by_data_types
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
}
#[derive(Debug)]
struct LastValueAccumulator {
last: ScalarValue,
is_set: bool,
orderings: Vec<ScalarValue>,
ordering_req: LexOrdering,
}
impl LastValueAccumulator {
pub fn try_new(
data_type: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
) -> Result<Self> {
let orderings = ordering_dtypes
.iter()
.map(ScalarValue::try_from)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
last: ScalarValue::try_from(data_type)?,
is_set: false,
orderings,
ordering_req,
})
}
fn update_with_new_row(&mut self, row: &[ScalarValue]) {
self.last = row[0].clone();
self.orderings = row[1..].to_vec();
self.is_set = true;
}
}
impl Accumulator for LastValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.last.clone()];
result.extend(self.orderings.clone());
result.push(ScalarValue::Boolean(Some(self.is_set)));
Ok(result)
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if !values[0].is_empty() {
let row = get_row_at_idx(values, values[0].len() - 1)?;
self.update_with_new_row(&row);
}
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let is_set_idx = states.len() - 1;
let flags = states[is_set_idx].as_boolean();
let filtered_states = filter_states_according_to_is_set(states, flags)?;
let sort_cols =
convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req);
let ordered_states = if sort_cols.is_empty() {
filtered_states
} else {
let indices = lexsort_to_indices(&sort_cols, None)?;
get_arrayref_at_indices(&filtered_states, &indices)?
};
if !ordered_states[0].is_empty() {
let last_idx = ordered_states[0].len() - 1;
let last_row = get_row_at_idx(&ordered_states, last_idx)?;
let last_ordering = &last_row[1..];
let sort_options = get_sort_options(&self.ordering_req);
if !self.is_set
|| compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt()
{
self.update_with_new_row(&last_row);
}
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(self.last.clone())
}
fn size(&self) -> usize {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.last)
+ self.last.size()
+ ScalarValue::size_of_vec(&self.orderings)
- std::mem::size_of_val(&self.orderings)
}
}
fn filter_states_according_to_is_set(
states: &[ArrayRef],
flags: &BooleanArray,
) -> Result<Vec<ArrayRef>> {
states
.iter()
.map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError))
.collect::<Result<Vec<_>>>()
}
fn convert_to_sort_cols(
arrs: &[ArrayRef],
sort_exprs: &[PhysicalSortExpr],
) -> Vec<SortColumn> {
arrs.iter()
.zip(sort_exprs.iter())
.map(|(item, sort_expr)| SortColumn {
values: item.clone(),
options: Some(sort_expr.options),
})
.collect::<Vec<_>>()
}
fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
ordering_req
.iter()
.map(|item| item.options)
.collect::<Vec<_>>()
}
#[cfg(test)]
mod tests {
use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator};
use arrow_array::{ArrayRef, Int64Array};
use arrow_schema::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
use std::sync::Arc;
#[test]
fn test_first_last_value_value() -> Result<()> {
let mut first_accumulator =
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
let mut last_accumulator =
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
let arrs = ranges
.into_iter()
.map(|(start, end)| {
Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
})
.collect::<Vec<_>>();
for arr in arrs {
first_accumulator.update_batch(&[arr.clone()])?;
last_accumulator.update_batch(&[arr])?;
}
assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
Ok(())
}
}