nabla_ml/
nab_loss.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use crate::nab_array::NDArray;

pub trait Loss {
    fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64;
    fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray;
}

// pub struct CategoricalCrossentropy;

// impl Loss for CategoricalCrossentropy {
//     fn forward(&self, y_pred: &NDArray, y_true: &NDArray) -> f64 {
//         // Compute cross-entropy loss
//         let epsilon = 1e-8;
//         let clipped_pred = y_pred.clip(epsilon, 1.0 - epsilon);
//         -y_true.multiply(&clipped_pred.log()).sum() / y_true.shape()[0] as f64
//     }

//     fn backward(&self, y_pred: &NDArray, y_true: &NDArray) -> NDArray {
//         NDArray::crossentropy_nabla(y_pred, y_true)
//     }
// }

pub struct NabLoss;

impl NabLoss {
    /// Calculates the Mean Squared Error (MSE) between two arrays
    ///
    /// # Arguments
    ///
    /// * `y_true` - The true values as an NDArray.
    /// * `y_pred` - The predicted values as an NDArray.
    ///
    /// # Returns
    ///
    /// The MSE as a f64.
    #[allow(dead_code)]
    pub fn mean_squared_error(y_true: &NDArray, y_pred: &NDArray) -> f64 {
        assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
        let diff = y_true.data().iter().zip(y_pred.data().iter()).map(|(t, p)| (t - p).powi(2)).collect::<Vec<f64>>();
        diff.iter().sum::<f64>() / y_true.data().len() as f64
    }

    /// Calculates the Cross-Entropy Loss between two arrays
    ///
    /// # Arguments
    ///
    /// * `y_true` - The true values as an NDArray (one-hot encoded).
    /// * `y_pred` - The predicted probabilities as an NDArray.
    ///
    /// # Returns
    ///
    /// The Cross-Entropy Loss as a f64.    
    #[allow(dead_code)]
    pub fn cross_entropy_loss(y_true: &NDArray, y_pred: &NDArray) -> f64 {
        assert_eq!(y_true.shape(), y_pred.shape(), "Shapes of y_true and y_pred must match");
        let epsilon = 1e-8;
        let clipped_pred = y_pred.data().iter().map(|&p| p.clamp(epsilon, 1.0 - epsilon)).collect::<Vec<f64>>();
        let loss = y_true.data().iter().zip(clipped_pred.iter()).map(|(t, p)| t * p.ln()).collect::<Vec<f64>>();
        -loss.iter().sum::<f64>() / y_true.shape()[0] as f64
    }

}

impl Loss for NabLoss {
    fn forward(&self, predictions: &NDArray, targets: &NDArray) -> f64 {
        NabLoss::mean_squared_error(predictions, targets)
    }

    fn backward(&self, predictions: &NDArray, targets: &NDArray) -> NDArray {
        // Default to MSE gradient
        predictions.subtract(targets).multiply_scalar(2.0 / predictions.shape()[0] as f64)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    pub fn test_mean_squared_error() {
        let y_true = NDArray::from_vec(vec![1.0, 0.0, 1.0, 1.0]);
        let y_pred = NDArray::from_vec(vec![0.9, 0.2, 0.8, 0.6]);
        let mse = NabLoss::mean_squared_error(&y_true, &y_pred);
        assert!((mse - 0.0625).abs() < 1e-4);
    }

    #[test]
    pub fn test_cross_entropy_loss() {
        let y_true = NDArray::from_matrix(vec![
            vec![1.0, 0.0, 0.0],
            vec![0.0, 1.0, 0.0],
            vec![0.0, 0.0, 1.0],
        ]);
        let y_pred = NDArray::from_matrix(vec![
            vec![0.7, 0.2, 0.1],
            vec![0.1, 0.8, 0.1],
            vec![0.05, 0.15, 0.8],
        ]);
        let cross_entropy = NabLoss::cross_entropy_loss(&y_true, &y_pred);
        assert!((cross_entropy - 0.267654016).abs() < 1e-4);
    }
}