1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
use protobuf::Enum;
use crate::{attribute_proto::AttributeType, AttributeProto, GraphProto, SparseTensorProto, TensorProto, TypeProto};
use crate::tensor_proto::DataType;

/// Typed onnx attribute value
pub enum AttributeValue<'a> {
    Unknown(i32),
    Float32(f32),
    Float32s(&'a [f32]),
    Integer64(i64),
    Integer64s(&'a [i64]),
    String(&'a [u8]),
    Strings(&'a [Vec<u8>]),
    Type(&'a TypeProto),
    Types(&'a [TypeProto]),
    Graph(&'a GraphProto),
    Graphs(&'a [GraphProto]),
    Tensor(&'a TensorProto),
    Tensors(&'a [TensorProto]),
    SparseTensor(&'a SparseTensorProto),
    SparseTensors(&'a [SparseTensorProto]),
}

impl AttributeProto {
    /// Convert to typed value
    pub fn as_value(&self) -> AttributeValue {
        match self.type_.enum_value() {
            Ok(o) => match o {
                AttributeType::UNDEFINED => AttributeValue::Unknown(0),
                AttributeType::FLOAT => AttributeValue::Float32(self.f),
                AttributeType::FLOATS => AttributeValue::Float32s(&self.floats),
                AttributeType::INT => AttributeValue::Integer64(self.i),
                AttributeType::INTS => AttributeValue::Integer64s(&self.ints),
                AttributeType::STRING => AttributeValue::String(&self.s),
                AttributeType::STRINGS => AttributeValue::Strings(&self.strings),
                AttributeType::TENSOR => match &self.t.0 {
                    Some(s) => AttributeValue::Tensor(s),
                    None => unreachable!(),
                },
                AttributeType::TENSORS => AttributeValue::Tensors(&self.tensors),
                AttributeType::GRAPH => match &self.g.0 {
                    Some(s) => AttributeValue::Graph(s),
                    None => unreachable!(),
                },
                AttributeType::GRAPHS => AttributeValue::Graphs(&self.graphs),
                AttributeType::SPARSE_TENSOR => match &self.sparse_tensor.0 {
                    Some(s) => AttributeValue::SparseTensor(s),
                    None => unreachable!(),
                },
                AttributeType::SPARSE_TENSORS => AttributeValue::SparseTensors(&self.sparse_tensors),
                AttributeType::TYPE_PROTO => match &self.tp.0 {
                    Some(s) => AttributeValue::Type(s),
                    None => unreachable!(),
                },
                AttributeType::TYPE_PROTOS => AttributeValue::Types(&self.type_protos),
            },
            Err(e) => AttributeValue::Unknown(e),
        }
    }
}

/// Typed onnx tensor value
pub enum TensorValue<'p> {
    Unknown(i32),
    Float32(&'p [f32]),
    Float64(&'p [f64]),
    Unsigned32(&'p [u32]),
    Unsigned64(&'p [u64]),
    Integer8(&'p [i8]),
    Integer16(&'p [i16]),
    Integer32(&'p [i32]),
    Integer64(&'p [i64]),
}

impl TensorProto {
    /// Convert to typed value
    pub fn as_value(&self) -> TensorValue {
        match DataType::from_i32(self.data_type) {
            Some(o) => {
                match o {
                    DataType::UNDEFINED => { TensorValue::Unknown(0) }
                    DataType::FLOAT8E4M3FN => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::FLOAT8E4M3FNUZ => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::FLOAT8E5M2 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::FLOAT8E5M2FNUZ => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::FLOAT16 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::BFLOAT16 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::FLOAT => {
                        TensorValue::Float32(&self.float_data)
                    }
                    DataType::DOUBLE => {
                        TensorValue::Float64(&self.double_data)
                    }
                    DataType::UINT4 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::UINT8 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::UINT16 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::UINT32 => {
                        let u64_slice: &[u64] = &self.uint64_data;
                        let u32_slice: &[u32] = unsafe { core::slice::from_raw_parts(u64_slice.as_ptr() as *const u32, u64_slice.len() * 2) };
                        TensorValue::Unsigned32(u32_slice)
                    }
                    DataType::UINT64 => {
                        TensorValue::Unsigned64(&self.uint64_data)
                    }
                    DataType::BOOL => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::INT4 => {
                        let _ = &self.int32_data;
                        todo!()
                    }
                    DataType::INT8 => {
                        let i32_slice: &[i32] = &self.int32_data;
                        let i8_slice: &[i8] = unsafe { core::slice::from_raw_parts(i32_slice.as_ptr() as *const i8, i32_slice.len() * 4) };
                        TensorValue::Integer8(i8_slice)
                    }
                    DataType::INT16 => {
                        let i32_slice: &[i32] = &self.int32_data;
                        let i16_slice: &[i16] = unsafe { core::slice::from_raw_parts(i32_slice.as_ptr() as *const i16, i32_slice.len() * 2) };
                        TensorValue::Integer16(i16_slice)
                    }
                    DataType::INT32 => {
                        TensorValue::Integer32(&self.int32_data)
                    }
                    DataType::INT64 => {
                        TensorValue::Integer64(&self.int64_data)
                    }
                    DataType::COMPLEX64 => {
                        let _ = &self.float_data;
                        todo!()
                    }
                    DataType::COMPLEX128 => {
                        let _ = &self.double_data;
                        todo!()
                    }
                    DataType::STRING => { todo!() }
                }
            }
            None => TensorValue::Unknown(self.data_type),
        }
    }
}