1use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
4use num_traits::Float;
5
6pub trait EntropyExt<A, S, D>
10where
11 S: Data<Elem = A>,
12 D: Dimension,
13{
14 fn entropy(&self) -> Result<A, EmptyInput>
42 where
43 A: Float;
44
45 fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
78 where
79 S2: Data<Elem = A>,
80 A: Float;
81
82 fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
120 where
121 S2: Data<Elem = A>,
122 A: Float;
123
124 private_decl! {}
125}
126
127impl<A, S, D> EntropyExt<A, S, D> for ArrayBase<S, D>
128where
129 S: Data<Elem = A>,
130 D: Dimension,
131{
132 fn entropy(&self) -> Result<A, EmptyInput>
133 where
134 A: Float,
135 {
136 if self.is_empty() {
137 Err(EmptyInput)
138 } else {
139 let entropy = -self
140 .mapv(|x| {
141 if x == A::zero() {
142 A::zero()
143 } else {
144 x * x.ln()
145 }
146 })
147 .sum();
148 Ok(entropy)
149 }
150 }
151
152 fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
153 where
154 A: Float,
155 S2: Data<Elem = A>,
156 {
157 if self.is_empty() {
158 return Err(MultiInputError::EmptyInput);
159 }
160 if self.shape() != q.shape() {
161 return Err(ShapeMismatch {
162 first_shape: self.shape().to_vec(),
163 second_shape: q.shape().to_vec(),
164 }
165 .into());
166 }
167
168 let mut temp = Array::zeros(self.raw_dim());
169 Zip::from(&mut temp)
170 .and(self)
171 .and(q)
172 .for_each(|result, &p, &q| {
173 *result = {
174 if p == A::zero() {
175 A::zero()
176 } else {
177 p * (q / p).ln()
178 }
179 }
180 });
181 let kl_divergence = -temp.sum();
182 Ok(kl_divergence)
183 }
184
185 fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
186 where
187 S2: Data<Elem = A>,
188 A: Float,
189 {
190 if self.is_empty() {
191 return Err(MultiInputError::EmptyInput);
192 }
193 if self.shape() != q.shape() {
194 return Err(ShapeMismatch {
195 first_shape: self.shape().to_vec(),
196 second_shape: q.shape().to_vec(),
197 }
198 .into());
199 }
200
201 let mut temp = Array::zeros(self.raw_dim());
202 Zip::from(&mut temp)
203 .and(self)
204 .and(q)
205 .for_each(|result, &p, &q| {
206 *result = {
207 if p == A::zero() {
208 A::zero()
209 } else {
210 p * q.ln()
211 }
212 }
213 });
214 let cross_entropy = -temp.sum();
215 Ok(cross_entropy)
216 }
217
218 private_impl! {}
219}
220
221#[cfg(test)]
222mod tests {
223 use super::EntropyExt;
224 use crate::errors::{EmptyInput, MultiInputError};
225 use approx::assert_abs_diff_eq;
226 use ndarray::{array, Array1};
227 use noisy_float::types::n64;
228 use std::f64;
229
230 #[test]
231 fn test_entropy_with_nan_values() {
232 let a = array![f64::NAN, 1.];
233 assert!(a.entropy().unwrap().is_nan());
234 }
235
236 #[test]
237 fn test_entropy_with_empty_array_of_floats() {
238 let a: Array1<f64> = array![];
239 assert_eq!(a.entropy(), Err(EmptyInput));
240 }
241
242 #[test]
243 fn test_entropy_with_array_of_floats() {
244 let a: Array1<f64> = array![
246 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396,
247 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418,
248 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668,
249 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495,
250 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588,
251 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634,
252 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204,
253 0.01866295,
254 ];
255 let expected_entropy = 3.721606155686918;
257
258 assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6);
259 }
260
261 #[test]
262 fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
263 let a = array![f64::NAN, 1.];
264 let b = array![2., 1.];
265 assert!(a.cross_entropy(&b)?.is_nan());
266 assert!(b.cross_entropy(&a)?.is_nan());
267 assert!(a.kl_divergence(&b)?.is_nan());
268 assert!(b.kl_divergence(&a)?.is_nan());
269 Ok(())
270 }
271
272 #[test]
273 fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() {
274 let p = array![f64::NAN, 1.];
275 let q = array![2., 1., 5.];
276 assert!(q.cross_entropy(&p).is_err());
277 assert!(p.cross_entropy(&q).is_err());
278 assert!(q.kl_divergence(&p).is_err());
279 assert!(p.kl_divergence(&q).is_err());
280 }
281
282 #[test]
283 fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() {
284 let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]];
286 let q = array![[2., 1., 5.], [1., 1., 7.],];
288 assert!(q.cross_entropy(&p).is_err());
289 assert!(p.cross_entropy(&q).is_err());
290 assert!(q.kl_divergence(&p).is_err());
291 assert!(p.kl_divergence(&q).is_err());
292 }
293
294 #[test]
295 fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
296 let p: Array1<f64> = array![];
297 let q: Array1<f64> = array![];
298 assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
299 assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
300 }
301
302 #[test]
303 fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
304 let p = array![1.];
305 let q = array![-1.];
306 let cross_entropy: f64 = p.cross_entropy(&q)?;
307 let kl_divergence: f64 = p.kl_divergence(&q)?;
308 assert!(cross_entropy.is_nan());
309 assert!(kl_divergence.is_nan());
310 Ok(())
311 }
312
313 #[test]
314 #[should_panic]
315 fn test_cross_entropy_with_noisy_negative_qs() {
316 let p = array![n64(1.)];
317 let q = array![n64(-1.)];
318 let _ = p.cross_entropy(&q);
319 }
320
321 #[test]
322 #[should_panic]
323 fn test_kl_with_noisy_negative_qs() {
324 let p = array![n64(1.)];
325 let q = array![n64(-1.)];
326 let _ = p.kl_divergence(&q);
327 }
328
329 #[test]
330 fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
331 let p = array![0., 0.];
332 let q = array![0., 0.5];
333 assert_eq!(p.cross_entropy(&q)?, 0.);
334 assert_eq!(p.kl_divergence(&q)?, 0.);
335 Ok(())
336 }
337
338 #[test]
339 fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
340 ) -> Result<(), MultiInputError> {
341 let p = array![0.5, 0.5];
342 let mut q = array![0.5, 0.];
343 assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
344 assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
345 Ok(())
346 }
347
348 #[test]
349 fn test_cross_entropy() -> Result<(), MultiInputError> {
350 let p: Array1<f64> = array![
352 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
353 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174,
354 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246,
355 0.00727477, 0.01004402, 0.01854399, 0.03504082,
356 ];
357 let q: Array1<f64> = array![
358 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812,
359 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292,
360 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064,
361 0.01813342, 0.0007763, 0.0735472, 0.05857833,
362 ];
363 let expected_cross_entropy = 3.385347705020779;
365
366 assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
367 Ok(())
368 }
369
370 #[test]
371 fn test_kl() -> Result<(), MultiInputError> {
372 let p: Array1<f64> = array![
374 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
375 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498,
376 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487,
377 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063,
378 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151,
379 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831,
380 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394,
381 0.01108706,
382 ];
383 let q: Array1<f64> = array![
384 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717,
385 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548,
386 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771,
387 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438,
388 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686,
389 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325,
390 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769,
391 0.02082707,
392 ];
393 let expected_kl = 0.3555862567800096;
395
396 assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
397 Ok(())
398 }
399}