ndarray_stats/deviation.rs
1use ndarray::{ArrayBase, Data, Dimension, Zip};
2use num_traits::{Signed, ToPrimitive};
3use std::convert::Into;
4use std::ops::AddAssign;
5
6use crate::errors::MultiInputError;
7
8/// An extension trait for `ArrayBase` providing functions
9/// to compute different deviation measures.
10pub trait DeviationExt<A, S, D>
11where
12 S: Data<Elem = A>,
13 D: Dimension,
14{
15 /// Counts the number of indices at which the elements of the arrays `self`
16 /// and `other` are equal.
17 ///
18 /// The following **errors** may be returned:
19 ///
20 /// * `MultiInputError::EmptyInput` if `self` is empty
21 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22 fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
23 where
24 A: PartialEq,
25 T: Data<Elem = A>;
26
27 /// Counts the number of indices at which the elements of the arrays `self`
28 /// and `other` are not equal.
29 ///
30 /// The following **errors** may be returned:
31 ///
32 /// * `MultiInputError::EmptyInput` if `self` is empty
33 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
34 fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
35 where
36 A: PartialEq,
37 T: Data<Elem = A>;
38
39 /// Computes the [squared L2 distance] between `self` and `other`.
40 ///
41 /// ```text
42 /// n
43 /// ∑ |aᵢ - bᵢ|²
44 /// i=1
45 /// ```
46 ///
47 /// where `self` is `a` and `other` is `b`.
48 ///
49 /// The following **errors** may be returned:
50 ///
51 /// * `MultiInputError::EmptyInput` if `self` is empty
52 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
53 ///
54 /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
55 fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
56 where
57 A: AddAssign + Clone + Signed,
58 T: Data<Elem = A>;
59
60 /// Computes the [L2 distance] between `self` and `other`.
61 ///
62 /// ```text
63 /// n
64 /// √ ( ∑ |aᵢ - bᵢ|² )
65 /// i=1
66 /// ```
67 ///
68 /// where `self` is `a` and `other` is `b`.
69 ///
70 /// The following **errors** may be returned:
71 ///
72 /// * `MultiInputError::EmptyInput` if `self` is empty
73 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
74 ///
75 /// **Panics** if the type cast from `A` to `f64` fails.
76 ///
77 /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
78 fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
79 where
80 A: AddAssign + Clone + Signed + ToPrimitive,
81 T: Data<Elem = A>;
82
83 /// Computes the [L1 distance] between `self` and `other`.
84 ///
85 /// ```text
86 /// n
87 /// ∑ |aᵢ - bᵢ|
88 /// i=1
89 /// ```
90 ///
91 /// where `self` is `a` and `other` is `b`.
92 ///
93 /// The following **errors** may be returned:
94 ///
95 /// * `MultiInputError::EmptyInput` if `self` is empty
96 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
97 ///
98 /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
99 fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
100 where
101 A: AddAssign + Clone + Signed,
102 T: Data<Elem = A>;
103
104 /// Computes the [L∞ distance] between `self` and `other`.
105 ///
106 /// ```text
107 /// max(|aᵢ - bᵢ|)
108 /// ᵢ
109 /// ```
110 ///
111 /// where `self` is `a` and `other` is `b`.
112 ///
113 /// The following **errors** may be returned:
114 ///
115 /// * `MultiInputError::EmptyInput` if `self` is empty
116 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
117 ///
118 /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
119 fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
120 where
121 A: Clone + PartialOrd + Signed,
122 T: Data<Elem = A>;
123
124 /// Computes the [mean absolute error] between `self` and `other`.
125 ///
126 /// ```text
127 /// n
128 /// 1/n * ∑ |aᵢ - bᵢ|
129 /// i=1
130 /// ```
131 ///
132 /// where `self` is `a` and `other` is `b`.
133 ///
134 /// The following **errors** may be returned:
135 ///
136 /// * `MultiInputError::EmptyInput` if `self` is empty
137 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
138 ///
139 /// **Panics** if the type cast from `A` to `f64` fails.
140 ///
141 /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
142 fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
143 where
144 A: AddAssign + Clone + Signed + ToPrimitive,
145 T: Data<Elem = A>;
146
147 /// Computes the [mean squared error] between `self` and `other`.
148 ///
149 /// ```text
150 /// n
151 /// 1/n * ∑ |aᵢ - bᵢ|²
152 /// i=1
153 /// ```
154 ///
155 /// where `self` is `a` and `other` is `b`.
156 ///
157 /// The following **errors** may be returned:
158 ///
159 /// * `MultiInputError::EmptyInput` if `self` is empty
160 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
161 ///
162 /// **Panics** if the type cast from `A` to `f64` fails.
163 ///
164 /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
165 fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
166 where
167 A: AddAssign + Clone + Signed + ToPrimitive,
168 T: Data<Elem = A>;
169
170 /// Computes the unnormalized [root-mean-square error] between `self` and `other`.
171 ///
172 /// ```text
173 /// √ mse(a, b)
174 /// ```
175 ///
176 /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
177 ///
178 /// The following **errors** may be returned:
179 ///
180 /// * `MultiInputError::EmptyInput` if `self` is empty
181 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
182 ///
183 /// **Panics** if the type cast from `A` to `f64` fails.
184 ///
185 /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
186 fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
187 where
188 A: AddAssign + Clone + Signed + ToPrimitive,
189 T: Data<Elem = A>;
190
191 /// Computes the [peak signal-to-noise ratio] between `self` and `other`.
192 ///
193 /// ```text
194 /// 10 * log10(maxv^2 / mse(a, b))
195 /// ```
196 ///
197 /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
198 /// and `maxv` is the maximum possible value either array can take.
199 ///
200 /// The following **errors** may be returned:
201 ///
202 /// * `MultiInputError::EmptyInput` if `self` is empty
203 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
204 ///
205 /// **Panics** if the type cast from `A` to `f64` fails.
206 ///
207 /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
208 fn peak_signal_to_noise_ratio<T>(
209 &self,
210 other: &ArrayBase<T, D>,
211 maxv: A,
212 ) -> Result<f64, MultiInputError>
213 where
214 A: AddAssign + Clone + Signed + ToPrimitive,
215 T: Data<Elem = A>;
216
217 private_decl! {}
218}
219
220impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
221where
222 S: Data<Elem = A>,
223 D: Dimension,
224{
225 fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
226 where
227 A: PartialEq,
228 T: Data<Elem = A>,
229 {
230 return_err_if_empty!(self);
231 return_err_unless_same_shape!(self, other);
232
233 let mut count = 0;
234
235 Zip::from(self).and(other).for_each(|a, b| {
236 if a == b {
237 count += 1;
238 }
239 });
240
241 Ok(count)
242 }
243
244 fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
245 where
246 A: PartialEq,
247 T: Data<Elem = A>,
248 {
249 self.count_eq(other).map(|n_eq| self.len() - n_eq)
250 }
251
252 fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
253 where
254 A: AddAssign + Clone + Signed,
255 T: Data<Elem = A>,
256 {
257 return_err_if_empty!(self);
258 return_err_unless_same_shape!(self, other);
259
260 let mut result = A::zero();
261
262 Zip::from(self).and(other).for_each(|self_i, other_i| {
263 let (a, b) = (self_i.clone(), other_i.clone());
264 let diff = a - b;
265 result += diff.clone() * diff;
266 });
267
268 Ok(result)
269 }
270
271 fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
272 where
273 A: AddAssign + Clone + Signed + ToPrimitive,
274 T: Data<Elem = A>,
275 {
276 let sq_l2_dist = self
277 .sq_l2_dist(other)?
278 .to_f64()
279 .expect("failed cast from type A to f64");
280
281 Ok(sq_l2_dist.sqrt())
282 }
283
284 fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
285 where
286 A: AddAssign + Clone + Signed,
287 T: Data<Elem = A>,
288 {
289 return_err_if_empty!(self);
290 return_err_unless_same_shape!(self, other);
291
292 let mut result = A::zero();
293
294 Zip::from(self).and(other).for_each(|self_i, other_i| {
295 let (a, b) = (self_i.clone(), other_i.clone());
296 result += (a - b).abs();
297 });
298
299 Ok(result)
300 }
301
302 fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
303 where
304 A: Clone + PartialOrd + Signed,
305 T: Data<Elem = A>,
306 {
307 return_err_if_empty!(self);
308 return_err_unless_same_shape!(self, other);
309
310 let mut max = A::zero();
311
312 Zip::from(self).and(other).for_each(|self_i, other_i| {
313 let (a, b) = (self_i.clone(), other_i.clone());
314 let diff = (a - b).abs();
315 if diff > max {
316 max = diff;
317 }
318 });
319
320 Ok(max)
321 }
322
323 fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
324 where
325 A: AddAssign + Clone + Signed + ToPrimitive,
326 T: Data<Elem = A>,
327 {
328 let l1_dist = self
329 .l1_dist(other)?
330 .to_f64()
331 .expect("failed cast from type A to f64");
332 let n = self.len() as f64;
333
334 Ok(l1_dist / n)
335 }
336
337 fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
338 where
339 A: AddAssign + Clone + Signed + ToPrimitive,
340 T: Data<Elem = A>,
341 {
342 let sq_l2_dist = self
343 .sq_l2_dist(other)?
344 .to_f64()
345 .expect("failed cast from type A to f64");
346 let n = self.len() as f64;
347
348 Ok(sq_l2_dist / n)
349 }
350
351 fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
352 where
353 A: AddAssign + Clone + Signed + ToPrimitive,
354 T: Data<Elem = A>,
355 {
356 let msd = self.mean_sq_err(other)?;
357 Ok(msd.sqrt())
358 }
359
360 fn peak_signal_to_noise_ratio<T>(
361 &self,
362 other: &ArrayBase<T, D>,
363 maxv: A,
364 ) -> Result<f64, MultiInputError>
365 where
366 A: AddAssign + Clone + Signed + ToPrimitive,
367 T: Data<Elem = A>,
368 {
369 let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
370 let msd = self.mean_sq_err(&other)?;
371 let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
372
373 Ok(psnr)
374 }
375
376 private_impl! {}
377}