1#![allow(clippy::integer_arithmetic)]
2use {
6 borsh::{
7 maybestd::io::{Error, Write},
8 schema::{BorshSchema, Declaration, Definition, Fields},
9 BorshDeserialize, BorshSerialize,
10 },
11 std::collections::HashMap,
12};
13
14fn get_declaration_packed_len(
16 declaration: &str,
17 definitions: &HashMap<Declaration, Definition>,
18) -> usize {
19 match definitions.get(declaration) {
20 Some(Definition::Array { length, elements }) => {
21 *length as usize * get_declaration_packed_len(elements, definitions)
22 }
23 Some(Definition::Enum { variants }) => {
24 1 + variants
25 .iter()
26 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
27 .max()
28 .unwrap_or(0)
29 }
30 Some(Definition::Struct { fields }) => match fields {
31 Fields::NamedFields(named_fields) => named_fields
32 .iter()
33 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
34 .sum(),
35 Fields::UnnamedFields(declarations) => declarations
36 .iter()
37 .map(|declaration| get_declaration_packed_len(declaration, definitions))
38 .sum(),
39 Fields::Empty => 0,
40 },
41 Some(Definition::Sequence {
42 elements: _elements,
43 }) => panic!("Missing support for Definition::Sequence"),
44 Some(Definition::Tuple { elements }) => elements
45 .iter()
46 .map(|element| get_declaration_packed_len(element, definitions))
47 .sum(),
48 None => match declaration {
49 "bool" | "u8" | "i8" => 1,
50 "u16" | "i16" => 2,
51 "u32" | "i32" => 4,
52 "u64" | "i64" => 8,
53 "u128" | "i128" => 16,
54 "nil" => 0,
55 _ => panic!("Missing primitive type: {}", declaration),
56 },
57 }
58}
59
60pub fn get_packed_len<S: BorshSchema>() -> usize {
65 let schema_container = S::schema_container();
66 get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions)
67}
68
69pub fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, Error> {
79 let mut data_mut = data;
80 let result = T::deserialize(&mut data_mut)?;
81 Ok(result)
82}
83
84#[derive(Default)]
86struct WriteCounter {
87 count: usize,
88}
89
90impl Write for WriteCounter {
91 fn write(&mut self, data: &[u8]) -> Result<usize, Error> {
92 let amount = data.len();
93 self.count += amount;
94 Ok(amount)
95 }
96
97 fn flush(&mut self) -> Result<(), Error> {
98 Ok(())
99 }
100}
101
102pub fn get_instance_packed_len<T: BorshSerialize>(instance: &T) -> Result<usize, Error> {
109 let mut counter = WriteCounter::default();
110 instance.serialize(&mut counter)?;
111 Ok(counter.count)
112}
113
114#[cfg(test)]
115mod tests {
116 use {
117 super::*,
118 borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize},
119 std::{collections::HashMap, mem::size_of},
120 };
121
122 #[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)]
123 enum TestEnum {
124 NoValue,
125 Number(u32),
126 Struct {
127 #[allow(dead_code)]
128 number: u64,
129 #[allow(dead_code)]
130 array: [u8; 8],
131 },
132 }
133
134 impl Default for TestEnum {
136 fn default() -> Self {
137 Self::NoValue
138 }
139 }
140
141 #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)]
142 struct TestStruct {
143 pub array: [u64; 16],
144 pub number_u128: u128,
145 pub number_u32: u32,
146 pub tuple: (u8, u16),
147 pub enumeration: TestEnum,
148 pub r#bool: bool,
149 }
150
151 #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
152 struct Child {
153 pub data: [u8; 64],
154 }
155
156 #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
157 struct Parent {
158 pub data: Vec<Child>,
159 }
160
161 #[test]
162 fn unchecked_deserialization() {
163 let data = vec![
164 Child { data: [0u8; 64] },
165 Child { data: [1u8; 64] },
166 Child { data: [2u8; 64] },
167 ];
168 let parent = Parent { data };
169
170 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
172 let mut bytes = byte_vec.as_mut_slice();
173 parent.serialize(&mut bytes).unwrap();
174 let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
175 assert_eq!(deserialized, parent);
176 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
177 assert_eq!(deserialized, parent);
178
179 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
181 let mut bytes = byte_vec.as_mut_slice();
182 parent.serialize(&mut bytes).unwrap();
183 let err = Parent::try_from_slice(&byte_vec).unwrap_err();
184 assert_eq!(err.kind(), ErrorKind::InvalidData);
185 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
186 assert_eq!(deserialized, parent);
187 }
188
189 #[test]
190 fn packed_len() {
191 assert_eq!(
192 get_packed_len::<TestEnum>(),
193 size_of::<u8>() + size_of::<u64>() + u8::BITS as usize
194 );
195 assert_eq!(
196 get_packed_len::<TestStruct>(),
197 size_of::<u64>() * 16
198 + size_of::<bool>()
199 + size_of::<u128>()
200 + size_of::<u32>()
201 + size_of::<u8>()
202 + size_of::<u16>()
203 + get_packed_len::<TestEnum>()
204 );
205 }
206
207 #[test]
208 fn instance_packed_len_matches_packed_len() {
209 let enumeration = TestEnum::Struct {
210 number: u64::MAX,
211 array: [255; 8],
212 };
213 assert_eq!(
214 get_packed_len::<TestEnum>(),
215 get_instance_packed_len(&enumeration).unwrap(),
216 );
217 let test_struct = TestStruct {
218 enumeration,
219 ..TestStruct::default()
220 };
221 assert_eq!(
222 get_packed_len::<TestStruct>(),
223 get_instance_packed_len(&test_struct).unwrap(),
224 );
225 assert_eq!(
226 get_packed_len::<u8>(),
227 get_instance_packed_len(&0u8).unwrap(),
228 );
229 assert_eq!(
230 get_packed_len::<u16>(),
231 get_instance_packed_len(&0u16).unwrap(),
232 );
233 assert_eq!(
234 get_packed_len::<u32>(),
235 get_instance_packed_len(&0u32).unwrap(),
236 );
237 assert_eq!(
238 get_packed_len::<u64>(),
239 get_instance_packed_len(&0u64).unwrap(),
240 );
241 assert_eq!(
242 get_packed_len::<u128>(),
243 get_instance_packed_len(&0u128).unwrap(),
244 );
245 assert_eq!(
246 get_packed_len::<[u8; 10]>(),
247 get_instance_packed_len(&[0u8; 10]).unwrap(),
248 );
249 assert_eq!(
250 get_packed_len::<(i8, i16, i32, i64, i128)>(),
251 get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(),
252 );
253 }
254
255 #[test]
256 fn instance_packed_len_with_vec() {
257 let data = vec![
258 Child { data: [0u8; 64] },
259 Child { data: [1u8; 64] },
260 Child { data: [2u8; 64] },
261 Child { data: [3u8; 64] },
262 Child { data: [4u8; 64] },
263 Child { data: [5u8; 64] },
264 ];
265 let parent = Parent { data };
266 assert_eq!(
267 get_instance_packed_len(&parent).unwrap(),
268 4 + parent.data.len() * get_packed_len::<Child>()
269 );
270 }
271
272 #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
273 struct StructWithHashMap {
274 data: HashMap<String, TestEnum>,
275 }
276
277 #[test]
278 fn instance_packed_len_with_varying_sizes_in_hashmap() {
279 let mut data = HashMap::new();
280 let string1 = "the first string, it's actually really really long".to_string();
281 let enum1 = TestEnum::NoValue;
282 let string2 = "second string, shorter".to_string();
283 let enum2 = TestEnum::Number(u32::MAX);
284 let string3 = "third".to_string();
285 let enum3 = TestEnum::Struct {
286 number: 0,
287 array: [0; 8],
288 };
289 data.insert(string1.clone(), enum1.clone());
290 data.insert(string2.clone(), enum2.clone());
291 data.insert(string3.clone(), enum3.clone());
292 let instance = StructWithHashMap { data };
293 assert_eq!(
294 get_instance_packed_len(&instance).unwrap(),
295 4 + get_instance_packed_len(&string1).unwrap()
296 + get_instance_packed_len(&enum1).unwrap()
297 + get_instance_packed_len(&string2).unwrap()
298 + get_instance_packed_len(&enum2).unwrap()
299 + get_instance_packed_len(&string3).unwrap()
300 + get_instance_packed_len(&enum3).unwrap()
301 );
302 }
303}