stats/
online.rs

1use std::default::Default;
2use std::fmt;
3use std::iter::{FromIterator, IntoIterator};
4
5use num_traits::ToPrimitive;
6
7use Commute;
8
9/// Compute the standard deviation of a stream in constant space.
10pub fn stddev<I>(it: I) -> f64
11        where I: Iterator, <I as Iterator>::Item: ToPrimitive {
12    it.collect::<OnlineStats>().stddev()
13}
14
15/// Compute the variance of a stream in constant space.
16pub fn variance<I>(it: I) -> f64
17        where I: Iterator, <I as Iterator>::Item: ToPrimitive {
18    it.collect::<OnlineStats>().variance()
19}
20
21/// Compute the mean of a stream in constant space.
22pub fn mean<I>(it: I) -> f64
23        where I: Iterator, <I as Iterator>::Item: ToPrimitive {
24    it.collect::<OnlineStats>().mean()
25}
26
27/// Online state for computing mean, variance and standard deviation.
28#[derive(Clone, Copy)]
29pub struct OnlineStats {
30    size: u64,
31    mean: f64,
32    variance: f64,
33}
34
35impl OnlineStats {
36    /// Create initial state.
37    ///
38    /// Population size, variance and mean are set to `0`.
39    pub fn new() -> OnlineStats {
40        Default::default()
41    }
42
43    /// Initializes variance from a sample.
44    pub fn from_slice<T: ToPrimitive>(samples: &[T]) -> OnlineStats {
45        samples.iter().map(|n| n.to_f64().unwrap()).collect()
46    }
47
48    /// Return the current mean.
49    pub fn mean(&self) -> f64 {
50        self.mean
51    }
52
53    /// Return the current standard deviation.
54    pub fn stddev(&self) -> f64 {
55        self.variance.sqrt()
56    }
57
58    /// Return the current variance.
59    pub fn variance(&self) -> f64 {
60        self.variance
61    }
62
63    /// Add a new sample.
64    pub fn add<T: ToPrimitive>(&mut self, sample: T) {
65        let sample = sample.to_f64().unwrap();
66        // Taken from: http://goo.gl/JKeqvj
67        // See also: http://goo.gl/qTtI3V
68        let oldmean = self.mean;
69        let prevq = self.variance * (self.size as f64);
70
71        self.size += 1;
72        self.mean += (sample - oldmean) / (self.size as f64);
73        self.variance = (prevq + (sample - oldmean) * (sample - self.mean))
74                        / (self.size as f64);
75    }
76
77    /// Add a new NULL value to the population.
78    ///
79    /// This increases the population size by `1`.
80    pub fn add_null(&mut self) {
81        self.add(0usize);
82    }
83
84    /// Returns the number of data points.
85    pub fn len(&self) -> usize {
86        self.size as usize
87    }
88}
89
90impl Commute for OnlineStats {
91    fn merge(&mut self, v: OnlineStats) {
92        // Taken from: http://goo.gl/iODi28
93        let (s1, s2) = (self.size as f64, v.size as f64);
94        let meandiffsq = (self.mean - v.mean) * (self.mean - v.mean);
95        let mean = ((s1 * self.mean) + (s2 * v.mean)) / (s1 + s2);
96        let var = (((s1 * self.variance) + (s2 * v.variance))
97                   / (s1 + s2))
98                  +
99                  ((s1 * s2 * meandiffsq) / ((s1 + s2) * (s1 + s2)));
100        self.size += v.size;
101        self.mean = mean;
102        self.variance = var;
103    }
104}
105
106impl Default for OnlineStats {
107    fn default() -> OnlineStats {
108        OnlineStats {
109            size: 0,
110            mean: 0.0,
111            variance: 0.0,
112        }
113    }
114}
115
116impl fmt::Debug for OnlineStats {
117    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118        write!(f, "{:.10} +/- {:.10}", self.mean(), self.stddev())
119    }
120}
121
122impl<T: ToPrimitive> FromIterator<T> for OnlineStats {
123    fn from_iter<I: IntoIterator<Item=T>>(it: I) -> OnlineStats {
124        let mut v = OnlineStats::new();
125        v.extend(it);
126        v
127    }
128}
129
130impl<T: ToPrimitive> Extend<T> for OnlineStats {
131    fn extend<I: IntoIterator<Item=T>>(&mut self, it: I) {
132        for sample in it {
133            self.add(sample)
134        }
135    }
136}
137
138#[cfg(test)]
139mod test {
140    use {Commute, merge_all};
141    use super::OnlineStats;
142
143    #[test]
144    fn stddev() {
145        // TODO: Convert this to a quickcheck test.
146        let expected = OnlineStats::from_slice(&[1usize, 2, 3, 2, 4, 6]);
147
148        let var1 = OnlineStats::from_slice(&[1usize, 2, 3]);
149        let var2 = OnlineStats::from_slice(&[2usize, 4, 6]);
150        let mut got = var1;
151        got.merge(var2);
152        assert_eq!(expected.stddev(), got.stddev());
153    }
154
155    #[test]
156    fn stddev_many() {
157        // TODO: Convert this to a quickcheck test.
158        let expected = OnlineStats::from_slice(
159            &[1usize, 2, 3, 2, 4, 6, 3, 6, 9]);
160
161        let vars = vec![
162            OnlineStats::from_slice(&[1usize, 2, 3]),
163            OnlineStats::from_slice(&[2usize, 4, 6]),
164            OnlineStats::from_slice(&[3usize, 6, 9]),
165        ];
166        assert_eq!(expected.stddev(),
167                   merge_all(vars.into_iter()).unwrap().stddev());
168    }
169}