ndarray_stats/
deviation.rs

1use ndarray::{ArrayBase, Data, Dimension, Zip};
2use num_traits::{Signed, ToPrimitive};
3use std::convert::Into;
4use std::ops::AddAssign;
5
6use crate::errors::MultiInputError;
7
8/// An extension trait for `ArrayBase` providing functions
9/// to compute different deviation measures.
10pub trait DeviationExt<A, S, D>
11where
12    S: Data<Elem = A>,
13    D: Dimension,
14{
15    /// Counts the number of indices at which the elements of the arrays `self`
16    /// and `other` are equal.
17    ///
18    /// The following **errors** may be returned:
19    ///
20    /// * `MultiInputError::EmptyInput` if `self` is empty
21    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22    fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
23    where
24        A: PartialEq,
25        T: Data<Elem = A>;
26
27    /// Counts the number of indices at which the elements of the arrays `self`
28    /// and `other` are not equal.
29    ///
30    /// The following **errors** may be returned:
31    ///
32    /// * `MultiInputError::EmptyInput` if `self` is empty
33    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
34    fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
35    where
36        A: PartialEq,
37        T: Data<Elem = A>;
38
39    /// Computes the [squared L2 distance] between `self` and `other`.
40    ///
41    /// ```text
42    ///  n
43    ///  ∑  |aᵢ - bᵢ|²
44    /// i=1
45    /// ```
46    ///
47    /// where `self` is `a` and `other` is `b`.
48    ///
49    /// The following **errors** may be returned:
50    ///
51    /// * `MultiInputError::EmptyInput` if `self` is empty
52    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
53    ///
54    /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
55    fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
56    where
57        A: AddAssign + Clone + Signed,
58        T: Data<Elem = A>;
59
60    /// Computes the [L2 distance] between `self` and `other`.
61    ///
62    /// ```text
63    ///      n
64    /// √ (  ∑  |aᵢ - bᵢ|² )
65    ///     i=1
66    /// ```
67    ///
68    /// where `self` is `a` and `other` is `b`.
69    ///
70    /// The following **errors** may be returned:
71    ///
72    /// * `MultiInputError::EmptyInput` if `self` is empty
73    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
74    ///
75    /// **Panics** if the type cast from `A` to `f64` fails.
76    ///
77    /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
78    fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
79    where
80        A: AddAssign + Clone + Signed + ToPrimitive,
81        T: Data<Elem = A>;
82
83    /// Computes the [L1 distance] between `self` and `other`.
84    ///
85    /// ```text
86    ///  n
87    ///  ∑  |aᵢ - bᵢ|
88    /// i=1
89    /// ```
90    ///
91    /// where `self` is `a` and `other` is `b`.
92    ///
93    /// The following **errors** may be returned:
94    ///
95    /// * `MultiInputError::EmptyInput` if `self` is empty
96    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
97    ///
98    /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
99    fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
100    where
101        A: AddAssign + Clone + Signed,
102        T: Data<Elem = A>;
103
104    /// Computes the [L∞ distance] between `self` and `other`.
105    ///
106    /// ```text
107    /// max(|aᵢ - bᵢ|)
108    ///  ᵢ
109    /// ```
110    ///
111    /// where `self` is `a` and `other` is `b`.
112    ///
113    /// The following **errors** may be returned:
114    ///
115    /// * `MultiInputError::EmptyInput` if `self` is empty
116    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
117    ///
118    /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
119    fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
120    where
121        A: Clone + PartialOrd + Signed,
122        T: Data<Elem = A>;
123
124    /// Computes the [mean absolute error] between `self` and `other`.
125    ///
126    /// ```text
127    ///        n
128    /// 1/n *  ∑  |aᵢ - bᵢ|
129    ///       i=1
130    /// ```
131    ///
132    /// where `self` is `a` and `other` is `b`.
133    ///
134    /// The following **errors** may be returned:
135    ///
136    /// * `MultiInputError::EmptyInput` if `self` is empty
137    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
138    ///
139    /// **Panics** if the type cast from `A` to `f64` fails.
140    ///
141    /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
142    fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
143    where
144        A: AddAssign + Clone + Signed + ToPrimitive,
145        T: Data<Elem = A>;
146
147    /// Computes the [mean squared error] between `self` and `other`.
148    ///
149    /// ```text
150    ///        n
151    /// 1/n *  ∑  |aᵢ - bᵢ|²
152    ///       i=1
153    /// ```
154    ///
155    /// where `self` is `a` and `other` is `b`.
156    ///
157    /// The following **errors** may be returned:
158    ///
159    /// * `MultiInputError::EmptyInput` if `self` is empty
160    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
161    ///
162    /// **Panics** if the type cast from `A` to `f64` fails.
163    ///
164    /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
165    fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
166    where
167        A: AddAssign + Clone + Signed + ToPrimitive,
168        T: Data<Elem = A>;
169
170    /// Computes the unnormalized [root-mean-square error] between `self` and `other`.
171    ///
172    /// ```text
173    /// √ mse(a, b)
174    /// ```
175    ///
176    /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
177    ///
178    /// The following **errors** may be returned:
179    ///
180    /// * `MultiInputError::EmptyInput` if `self` is empty
181    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
182    ///
183    /// **Panics** if the type cast from `A` to `f64` fails.
184    ///
185    /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
186    fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
187    where
188        A: AddAssign + Clone + Signed + ToPrimitive,
189        T: Data<Elem = A>;
190
191    /// Computes the [peak signal-to-noise ratio] between `self` and `other`.
192    ///
193    /// ```text
194    /// 10 * log10(maxv^2 / mse(a, b))
195    /// ```
196    ///
197    /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
198    /// and `maxv` is the maximum possible value either array can take.
199    ///
200    /// The following **errors** may be returned:
201    ///
202    /// * `MultiInputError::EmptyInput` if `self` is empty
203    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
204    ///
205    /// **Panics** if the type cast from `A` to `f64` fails.
206    ///
207    /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
208    fn peak_signal_to_noise_ratio<T>(
209        &self,
210        other: &ArrayBase<T, D>,
211        maxv: A,
212    ) -> Result<f64, MultiInputError>
213    where
214        A: AddAssign + Clone + Signed + ToPrimitive,
215        T: Data<Elem = A>;
216
217    private_decl! {}
218}
219
220impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
221where
222    S: Data<Elem = A>,
223    D: Dimension,
224{
225    fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
226    where
227        A: PartialEq,
228        T: Data<Elem = A>,
229    {
230        return_err_if_empty!(self);
231        return_err_unless_same_shape!(self, other);
232
233        let mut count = 0;
234
235        Zip::from(self).and(other).for_each(|a, b| {
236            if a == b {
237                count += 1;
238            }
239        });
240
241        Ok(count)
242    }
243
244    fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
245    where
246        A: PartialEq,
247        T: Data<Elem = A>,
248    {
249        self.count_eq(other).map(|n_eq| self.len() - n_eq)
250    }
251
252    fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
253    where
254        A: AddAssign + Clone + Signed,
255        T: Data<Elem = A>,
256    {
257        return_err_if_empty!(self);
258        return_err_unless_same_shape!(self, other);
259
260        let mut result = A::zero();
261
262        Zip::from(self).and(other).for_each(|self_i, other_i| {
263            let (a, b) = (self_i.clone(), other_i.clone());
264            let diff = a - b;
265            result += diff.clone() * diff;
266        });
267
268        Ok(result)
269    }
270
271    fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
272    where
273        A: AddAssign + Clone + Signed + ToPrimitive,
274        T: Data<Elem = A>,
275    {
276        let sq_l2_dist = self
277            .sq_l2_dist(other)?
278            .to_f64()
279            .expect("failed cast from type A to f64");
280
281        Ok(sq_l2_dist.sqrt())
282    }
283
284    fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
285    where
286        A: AddAssign + Clone + Signed,
287        T: Data<Elem = A>,
288    {
289        return_err_if_empty!(self);
290        return_err_unless_same_shape!(self, other);
291
292        let mut result = A::zero();
293
294        Zip::from(self).and(other).for_each(|self_i, other_i| {
295            let (a, b) = (self_i.clone(), other_i.clone());
296            result += (a - b).abs();
297        });
298
299        Ok(result)
300    }
301
302    fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
303    where
304        A: Clone + PartialOrd + Signed,
305        T: Data<Elem = A>,
306    {
307        return_err_if_empty!(self);
308        return_err_unless_same_shape!(self, other);
309
310        let mut max = A::zero();
311
312        Zip::from(self).and(other).for_each(|self_i, other_i| {
313            let (a, b) = (self_i.clone(), other_i.clone());
314            let diff = (a - b).abs();
315            if diff > max {
316                max = diff;
317            }
318        });
319
320        Ok(max)
321    }
322
323    fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
324    where
325        A: AddAssign + Clone + Signed + ToPrimitive,
326        T: Data<Elem = A>,
327    {
328        let l1_dist = self
329            .l1_dist(other)?
330            .to_f64()
331            .expect("failed cast from type A to f64");
332        let n = self.len() as f64;
333
334        Ok(l1_dist / n)
335    }
336
337    fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
338    where
339        A: AddAssign + Clone + Signed + ToPrimitive,
340        T: Data<Elem = A>,
341    {
342        let sq_l2_dist = self
343            .sq_l2_dist(other)?
344            .to_f64()
345            .expect("failed cast from type A to f64");
346        let n = self.len() as f64;
347
348        Ok(sq_l2_dist / n)
349    }
350
351    fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
352    where
353        A: AddAssign + Clone + Signed + ToPrimitive,
354        T: Data<Elem = A>,
355    {
356        let msd = self.mean_sq_err(other)?;
357        Ok(msd.sqrt())
358    }
359
360    fn peak_signal_to_noise_ratio<T>(
361        &self,
362        other: &ArrayBase<T, D>,
363        maxv: A,
364    ) -> Result<f64, MultiInputError>
365    where
366        A: AddAssign + Clone + Signed + ToPrimitive,
367        T: Data<Elem = A>,
368    {
369        let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
370        let msd = self.mean_sq_err(&other)?;
371        let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
372
373        Ok(psnr)
374    }
375
376    private_impl! {}
377}