solana_serde_varint/
lib.rs

1//! Integers that serialize to variable size.
2
3#![allow(clippy::arithmetic_side_effects)]
4use {
5    serde::{
6        de::{Error as _, SeqAccess, Visitor},
7        ser::SerializeTuple,
8        Deserializer, Serializer,
9    },
10    std::{fmt, marker::PhantomData},
11};
12
13pub trait VarInt: Sized {
14    fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
15    where
16        A: SeqAccess<'de>;
17
18    fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
19    where
20        S: Serializer;
21}
22
23struct VarIntVisitor<T> {
24    phantom: PhantomData<T>,
25}
26
27impl<'de, T> Visitor<'de> for VarIntVisitor<T>
28where
29    T: VarInt,
30{
31    type Value = T;
32
33    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34        formatter.write_str("a VarInt")
35    }
36
37    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
38    where
39        A: SeqAccess<'de>,
40    {
41        T::visit_seq(seq)
42    }
43}
44
45pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
46where
47    T: Copy + VarInt,
48    S: Serializer,
49{
50    (*value).serialize(serializer)
51}
52
53pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
54where
55    D: Deserializer<'de>,
56    T: VarInt,
57{
58    deserializer.deserialize_tuple(
59        (std::mem::size_of::<T>() * 8 + 6) / 7,
60        VarIntVisitor {
61            phantom: PhantomData,
62        },
63    )
64}
65
66macro_rules! impl_var_int {
67    ($type:ty) => {
68        impl VarInt for $type {
69            fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
70            where
71                A: SeqAccess<'de>,
72            {
73                let mut out = 0;
74                let mut shift = 0u32;
75                while shift < <$type>::BITS {
76                    let Some(byte) = seq.next_element::<u8>()? else {
77                        return Err(A::Error::custom("Invalid Sequence"));
78                    };
79                    out |= ((byte & 0x7F) as Self) << shift;
80                    if byte & 0x80 == 0 {
81                        // Last byte should not have been truncated when it was
82                        // shifted to the left above.
83                        if (out >> shift) as u8 != byte {
84                            return Err(A::Error::custom("Last Byte Truncated"));
85                        }
86                        // Last byte can be zero only if there was only one
87                        // byte and the output is also zero.
88                        if byte == 0u8 && (shift != 0 || out != 0) {
89                            return Err(A::Error::custom("Invalid Trailing Zeros"));
90                        }
91                        return Ok(out);
92                    }
93                    shift += 7;
94                }
95                Err(A::Error::custom("Left Shift Overflows"))
96            }
97
98            fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
99            where
100                S: Serializer,
101            {
102                let bits = <$type>::BITS - self.leading_zeros();
103                let num_bytes = ((bits + 6) / 7).max(1) as usize;
104                let mut seq = serializer.serialize_tuple(num_bytes)?;
105                while self >= 0x80 {
106                    let byte = ((self & 0x7F) | 0x80) as u8;
107                    seq.serialize_element(&byte)?;
108                    self >>= 7;
109                }
110                seq.serialize_element(&(self as u8))?;
111                seq.end()
112            }
113        }
114    };
115}
116
117impl_var_int!(u16);
118impl_var_int!(u32);
119impl_var_int!(u64);
120
121#[cfg(test)]
122mod tests {
123    use {
124        rand::Rng,
125        serde_derive::{Deserialize, Serialize},
126        solana_short_vec::ShortU16,
127    };
128
129    #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
130    struct Dummy {
131        #[serde(with = "super")]
132        a: u32,
133        b: u64,
134        #[serde(with = "super")]
135        c: u64,
136        d: u32,
137    }
138
139    #[test]
140    fn test_serde_varint() {
141        assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
142        assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
143        let dummy = Dummy {
144            a: 698,
145            b: 370,
146            c: 146,
147            d: 796,
148        };
149        let bytes = bincode::serialize(&dummy).unwrap();
150        assert_eq!(bytes.len(), 16);
151        let other: Dummy = bincode::deserialize(&bytes).unwrap();
152        assert_eq!(other, dummy);
153    }
154
155    #[test]
156    fn test_serde_varint_zero() {
157        let dummy = Dummy {
158            a: 0,
159            b: 0,
160            c: 0,
161            d: 0,
162        };
163        let bytes = bincode::serialize(&dummy).unwrap();
164        assert_eq!(bytes.len(), 14);
165        let other: Dummy = bincode::deserialize(&bytes).unwrap();
166        assert_eq!(other, dummy);
167    }
168
169    #[test]
170    fn test_serde_varint_max() {
171        let dummy = Dummy {
172            a: u32::MAX,
173            b: u64::MAX,
174            c: u64::MAX,
175            d: u32::MAX,
176        };
177        let bytes = bincode::serialize(&dummy).unwrap();
178        assert_eq!(bytes.len(), 27);
179        let other: Dummy = bincode::deserialize(&bytes).unwrap();
180        assert_eq!(other, dummy);
181    }
182
183    #[test]
184    fn test_serde_varint_rand() {
185        let mut rng = rand::thread_rng();
186        for _ in 0..100_000 {
187            let dummy = Dummy {
188                a: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
189                b: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
190                c: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
191                d: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
192            };
193            let bytes = bincode::serialize(&dummy).unwrap();
194            let other: Dummy = bincode::deserialize(&bytes).unwrap();
195            assert_eq!(other, dummy);
196        }
197    }
198
199    #[test]
200    fn test_serde_varint_trailing_zeros() {
201        let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
202        let out = bincode::deserialize::<Dummy>(&buffer);
203        assert!(out.is_err());
204        assert_eq!(
205            format!("{out:?}"),
206            r#"Err(Custom("Invalid Trailing Zeros"))"#
207        );
208        let buffer = [0x80, 0x0];
209        let out = bincode::deserialize::<Dummy>(&buffer);
210        assert!(out.is_err());
211        assert_eq!(
212            format!("{out:?}"),
213            r#"Err(Custom("Invalid Trailing Zeros"))"#
214        );
215    }
216
217    #[test]
218    fn test_serde_varint_last_byte_truncated() {
219        let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
220        let out = bincode::deserialize::<Dummy>(&buffer);
221        assert!(out.is_err());
222        assert_eq!(format!("{out:?}"), r#"Err(Custom("Last Byte Truncated"))"#);
223    }
224
225    #[test]
226    fn test_serde_varint_shift_overflow() {
227        let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
228        let out = bincode::deserialize::<Dummy>(&buffer);
229        assert!(out.is_err());
230        assert_eq!(format!("{out:?}"), r#"Err(Custom("Left Shift Overflows"))"#);
231    }
232
233    #[test]
234    fn test_serde_varint_short_buffer() {
235        let buffer = [0x84, 0xdf, 0x96, 0xfa];
236        let out = bincode::deserialize::<Dummy>(&buffer);
237        assert!(out.is_err());
238        assert_eq!(format!("{out:?}"), r#"Err(Io(Kind(UnexpectedEof)))"#);
239    }
240
241    #[test]
242    fn test_serde_varint_fuzz() {
243        let mut rng = rand::thread_rng();
244        let mut buffer = [0u8; 36];
245        let mut num_errors = 0;
246        for _ in 0..200_000 {
247            rng.fill(&mut buffer[..]);
248            match bincode::deserialize::<Dummy>(&buffer) {
249                Err(_) => {
250                    num_errors += 1;
251                }
252                Ok(dummy) => {
253                    let bytes = bincode::serialize(&dummy).unwrap();
254                    assert_eq!(bytes, &buffer[..bytes.len()]);
255                }
256            }
257        }
258        assert!(
259            (3_000..23_000).contains(&num_errors),
260            "num errors: {num_errors}"
261        );
262    }
263
264    #[test]
265    fn test_serde_varint_cross_fuzz() {
266        #[derive(Serialize, Deserialize)]
267        struct U16(#[serde(with = "super")] u16);
268        let mut rng = rand::thread_rng();
269        let mut buffer = [0u8; 16];
270        let mut num_errors = 0;
271        for _ in 0..200_000 {
272            rng.fill(&mut buffer[..]);
273            match bincode::deserialize::<U16>(&buffer) {
274                Err(_) => {
275                    assert!(bincode::deserialize::<ShortU16>(&buffer).is_err());
276                    num_errors += 1;
277                }
278                Ok(k) => {
279                    let bytes = bincode::serialize(&k).unwrap();
280                    assert_eq!(bytes, &buffer[..bytes.len()]);
281                    assert_eq!(bytes, bincode::serialize(&ShortU16(k.0)).unwrap());
282                    assert_eq!(bincode::deserialize::<ShortU16>(&buffer).unwrap().0, k.0);
283                }
284            }
285        }
286        assert!(
287            (30_000..70_000).contains(&num_errors),
288            "num errors: {num_errors}"
289        );
290    }
291}