use super::*;
#[inline]
fn new_is_min<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
compare_fn_nan_min(old, new).is_ge()
}
#[inline]
fn new_is_max<T: NativeType + IsFloat + PartialOrd>(old: &T, new: &T) -> bool {
compare_fn_nan_max(old, new).is_le()
}
#[inline]
unsafe fn get_min_and_idx<T>(
slice: &[T],
start: usize,
end: usize,
sorted_to: usize,
) -> Option<(usize, &T)>
where
T: NativeType + IsFloat + PartialOrd,
{
if sorted_to >= end {
Some((start, slice.get_unchecked(start)))
} else if sorted_to <= start {
slice
.get_unchecked(start..end)
.iter()
.enumerate()
.rev()
.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
.map(|v| (v.0 + start, v.1))
} else {
let s = (start, slice.get_unchecked(start));
slice
.get_unchecked(sorted_to..end)
.iter()
.enumerate()
.rev()
.min_by(|&a, &b| compare_fn_nan_min(a.1, b.1))
.map(|v| {
if new_is_min(s.1, v.1) {
(v.0 + sorted_to, v.1)
} else {
s
}
})
}
}
#[inline]
unsafe fn get_max_and_idx<T>(
slice: &[T],
start: usize,
end: usize,
sorted_to: usize,
) -> Option<(usize, &T)>
where
T: NativeType + IsFloat + PartialOrd,
{
if sorted_to >= end {
Some((start, slice.get_unchecked(start)))
} else if sorted_to <= start {
slice
.get_unchecked(start..end)
.iter()
.enumerate()
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
.map(|v| (v.0 + start, v.1))
} else {
let s = (start, slice.get_unchecked(start));
slice
.get_unchecked(sorted_to..end)
.iter()
.enumerate()
.max_by(|&a, &b| compare_fn_nan_max(a.1, b.1))
.map(|v| {
if new_is_max(s.1, v.1) {
(v.0 + sorted_to, v.1)
} else {
s
}
})
}
}
#[inline]
fn n_sorted_past_min<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
slice
.windows(2)
.position(|x| compare_fn_nan_min(&x[0], &x[1]).is_gt())
.unwrap_or(slice.len() - 1)
}
#[inline]
fn n_sorted_past_max<T: NativeType + IsFloat + PartialOrd>(slice: &[T]) -> usize {
slice
.windows(2)
.position(|x| compare_fn_nan_max(&x[0], &x[1]).is_lt())
.unwrap_or(slice.len() - 1)
}
macro_rules! minmax_window {
($m_window:tt, $get_m_and_idx:ident, $new_is_m:ident, $n_sorted_past:ident) => {
pub struct $m_window<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
m: T,
m_idx: usize,
sorted_to: usize,
last_start: usize,
last_end: usize,
}
impl<'a, T: NativeType + IsFloat + PartialOrd> $m_window<'a, T> {
#[inline]
unsafe fn update_m_and_m_idx(&mut self, m_and_idx: (usize, &T)) {
self.m = *m_and_idx.1;
self.m_idx = m_and_idx.0;
if self.sorted_to <= self.m_idx {
self.sorted_to =
self.m_idx + 1 + $n_sorted_past(&self.slice.get_unchecked(self.m_idx..));
}
}
}
impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T>
for $m_window<'a, T>
{
fn new(
slice: &'a [T],
start: usize,
end: usize,
_params: Option<RollingFnParams>,
) -> Self {
let (idx, m) =
unsafe { $get_m_and_idx(slice, start, end, 0).unwrap_or((0, &slice[start])) };
Self {
slice,
m: *m,
m_idx: idx,
sorted_to: idx + 1 + $n_sorted_past(&slice[idx..]),
last_start: start,
last_end: end,
}
}
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
self.last_start = start; let old_last_end = self.last_end; self.last_end = end;
let entering_start = std::cmp::max(old_last_end, start);
let entering = if end - entering_start == 1 {
Some((entering_start, self.slice.get_unchecked(entering_start)))
} else if old_last_end == end {
None
} else {
$get_m_and_idx(self.slice, entering_start, end, self.sorted_to)
};
let empty_overlap = old_last_end <= start;
if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) {
self.update_m_and_m_idx(entering.unwrap());
return Some(self.m);
} else if self.m_idx >= start || empty_overlap {
return Some(self.m);
}
match (
$get_m_and_idx(self.slice, start, old_last_end, self.sorted_to),
entering,
) {
(Some(pm), Some(em)) => {
if $new_is_m(pm.1, em.1) {
self.update_m_and_m_idx(em);
} else {
self.update_m_and_m_idx(pm);
}
},
(Some(pm), None) => self.update_m_and_m_idx(pm),
(None, Some(em)) => self.update_m_and_m_idx(em),
(None, None) => unreachable!(),
}
Some(self.m)
}
}
};
}
minmax_window!(MinWindow, get_min_and_idx, new_is_min, n_sorted_past_min);
minmax_window!(MaxWindow, get_max_and_idx, new_is_max, n_sorted_past_max);
pub(crate) fn compute_min_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
{
values
.iter()
.zip(weights)
.map(|(v, w)| *v * *w)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub(crate) fn compute_max_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + IsFloat + Bounded + Mul<Output = T>,
{
let mut max = T::min_value();
for v in values.iter().zip(weights).map(|(v, w)| *v * *w) {
if T::is_float() && v.is_nan() {
return v;
}
if v > max {
max = v
}
}
max
}
macro_rules! rolling_minmax_func {
($rolling_m:ident, $window:tt, $wtd_f:ident) => {
pub fn $rolling_m<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
_params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
{
let offset_fn = match center {
true => det_offsets_center,
false => det_offsets,
};
match weights {
None => rolling_apply_agg_window::<$window<_>, _, _>(
values,
window_size,
min_periods,
offset_fn,
None,
),
Some(weights) => {
assert!(
T::is_float(),
"implementation error, should only be reachable by float types"
);
let weights = weights
.iter()
.map(|v| NumCast::from(*v).unwrap())
.collect::<Vec<_>>();
no_nulls::rolling_apply_weights(
values,
window_size,
min_periods,
offset_fn,
$wtd_f,
&weights,
)
},
}
}
};
}
rolling_minmax_func!(rolling_min, MinWindow, compute_min_weights);
rolling_minmax_func!(rolling_max, MaxWindow, compute_max_weights);
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_rolling_min_max() {
let values = &[1.0f64, 5.0, 3.0, 4.0];
let out = rolling_min(values, 2, 2, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
let out = rolling_max(values, 2, 2, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
let out = rolling_min(values, 2, 1, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
let out = rolling_max(values, 2, 1, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
let out = rolling_max(values, 3, 1, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
let out = rolling_min(values, 3, 3, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(
format!("{:?}", out.as_slice()),
format!(
"{:?}",
&[
None,
None,
Some(1.0),
Some(f64::nan()),
Some(f64::nan()),
Some(f64::nan()),
Some(5.0)
]
)
);
let out = rolling_max(values, 3, 3, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(
format!("{:?}", out.as_slice()),
format!(
"{:?}",
&[
None,
None,
Some(3.0),
Some(f64::nan()),
Some(f64::nan()),
Some(f64::nan()),
Some(7.0)
]
)
);
}
}