polars_plan/dsl/function_expr/
rolling.rs

1#[cfg(feature = "cov")]
2use std::ops::BitAnd;
3
4use polars_core::utils::Container;
5use polars_time::chunkedarray::*;
6
7use super::*;
8#[cfg(feature = "cov")]
9use crate::dsl::pow::pow;
10
11#[derive(Clone, PartialEq, Debug)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13pub enum RollingFunction {
14    Min(RollingOptionsFixedWindow),
15    Max(RollingOptionsFixedWindow),
16    Mean(RollingOptionsFixedWindow),
17    Sum(RollingOptionsFixedWindow),
18    Quantile(RollingOptionsFixedWindow),
19    Var(RollingOptionsFixedWindow),
20    Std(RollingOptionsFixedWindow),
21    #[cfg(feature = "moment")]
22    Skew(usize, bool),
23    #[cfg(feature = "cov")]
24    CorrCov {
25        rolling_options: RollingOptionsFixedWindow,
26        corr_cov_options: RollingCovOptions,
27        // Whether is Corr or Cov
28        is_corr: bool,
29    },
30}
31
32impl Display for RollingFunction {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        use RollingFunction::*;
35
36        let name = match self {
37            Min(_) => "rolling_min",
38            Max(_) => "rolling_max",
39            Mean(_) => "rolling_mean",
40            Sum(_) => "rolling_sum",
41            Quantile(_) => "rolling_quantile",
42            Var(_) => "rolling_var",
43            Std(_) => "rolling_std",
44            #[cfg(feature = "moment")]
45            Skew(..) => "rolling_skew",
46            #[cfg(feature = "cov")]
47            CorrCov { is_corr, .. } => {
48                if *is_corr {
49                    "rolling_corr"
50                } else {
51                    "rolling_cov"
52                }
53            },
54        };
55
56        write!(f, "{name}")
57    }
58}
59
60impl Hash for RollingFunction {
61    fn hash<H: Hasher>(&self, state: &mut H) {
62        use RollingFunction::*;
63
64        std::mem::discriminant(self).hash(state);
65        match self {
66            #[cfg(feature = "moment")]
67            Skew(window_size, bias) => {
68                window_size.hash(state);
69                bias.hash(state)
70            },
71            #[cfg(feature = "cov")]
72            CorrCov { is_corr, .. } => {
73                is_corr.hash(state);
74            },
75            _ => {},
76        }
77    }
78}
79
80pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
81    // @scalar-opt
82    s.as_materialized_series()
83        .rolling_min(options)
84        .map(Column::from)
85}
86
87pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
88    // @scalar-opt
89    s.as_materialized_series()
90        .rolling_max(options)
91        .map(Column::from)
92}
93
94pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
95    // @scalar-opt
96    s.as_materialized_series()
97        .rolling_mean(options)
98        .map(Column::from)
99}
100
101pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
102    // @scalar-opt
103    s.as_materialized_series()
104        .rolling_sum(options)
105        .map(Column::from)
106}
107
108pub(super) fn rolling_quantile(
109    s: &Column,
110    options: RollingOptionsFixedWindow,
111) -> PolarsResult<Column> {
112    // @scalar-opt
113    s.as_materialized_series()
114        .rolling_quantile(options)
115        .map(Column::from)
116}
117
118pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
119    // @scalar-opt
120    s.as_materialized_series()
121        .rolling_var(options)
122        .map(Column::from)
123}
124
125pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
126    // @scalar-opt
127    s.as_materialized_series()
128        .rolling_std(options)
129        .map(Column::from)
130}
131
132#[cfg(feature = "moment")]
133pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> PolarsResult<Column> {
134    // @scalar-opt
135    s.as_materialized_series()
136        .rolling_skew(window_size, bias)
137        .map(Column::from)
138}
139
140#[cfg(feature = "cov")]
141fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
142    match dtype {
143        DataType::Float64 => {
144            let values = (0..len)
145                .map(|v| std::cmp::min(window_size, v + 1) as f64)
146                .collect::<Vec<_>>();
147            Series::new(PlSmallStr::EMPTY, values)
148        },
149        DataType::Float32 => {
150            let values = (0..len)
151                .map(|v| std::cmp::min(window_size, v + 1) as f32)
152                .collect::<Vec<_>>();
153            Series::new(PlSmallStr::EMPTY, values)
154        },
155        _ => unreachable!(),
156    }
157}
158
159#[cfg(feature = "cov")]
160pub(super) fn rolling_corr_cov(
161    s: &[Column],
162    rolling_options: RollingOptionsFixedWindow,
163    cov_options: RollingCovOptions,
164    is_corr: bool,
165) -> PolarsResult<Column> {
166    let mut x = s[0].as_materialized_series().rechunk();
167    let mut y = s[1].as_materialized_series().rechunk();
168
169    if !x.dtype().is_float() {
170        x = x.cast(&DataType::Float64)?;
171    }
172    if !y.dtype().is_float() {
173        y = y.cast(&DataType::Float64)?;
174    }
175    let dtype = x.dtype().clone();
176
177    let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
178    let rolling_options_count = RollingOptionsFixedWindow {
179        window_size: rolling_options.window_size,
180        min_periods: 0,
181        ..Default::default()
182    };
183
184    let count_x_y = if (x.null_count() + y.null_count()) > 0 {
185        // mask out nulls on both sides before compute mean/var
186        let valids = x.is_not_null().bitand(y.is_not_null());
187        let valids_arr = valids.clone().downcast_into_array();
188        let valids_bitmap = valids_arr.values();
189
190        unsafe {
191            let xarr = &mut x.chunks_mut()[0];
192            *xarr = xarr.with_validity(Some(valids_bitmap.clone()));
193            let yarr = &mut y.chunks_mut()[0];
194            *yarr = yarr.with_validity(Some(valids_bitmap.clone()));
195            x.compute_len();
196            y.compute_len();
197        }
198        valids
199            .cast(&dtype)
200            .unwrap()
201            .rolling_sum(rolling_options_count)?
202    } else {
203        det_count_x_y(rolling_options.window_size, x.len(), &dtype)
204    };
205
206    let mean_x = x.rolling_mean(rolling_options.clone())?;
207    let mean_y = y.rolling_mean(rolling_options.clone())?;
208    let ddof = Series::new(
209        PlSmallStr::EMPTY,
210        &[AnyValue::from(cov_options.ddof).cast(&dtype)],
211    );
212
213    let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
214        * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
215    .unwrap();
216
217    if is_corr {
218        let var_x = x.rolling_var(rolling_options.clone())?;
219        let var_y = y.rolling_var(rolling_options.clone())?;
220
221        let base = (var_x * var_y).unwrap();
222        let sc = Scalar::new(
223            base.dtype().clone(),
224            AnyValue::Float64(0.5).cast(&dtype).into_static(),
225        );
226        let denominator = pow(&mut [base.into_column(), sc.into_column("".into())])
227            .unwrap()
228            .unwrap()
229            .take_materialized_series();
230
231        Ok((numerator / denominator)?.into_column())
232    } else {
233        Ok(numerator.into_column())
234    }
235}