use crate::utils::{get_scalar_value_from_args, get_signed_integer};
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::arrow::datatypes::Field;
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
use datafusion_expr::{
Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::expr::ExpressionArgs;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::any::Any;
use std::cmp::min;
use std::collections::VecDeque;
use std::ops::{Neg, Range};
use std::sync::{Arc, OnceLock};
get_or_init_udwf!(
Lag,
lag,
"Returns the row value that precedes the current row by a specified \
offset within partition. If no such row exists, then returns the \
default value.",
WindowShift::lag
);
get_or_init_udwf!(
Lead,
lead,
"Returns the value from a row that follows the current row by a \
specified offset within the partition. If no such row exists, then \
returns the default value.",
WindowShift::lead
);
pub fn lag(
arg: datafusion_expr::Expr,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> datafusion_expr::Expr {
let shift_offset_lit = shift_offset
.map(|v| v.lit())
.unwrap_or(ScalarValue::Null.lit());
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
}
pub fn lead(
arg: datafusion_expr::Expr,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> datafusion_expr::Expr {
let shift_offset_lit = shift_offset
.map(|v| v.lit())
.unwrap_or(ScalarValue::Null.lit());
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
}
#[derive(Debug)]
enum WindowShiftKind {
Lag,
Lead,
}
impl WindowShiftKind {
fn name(&self) -> &'static str {
match self {
WindowShiftKind::Lag => "lag",
WindowShiftKind::Lead => "lead",
}
}
fn shift_offset(&self, value: Option<i64>) -> i64 {
match self {
WindowShiftKind::Lag => value.unwrap_or(1),
WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
}
}
}
#[derive(Debug)]
pub struct WindowShift {
signature: Signature,
kind: WindowShiftKind,
}
impl WindowShift {
fn new(kind: WindowShiftKind) -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Any(1),
TypeSignature::Any(2),
TypeSignature::Any(3),
],
Volatility::Immutable,
),
kind,
}
}
pub fn lag() -> Self {
Self::new(WindowShiftKind::Lag)
}
pub fn lead() -> Self {
Self::new(WindowShiftKind::Lead)
}
}
static LAG_DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_lag_doc() -> &'static Documentation {
LAG_DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ANALYTICAL)
.with_description(
"Returns value evaluated at the row that is offset rows before the \
current row within the partition; if there is no such row, instead return default \
(which must be of the same type as value).",
)
.with_syntax_example("lag(expression, offset, default)")
.with_argument("expression", "Expression to operate on")
.with_argument("offset", "Integer. Specifies how many rows back \
the value of expression should be retrieved. Defaults to 1.")
.with_argument("default", "The default value if the offset is \
not within the partition. Must be of the same type as expression.")
.build()
.unwrap()
})
}
static LEAD_DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_lead_doc() -> &'static Documentation {
LEAD_DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_ANALYTICAL)
.with_description(
"Returns value evaluated at the row that is offset rows after the \
current row within the partition; if there is no such row, instead return default \
(which must be of the same type as value).",
)
.with_syntax_example("lead(expression, offset, default)")
.with_argument("expression", "Expression to operate on")
.with_argument("offset", "Integer. Specifies how many rows \
forward the value of expression should be retrieved. Defaults to 1.")
.with_argument("default", "The default value if the offset is \
not within the partition. Must be of the same type as expression.")
.build()
.unwrap()
})
}
impl WindowUDFImpl for WindowShift {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.kind.name()
}
fn signature(&self) -> &Signature {
&self.signature
}
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
parse_expr(expr_args.input_exprs(), expr_args.input_types())
.into_iter()
.collect::<Vec<_>>()
}
fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
let shift_offset =
get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))
.map(|n| self.kind.shift_offset(n))
.map(|offset| {
if partition_evaluator_args.is_reversed() {
-offset
} else {
offset
}
})?;
let default_value = parse_default_value(
partition_evaluator_args.input_exprs(),
partition_evaluator_args.input_types(),
)?;
Ok(Box::new(WindowShiftEvaluator {
shift_offset,
default_value,
ignore_nulls: partition_evaluator_args.ignore_nulls(),
non_null_offsets: VecDeque::new(),
}))
}
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
let return_type = parse_expr_type(field_args.input_types())?;
Ok(Field::new(field_args.name(), return_type, true))
}
fn reverse_expr(&self) -> ReversedUDWF {
match self.kind {
WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
}
}
fn documentation(&self) -> Option<&Documentation> {
match self.kind {
WindowShiftKind::Lag => Some(get_lag_doc()),
WindowShiftKind::Lead => Some(get_lead_doc()),
}
}
}
fn parse_expr(
input_exprs: &[Arc<dyn PhysicalExpr>],
input_types: &[DataType],
) -> Result<Arc<dyn PhysicalExpr>> {
assert!(!input_exprs.is_empty());
assert!(!input_types.is_empty());
let expr = Arc::clone(input_exprs.first().unwrap());
let expr_type = input_types.first().unwrap();
if !expr_type.is_null() {
return Ok(expr);
}
let default_value = get_scalar_value_from_args(input_exprs, 2)?;
default_value.map_or(Ok(expr), |value| {
ScalarValue::try_from(&value.data_type()).map(|v| {
Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
as Arc<dyn PhysicalExpr>
})
})
}
fn parse_expr_type(input_types: &[DataType]) -> Result<DataType> {
assert!(!input_types.is_empty());
let expr_type = input_types.first().unwrap_or(&DataType::Null);
if !expr_type.is_null() {
return Ok(expr_type.clone());
}
let default_value_type = input_types.get(2).unwrap_or(&DataType::Null);
Ok(default_value_type.clone())
}
fn parse_default_value(
input_exprs: &[Arc<dyn PhysicalExpr>],
input_types: &[DataType],
) -> Result<ScalarValue> {
let expr_type = parse_expr_type(input_types)?;
let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
unparsed
.filter(|v| !v.data_type().is_null())
.map(|v| v.cast_to(&expr_type))
.unwrap_or(ScalarValue::try_from(expr_type))
}
#[derive(Debug)]
struct WindowShiftEvaluator {
shift_offset: i64,
default_value: ScalarValue,
ignore_nulls: bool,
non_null_offsets: VecDeque<usize>,
}
impl WindowShiftEvaluator {
fn is_lag(&self) -> bool {
self.shift_offset > 0
}
}
fn evaluate_all_with_ignore_null(
array: &ArrayRef,
offset: i64,
default_value: &ScalarValue,
is_lag: bool,
) -> Result<ArrayRef, DataFusionError> {
let valid_indices: Vec<usize> =
array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
let direction = !is_lag;
let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
.map(|id| {
let result_index = match valid_indices.binary_search(&id) {
Ok(pos) => if direction {
pos.checked_add(offset as usize)
} else {
pos.checked_sub(offset.unsigned_abs() as usize)
}
.and_then(|new_pos| {
if new_pos < valid_indices.len() {
Some(valid_indices[new_pos])
} else {
None
}
}),
Err(pos) => if direction {
pos.checked_add(offset as usize)
} else if pos > 0 {
pos.checked_sub(offset.unsigned_abs() as usize)
} else {
None
}
.and_then(|new_pos| {
if new_pos < valid_indices.len() {
Some(valid_indices[new_pos])
} else {
None
}
}),
};
match result_index {
Some(index) => ScalarValue::try_from_array(array, index),
None => Ok(default_value.clone()),
}
})
.collect();
let new_array = new_array_results?;
ScalarValue::iter_to_array(new_array)
}
fn shift_with_default_value(
array: &ArrayRef,
offset: i64,
default_value: &ScalarValue,
) -> Result<ArrayRef> {
use datafusion_common::arrow::compute::concat;
let value_len = array.len() as i64;
if offset == 0 {
Ok(Arc::clone(array))
} else if offset == i64::MIN || offset.abs() >= value_len {
default_value.to_array_of_size(value_len as usize)
} else {
let slice_offset = (-offset).clamp(0, value_len) as usize;
let length = array.len() - offset.unsigned_abs() as usize;
let slice = array.slice(slice_offset, length);
let nulls = offset.unsigned_abs() as usize;
let default_values = default_value.to_array_of_size(nulls)?;
if offset > 0 {
concat(&[default_values.as_ref(), slice.as_ref()])
.map_err(|e| arrow_datafusion_err!(e))
} else {
concat(&[slice.as_ref(), default_values.as_ref()])
.map_err(|e| arrow_datafusion_err!(e))
}
}
}
impl PartitionEvaluator for WindowShiftEvaluator {
fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
if self.is_lag() {
let start = if self.non_null_offsets.len() == self.shift_offset as usize {
let offset: usize = self.non_null_offsets.iter().sum();
idx.saturating_sub(offset)
} else if !self.ignore_nulls {
let offset = self.shift_offset as usize;
idx.saturating_sub(offset)
} else {
0
};
let end = idx + 1;
Ok(Range { start, end })
} else {
let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
let offset: usize = self.non_null_offsets.iter().sum();
min(idx + offset + 1, n_rows)
} else if !self.ignore_nulls {
let offset = (-self.shift_offset) as usize;
min(idx + offset, n_rows)
} else {
n_rows
};
Ok(Range { start: idx, end })
}
}
fn is_causal(&self) -> bool {
self.is_lag()
}
fn evaluate(
&mut self,
values: &[ArrayRef],
range: &Range<usize>,
) -> Result<ScalarValue> {
let array = &values[0];
let len = array.len();
let i = if self.is_lag() {
(range.end as i64 - self.shift_offset - 1) as usize
} else {
(range.start as i64 - self.shift_offset) as usize
};
let mut idx: Option<usize> = if i < len { Some(i) } else { None };
if self.ignore_nulls && self.is_lag() {
idx = if self.non_null_offsets.len() == self.shift_offset as usize {
let total_offset: usize = self.non_null_offsets.iter().sum();
Some(range.end - 1 - total_offset)
} else {
None
};
if array.is_valid(range.end - 1) {
self.non_null_offsets.push_back(1);
if self.non_null_offsets.len() > self.shift_offset as usize {
self.non_null_offsets.pop_front();
}
} else if !self.non_null_offsets.is_empty() {
let end_idx = self.non_null_offsets.len() - 1;
self.non_null_offsets[end_idx] += 1;
}
} else if self.ignore_nulls && !self.is_lag() {
let non_null_row_count = (-self.shift_offset) as usize;
if self.non_null_offsets.is_empty() {
let mut offset_val = 1;
for idx in range.start + 1..range.end {
if array.is_valid(idx) {
self.non_null_offsets.push_back(offset_val);
offset_val = 1;
} else {
offset_val += 1;
}
if self.non_null_offsets.len() == non_null_row_count + 1 {
break;
}
}
} else if range.end < len && array.is_valid(range.end) {
if array.is_valid(range.end) {
self.non_null_offsets.push_back(1);
} else {
let last_idx = self.non_null_offsets.len() - 1;
self.non_null_offsets[last_idx] += 1;
}
}
idx = if self.non_null_offsets.len() >= non_null_row_count {
let total_offset: usize =
self.non_null_offsets.iter().take(non_null_row_count).sum();
Some(range.start + total_offset)
} else {
None
};
if !self.non_null_offsets.is_empty() {
self.non_null_offsets[0] -= 1;
if self.non_null_offsets[0] == 0 {
self.non_null_offsets.pop_front();
}
}
}
#[allow(clippy::unnecessary_unwrap)]
if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
ScalarValue::try_from_array(array, idx.unwrap())
} else {
Ok(self.default_value.clone())
}
}
fn evaluate_all(
&mut self,
values: &[ArrayRef],
_num_rows: usize,
) -> Result<ArrayRef> {
let value = &values[0];
if !self.ignore_nulls {
shift_with_default_value(value, self.shift_offset, &self.default_value)
} else {
evaluate_all_with_ignore_null(
value,
self.shift_offset,
&self.default_value,
self.is_lag(),
)
}
}
fn supports_bounded_execution(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::*;
use datafusion_common::cast::as_int32_array;
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
fn test_i32_result(
expr: WindowShift,
partition_evaluator_args: PartitionEvaluatorArgs,
expected: Int32Array,
) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let num_rows = values.len();
let result = expr
.partition_evaluator(partition_evaluator_args)?
.evaluate_all(&values, num_rows)?;
let result = as_int32_array(&result)?;
assert_eq!(expected, *result);
Ok(())
}
#[test]
fn lead_lag_get_range() -> Result<()> {
let lag_fn = WindowShiftEvaluator {
shift_offset: 2,
default_value: ScalarValue::Null,
ignore_nulls: false,
non_null_offsets: Default::default(),
};
assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
let lag_fn = WindowShiftEvaluator {
shift_offset: 2,
default_value: ScalarValue::Null,
ignore_nulls: true,
non_null_offsets: vec![2, 2].into(), };
assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
let lead_fn = WindowShiftEvaluator {
shift_offset: -2,
default_value: ScalarValue::Null,
ignore_nulls: false,
non_null_offsets: Default::default(),
};
assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
let lead_fn = WindowShiftEvaluator {
shift_offset: -2,
default_value: ScalarValue::Null,
ignore_nulls: true,
non_null_offsets: vec![2, 2].into(),
};
assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
Ok(())
}
#[test]
fn test_lead_window_shift() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
test_i32_result(
WindowShift::lead(),
PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
[
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
Some(8),
None,
]
.iter()
.collect::<Int32Array>(),
)
}
#[test]
fn test_lag_window_shift() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
test_i32_result(
WindowShift::lag(),
PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false),
[
None,
Some(1),
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
]
.iter()
.collect::<Int32Array>(),
)
}
#[test]
fn test_lag_with_default() -> Result<()> {
let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
let shift_offset =
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
as Arc<dyn PhysicalExpr>;
let input_exprs = &[expr, shift_offset, default_value];
let input_types: &[DataType] =
&[DataType::Int32, DataType::Int32, DataType::Int32];
test_i32_result(
WindowShift::lag(),
PartitionEvaluatorArgs::new(input_exprs, input_types, false, false),
[
Some(100),
Some(1),
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
]
.iter()
.collect::<Int32Array>(),
)
}
}