nabla_ml/
nab_io.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use std::fs::File;
use std::io::{self, BufRead, Read, Write};
use flate2::{Compression, write::GzEncoder, read::GzDecoder};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;

use crate::nab_array::NDArray;

#[derive(Serialize, Deserialize)]
struct SerializableNDArray {
    data: Vec<f64>,
    shape: Vec<usize>,
}

/// Saves an NDArray to a .nab file with compression
pub fn save_nab(filename: &str, array: &NDArray) -> io::Result<()> {
    let file = File::create(filename)?;
    let mut encoder = GzEncoder::new(file, Compression::default());
    let serializable_array = SerializableNDArray {
        data: array.data().to_vec(),
        shape: array.shape().to_vec(),
    };
    let serialized_data = bincode::serialize(&serializable_array).unwrap();
    encoder.write_all(&serialized_data)?;
    encoder.finish()?;
    Ok(())
}

/// Loads an NDArray from a compressed .nab file
pub fn load_nab(filename: &str) -> io::Result<NDArray> {
    let file = File::open(filename)?;
    let mut decoder = GzDecoder::new(file);
    let mut serialized_data = Vec::new();
    decoder.read_to_end(&mut serialized_data)?;
    let serializable_array: SerializableNDArray = bincode::deserialize(&serialized_data).unwrap();
    Ok(NDArray::new(serializable_array.data, serializable_array.shape))
}

/// Saves multiple NDArrays to a .nab file
///
/// # Arguments
///
/// * `filename` - The name of the file to save the arrays to.
/// * `arrays` - A vector of tuples containing the name and NDArray to save.
#[allow(dead_code)]
pub fn savez_nab(filename: &str, arrays: Vec<(&str, &NDArray)>) -> io::Result<()> {
    let mut file = File::create(filename)?;
    for (name, array) in arrays {
        let shape_str = array.shape().iter().map(|s| s.to_string()).collect::<Vec<_>>().join(",");
        let data_str = array.data().iter().map(|d| d.to_string()).collect::<Vec<_>>().join(",");
        writeln!(file, "{}:{};{}", name, shape_str, data_str)?;
    }
    Ok(())
}

#[allow(dead_code)]
pub fn loadz_nab(filename: &str) -> io::Result<HashMap<String, NDArray>> {
    let file = File::open(filename)?;
    let mut arrays = HashMap::new();
    
    // Read the file line by line
    for line in io::BufReader::new(file).lines() {
        let line = line?;
        // Split the line into name, shape, and data parts
        let parts: Vec<&str> = line.split(':').collect();
        if parts.len() != 2 {
            continue;
        }
        
        let name = parts[0].to_string();
        let shape_and_data: Vec<&str> = parts[1].split(';').collect();
        if shape_and_data.len() != 2 {
            continue;
        }
        
        // Parse shape
        let shape: Vec<usize> = shape_and_data[0]
            .split(',')
            .filter_map(|s| s.parse().ok())
            .collect();
            
        // Parse data
        let data: Vec<f64> = shape_and_data[1]
            .split(',')
            .filter_map(|s| s.parse().ok())
            .collect();
            
        arrays.insert(name, NDArray::new(data, shape));
    }
    
    Ok(arrays)
}