solana_borsh/
macros.rs

1//! Macros for implementing functions across multiple versions of Borsh
2
3macro_rules! impl_get_packed_len_v0 {
4    ($borsh:ident $(,#[$meta:meta])?) => {
5        /// Get the worst-case packed length for the given BorshSchema
6        ///
7        /// Note: due to the serializer currently used by Borsh, this function cannot
8        /// be used on-chain in the Solana SBF execution environment.
9        $(#[$meta])?
10        pub fn get_packed_len<S: $borsh::BorshSchema>() -> usize {
11            let $borsh::schema::BorshSchemaContainer { declaration, definitions } =
12                &S::schema_container();
13            get_declaration_packed_len(declaration, definitions)
14        }
15
16        /// Get packed length for the given BorshSchema Declaration
17        fn get_declaration_packed_len(
18            declaration: &str,
19            definitions: &std::collections::HashMap<$borsh::schema::Declaration, $borsh::schema::Definition>,
20        ) -> usize {
21            match definitions.get(declaration) {
22                Some($borsh::schema::Definition::Array { length, elements }) => {
23                    *length as usize * get_declaration_packed_len(elements, definitions)
24                }
25                Some($borsh::schema::Definition::Enum { variants }) => {
26                    1 + variants
27                        .iter()
28                        .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
29                        .max()
30                        .unwrap_or(0)
31                }
32                Some($borsh::schema::Definition::Struct { fields }) => match fields {
33                    $borsh::schema::Fields::NamedFields(named_fields) => named_fields
34                        .iter()
35                        .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
36                        .sum(),
37                    $borsh::schema::Fields::UnnamedFields(declarations) => declarations
38                        .iter()
39                        .map(|declaration| get_declaration_packed_len(declaration, definitions))
40                        .sum(),
41                    $borsh::schema::Fields::Empty => 0,
42                },
43                Some($borsh::schema::Definition::Sequence {
44                    elements: _elements,
45                }) => panic!("Missing support for Definition::Sequence"),
46                Some($borsh::schema::Definition::Tuple { elements }) => elements
47                    .iter()
48                    .map(|element| get_declaration_packed_len(element, definitions))
49                    .sum(),
50                None => match declaration {
51                    "bool" | "u8" | "i8" => 1,
52                    "u16" | "i16" => 2,
53                    "u32" | "i32" => 4,
54                    "u64" | "i64" => 8,
55                    "u128" | "i128" => 16,
56                    "nil" => 0,
57                    _ => panic!("Missing primitive type: {declaration}"),
58                },
59            }
60        }
61    }
62}
63pub(crate) use impl_get_packed_len_v0;
64
65macro_rules! impl_get_packed_len_v1 {
66    ($borsh:ident $(,#[$meta:meta])?) => {
67        /// Get the worst-case packed length for the given BorshSchema
68        ///
69        /// Note: due to the serializer currently used by Borsh, this function cannot
70        /// be used on-chain in the Solana SBF execution environment.
71        $(#[$meta])?
72        pub fn get_packed_len<S: $borsh::BorshSchema>() -> usize {
73            let container = $borsh::schema_container_of::<S>();
74            get_declaration_packed_len(container.declaration(), &container)
75        }
76
77        /// Get packed length for the given BorshSchema Declaration
78        fn get_declaration_packed_len(
79            declaration: &str,
80            container: &$borsh::schema::BorshSchemaContainer,
81        ) -> usize {
82            match container.get_definition(declaration) {
83                Some($borsh::schema::Definition::Sequence { length_width, length_range, elements }) if *length_width == 0 => {
84                    *length_range.end() as usize * get_declaration_packed_len(elements, container)
85                }
86                Some($borsh::schema::Definition::Enum { tag_width, variants }) => {
87                    (*tag_width as usize) + variants
88                        .iter()
89                        .map(|(_, _, declaration)| get_declaration_packed_len(declaration, container))
90                        .max()
91                        .unwrap_or(0)
92                }
93                Some($borsh::schema::Definition::Struct { fields }) => match fields {
94                    $borsh::schema::Fields::NamedFields(named_fields) => named_fields
95                        .iter()
96                        .map(|(_, declaration)| get_declaration_packed_len(declaration, container))
97                        .sum(),
98                    $borsh::schema::Fields::UnnamedFields(declarations) => declarations
99                        .iter()
100                        .map(|declaration| get_declaration_packed_len(declaration, container))
101                        .sum(),
102                    $borsh::schema::Fields::Empty => 0,
103                },
104                Some($borsh::schema::Definition::Sequence {
105                    ..
106                }) => panic!("Missing support for Definition::Sequence"),
107                Some($borsh::schema::Definition::Tuple { elements }) => elements
108                    .iter()
109                    .map(|element| get_declaration_packed_len(element, container))
110                    .sum(),
111                Some($borsh::schema::Definition::Primitive(size)) => *size as usize,
112                None => match declaration {
113                    "bool" | "u8" | "i8" => 1,
114                    "u16" | "i16" => 2,
115                    "u32" | "i32" => 4,
116                    "u64" | "i64" => 8,
117                    "u128" | "i128" => 16,
118                    "nil" => 0,
119                    _ => panic!("Missing primitive type: {declaration}"),
120                },
121            }
122        }
123    }
124}
125pub(crate) use impl_get_packed_len_v1;
126
127macro_rules! impl_try_from_slice_unchecked {
128    ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
129        /// Deserializes without checking that the entire slice has been consumed
130        ///
131        /// Normally, `try_from_slice` checks the length of the final slice to ensure
132        /// that the deserialization uses up all of the bytes in the slice.
133        ///
134        /// Note that there is a potential issue with this function. Any buffer greater than
135        /// or equal to the expected size will properly deserialize. For example, if the
136        /// user passes a buffer destined for a different type, the error won't get caught
137        /// as easily.
138        $(#[$meta])?
139        pub fn try_from_slice_unchecked<T: $borsh::BorshDeserialize>(data: &[u8]) -> Result<T, $borsh_io::Error> {
140            let mut data_mut = data;
141            let result = T::deserialize(&mut data_mut)?;
142            Ok(result)
143        }
144    }
145}
146pub(crate) use impl_try_from_slice_unchecked;
147
148macro_rules! impl_get_instance_packed_len {
149    ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
150        /// Helper struct which to count how much data would be written during serialization
151        #[derive(Default)]
152        struct WriteCounter {
153            count: usize,
154        }
155
156        impl $borsh_io::Write for WriteCounter {
157            fn write(&mut self, data: &[u8]) -> Result<usize, $borsh_io::Error> {
158                let amount = data.len();
159                self.count += amount;
160                Ok(amount)
161            }
162
163            fn flush(&mut self) -> Result<(), $borsh_io::Error> {
164                Ok(())
165            }
166        }
167
168        /// Get the packed length for the serialized form of this object instance.
169        ///
170        /// Useful when working with instances of types that contain a variable-length
171        /// sequence, such as a Vec or HashMap.  Since it is impossible to know the packed
172        /// length only from the type's schema, this can be used when an instance already
173        /// exists, to figure out how much space to allocate in an account.
174        $(#[$meta])?
175        pub fn get_instance_packed_len<T: $borsh::BorshSerialize>(instance: &T) -> Result<usize, $borsh_io::Error> {
176            let mut counter = WriteCounter::default();
177            instance.serialize(&mut counter)?;
178            Ok(counter.count)
179        }
180    }
181}
182pub(crate) use impl_get_instance_packed_len;
183
184#[cfg(test)]
185macro_rules! impl_tests {
186    ($borsh:ident, $borsh_io:ident) => {
187        extern crate alloc;
188        use {
189            super::*,
190            std::{collections::HashMap, mem::size_of},
191            $borsh::{BorshDeserialize, BorshSerialize},
192            $borsh_io::ErrorKind,
193        };
194
195        type Child = [u8; 64];
196        type Parent = Vec<Child>;
197
198        #[test]
199        fn unchecked_deserialization() {
200            let parent = vec![[0u8; 64], [1u8; 64], [2u8; 64]];
201
202            // exact size, both work
203            let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
204            let mut bytes = byte_vec.as_mut_slice();
205            parent.serialize(&mut bytes).unwrap();
206            let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
207            assert_eq!(deserialized, parent);
208            let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
209            assert_eq!(deserialized, parent);
210
211            // too big, only unchecked works
212            let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
213            let mut bytes = byte_vec.as_mut_slice();
214            parent.serialize(&mut bytes).unwrap();
215            let err = Parent::try_from_slice(&byte_vec).unwrap_err();
216            assert_eq!(err.kind(), ErrorKind::InvalidData);
217            let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
218            assert_eq!(deserialized, parent);
219        }
220
221        #[test]
222        fn packed_len() {
223            assert_eq!(get_packed_len::<u64>(), size_of::<u64>());
224            assert_eq!(get_packed_len::<Child>(), size_of::<u8>() * 64);
225        }
226
227        #[test]
228        fn instance_packed_len_matches_packed_len() {
229            let child = [0u8; 64];
230            assert_eq!(
231                get_packed_len::<Child>(),
232                get_instance_packed_len(&child).unwrap(),
233            );
234            assert_eq!(
235                get_packed_len::<u8>(),
236                get_instance_packed_len(&0u8).unwrap(),
237            );
238            assert_eq!(
239                get_packed_len::<u16>(),
240                get_instance_packed_len(&0u16).unwrap(),
241            );
242            assert_eq!(
243                get_packed_len::<u32>(),
244                get_instance_packed_len(&0u32).unwrap(),
245            );
246            assert_eq!(
247                get_packed_len::<u64>(),
248                get_instance_packed_len(&0u64).unwrap(),
249            );
250            assert_eq!(
251                get_packed_len::<u128>(),
252                get_instance_packed_len(&0u128).unwrap(),
253            );
254            assert_eq!(
255                get_packed_len::<[u8; 10]>(),
256                get_instance_packed_len(&[0u8; 10]).unwrap(),
257            );
258            assert_eq!(
259                get_packed_len::<(i8, i16, i32, i64, i128)>(),
260                get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX))
261                    .unwrap(),
262            );
263        }
264
265        #[test]
266        fn instance_packed_len_with_vec() {
267            let parent = vec![
268                [0u8; 64], [1u8; 64], [2u8; 64], [3u8; 64], [4u8; 64], [5u8; 64],
269            ];
270            assert_eq!(
271                get_instance_packed_len(&parent).unwrap(),
272                4 + parent.len() * get_packed_len::<Child>()
273            );
274        }
275
276        #[test]
277        fn instance_packed_len_with_varying_sizes_in_hashmap() {
278            let mut data = HashMap::new();
279            let key1 = "the first string, it's actually really really long".to_string();
280            let value1 = "".to_string();
281            let key2 = "second string, shorter".to_string();
282            let value2 = "a real value".to_string();
283            let key3 = "third".to_string();
284            let value3 = "an even longer value".to_string();
285            data.insert(key1.clone(), value1.clone());
286            data.insert(key2.clone(), value2.clone());
287            data.insert(key3.clone(), value3.clone());
288            assert_eq!(
289                get_instance_packed_len(&data).unwrap(),
290                4 + get_instance_packed_len(&key1).unwrap()
291                    + get_instance_packed_len(&value1).unwrap()
292                    + get_instance_packed_len(&key2).unwrap()
293                    + get_instance_packed_len(&value2).unwrap()
294                    + get_instance_packed_len(&key3).unwrap()
295                    + get_instance_packed_len(&value3).unwrap()
296            );
297        }
298    };
299}
300#[cfg(test)]
301pub(crate) use impl_tests;