datafusion_functions_aggregate/
correlation.rsuse std::any::Any;
use std::fmt::Debug;
use std::mem::size_of_val;
use std::sync::{Arc, OnceLock};
use arrow::compute::{and, filter, is_not_null};
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
};
use crate::covariance::CovarianceAccumulator;
use crate::stddev::StddevAccumulator;
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
};
use datafusion_functions_aggregate_common::stats::StatsType;
make_udaf_expr_and_func!(
Correlation,
corr,
y x,
"Correlation between two numeric values.",
corr_udaf
);
#[derive(Debug)]
pub struct Correlation {
signature: Signature,
}
impl Default for Correlation {
fn default() -> Self {
Self::new()
}
}
impl Correlation {
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for Correlation {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"corr"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Correlation requires numeric input types");
}
Ok(DataType::Float64)
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CorrelationAccumulator::try_new()?))
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
Field::new(
format_state_name(name, "algo_const"),
DataType::Float64,
true,
),
])
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_corr_doc())
}
}
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_corr_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Returns the coefficient of correlation between two numeric values.",
)
.with_syntax_example("corr(expression1, expression2)")
.with_sql_example(
r#"```sql
> SELECT corr(column1, column2) FROM table_name;
+--------------------------------+
| corr(column1, column2) |
+--------------------------------+
| 0.85 |
+--------------------------------+
```"#,
)
.with_standard_argument("expression1", Some("First"))
.with_standard_argument("expression2", Some("Second"))
.build()
.unwrap()
})
}
#[derive(Debug)]
pub struct CorrelationAccumulator {
covar: CovarianceAccumulator,
stddev1: StddevAccumulator,
stddev2: StddevAccumulator,
}
impl CorrelationAccumulator {
pub fn try_new() -> Result<Self> {
Ok(Self {
covar: CovarianceAccumulator::try_new(StatsType::Population)?,
stddev1: StddevAccumulator::try_new(StatsType::Population)?,
stddev2: StddevAccumulator::try_new(StatsType::Population)?,
})
}
}
impl Accumulator for CorrelationAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;
vec![values1, values2]
} else {
values.to_vec()
};
self.covar.update_batch(&values)?;
self.stddev1.update_batch(&values[0..1])?;
self.stddev2.update_batch(&values[1..2])?;
Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
let covar = self.covar.evaluate()?;
let stddev1 = self.stddev1.evaluate()?;
let stddev2 = self.stddev2.evaluate()?;
if let ScalarValue::Float64(Some(c)) = covar {
if let ScalarValue::Float64(Some(s1)) = stddev1 {
if let ScalarValue::Float64(Some(s2)) = stddev2 {
if s1 == 0_f64 || s2 == 0_f64 {
return Ok(ScalarValue::Float64(Some(0_f64)));
} else {
return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
}
}
}
}
Ok(ScalarValue::Float64(None))
}
fn size(&self) -> usize {
size_of_val(self) - size_of_val(&self.covar) + self.covar.size()
- size_of_val(&self.stddev1)
+ self.stddev1.size()
- size_of_val(&self.stddev2)
+ self.stddev2.size()
}
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.covar.get_count()),
ScalarValue::from(self.covar.get_mean1()),
ScalarValue::from(self.stddev1.get_m2()),
ScalarValue::from(self.covar.get_mean2()),
ScalarValue::from(self.stddev2.get_m2()),
ScalarValue::from(self.covar.get_algo_const()),
])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let states_c = [
Arc::clone(&states[0]),
Arc::clone(&states[1]),
Arc::clone(&states[3]),
Arc::clone(&states[5]),
];
let states_s1 = [
Arc::clone(&states[0]),
Arc::clone(&states[1]),
Arc::clone(&states[2]),
];
let states_s2 = [
Arc::clone(&states[0]),
Arc::clone(&states[3]),
Arc::clone(&states[4]),
];
self.covar.merge_batch(&states_c)?;
self.stddev1.merge_batch(&states_s1)?;
self.stddev2.merge_batch(&states_s2)?;
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;
vec![values1, values2]
} else {
values.to_vec()
};
self.covar.retract_batch(&values)?;
self.stddev1.retract_batch(&values[0..1])?;
self.stddev2.retract_batch(&values[1..2])?;
Ok(())
}
}