polars_plan/dsl/functions/
correlation.rs

1use super::*;
2
3/// Compute the covariance between two columns.
4pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
5    let input = vec![a, b];
6    let function = FunctionExpr::Correlation {
7        method: CorrelationMethod::Covariance(ddof),
8    };
9    Expr::Function {
10        input,
11        function,
12        options: FunctionOptions {
13            collect_groups: ApplyOptions::GroupWise,
14            cast_options: Some(CastingRules::cast_to_supertypes()),
15            flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
16            ..Default::default()
17        },
18    }
19}
20
21/// Compute the pearson correlation between two columns.
22pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
23    let input = vec![a, b];
24    let function = FunctionExpr::Correlation {
25        method: CorrelationMethod::Pearson,
26    };
27    Expr::Function {
28        input,
29        function,
30        options: FunctionOptions {
31            collect_groups: ApplyOptions::GroupWise,
32            cast_options: Some(CastingRules::cast_to_supertypes()),
33            flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
34            ..Default::default()
35        },
36    }
37}
38
39/// Compute the spearman rank correlation between two columns.
40/// Missing data will be excluded from the computation.
41/// # Arguments
42/// * propagate_nans
43///     If `true` any `NaN` encountered will lead to `NaN` in the output.
44///     If to `false` then `NaN` are regarded as larger than any finite number
45///     and thus lead to the highest rank.
46#[cfg(all(feature = "rank", feature = "propagate_nans"))]
47pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
48    let input = vec![a, b];
49    let function = FunctionExpr::Correlation {
50        method: CorrelationMethod::SpearmanRank(propagate_nans),
51    };
52    Expr::Function {
53        input,
54        function,
55        options: FunctionOptions {
56            collect_groups: ApplyOptions::GroupWise,
57            cast_options: Some(CastingRules::cast_to_supertypes()),
58            flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
59            ..Default::default()
60        },
61    }
62}
63
64#[cfg(all(feature = "rolling_window", feature = "cov"))]
65fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr {
66    // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804
67    let rolling_options = RollingOptionsFixedWindow {
68        window_size: options.window_size as usize,
69        min_periods: options.min_periods as usize,
70        ..Default::default()
71    };
72
73    Expr::Function {
74        input: vec![x, y],
75        function: FunctionExpr::RollingExpr(RollingFunction::CorrCov {
76            rolling_options,
77            corr_cov_options: options,
78            is_corr,
79        }),
80        options: Default::default(),
81    }
82}
83
84#[cfg(all(feature = "rolling_window", feature = "cov"))]
85pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
86    dispatch_corr_cov(x, y, options, true)
87}
88
89#[cfg(all(feature = "rolling_window", feature = "cov"))]
90pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
91    dispatch_corr_cov(x, y, options, false)
92}