solana_program/
borsh.rs

1#![allow(clippy::integer_arithmetic)]
2//! Utilities for the [borsh] serialization format.
3//!
4//! [borsh]: https://borsh.io/
5use {
6    borsh::{
7        maybestd::io::{Error, Write},
8        schema::{BorshSchema, Declaration, Definition, Fields},
9        BorshDeserialize, BorshSerialize,
10    },
11    std::collections::HashMap,
12};
13
14/// Get packed length for the given BorchSchema Declaration
15fn get_declaration_packed_len(
16    declaration: &str,
17    definitions: &HashMap<Declaration, Definition>,
18) -> usize {
19    match definitions.get(declaration) {
20        Some(Definition::Array { length, elements }) => {
21            *length as usize * get_declaration_packed_len(elements, definitions)
22        }
23        Some(Definition::Enum { variants }) => {
24            1 + variants
25                .iter()
26                .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
27                .max()
28                .unwrap_or(0)
29        }
30        Some(Definition::Struct { fields }) => match fields {
31            Fields::NamedFields(named_fields) => named_fields
32                .iter()
33                .map(|(_, declaration)| get_declaration_packed_len(declaration, definitions))
34                .sum(),
35            Fields::UnnamedFields(declarations) => declarations
36                .iter()
37                .map(|declaration| get_declaration_packed_len(declaration, definitions))
38                .sum(),
39            Fields::Empty => 0,
40        },
41        Some(Definition::Sequence {
42            elements: _elements,
43        }) => panic!("Missing support for Definition::Sequence"),
44        Some(Definition::Tuple { elements }) => elements
45            .iter()
46            .map(|element| get_declaration_packed_len(element, definitions))
47            .sum(),
48        None => match declaration {
49            "bool" | "u8" | "i8" => 1,
50            "u16" | "i16" => 2,
51            "u32" | "i32" => 4,
52            "u64" | "i64" => 8,
53            "u128" | "i128" => 16,
54            "nil" => 0,
55            _ => panic!("Missing primitive type: {}", declaration),
56        },
57    }
58}
59
60/// Get the worst-case packed length for the given BorshSchema
61///
62/// Note: due to the serializer currently used by Borsh, this function cannot
63/// be used on-chain in the Safecoin BPF execution environment.
64pub fn get_packed_len<S: BorshSchema>() -> usize {
65    let schema_container = S::schema_container();
66    get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions)
67}
68
69/// Deserializes without checking that the entire slice has been consumed
70///
71/// Normally, `try_from_slice` checks the length of the final slice to ensure
72/// that the deserialization uses up all of the bytes in the slice.
73///
74/// Note that there is a potential issue with this function. Any buffer greater than
75/// or equal to the expected size will properly deserialize. For example, if the
76/// user passes a buffer destined for a different type, the error won't get caught
77/// as easily.
78pub fn try_from_slice_unchecked<T: BorshDeserialize>(data: &[u8]) -> Result<T, Error> {
79    let mut data_mut = data;
80    let result = T::deserialize(&mut data_mut)?;
81    Ok(result)
82}
83
84/// Helper struct which to count how much data would be written during serialization
85#[derive(Default)]
86struct WriteCounter {
87    count: usize,
88}
89
90impl Write for WriteCounter {
91    fn write(&mut self, data: &[u8]) -> Result<usize, Error> {
92        let amount = data.len();
93        self.count += amount;
94        Ok(amount)
95    }
96
97    fn flush(&mut self) -> Result<(), Error> {
98        Ok(())
99    }
100}
101
102/// Get the packed length for the serialized form of this object instance.
103///
104/// Useful when working with instances of types that contain a variable-length
105/// sequence, such as a Vec or HashMap.  Since it is impossible to know the packed
106/// length only from the type's schema, this can be used when an instance already
107/// exists, to figure out how much space to allocate in an account.
108pub fn get_instance_packed_len<T: BorshSerialize>(instance: &T) -> Result<usize, Error> {
109    let mut counter = WriteCounter::default();
110    instance.serialize(&mut counter)?;
111    Ok(counter.count)
112}
113
114#[cfg(test)]
115mod tests {
116    use {
117        super::*,
118        borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize},
119        std::{collections::HashMap, mem::size_of},
120    };
121
122    #[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)]
123    enum TestEnum {
124        NoValue,
125        Number(u32),
126        Struct {
127            #[allow(dead_code)]
128            number: u64,
129            #[allow(dead_code)]
130            array: [u8; 8],
131        },
132    }
133
134    // for test simplicity
135    impl Default for TestEnum {
136        fn default() -> Self {
137            Self::NoValue
138        }
139    }
140
141    #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)]
142    struct TestStruct {
143        pub array: [u64; 16],
144        pub number_u128: u128,
145        pub number_u32: u32,
146        pub tuple: (u8, u16),
147        pub enumeration: TestEnum,
148        pub r#bool: bool,
149    }
150
151    #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
152    struct Child {
153        pub data: [u8; 64],
154    }
155
156    #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
157    struct Parent {
158        pub data: Vec<Child>,
159    }
160
161    #[test]
162    fn unchecked_deserialization() {
163        let data = vec![
164            Child { data: [0u8; 64] },
165            Child { data: [1u8; 64] },
166            Child { data: [2u8; 64] },
167        ];
168        let parent = Parent { data };
169
170        // exact size, both work
171        let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
172        let mut bytes = byte_vec.as_mut_slice();
173        parent.serialize(&mut bytes).unwrap();
174        let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
175        assert_eq!(deserialized, parent);
176        let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
177        assert_eq!(deserialized, parent);
178
179        // too big, only unchecked works
180        let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
181        let mut bytes = byte_vec.as_mut_slice();
182        parent.serialize(&mut bytes).unwrap();
183        let err = Parent::try_from_slice(&byte_vec).unwrap_err();
184        assert_eq!(err.kind(), ErrorKind::InvalidData);
185        let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
186        assert_eq!(deserialized, parent);
187    }
188
189    #[test]
190    fn packed_len() {
191        assert_eq!(
192            get_packed_len::<TestEnum>(),
193            size_of::<u8>() + size_of::<u64>() + u8::BITS as usize
194        );
195        assert_eq!(
196            get_packed_len::<TestStruct>(),
197            size_of::<u64>() * 16
198                + size_of::<bool>()
199                + size_of::<u128>()
200                + size_of::<u32>()
201                + size_of::<u8>()
202                + size_of::<u16>()
203                + get_packed_len::<TestEnum>()
204        );
205    }
206
207    #[test]
208    fn instance_packed_len_matches_packed_len() {
209        let enumeration = TestEnum::Struct {
210            number: u64::MAX,
211            array: [255; 8],
212        };
213        assert_eq!(
214            get_packed_len::<TestEnum>(),
215            get_instance_packed_len(&enumeration).unwrap(),
216        );
217        let test_struct = TestStruct {
218            enumeration,
219            ..TestStruct::default()
220        };
221        assert_eq!(
222            get_packed_len::<TestStruct>(),
223            get_instance_packed_len(&test_struct).unwrap(),
224        );
225        assert_eq!(
226            get_packed_len::<u8>(),
227            get_instance_packed_len(&0u8).unwrap(),
228        );
229        assert_eq!(
230            get_packed_len::<u16>(),
231            get_instance_packed_len(&0u16).unwrap(),
232        );
233        assert_eq!(
234            get_packed_len::<u32>(),
235            get_instance_packed_len(&0u32).unwrap(),
236        );
237        assert_eq!(
238            get_packed_len::<u64>(),
239            get_instance_packed_len(&0u64).unwrap(),
240        );
241        assert_eq!(
242            get_packed_len::<u128>(),
243            get_instance_packed_len(&0u128).unwrap(),
244        );
245        assert_eq!(
246            get_packed_len::<[u8; 10]>(),
247            get_instance_packed_len(&[0u8; 10]).unwrap(),
248        );
249        assert_eq!(
250            get_packed_len::<(i8, i16, i32, i64, i128)>(),
251            get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(),
252        );
253    }
254
255    #[test]
256    fn instance_packed_len_with_vec() {
257        let data = vec![
258            Child { data: [0u8; 64] },
259            Child { data: [1u8; 64] },
260            Child { data: [2u8; 64] },
261            Child { data: [3u8; 64] },
262            Child { data: [4u8; 64] },
263            Child { data: [5u8; 64] },
264        ];
265        let parent = Parent { data };
266        assert_eq!(
267            get_instance_packed_len(&parent).unwrap(),
268            4 + parent.data.len() * get_packed_len::<Child>()
269        );
270    }
271
272    #[derive(Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, BorshSchema)]
273    struct StructWithHashMap {
274        data: HashMap<String, TestEnum>,
275    }
276
277    #[test]
278    fn instance_packed_len_with_varying_sizes_in_hashmap() {
279        let mut data = HashMap::new();
280        let string1 = "the first string, it's actually really really long".to_string();
281        let enum1 = TestEnum::NoValue;
282        let string2 = "second string, shorter".to_string();
283        let enum2 = TestEnum::Number(u32::MAX);
284        let string3 = "third".to_string();
285        let enum3 = TestEnum::Struct {
286            number: 0,
287            array: [0; 8],
288        };
289        data.insert(string1.clone(), enum1.clone());
290        data.insert(string2.clone(), enum2.clone());
291        data.insert(string3.clone(), enum3.clone());
292        let instance = StructWithHashMap { data };
293        assert_eq!(
294            get_instance_packed_len(&instance).unwrap(),
295            4 + get_instance_packed_len(&string1).unwrap()
296                + get_instance_packed_len(&enum1).unwrap()
297                + get_instance_packed_len(&string2).unwrap()
298                + get_instance_packed_len(&enum2).unwrap()
299                + get_instance_packed_len(&string3).unwrap()
300                + get_instance_packed_len(&enum3).unwrap()
301        );
302    }
303}