polars_plan/dsl/functions/
correlation.rs1use super::*;
2
3pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
5 let input = vec![a, b];
6 let function = FunctionExpr::Correlation {
7 method: CorrelationMethod::Covariance(ddof),
8 };
9 Expr::Function {
10 input,
11 function,
12 options: FunctionOptions {
13 collect_groups: ApplyOptions::GroupWise,
14 cast_options: Some(CastingRules::cast_to_supertypes()),
15 flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
16 ..Default::default()
17 },
18 }
19}
20
21pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
23 let input = vec![a, b];
24 let function = FunctionExpr::Correlation {
25 method: CorrelationMethod::Pearson,
26 };
27 Expr::Function {
28 input,
29 function,
30 options: FunctionOptions {
31 collect_groups: ApplyOptions::GroupWise,
32 cast_options: Some(CastingRules::cast_to_supertypes()),
33 flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
34 ..Default::default()
35 },
36 }
37}
38
39#[cfg(all(feature = "rank", feature = "propagate_nans"))]
47pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
48 let input = vec![a, b];
49 let function = FunctionExpr::Correlation {
50 method: CorrelationMethod::SpearmanRank(propagate_nans),
51 };
52 Expr::Function {
53 input,
54 function,
55 options: FunctionOptions {
56 collect_groups: ApplyOptions::GroupWise,
57 cast_options: Some(CastingRules::cast_to_supertypes()),
58 flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
59 ..Default::default()
60 },
61 }
62}
63
64#[cfg(all(feature = "rolling_window", feature = "cov"))]
65fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr {
66 let rolling_options = RollingOptionsFixedWindow {
68 window_size: options.window_size as usize,
69 min_periods: options.min_periods as usize,
70 ..Default::default()
71 };
72
73 Expr::Function {
74 input: vec![x, y],
75 function: FunctionExpr::RollingExpr(RollingFunction::CorrCov {
76 rolling_options,
77 corr_cov_options: options,
78 is_corr,
79 }),
80 options: Default::default(),
81 }
82}
83
84#[cfg(all(feature = "rolling_window", feature = "cov"))]
85pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
86 dispatch_corr_cov(x, y, options, true)
87}
88
89#[cfg(all(feature = "rolling_window", feature = "cov"))]
90pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
91 dispatch_corr_cov(x, y, options, false)
92}