1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use super::*;

/// Compute the covariance between two columns.
pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
    let input = vec![a, b];
    let function = FunctionExpr::Correlation {
        method: CorrelationMethod::Covariance,
        ddof,
    };
    Expr::Function {
        input,
        function,
        options: FunctionOptions {
            collect_groups: ApplyOptions::GroupWise,
            cast_to_supertypes: Some(Default::default()),
            returns_scalar: true,
            ..Default::default()
        },
    }
}

/// Compute the pearson correlation between two columns.
///
/// # Arguments
/// * ddof
///     Delta degrees of freedom
pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
    let input = vec![a, b];
    let function = FunctionExpr::Correlation {
        method: CorrelationMethod::Pearson,
        ddof,
    };
    Expr::Function {
        input,
        function,
        options: FunctionOptions {
            collect_groups: ApplyOptions::GroupWise,
            cast_to_supertypes: Some(Default::default()),
            returns_scalar: true,
            ..Default::default()
        },
    }
}

/// Compute the spearman rank correlation between two columns.
/// Missing data will be excluded from the computation.
/// # Arguments
/// * ddof
///     Delta degrees of freedom
/// * propagate_nans
///     If `true` any `NaN` encountered will lead to `NaN` in the output.
///     If to `false` then `NaN` are regarded as larger than any finite number
///     and thus lead to the highest rank.
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> Expr {
    let input = vec![a, b];
    let function = FunctionExpr::Correlation {
        method: CorrelationMethod::SpearmanRank(propagate_nans),
        ddof,
    };
    Expr::Function {
        input,
        function,
        options: FunctionOptions {
            collect_groups: ApplyOptions::GroupWise,
            cast_to_supertypes: Some(Default::default()),
            returns_scalar: true,
            ..Default::default()
        },
    }
}

#[cfg(feature = "rolling_window")]
pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
    // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804
    let rolling_options = RollingOptionsFixedWindow {
        window_size: options.window_size as usize,
        min_periods: options.min_periods as usize,
        ..Default::default()
    };

    let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone());
    let mean_x = x.clone().rolling_mean(rolling_options.clone());
    let mean_y = y.clone().rolling_mean(rolling_options.clone());
    let var_x = x.clone().rolling_var(rolling_options.clone());
    let var_y = y.clone().rolling_var(rolling_options);

    let rolling_options_count = RollingOptionsFixedWindow {
        window_size: options.window_size as usize,
        min_periods: 0,
        ..Default::default()
    };
    let ddof = options.ddof as f64;
    let count_x_y = (x + y)
        .is_not_null()
        .cast(DataType::Float64)
        .rolling_sum(rolling_options_count);
    let numerator = (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof)));
    let denominator = (var_x * var_y).pow(lit(0.5));

    numerator / denominator
}

#[cfg(feature = "rolling_window")]
pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
    // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700
    let rolling_options = RollingOptionsFixedWindow {
        window_size: options.window_size as usize,
        min_periods: options.min_periods as usize,
        ..Default::default()
    };

    let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone());
    let mean_x = x.clone().rolling_mean(rolling_options.clone());
    let mean_y = y.clone().rolling_mean(rolling_options);
    let rolling_options_count = RollingOptionsFixedWindow {
        window_size: options.window_size as usize,
        min_periods: 0,
        ..Default::default()
    };
    let count_x_y = (x + y)
        .is_not_null()
        .cast(DataType::Float64)
        .rolling_sum(rolling_options_count);

    let ddof = options.ddof as f64;

    (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof)))
}