1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
mod checked;
use std::marker::PhantomData;

pub use checked::*;

mod ordered;
pub use ordered::*;

use crate::{
    error::Fallible,
    traits::{Float, InfAdd},
};

/// Marker type to represent sequential, or recursive summation
pub struct Sequential<T>(PhantomData<T>);

/// Marker type to represent pairwise, or cascading summation
pub struct Pairwise<T>(PhantomData<T>);

#[doc(hidden)]
pub trait SumRelaxation {
    type Item: Float;
    fn error(size: usize, lower: Self::Item, upper: Self::Item) -> Fallible<Self::Item>;
    fn relaxation(size: usize, lower: Self::Item, upper: Self::Item) -> Fallible<Self::Item> {
        let error = Self::error(size, lower, upper)?;
        error.inf_add(&error)
    }
}

impl<T: Float> SumRelaxation for Sequential<T> {
    type Item = T;
    fn error(size: usize, lower: Self::Item, upper: Self::Item) -> Fallible<Self::Item> {
        let size = T::exact_int_cast(size)?;
        let mantissa_bits = T::exact_int_cast(T::MANTISSA_BITS)?;
        let _2 = T::exact_int_cast(2)?;

        // n^2 / 2^(k - 1) max(|L|, U)
        size.inf_mul(&size)?
            .inf_div(&_2.inf_pow(&mantissa_bits)?)?
            .inf_mul(&lower.alerting_abs()?.total_max(upper)?)
    }
}

impl<T: Float> SumRelaxation for Pairwise<T> {
    type Item = T;
    fn error(size: usize, lower: Self::Item, upper: Self::Item) -> Fallible<Self::Item> {
        let size = T::exact_int_cast(size)?;
        let mantissa_bits = T::exact_int_cast(T::MANTISSA_BITS)?;
        let _2 = T::exact_int_cast(2)?;

        // u * k where k = log_2(n)
        let uk = size.inf_log2()?.inf_div(&_2.inf_pow(&mantissa_bits)?)?;

        // (uk / (1 - uk)) n max(|L|, U)
        uk.inf_div(&T::one().neg_inf_sub(&uk)?)?
            .inf_mul(&size)?
            .inf_mul(&lower.alerting_abs()?.total_max(upper)?)
    }
}