use crate::nab_array::NDArray;
use crate::nab_activations::NablaActivation;
#[allow(dead_code)]
#[derive(Clone)]
pub struct NabLayer {
layer_type: String,
name: String,
input_shape: Vec<usize>,
output_shape: Vec<usize>,
pub weights: Option<NDArray>,
pub biases: Option<NDArray>,
input_cache: Option<NDArray>,
output_cache: Option<NDArray>,
trainable: bool,
pub weight_gradients: Option<NDArray>,
pub bias_gradients: Option<NDArray>,
activation: Option<String>,
dropout_rate: Option<f64>,
dropout_mask: Option<NDArray>,
epsilon: Option<f64>,
momentum: Option<f64>,
running_mean: Option<NDArray>,
running_var: Option<NDArray>,
batch_mean: Option<NDArray>,
batch_var: Option<NDArray>,
normalized: Option<NDArray>,
pub node_index: Option<usize>,
input_nodes: Option<Vec<usize>>,
}
#[allow(dead_code)]
impl NabLayer {
pub fn input(shape: Vec<usize>, name: Option<&str>) -> Self {
NabLayer {
layer_type: "Input".to_string(),
name: name.unwrap_or("input").to_string(),
input_shape: shape.clone(),
output_shape: shape,
weights: None,
biases: None,
input_cache: None,
output_cache: None,
trainable: false,
weight_gradients: None,
bias_gradients: None,
activation: None,
dropout_rate: None,
dropout_mask: None,
epsilon: None,
momentum: None,
running_mean: None,
running_var: None,
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
pub fn dense(
input_dim: usize,
units: usize,
activation: Option<&str>,
name: Option<&str>
) -> Self {
let scale = (2.0 / input_dim as f64).sqrt();
let weights = NDArray::randn_2d(input_dim, units)
.multiply_scalar(scale);
let biases = NDArray::zeros(vec![units]);
NabLayer {
layer_type: "Dense".to_string(),
name: name.unwrap_or("dense").to_string(),
input_shape: vec![input_dim],
output_shape: vec![units],
weights: Some(weights),
biases: Some(biases),
input_cache: None,
output_cache: None,
trainable: true,
weight_gradients: None,
bias_gradients: None,
activation: activation.map(|s| s.to_string()),
dropout_rate: None,
dropout_mask: None,
epsilon: None,
momentum: None,
running_mean: None,
running_var: None,
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
pub fn activation(activation_type: &str, input_shape: Vec<usize>, name: Option<&str>) -> Self {
NabLayer {
layer_type: "Activation".to_string(),
name: name.unwrap_or("activation").to_string(),
input_shape: input_shape.clone(),
output_shape: input_shape,
weights: None,
biases: None,
input_cache: None,
output_cache: None,
trainable: false,
weight_gradients: None,
bias_gradients: None,
activation: Some(activation_type.to_string()),
dropout_rate: None,
dropout_mask: None,
epsilon: None,
momentum: None,
running_mean: None,
running_var: None,
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
pub fn flatten(input_shape: Vec<usize>, name: Option<&str>) -> Self {
let flattened_size = input_shape.iter().product();
NabLayer {
layer_type: "Flatten".to_string(),
name: name.unwrap_or("flatten").to_string(),
input_shape: input_shape,
output_shape: vec![flattened_size],
weights: None,
biases: None,
input_cache: None,
output_cache: None,
trainable: false,
weight_gradients: None,
bias_gradients: None,
activation: None,
dropout_rate: None,
dropout_mask: None,
epsilon: None,
momentum: None,
running_mean: None,
running_var: None,
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
pub fn dropout(input_shape: Vec<usize>, rate: f64, name: Option<&str>) -> Self {
assert!(rate >= 0.0 && rate < 1.0, "Dropout rate must be between 0 and 1");
NabLayer {
layer_type: "Dropout".to_string(),
name: name.unwrap_or("dropout").to_string(),
input_shape: input_shape.clone(),
output_shape: input_shape,
weights: None,
biases: None,
input_cache: None,
output_cache: None,
trainable: false,
weight_gradients: None,
bias_gradients: None,
activation: None,
dropout_rate: Some(rate),
dropout_mask: None,
epsilon: None,
momentum: None,
running_mean: None,
running_var: None,
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
pub fn batch_norm(
input_shape: Vec<usize>,
epsilon: Option<f64>,
momentum: Option<f64>,
name: Option<&str>
) -> Self {
let features = input_shape[0];
let epsilon = epsilon.unwrap_or(1e-5);
let momentum = momentum.unwrap_or(0.99);
let gamma = NDArray::ones(features);
let beta = NDArray::zeros(vec![features]);
let running_mean = NDArray::zeros(vec![features]);
let running_var = NDArray::ones(features);
NabLayer {
layer_type: "BatchNorm".to_string(),
name: name.unwrap_or("batch_norm").to_string(),
input_shape: input_shape.clone(),
output_shape: input_shape,
weights: Some(gamma), biases: Some(beta), input_cache: None,
output_cache: None,
trainable: true,
weight_gradients: None, bias_gradients: None, activation: None,
dropout_rate: None,
dropout_mask: None,
epsilon: Some(epsilon),
momentum: Some(momentum),
running_mean: Some(running_mean),
running_var: Some(running_var),
batch_mean: None,
batch_var: None,
normalized: None,
node_index: None,
input_nodes: None,
}
}
fn broadcast_to_batch(&self, array: &NDArray, batch_size: usize) -> NDArray {
let features = array.data().len();
let mut broadcasted = Vec::with_capacity(batch_size * features);
for _ in 0..batch_size {
broadcasted.extend(array.data());
}
NDArray::new(broadcasted, vec![batch_size, features])
}
fn reshape_to_1d(&self, array: &NDArray) -> NDArray {
let features = array.data().len();
array.reshape(&[features]).expect("Failed to reshape to 1D")
}
pub fn forward(&mut self, input: &NDArray, training: bool) -> NDArray {
self.input_cache = Some(input.clone());
let output = match self.layer_type.as_str() {
"Dense" => {
let out = self.dense_forward(input);
self.output_cache = Some(out.clone());
out
},
"BatchNorm" => {
self.batch_norm_forward(input, training)
},
"Activation" => {
let out = match self.activation.as_ref().unwrap().as_str() {
"relu" => NablaActivation::relu_forward(input),
"sigmoid" => NablaActivation::sigmoid_forward(input),
"tanh" => NablaActivation::tanh_forward(input),
"softmax" => NablaActivation::softmax_forward(input, None),
_ => input.clone(),
};
self.output_cache = Some(out.clone());
out
},
"Flatten" => {
let out = self.flatten_forward(input, training);
self.output_cache = Some(out.clone());
out
},
"Dropout" => {
let out = self.dropout_forward(input, training);
self.output_cache = Some(out.clone());
out
},
_ => input.clone()
};
output
}
fn dense_forward(&self, input: &NDArray) -> NDArray {
let weights = self.weights.as_ref().unwrap();
let biases = self.biases.as_ref().unwrap();
let wx = input.dot(weights);
let batch_size = input.shape()[0];
let broadcasted_biases = NDArray::from_matrix(
vec![biases.data().to_vec(); batch_size]
);
let linear_output = wx.add(&broadcasted_biases);
let output = if let Some(act_type) = &self.activation {
match act_type.as_str() {
"relu" => NablaActivation::relu_forward(&linear_output),
"sigmoid" => NablaActivation::sigmoid_forward(&linear_output),
"tanh" => NablaActivation::tanh_forward(&linear_output),
"softmax" => NablaActivation::softmax_forward(&linear_output, None),
_ => panic!("Unknown activation type: {}", act_type),
}
} else {
linear_output
};
output.clone()
}
fn activation_forward(&mut self, input: &NDArray, _training: bool) -> NDArray {
self.input_cache = Some(input.clone());
let output = match self.activation.as_ref().unwrap().as_str() {
"relu" => NablaActivation::relu_forward(input),
"sigmoid" => NablaActivation::sigmoid_forward(input),
"tanh" => NablaActivation::tanh_forward(input),
"leaky_relu" => NablaActivation::leaky_relu_forward(input, None),
_ => panic!("Unknown activation type: {}", self.activation.as_ref().unwrap()),
};
self.output_cache = Some(output.clone());
output
}
fn flatten_forward(&mut self, input: &NDArray, _training: bool) -> NDArray {
self.input_cache = Some(input.clone());
let batch_size = input.shape()[0];
let flattened_size = self.output_shape[0];
let new_shape = vec![batch_size, flattened_size];
let output = input.reshape(&new_shape)
.expect("Failed to reshape in flatten forward");
self.output_cache = Some(output.clone());
output
}
fn dropout_forward(&mut self, input: &NDArray, training: bool) -> NDArray {
self.input_cache = Some(input.clone());
if !training || self.dropout_rate.unwrap() == 0.0 {
return input.clone();
}
let mask = NDArray::rand_uniform(input.shape())
.map(|x| if x > self.dropout_rate.unwrap() { 1.0 } else { 0.0 })
.multiply_scalar(1.0 / (1.0 - self.dropout_rate.unwrap()));
self.dropout_mask = Some(mask.clone());
let output = input.multiply(&mask);
self.output_cache = Some(output.clone());
output
}
fn batch_norm_forward(&mut self, input: &NDArray, training: bool) -> NDArray {
self.input_cache = Some(input.clone());
let batch_size = input.shape()[0];
let (mean, var) = if training {
let batch_mean = self.reshape_to_1d(&input.mean_axis(0));
let broadcasted_mean = self.broadcast_to_batch(&batch_mean, batch_size);
let centered = input.subtract(&broadcasted_mean);
let batch_var = self.reshape_to_1d(¢ered.multiply(¢ered).mean_axis(0));
if let (Some(running_mean), Some(running_var)) =
(&mut self.running_mean, &mut self.running_var)
{
let momentum = self.momentum.unwrap();
*running_mean = running_mean.multiply_scalar(momentum)
.add(&batch_mean.multiply_scalar(1.0 - momentum));
*running_var = running_var.multiply_scalar(momentum)
.add(&batch_var.multiply_scalar(1.0 - momentum));
}
self.batch_mean = Some(batch_mean.clone());
self.batch_var = Some(batch_var.clone());
(batch_mean, batch_var)
} else {
(self.running_mean.as_ref().unwrap().clone(),
self.running_var.as_ref().unwrap().clone())
};
let broadcasted_mean = self.broadcast_to_batch(&mean, batch_size);
let broadcasted_var = self.broadcast_to_batch(&var, batch_size);
let broadcasted_weights = self.broadcast_to_batch(self.weights.as_ref().unwrap(), batch_size);
let broadcasted_biases = self.broadcast_to_batch(self.biases.as_ref().unwrap(), batch_size);
let centered = input.subtract(&broadcasted_mean);
let std_dev = broadcasted_var.add_scalar(self.epsilon.unwrap()).sqrt();
let normalized = centered.divide(&std_dev);
self.normalized = Some(normalized.clone());
let output = normalized.multiply(&broadcasted_weights).add(&broadcasted_biases);
self.output_cache = Some(output.clone());
output
}
pub fn backward(&mut self, gradient: &NDArray) -> NDArray {
match self.layer_type.as_str() {
"Dense" => self.dense_backward(gradient),
"Input" => gradient.clone(),
"Activation" => self.activation_backward(gradient),
"Flatten" => self.flatten_backward(gradient),
"Dropout" => self.dropout_backward(gradient),
"BatchNorm" => self.batch_norm_backward(gradient),
_ => panic!("Unknown layer type: {}", self.layer_type),
}
}
fn dense_backward(&mut self, gradient: &NDArray) -> NDArray {
let input = self.input_cache.as_ref().unwrap();
let output = self.output_cache.as_ref().unwrap();
let weights = self.weights.as_ref().unwrap();
let act_gradient = if let Some(act_type) = &self.activation {
match act_type.as_str() {
"relu" => NablaActivation::relu_backward(gradient, output),
"sigmoid" => NablaActivation::sigmoid_backward(gradient, output),
"tanh" => NablaActivation::tanh_backward(gradient, output),
"softmax" => NablaActivation::softmax_backward(gradient, output),
_ => panic!("Unknown activation type: {}", act_type),
}
} else {
gradient.clone()
};
let input_t = input.transpose().expect("Failed to transpose input");
let weights_t = weights.transpose().expect("Failed to transpose weights");
self.weight_gradients = Some(input_t.dot(&act_gradient));
self.bias_gradients = Some(act_gradient.sum_axis(0).reshape(&[self.output_shape[0]])
.expect("Failed to reshape bias gradients"));
act_gradient.dot(&weights_t)
}
fn activation_backward(&mut self, gradient: &NDArray) -> NDArray {
let input = self.input_cache.as_ref().unwrap();
let output = self.output_cache.as_ref().unwrap();
match self.activation.as_ref().unwrap().as_str() {
"relu" => NablaActivation::relu_backward(gradient, input),
"sigmoid" => NablaActivation::sigmoid_backward(gradient, output),
"tanh" => NablaActivation::tanh_backward(gradient, output),
"leaky_relu" => NablaActivation::leaky_relu_backward(gradient, input, None),
_ => panic!("Unknown activation type: {}", self.activation.as_ref().unwrap()),
}
}
fn flatten_backward(&mut self, gradient: &NDArray) -> NDArray {
let original_shape = self.input_cache.as_ref().unwrap().shape();
gradient.reshape(original_shape)
.expect("Failed to reshape in flatten backward")
}
fn dropout_backward(&mut self, gradient: &NDArray) -> NDArray {
if let Some(mask) = &self.dropout_mask {
gradient.multiply(mask)
} else {
gradient.clone() }
}
#[allow(unused_variables)]
fn batch_norm_backward(&mut self, gradient: &NDArray) -> NDArray {
let input = self.input_cache.as_ref().unwrap();
let batch_size = input.shape()[0];
let weights = self.weights.as_ref().unwrap();
let normalized = self.normalized.as_ref().unwrap();
let mut broadcasted_weights = Vec::with_capacity(input.data().len());
for _ in 0..batch_size {
broadcasted_weights.extend(weights.data());
}
let broadcasted_weights = NDArray::new(broadcasted_weights, input.shape().to_vec());
self.weight_gradients = Some(gradient.multiply(normalized).sum_axis(0));
self.bias_gradients = Some(gradient.sum_axis(0));
let dx_normalized = gradient.multiply(&broadcasted_weights);
let std_dev = self.batch_var.as_ref().unwrap()
.add_scalar(self.epsilon.unwrap())
.sqrt();
let mut broadcasted_std = Vec::with_capacity(input.data().len());
for _ in 0..batch_size {
broadcasted_std.extend(std_dev.data());
}
let broadcasted_std = NDArray::new(broadcasted_std, input.shape().to_vec());
let dx = dx_normalized.divide(&broadcasted_std);
let mut broadcasted_mean = Vec::with_capacity(input.data().len());
for _ in 0..batch_size {
broadcasted_mean.extend(self.batch_mean.as_ref().unwrap().data());
}
let broadcasted_mean = NDArray::new(broadcasted_mean, input.shape().to_vec());
let centered = input.subtract(&broadcasted_mean);
dx.multiply_scalar(1.0 / batch_size as f64)
}
pub fn get_output_shape(&self) -> &[usize] {
&self.output_shape
}
pub fn get_name(&self) -> &str {
&self.name
}
pub fn is_trainable(&self) -> bool {
self.trainable
}
pub fn compute_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
match self.layer_type.as_str() {
"Dense" => {
if input_shape.len() > 1 {
vec![input_shape[0], self.output_shape[0]]
} else {
vec![self.output_shape[0]]
}
},
"Input" => {
input_shape.to_vec()
},
"Flatten" => {
let flat_size: usize = input_shape[1..].iter().product();
vec![input_shape[0], flat_size]
},
_ => {
if input_shape.len() > 1 {
let mut shape = vec![input_shape[0]];
shape.extend(self.output_shape.iter());
shape
} else {
self.output_shape.clone()
}
}
}
}
pub fn set_node_index(&mut self, index: usize) {
self.node_index = Some(index);
}
pub fn set_inputs(&mut self, inputs: Vec<usize>) {
self.input_nodes = Some(inputs);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_input_layer() {
let input = NabLayer::input(vec![784], Some("mnist_input"));
assert_eq!(input.get_name(), "mnist_input");
assert_eq!(input.get_output_shape(), &[784]);
assert!(!input.is_trainable());
let data = NDArray::from_matrix(vec![vec![1.0; 784]; 32]); let mut layer = NabLayer::input(vec![784], None);
let output = layer.forward(&data, true);
assert_eq!(output.shape(), vec![32, 784]);
assert_eq!(output.data(), data.data());
let gradient = NDArray::from_matrix(vec![vec![1.0; 784]; 32]);
let backward = layer.backward(&gradient);
assert_eq!(backward.data(), gradient.data());
}
#[test]
fn test_dense_layer() {
let dense = NabLayer::dense(784, 128, Some("relu"), Some("hidden_1"));
assert_eq!(dense.get_name(), "hidden_1");
assert_eq!(dense.get_output_shape(), &[128]);
assert!(dense.is_trainable());
let batch_size = 32;
let input = NDArray::from_matrix(vec![vec![0.1; 784]; batch_size]);
let mut layer = NabLayer::dense(784, 128, None, None);
let output = layer.forward(&input, true);
assert_eq!(output.shape(), vec![batch_size, 128]);
let gradient = NDArray::from_matrix(vec![vec![0.1; 128]; batch_size]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), vec![batch_size, 784]);
assert!(layer.weight_gradients.is_some());
assert!(layer.bias_gradients.is_some());
}
#[test]
fn test_activation_layer() {
let relu = NabLayer::activation("relu", vec![128], Some("relu_1"));
assert_eq!(relu.get_name(), "relu_1");
assert_eq!(relu.get_output_shape(), &[128]);
assert!(!relu.is_trainable());
let batch_size = 32;
let input = NDArray::from_matrix(vec![vec![-0.5, 0.0, 0.5]; batch_size]);
let mut layer = NabLayer::activation("relu", vec![3], None);
let output = layer.forward(&input, true);
assert_eq!(output.shape(), vec![batch_size, 3]);
for row in 0..batch_size {
assert_eq!(output.get_2d(row, 0), 0.0); assert_eq!(output.get_2d(row, 1), 0.0); assert_eq!(output.get_2d(row, 2), 0.5); }
let gradient = NDArray::from_matrix(vec![vec![1.0; 3]; batch_size]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), vec![batch_size, 3]);
for row in 0..batch_size {
assert_eq!(backward.get_2d(row, 0), 0.0); assert_eq!(backward.get_2d(row, 1), 0.0); assert_eq!(backward.get_2d(row, 2), 1.0); }
}
#[test]
fn test_dense_layer_with_activation() {
let dense = NabLayer::dense(3, 2, Some("relu"), Some("dense_relu"));
assert_eq!(dense.get_name(), "dense_relu");
assert_eq!(dense.get_output_shape(), &[2]);
assert!(dense.is_trainable());
let input = NDArray::from_matrix(vec![
vec![-1.0, 0.0, 1.0], vec![2.0, -2.0, 0.0], ]);
let mut layer = NabLayer::dense(3, 2, Some("relu"), None);
layer.weights = Some(NDArray::from_matrix(vec![
vec![1.0, -1.0], vec![-1.0, 1.0], vec![0.5, 0.5], ]));
layer.biases = Some(NDArray::from_vec(vec![0.0, 0.0]));
let output = layer.forward(&input, true);
assert_eq!(output.shape(), vec![2, 2]);
assert!(output.get_2d(0, 0) >= 0.0); assert!(output.get_2d(0, 1) >= 0.0); assert!(output.get_2d(1, 0) >= 0.0); assert!(output.get_2d(1, 1) >= 0.0); let gradient = NDArray::from_matrix(vec![
vec![1.0, 1.0],
vec![1.0, 1.0],
]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), vec![2, 3]);
let output_cache = layer.output_cache.as_ref().unwrap();
println!("Output shape: {:?}", output_cache.shape());
println!("Backward shape: {:?}", backward.shape());
let negative_outputs: Vec<(usize, usize)> = (0..2)
.flat_map(|i| (0..2).map(move |j| (i, j)))
.filter(|&(i, j)| output_cache.get_2d(i, j) <= 0.0)
.collect();
for (i, j) in negative_outputs {
println!("Checking gradient for negative output at ({}, {})", i, j);
println!("Output value: {}", output_cache.get_2d(i, j));
println!("Backward value: {}", backward.get_2d(i, j));
}
}
#[test]
fn test_flatten_layer() {
let flatten = NabLayer::flatten(vec![2, 3, 4], Some("flatten_1"));
assert_eq!(flatten.get_name(), "flatten_1");
assert_eq!(flatten.get_output_shape(), &[24]); assert!(!flatten.is_trainable());
let batch_size = 2;
let input = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], ]);
let mut layer = NabLayer::flatten(vec![2, 3], None);
let output = layer.forward(&input, true);
assert_eq!(output.shape(), vec![batch_size, 6]); for i in 0..batch_size {
for j in 0..6 {
assert_eq!(output.get_2d(i, j), input.get_2d(i, j));
}
}
let gradient = NDArray::from_matrix(vec![
vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
vec![0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), input.shape());
for i in 0..batch_size {
for j in 0..6 {
assert_eq!(backward.get_2d(i, j), gradient.get_2d(i, j));
}
}
}
#[test]
fn test_dropout_layer() {
let dropout = NabLayer::dropout(vec![100], 0.5, Some("dropout_1"));
assert_eq!(dropout.get_name(), "dropout_1");
assert_eq!(dropout.get_output_shape(), &[100]);
assert!(!dropout.is_trainable());
let batch_size = 10;
let input = NDArray::from_matrix(vec![vec![1.0; 100]; batch_size]);
let mut layer = NabLayer::dropout(vec![100], 0.5, None);
let output_train = layer.forward(&input, true);
assert_eq!(output_train.shape(), vec![batch_size, 100]);
let zeros = output_train.data().iter().filter(|&&x| x == 0.0).count();
let total = output_train.data().len();
let dropout_rate = zeros as f64 / total as f64;
assert!((dropout_rate - 0.5).abs() < 0.1,
"Dropout rate should be approximately 0.5, got {}", dropout_rate);
let output_test = layer.forward(&input, false);
assert_eq!(output_test.data(), input.data(),
"During testing, dropout should act as identity");
let gradient = NDArray::from_matrix(vec![vec![1.0; 100]; batch_size]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), gradient.shape());
if let Some(mask) = &layer.dropout_mask {
for i in 0..total {
if mask.data()[i] == 0.0 {
assert_eq!(backward.data()[i], 0.0,
"Gradient should be zero where input was dropped");
}
}
}
let input = NDArray::from_matrix(vec![vec![1.0; 100]; batch_size]);
let mut layer = NabLayer::dropout(vec![100], 0.5, None);
let mut different_masks = false;
let first_output = layer.forward(&input, true);
for _ in 0..5 {
let output = layer.forward(&input, true);
if output.data() != first_output.data() {
different_masks = true;
break;
}
}
assert!(different_masks, "Dropout should generate different masks");
}
#[test]
fn test_batch_norm_layer() {
let bn = NabLayer::batch_norm(vec![3], Some(1e-5), Some(0.99), Some("bn_1"));
assert_eq!(bn.get_name(), "bn_1");
assert_eq!(bn.get_output_shape(), &[3]);
assert!(bn.is_trainable());
let input = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
]);
let mut layer = NabLayer::batch_norm(vec![3], Some(1e-5), Some(0.99), None);
let output_train = layer.forward(&input, true);
assert_eq!(output_train.shape(), vec![2, 3]);
let output_mean = output_train.mean_axis(0);
let output_var = output_train.var_axis(0);
for i in 0..3 {
assert!((output_mean.get_2d(0, i)).abs() < 1e-5,
"Mean should be close to 0, got {}", output_mean.get_2d(0, i));
assert!((output_var.get_2d(0, i) - 1.0).abs() < 1e-5,
"Variance should be close to 1, got {}", output_var.get_2d(0, i));
}
let gradient = NDArray::from_matrix(vec![
vec![0.1, 0.2, 0.3],
vec![0.4, 0.5, 0.6],
]);
let backward = layer.backward(&gradient);
assert_eq!(backward.shape(), input.shape());
assert!(layer.weight_gradients.is_some());
assert!(layer.bias_gradients.is_some());
}
#[test]
fn test_compute_output_shape() {
let input_layer = NabLayer::input(vec![784], Some("input_layer"));
assert_eq!(input_layer.compute_output_shape(&[32, 784]), vec![32, 784]);
let dense_layer = NabLayer::dense(784, 128, Some("relu"), Some("dense_layer"));
assert_eq!(dense_layer.compute_output_shape(&[32, 784]), vec![32, 128]);
let activation_layer = NabLayer::activation("relu", vec![128], Some("activation_layer"));
assert_eq!(activation_layer.compute_output_shape(&[32, 128]), vec![32, 128]);
let flatten_layer = NabLayer::flatten(vec![28, 28, 1], Some("flatten_layer"));
assert_eq!(flatten_layer.compute_output_shape(&[32, 28, 28, 1]), vec![32, 784]);
let dropout_layer = NabLayer::dropout(vec![128], 0.5, Some("dropout_layer"));
assert_eq!(dropout_layer.compute_output_shape(&[32, 128]), vec![32, 128]);
let batch_norm_layer = NabLayer::batch_norm(vec![128], None, None, Some("batch_norm_layer"));
assert_eq!(batch_norm_layer.compute_output_shape(&[32, 128]), vec![32, 128]);
}
}