pub mod keys;
pub mod normalisers;
pub mod output;
pub mod string_value;
pub mod version;
pub mod engine;
pub mod origin;
pub mod input_dims;
use keys::KeyBindings;
use normalisers::wrapper::NormaliserType;
use normalisers::NormaliserMap;
use output::Output;
use string_value::StringValue;
use version::Version;
use engine::Engine;
use origin::Origin;
use input_dims::InputDims;
use crate::safe_eject;
use crate::errors::error::{SurrealError, SurrealErrorStatus};
#[derive(Debug, PartialEq)]
pub struct Header {
pub keys: KeyBindings,
pub normalisers: NormaliserMap,
pub output: Output,
pub name: StringValue,
pub version: Version,
pub description: StringValue,
pub engine: Engine,
pub origin: Origin,
pub input_dims: InputDims,
}
impl Header {
pub fn fresh() -> Self {
Header {
keys: KeyBindings::fresh(),
normalisers: NormaliserMap::fresh(),
output: Output::fresh(),
name: StringValue::fresh(),
version: Version::fresh(),
description: StringValue::fresh(),
engine: Engine::fresh(),
origin: Origin::fresh(),
input_dims: InputDims::fresh(),
}
}
pub fn add_name(&mut self, model_name: String) {
self.name = StringValue::from_string(model_name);
}
pub fn add_version(&mut self, version: String) -> Result<(), SurrealError> {
self.version = Version::from_string(version)?;
Ok(())
}
pub fn add_description(&mut self, description: String) {
self.description = StringValue::from_string(description);
}
pub fn add_column(&mut self, column_name: String) {
self.keys.add_column(column_name);
}
pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), SurrealError> {
let _ = self.normalisers.add_normaliser(normaliser, column_name, &self.keys)?;
Ok(())
}
pub fn get_normaliser(&self, column_name: &String) -> Result<Option<&NormaliserType>, SurrealError> {
self.normalisers.get_normaliser(column_name.to_string(), &self.keys)
}
pub fn add_output(&mut self, column_name: String, normaliser: Option<NormaliserType>) {
self.output.name = Some(column_name);
self.output.normaliser = normaliser;
}
pub fn add_engine(&mut self, engine: String) {
self.engine = Engine::from_string(engine);
}
pub fn add_author(&mut self, author: String) {
self.origin.add_author(author);
}
pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> {
self.origin.add_origin(origin)
}
fn delimiter() -> &'static str {
"//=>"
}
pub fn from_bytes(data: Vec<u8>) -> Result<Self, SurrealError> {
let string_data = safe_eject!(String::from_utf8(data), SurrealErrorStatus::BadRequest);
let buffer = string_data.split(Self::delimiter()).collect::<Vec<&str>>();
let keys: KeyBindings = KeyBindings::from_string(buffer.get(1).unwrap_or(&"").to_string());
let normalisers = NormaliserMap::from_string(buffer.get(2).unwrap_or(&"").to_string(), &keys)?;
let output = Output::from_string(buffer.get(3).unwrap_or(&"").to_string())?;
let name = StringValue::from_string(buffer.get(4).unwrap_or(&"").to_string());
let version = Version::from_string(buffer.get(5).unwrap_or(&"").to_string())?;
let description = StringValue::from_string(buffer.get(6).unwrap_or(&"").to_string());
let engine = Engine::from_string(buffer.get(7).unwrap_or(&"").to_string());
let origin = Origin::from_string(buffer.get(8).unwrap_or(&"").to_string())?;
let input_dims = InputDims::from_string(buffer.get(9).unwrap_or(&"").to_string());
Ok(Header {keys, normalisers, output, name, version, description, engine, origin, input_dims})
}
pub fn to_bytes(&self) -> (i32, Vec<u8>) {
let buffer = vec![
"".to_string(),
self.keys.to_string(),
self.normalisers.to_string(),
self.output.to_string(),
self.name.to_string(),
self.version.to_string(),
self.description.to_string(),
self.engine.to_string(),
self.origin.to_string(),
self.input_dims.to_string(),
"".to_string(),
];
let buffer = buffer.join(Self::delimiter()).into_bytes();
(buffer.len() as i32, buffer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::keys::tests::generate_string as generate_key_string;
use super::normalisers::tests::generate_string as generate_normaliser_string;
use super::normalisers::{
clipping::Clipping,
linear_scaling::LinearScaling,
log_scale::LogScaling,
z_score::ZScore,
};
pub fn generate_string() -> String {
let keys = generate_key_string();
let normalisers = generate_normaliser_string();
let output = "g=>linear_scaling(0.0,1.0)".to_string();
format!(
"{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}",
Header::delimiter(),
keys,
Header::delimiter(),
normalisers,
Header::delimiter(),
output,
Header::delimiter(),
"test model name".to_string(),
Header::delimiter(),
"0.0.1".to_string(),
Header::delimiter(),
"test description".to_string(),
Header::delimiter(),
Engine::PyTorch.to_string(),
Header::delimiter(),
Origin::from_string("author=>local".to_string()).unwrap().to_string(),
Header::delimiter(),
InputDims::from_string("1,2".to_string()).to_string(),
Header::delimiter(),
)
}
pub fn generate_bytes() -> Vec<u8> {
generate_string().into_bytes()
}
#[test]
fn test_from_bytes() {
let header = Header::from_bytes(generate_bytes()).unwrap();
assert_eq!(header.keys.store.len(), 6);
assert_eq!(header.keys.reference.len(), 6);
assert_eq!(header.normalisers.store.len(), 4);
assert_eq!(header.keys.store[0], "a");
assert_eq!(header.keys.store[1], "b");
assert_eq!(header.keys.store[2], "c");
assert_eq!(header.keys.store[3], "d");
assert_eq!(header.keys.store[4], "e");
assert_eq!(header.keys.store[5], "f");
}
#[test]
fn test_empty_header() {
let string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
let data = string.as_bytes();
let header = Header::from_bytes(data.to_vec()).unwrap();
assert_eq!(header, Header::fresh());
let string = "".to_string();
let data = string.as_bytes();
let header = Header::from_bytes(data.to_vec()).unwrap();
assert_eq!(header, Header::fresh());
}
#[test]
fn test_to_bytes() {
let header = Header::from_bytes(generate_bytes()).unwrap();
let (bytes_num, bytes) = header.to_bytes();
let string = String::from_utf8(bytes).unwrap();
let expected_string = "//=>a=>b=>c=>d=>e=>f//=>a=>linear_scaling(0,1)//b=>clipping(0,1.5)//c=>log_scaling(10,0)//e=>z_score(0,1)//=>g=>linear_scaling(0,1)//=>test model name//=>0.0.1//=>test description//=>pytorch//=>author=>local//=>1,2//=>".to_string();
assert_eq!(string, expected_string);
assert_eq!(bytes_num, expected_string.len() as i32);
let empty_header = Header::fresh();
let (bytes_num, bytes) = empty_header.to_bytes();
let string = String::from_utf8(bytes).unwrap();
let expected_string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
assert_eq!(string, expected_string);
assert_eq!(bytes_num, expected_string.len() as i32);
}
#[test]
fn test_add_column() {
let mut header = Header::fresh();
header.add_column("a".to_string());
header.add_column("b".to_string());
header.add_column("c".to_string());
header.add_column("d".to_string());
header.add_column("e".to_string());
header.add_column("f".to_string());
assert_eq!(header.keys.store.len(), 6);
assert_eq!(header.keys.reference.len(), 6);
assert_eq!(header.keys.store[0], "a");
assert_eq!(header.keys.store[1], "b");
assert_eq!(header.keys.store[2], "c");
assert_eq!(header.keys.store[3], "d");
assert_eq!(header.keys.store[4], "e");
assert_eq!(header.keys.store[5], "f");
}
#[test]
fn test_add_normalizer() {
let mut header = Header::fresh();
header.add_column("a".to_string());
header.add_column("b".to_string());
header.add_column("c".to_string());
header.add_column("d".to_string());
header.add_column("e".to_string());
header.add_column("f".to_string());
let _ = header.add_normaliser(
"a".to_string(),
NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 })
);
let _ = header.add_normaliser(
"b".to_string(),
NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) })
);
let _ = header.add_normaliser(
"c".to_string(),
NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 })
);
let _ = header.add_normaliser(
"e".to_string(),
NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 })
);
assert_eq!(header.normalisers.store.len(), 4);
assert_eq!(header.normalisers.store[0], NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }));
assert_eq!(header.normalisers.store[1], NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) }));
assert_eq!(header.normalisers.store[2], NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 }));
assert_eq!(header.normalisers.store[3], NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 }));
}
}