ndarray_stats/
entropy.rs

1//! Information theory (e.g. entropy, KL divergence, etc.).
2use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
4use num_traits::Float;
5
6/// Extension trait for `ArrayBase` providing methods
7/// to compute information theory quantities
8/// (e.g. entropy, Kullback–Leibler divergence, etc.).
9pub trait EntropyExt<A, S, D>
10where
11    S: Data<Elem = A>,
12    D: Dimension,
13{
14    /// Computes the [entropy] *S* of the array values, defined as
15    ///
16    /// ```text
17    ///       n
18    /// S = - ∑ xᵢ ln(xᵢ)
19    ///      i=1
20    /// ```
21    ///
22    /// If the array is empty, `Err(EmptyInput)` is returned.
23    ///
24    /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
25    ///
26    /// ## Remarks
27    ///
28    /// The entropy is a measure used in [Information Theory]
29    /// to describe a probability distribution: it only make sense
30    /// when the array values sum to 1, with each entry between
31    /// 0 and 1 (extremes included).
32    ///
33    /// The array values are **not** normalised by this function before
34    /// computing the entropy to avoid introducing potentially
35    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
36    ///
37    /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0.
38    ///
39    /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
40    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
41    fn entropy(&self) -> Result<A, EmptyInput>
42    where
43        A: Float;
44
45    /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays,
46    /// where `self`=*p*.
47    ///
48    /// The Kullback-Leibler divergence is defined as:
49    ///
50    /// ```text
51    ///              n
52    /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ)
53    ///             i=1
54    /// ```
55    ///
56    /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
57    /// If the array shapes are not identical,
58    /// `Err(MultiInputError::ShapeMismatch)` is returned.
59    ///
60    /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
61    /// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
62    ///
63    /// ## Remarks
64    ///
65    /// The Kullback-Leibler divergence is a measure used in [Information Theory]
66    /// to describe the relationship between two probability distribution: it only make sense
67    /// when each array sums to 1 with entries between 0 and 1 (extremes included).
68    ///
69    /// The array values are **not** normalised by this function before
70    /// computing the entropy to avoid introducing potentially
71    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
72    ///
73    /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0.
74    ///
75    /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
76    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
77    fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
78    where
79        S2: Data<Elem = A>,
80        A: Float;
81
82    /// Computes the [cross entropy] *H(p,q)* between two arrays,
83    /// where `self`=*p*.
84    ///
85    /// The cross entropy is defined as:
86    ///
87    /// ```text
88    ///            n
89    /// H(p,q) = - ∑ pᵢ ln(qᵢ)
90    ///           i=1
91    /// ```
92    ///
93    /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
94    /// If the array shapes are not identical,
95    /// `Err(MultiInputError::ShapeMismatch)` is returned.
96    ///
97    /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
98    /// is a panic cause for `A`.
99    ///
100    /// ## Remarks
101    ///
102    /// The cross entropy is a measure used in [Information Theory]
103    /// to describe the relationship between two probability distributions: it only makes sense
104    /// when each array sums to 1 with entries between 0 and 1 (extremes included).
105    ///
106    /// The array values are **not** normalised by this function before
107    /// computing the entropy to avoid introducing potentially
108    /// unnecessary numerical errors (e.g. if the array were to be already normalised).
109    ///
110    /// The cross entropy is often used as an objective/loss function in
111    /// [optimization problems], including [machine learning].
112    ///
113    /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0.
114    ///
115    /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy
116    /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
117    /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
118    /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
119    fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
120    where
121        S2: Data<Elem = A>,
122        A: Float;
123
124    private_decl! {}
125}
126
127impl<A, S, D> EntropyExt<A, S, D> for ArrayBase<S, D>
128where
129    S: Data<Elem = A>,
130    D: Dimension,
131{
132    fn entropy(&self) -> Result<A, EmptyInput>
133    where
134        A: Float,
135    {
136        if self.is_empty() {
137            Err(EmptyInput)
138        } else {
139            let entropy = -self
140                .mapv(|x| {
141                    if x == A::zero() {
142                        A::zero()
143                    } else {
144                        x * x.ln()
145                    }
146                })
147                .sum();
148            Ok(entropy)
149        }
150    }
151
152    fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
153    where
154        A: Float,
155        S2: Data<Elem = A>,
156    {
157        if self.is_empty() {
158            return Err(MultiInputError::EmptyInput);
159        }
160        if self.shape() != q.shape() {
161            return Err(ShapeMismatch {
162                first_shape: self.shape().to_vec(),
163                second_shape: q.shape().to_vec(),
164            }
165            .into());
166        }
167
168        let mut temp = Array::zeros(self.raw_dim());
169        Zip::from(&mut temp)
170            .and(self)
171            .and(q)
172            .for_each(|result, &p, &q| {
173                *result = {
174                    if p == A::zero() {
175                        A::zero()
176                    } else {
177                        p * (q / p).ln()
178                    }
179                }
180            });
181        let kl_divergence = -temp.sum();
182        Ok(kl_divergence)
183    }
184
185    fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
186    where
187        S2: Data<Elem = A>,
188        A: Float,
189    {
190        if self.is_empty() {
191            return Err(MultiInputError::EmptyInput);
192        }
193        if self.shape() != q.shape() {
194            return Err(ShapeMismatch {
195                first_shape: self.shape().to_vec(),
196                second_shape: q.shape().to_vec(),
197            }
198            .into());
199        }
200
201        let mut temp = Array::zeros(self.raw_dim());
202        Zip::from(&mut temp)
203            .and(self)
204            .and(q)
205            .for_each(|result, &p, &q| {
206                *result = {
207                    if p == A::zero() {
208                        A::zero()
209                    } else {
210                        p * q.ln()
211                    }
212                }
213            });
214        let cross_entropy = -temp.sum();
215        Ok(cross_entropy)
216    }
217
218    private_impl! {}
219}
220
221#[cfg(test)]
222mod tests {
223    use super::EntropyExt;
224    use crate::errors::{EmptyInput, MultiInputError};
225    use approx::assert_abs_diff_eq;
226    use ndarray::{array, Array1};
227    use noisy_float::types::n64;
228    use std::f64;
229
230    #[test]
231    fn test_entropy_with_nan_values() {
232        let a = array![f64::NAN, 1.];
233        assert!(a.entropy().unwrap().is_nan());
234    }
235
236    #[test]
237    fn test_entropy_with_empty_array_of_floats() {
238        let a: Array1<f64> = array![];
239        assert_eq!(a.entropy(), Err(EmptyInput));
240    }
241
242    #[test]
243    fn test_entropy_with_array_of_floats() {
244        // Array of probability values - normalized and positive.
245        let a: Array1<f64> = array![
246            0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396,
247            0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418,
248            0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668,
249            0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495,
250            0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588,
251            0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634,
252            0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204,
253            0.01866295,
254        ];
255        // Computed using scipy.stats.entropy
256        let expected_entropy = 3.721606155686918;
257
258        assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6);
259    }
260
261    #[test]
262    fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
263        let a = array![f64::NAN, 1.];
264        let b = array![2., 1.];
265        assert!(a.cross_entropy(&b)?.is_nan());
266        assert!(b.cross_entropy(&a)?.is_nan());
267        assert!(a.kl_divergence(&b)?.is_nan());
268        assert!(b.kl_divergence(&a)?.is_nan());
269        Ok(())
270    }
271
272    #[test]
273    fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() {
274        let p = array![f64::NAN, 1.];
275        let q = array![2., 1., 5.];
276        assert!(q.cross_entropy(&p).is_err());
277        assert!(p.cross_entropy(&q).is_err());
278        assert!(q.kl_divergence(&p).is_err());
279        assert!(p.kl_divergence(&q).is_err());
280    }
281
282    #[test]
283    fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() {
284        // p: 3x2, 6 elements
285        let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]];
286        // q: 2x3, 6 elements
287        let q = array![[2., 1., 5.], [1., 1., 7.],];
288        assert!(q.cross_entropy(&p).is_err());
289        assert!(p.cross_entropy(&q).is_err());
290        assert!(q.kl_divergence(&p).is_err());
291        assert!(p.kl_divergence(&q).is_err());
292    }
293
294    #[test]
295    fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
296        let p: Array1<f64> = array![];
297        let q: Array1<f64> = array![];
298        assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
299        assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
300    }
301
302    #[test]
303    fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
304        let p = array![1.];
305        let q = array![-1.];
306        let cross_entropy: f64 = p.cross_entropy(&q)?;
307        let kl_divergence: f64 = p.kl_divergence(&q)?;
308        assert!(cross_entropy.is_nan());
309        assert!(kl_divergence.is_nan());
310        Ok(())
311    }
312
313    #[test]
314    #[should_panic]
315    fn test_cross_entropy_with_noisy_negative_qs() {
316        let p = array![n64(1.)];
317        let q = array![n64(-1.)];
318        let _ = p.cross_entropy(&q);
319    }
320
321    #[test]
322    #[should_panic]
323    fn test_kl_with_noisy_negative_qs() {
324        let p = array![n64(1.)];
325        let q = array![n64(-1.)];
326        let _ = p.kl_divergence(&q);
327    }
328
329    #[test]
330    fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
331        let p = array![0., 0.];
332        let q = array![0., 0.5];
333        assert_eq!(p.cross_entropy(&q)?, 0.);
334        assert_eq!(p.kl_divergence(&q)?, 0.);
335        Ok(())
336    }
337
338    #[test]
339    fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
340    ) -> Result<(), MultiInputError> {
341        let p = array![0.5, 0.5];
342        let mut q = array![0.5, 0.];
343        assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
344        assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
345        Ok(())
346    }
347
348    #[test]
349    fn test_cross_entropy() -> Result<(), MultiInputError> {
350        // Arrays of probability values - normalized and positive.
351        let p: Array1<f64> = array![
352            0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
353            0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174,
354            0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246,
355            0.00727477, 0.01004402, 0.01854399, 0.03504082,
356        ];
357        let q: Array1<f64> = array![
358            0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812,
359            0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292,
360            0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064,
361            0.01813342, 0.0007763, 0.0735472, 0.05857833,
362        ];
363        // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
364        let expected_cross_entropy = 3.385347705020779;
365
366        assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
367        Ok(())
368    }
369
370    #[test]
371    fn test_kl() -> Result<(), MultiInputError> {
372        // Arrays of probability values - normalized and positive.
373        let p: Array1<f64> = array![
374            0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
375            0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498,
376            0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487,
377            0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063,
378            0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151,
379            0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831,
380            0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394,
381            0.01108706,
382        ];
383        let q: Array1<f64> = array![
384            0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717,
385            0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548,
386            0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771,
387            0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438,
388            0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686,
389            0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325,
390            0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769,
391            0.02082707,
392        ];
393        // Computed using scipy.stats.entropy(p, q)
394        let expected_kl = 0.3555862567800096;
395
396        assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
397        Ok(())
398    }
399}