1macro_rules! impl_get_packed_len_v0 {
4 ($borsh:ident $(,#[$meta:meta])?) => {
5 $(#[$meta])?
10 pub fn get_packed_len<S: $borsh::BorshSchema>() -> usize {
11 let $borsh::schema::BorshSchemaContainer { declaration, definitions } =
12 &S::schema_container();
13 get_declaration_packed_len(declaration, definitions)
14 }
15
16 fn get_declaration_packed_len(
18 declaration: &str,
19 definitions: &std::collections::HashMap<$borsh::schema::Declaration, $borsh::schema::Definition>,
20 ) -> usize {
21 match definitions.get(declaration) {
22 Some($borsh::schema::Definition::Array { length, elements }) => {
23 *length as usize * get_declaration_packed_len(elements, definitions)
24 }
25 Some($borsh::schema::Definition::Enum { variants }) => {
26 1 + variants
27 .iter()
28 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
29 .max()
30 .unwrap_or(0)
31 }
32 Some($borsh::schema::Definition::Struct { fields }) => match fields {
33 $borsh::schema::Fields::NamedFields(named_fields) => named_fields
34 .iter()
35 .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
36 .sum(),
37 $borsh::schema::Fields::UnnamedFields(declarations) => declarations
38 .iter()
39 .map(|declaration| get_declaration_packed_len(declaration, definitions))
40 .sum(),
41 $borsh::schema::Fields::Empty => 0,
42 },
43 Some($borsh::schema::Definition::Sequence {
44 elements: _elements,
45 }) => panic!("Missing support for Definition::Sequence"),
46 Some($borsh::schema::Definition::Tuple { elements }) => elements
47 .iter()
48 .map(|element| get_declaration_packed_len(element, definitions))
49 .sum(),
50 None => match declaration {
51 "bool" | "u8" | "i8" => 1,
52 "u16" | "i16" => 2,
53 "u32" | "i32" => 4,
54 "u64" | "i64" => 8,
55 "u128" | "i128" => 16,
56 "nil" => 0,
57 _ => panic!("Missing primitive type: {declaration}"),
58 },
59 }
60 }
61 }
62}
63pub(crate) use impl_get_packed_len_v0;
64
65macro_rules! impl_get_packed_len_v1 {
66 ($borsh:ident $(,#[$meta:meta])?) => {
67 $(#[$meta])?
72 pub fn get_packed_len<S: $borsh::BorshSchema>() -> usize {
73 let container = $borsh::schema_container_of::<S>();
74 get_declaration_packed_len(container.declaration(), &container)
75 }
76
77 fn get_declaration_packed_len(
79 declaration: &str,
80 container: &$borsh::schema::BorshSchemaContainer,
81 ) -> usize {
82 match container.get_definition(declaration) {
83 Some($borsh::schema::Definition::Sequence { length_width, length_range, elements }) if *length_width == 0 => {
84 *length_range.end() as usize * get_declaration_packed_len(elements, container)
85 }
86 Some($borsh::schema::Definition::Enum { tag_width, variants }) => {
87 (*tag_width as usize) + variants
88 .iter()
89 .map(|(_, _, declaration)| get_declaration_packed_len(declaration, container))
90 .max()
91 .unwrap_or(0)
92 }
93 Some($borsh::schema::Definition::Struct { fields }) => match fields {
94 $borsh::schema::Fields::NamedFields(named_fields) => named_fields
95 .iter()
96 .map(|(_, declaration)| get_declaration_packed_len(declaration, container))
97 .sum(),
98 $borsh::schema::Fields::UnnamedFields(declarations) => declarations
99 .iter()
100 .map(|declaration| get_declaration_packed_len(declaration, container))
101 .sum(),
102 $borsh::schema::Fields::Empty => 0,
103 },
104 Some($borsh::schema::Definition::Sequence {
105 ..
106 }) => panic!("Missing support for Definition::Sequence"),
107 Some($borsh::schema::Definition::Tuple { elements }) => elements
108 .iter()
109 .map(|element| get_declaration_packed_len(element, container))
110 .sum(),
111 Some($borsh::schema::Definition::Primitive(size)) => *size as usize,
112 None => match declaration {
113 "bool" | "u8" | "i8" => 1,
114 "u16" | "i16" => 2,
115 "u32" | "i32" => 4,
116 "u64" | "i64" => 8,
117 "u128" | "i128" => 16,
118 "nil" => 0,
119 _ => panic!("Missing primitive type: {declaration}"),
120 },
121 }
122 }
123 }
124}
125pub(crate) use impl_get_packed_len_v1;
126
127macro_rules! impl_try_from_slice_unchecked {
128 ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
129 $(#[$meta])?
139 pub fn try_from_slice_unchecked<T: $borsh::BorshDeserialize>(data: &[u8]) -> Result<T, $borsh_io::Error> {
140 let mut data_mut = data;
141 let result = T::deserialize(&mut data_mut)?;
142 Ok(result)
143 }
144 }
145}
146pub(crate) use impl_try_from_slice_unchecked;
147
148macro_rules! impl_get_instance_packed_len {
149 ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
150 #[derive(Default)]
152 struct WriteCounter {
153 count: usize,
154 }
155
156 impl $borsh_io::Write for WriteCounter {
157 fn write(&mut self, data: &[u8]) -> Result<usize, $borsh_io::Error> {
158 let amount = data.len();
159 self.count += amount;
160 Ok(amount)
161 }
162
163 fn flush(&mut self) -> Result<(), $borsh_io::Error> {
164 Ok(())
165 }
166 }
167
168 $(#[$meta])?
175 pub fn get_instance_packed_len<T: $borsh::BorshSerialize>(instance: &T) -> Result<usize, $borsh_io::Error> {
176 let mut counter = WriteCounter::default();
177 instance.serialize(&mut counter)?;
178 Ok(counter.count)
179 }
180 }
181}
182pub(crate) use impl_get_instance_packed_len;
183
184#[cfg(test)]
185macro_rules! impl_tests {
186 ($borsh:ident, $borsh_io:ident) => {
187 extern crate alloc;
188 use {
189 super::*,
190 std::{collections::HashMap, mem::size_of},
191 $borsh::{BorshDeserialize, BorshSerialize},
192 $borsh_io::ErrorKind,
193 };
194
195 type Child = [u8; 64];
196 type Parent = Vec<Child>;
197
198 #[test]
199 fn unchecked_deserialization() {
200 let parent = vec![[0u8; 64], [1u8; 64], [2u8; 64]];
201
202 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
204 let mut bytes = byte_vec.as_mut_slice();
205 parent.serialize(&mut bytes).unwrap();
206 let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
207 assert_eq!(deserialized, parent);
208 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
209 assert_eq!(deserialized, parent);
210
211 let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
213 let mut bytes = byte_vec.as_mut_slice();
214 parent.serialize(&mut bytes).unwrap();
215 let err = Parent::try_from_slice(&byte_vec).unwrap_err();
216 assert_eq!(err.kind(), ErrorKind::InvalidData);
217 let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
218 assert_eq!(deserialized, parent);
219 }
220
221 #[test]
222 fn packed_len() {
223 assert_eq!(get_packed_len::<u64>(), size_of::<u64>());
224 assert_eq!(get_packed_len::<Child>(), size_of::<u8>() * 64);
225 }
226
227 #[test]
228 fn instance_packed_len_matches_packed_len() {
229 let child = [0u8; 64];
230 assert_eq!(
231 get_packed_len::<Child>(),
232 get_instance_packed_len(&child).unwrap(),
233 );
234 assert_eq!(
235 get_packed_len::<u8>(),
236 get_instance_packed_len(&0u8).unwrap(),
237 );
238 assert_eq!(
239 get_packed_len::<u16>(),
240 get_instance_packed_len(&0u16).unwrap(),
241 );
242 assert_eq!(
243 get_packed_len::<u32>(),
244 get_instance_packed_len(&0u32).unwrap(),
245 );
246 assert_eq!(
247 get_packed_len::<u64>(),
248 get_instance_packed_len(&0u64).unwrap(),
249 );
250 assert_eq!(
251 get_packed_len::<u128>(),
252 get_instance_packed_len(&0u128).unwrap(),
253 );
254 assert_eq!(
255 get_packed_len::<[u8; 10]>(),
256 get_instance_packed_len(&[0u8; 10]).unwrap(),
257 );
258 assert_eq!(
259 get_packed_len::<(i8, i16, i32, i64, i128)>(),
260 get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX))
261 .unwrap(),
262 );
263 }
264
265 #[test]
266 fn instance_packed_len_with_vec() {
267 let parent = vec![
268 [0u8; 64], [1u8; 64], [2u8; 64], [3u8; 64], [4u8; 64], [5u8; 64],
269 ];
270 assert_eq!(
271 get_instance_packed_len(&parent).unwrap(),
272 4 + parent.len() * get_packed_len::<Child>()
273 );
274 }
275
276 #[test]
277 fn instance_packed_len_with_varying_sizes_in_hashmap() {
278 let mut data = HashMap::new();
279 let key1 = "the first string, it's actually really really long".to_string();
280 let value1 = "".to_string();
281 let key2 = "second string, shorter".to_string();
282 let value2 = "a real value".to_string();
283 let key3 = "third".to_string();
284 let value3 = "an even longer value".to_string();
285 data.insert(key1.clone(), value1.clone());
286 data.insert(key2.clone(), value2.clone());
287 data.insert(key3.clone(), value3.clone());
288 assert_eq!(
289 get_instance_packed_len(&data).unwrap(),
290 4 + get_instance_packed_len(&key1).unwrap()
291 + get_instance_packed_len(&value1).unwrap()
292 + get_instance_packed_len(&key2).unwrap()
293 + get_instance_packed_len(&value2).unwrap()
294 + get_instance_packed_len(&key3).unwrap()
295 + get_instance_packed_len(&value3).unwrap()
296 );
297 }
298 };
299}
300#[cfg(test)]
301pub(crate) use impl_tests;