polars_plan/dsl/function_expr/
rolling.rs1#[cfg(feature = "cov")]
2use std::ops::BitAnd;
3
4use polars_core::utils::Container;
5use polars_time::chunkedarray::*;
6
7use super::*;
8#[cfg(feature = "cov")]
9use crate::dsl::pow::pow;
10
11#[derive(Clone, PartialEq, Debug)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13pub enum RollingFunction {
14 Min(RollingOptionsFixedWindow),
15 Max(RollingOptionsFixedWindow),
16 Mean(RollingOptionsFixedWindow),
17 Sum(RollingOptionsFixedWindow),
18 Quantile(RollingOptionsFixedWindow),
19 Var(RollingOptionsFixedWindow),
20 Std(RollingOptionsFixedWindow),
21 #[cfg(feature = "moment")]
22 Skew(usize, bool),
23 #[cfg(feature = "cov")]
24 CorrCov {
25 rolling_options: RollingOptionsFixedWindow,
26 corr_cov_options: RollingCovOptions,
27 is_corr: bool,
29 },
30}
31
32impl Display for RollingFunction {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 use RollingFunction::*;
35
36 let name = match self {
37 Min(_) => "rolling_min",
38 Max(_) => "rolling_max",
39 Mean(_) => "rolling_mean",
40 Sum(_) => "rolling_sum",
41 Quantile(_) => "rolling_quantile",
42 Var(_) => "rolling_var",
43 Std(_) => "rolling_std",
44 #[cfg(feature = "moment")]
45 Skew(..) => "rolling_skew",
46 #[cfg(feature = "cov")]
47 CorrCov { is_corr, .. } => {
48 if *is_corr {
49 "rolling_corr"
50 } else {
51 "rolling_cov"
52 }
53 },
54 };
55
56 write!(f, "{name}")
57 }
58}
59
60impl Hash for RollingFunction {
61 fn hash<H: Hasher>(&self, state: &mut H) {
62 use RollingFunction::*;
63
64 std::mem::discriminant(self).hash(state);
65 match self {
66 #[cfg(feature = "moment")]
67 Skew(window_size, bias) => {
68 window_size.hash(state);
69 bias.hash(state)
70 },
71 #[cfg(feature = "cov")]
72 CorrCov { is_corr, .. } => {
73 is_corr.hash(state);
74 },
75 _ => {},
76 }
77 }
78}
79
80pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
81 s.as_materialized_series()
83 .rolling_min(options)
84 .map(Column::from)
85}
86
87pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
88 s.as_materialized_series()
90 .rolling_max(options)
91 .map(Column::from)
92}
93
94pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
95 s.as_materialized_series()
97 .rolling_mean(options)
98 .map(Column::from)
99}
100
101pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
102 s.as_materialized_series()
104 .rolling_sum(options)
105 .map(Column::from)
106}
107
108pub(super) fn rolling_quantile(
109 s: &Column,
110 options: RollingOptionsFixedWindow,
111) -> PolarsResult<Column> {
112 s.as_materialized_series()
114 .rolling_quantile(options)
115 .map(Column::from)
116}
117
118pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
119 s.as_materialized_series()
121 .rolling_var(options)
122 .map(Column::from)
123}
124
125pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
126 s.as_materialized_series()
128 .rolling_std(options)
129 .map(Column::from)
130}
131
132#[cfg(feature = "moment")]
133pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> PolarsResult<Column> {
134 s.as_materialized_series()
136 .rolling_skew(window_size, bias)
137 .map(Column::from)
138}
139
140#[cfg(feature = "cov")]
141fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
142 match dtype {
143 DataType::Float64 => {
144 let values = (0..len)
145 .map(|v| std::cmp::min(window_size, v + 1) as f64)
146 .collect::<Vec<_>>();
147 Series::new(PlSmallStr::EMPTY, values)
148 },
149 DataType::Float32 => {
150 let values = (0..len)
151 .map(|v| std::cmp::min(window_size, v + 1) as f32)
152 .collect::<Vec<_>>();
153 Series::new(PlSmallStr::EMPTY, values)
154 },
155 _ => unreachable!(),
156 }
157}
158
159#[cfg(feature = "cov")]
160pub(super) fn rolling_corr_cov(
161 s: &[Column],
162 rolling_options: RollingOptionsFixedWindow,
163 cov_options: RollingCovOptions,
164 is_corr: bool,
165) -> PolarsResult<Column> {
166 let mut x = s[0].as_materialized_series().rechunk();
167 let mut y = s[1].as_materialized_series().rechunk();
168
169 if !x.dtype().is_float() {
170 x = x.cast(&DataType::Float64)?;
171 }
172 if !y.dtype().is_float() {
173 y = y.cast(&DataType::Float64)?;
174 }
175 let dtype = x.dtype().clone();
176
177 let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
178 let rolling_options_count = RollingOptionsFixedWindow {
179 window_size: rolling_options.window_size,
180 min_periods: 0,
181 ..Default::default()
182 };
183
184 let count_x_y = if (x.null_count() + y.null_count()) > 0 {
185 let valids = x.is_not_null().bitand(y.is_not_null());
187 let valids_arr = valids.clone().downcast_into_array();
188 let valids_bitmap = valids_arr.values();
189
190 unsafe {
191 let xarr = &mut x.chunks_mut()[0];
192 *xarr = xarr.with_validity(Some(valids_bitmap.clone()));
193 let yarr = &mut y.chunks_mut()[0];
194 *yarr = yarr.with_validity(Some(valids_bitmap.clone()));
195 x.compute_len();
196 y.compute_len();
197 }
198 valids
199 .cast(&dtype)
200 .unwrap()
201 .rolling_sum(rolling_options_count)?
202 } else {
203 det_count_x_y(rolling_options.window_size, x.len(), &dtype)
204 };
205
206 let mean_x = x.rolling_mean(rolling_options.clone())?;
207 let mean_y = y.rolling_mean(rolling_options.clone())?;
208 let ddof = Series::new(
209 PlSmallStr::EMPTY,
210 &[AnyValue::from(cov_options.ddof).cast(&dtype)],
211 );
212
213 let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
214 * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
215 .unwrap();
216
217 if is_corr {
218 let var_x = x.rolling_var(rolling_options.clone())?;
219 let var_y = y.rolling_var(rolling_options.clone())?;
220
221 let base = (var_x * var_y).unwrap();
222 let sc = Scalar::new(
223 base.dtype().clone(),
224 AnyValue::Float64(0.5).cast(&dtype).into_static(),
225 );
226 let denominator = pow(&mut [base.into_column(), sc.into_column("".into())])
227 .unwrap()
228 .unwrap()
229 .take_materialized_series();
230
231 Ok((numerator / denominator)?.into_column())
232 } else {
233 Ok(numerator.into_column())
234 }
235}