1#![allow(clippy::arithmetic_side_effects)]
4use {
5 serde::{
6 de::{Error as _, SeqAccess, Visitor},
7 ser::SerializeTuple,
8 Deserializer, Serializer,
9 },
10 std::{fmt, marker::PhantomData},
11};
12
13pub trait VarInt: Sized {
14 fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
15 where
16 A: SeqAccess<'de>;
17
18 fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
19 where
20 S: Serializer;
21}
22
23struct VarIntVisitor<T> {
24 phantom: PhantomData<T>,
25}
26
27impl<'de, T> Visitor<'de> for VarIntVisitor<T>
28where
29 T: VarInt,
30{
31 type Value = T;
32
33 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34 formatter.write_str("a VarInt")
35 }
36
37 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
38 where
39 A: SeqAccess<'de>,
40 {
41 T::visit_seq(seq)
42 }
43}
44
45pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
46where
47 T: Copy + VarInt,
48 S: Serializer,
49{
50 (*value).serialize(serializer)
51}
52
53pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
54where
55 D: Deserializer<'de>,
56 T: VarInt,
57{
58 deserializer.deserialize_tuple(
59 (std::mem::size_of::<T>() * 8 + 6) / 7,
60 VarIntVisitor {
61 phantom: PhantomData,
62 },
63 )
64}
65
66macro_rules! impl_var_int {
67 ($type:ty) => {
68 impl VarInt for $type {
69 fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
70 where
71 A: SeqAccess<'de>,
72 {
73 let mut out = 0;
74 let mut shift = 0u32;
75 while shift < <$type>::BITS {
76 let Some(byte) = seq.next_element::<u8>()? else {
77 return Err(A::Error::custom("Invalid Sequence"));
78 };
79 out |= ((byte & 0x7F) as Self) << shift;
80 if byte & 0x80 == 0 {
81 if (out >> shift) as u8 != byte {
84 return Err(A::Error::custom("Last Byte Truncated"));
85 }
86 if byte == 0u8 && (shift != 0 || out != 0) {
89 return Err(A::Error::custom("Invalid Trailing Zeros"));
90 }
91 return Ok(out);
92 }
93 shift += 7;
94 }
95 Err(A::Error::custom("Left Shift Overflows"))
96 }
97
98 fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
99 where
100 S: Serializer,
101 {
102 let bits = <$type>::BITS - self.leading_zeros();
103 let num_bytes = ((bits + 6) / 7).max(1) as usize;
104 let mut seq = serializer.serialize_tuple(num_bytes)?;
105 while self >= 0x80 {
106 let byte = ((self & 0x7F) | 0x80) as u8;
107 seq.serialize_element(&byte)?;
108 self >>= 7;
109 }
110 seq.serialize_element(&(self as u8))?;
111 seq.end()
112 }
113 }
114 };
115}
116
117impl_var_int!(u16);
118impl_var_int!(u32);
119impl_var_int!(u64);
120
121#[cfg(test)]
122mod tests {
123 use {
124 rand::Rng,
125 serde_derive::{Deserialize, Serialize},
126 solana_short_vec::ShortU16,
127 };
128
129 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
130 struct Dummy {
131 #[serde(with = "super")]
132 a: u32,
133 b: u64,
134 #[serde(with = "super")]
135 c: u64,
136 d: u32,
137 }
138
139 #[test]
140 fn test_serde_varint() {
141 assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
142 assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
143 let dummy = Dummy {
144 a: 698,
145 b: 370,
146 c: 146,
147 d: 796,
148 };
149 let bytes = bincode::serialize(&dummy).unwrap();
150 assert_eq!(bytes.len(), 16);
151 let other: Dummy = bincode::deserialize(&bytes).unwrap();
152 assert_eq!(other, dummy);
153 }
154
155 #[test]
156 fn test_serde_varint_zero() {
157 let dummy = Dummy {
158 a: 0,
159 b: 0,
160 c: 0,
161 d: 0,
162 };
163 let bytes = bincode::serialize(&dummy).unwrap();
164 assert_eq!(bytes.len(), 14);
165 let other: Dummy = bincode::deserialize(&bytes).unwrap();
166 assert_eq!(other, dummy);
167 }
168
169 #[test]
170 fn test_serde_varint_max() {
171 let dummy = Dummy {
172 a: u32::MAX,
173 b: u64::MAX,
174 c: u64::MAX,
175 d: u32::MAX,
176 };
177 let bytes = bincode::serialize(&dummy).unwrap();
178 assert_eq!(bytes.len(), 27);
179 let other: Dummy = bincode::deserialize(&bytes).unwrap();
180 assert_eq!(other, dummy);
181 }
182
183 #[test]
184 fn test_serde_varint_rand() {
185 let mut rng = rand::thread_rng();
186 for _ in 0..100_000 {
187 let dummy = Dummy {
188 a: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
189 b: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
190 c: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
191 d: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
192 };
193 let bytes = bincode::serialize(&dummy).unwrap();
194 let other: Dummy = bincode::deserialize(&bytes).unwrap();
195 assert_eq!(other, dummy);
196 }
197 }
198
199 #[test]
200 fn test_serde_varint_trailing_zeros() {
201 let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
202 let out = bincode::deserialize::<Dummy>(&buffer);
203 assert!(out.is_err());
204 assert_eq!(
205 format!("{out:?}"),
206 r#"Err(Custom("Invalid Trailing Zeros"))"#
207 );
208 let buffer = [0x80, 0x0];
209 let out = bincode::deserialize::<Dummy>(&buffer);
210 assert!(out.is_err());
211 assert_eq!(
212 format!("{out:?}"),
213 r#"Err(Custom("Invalid Trailing Zeros"))"#
214 );
215 }
216
217 #[test]
218 fn test_serde_varint_last_byte_truncated() {
219 let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
220 let out = bincode::deserialize::<Dummy>(&buffer);
221 assert!(out.is_err());
222 assert_eq!(format!("{out:?}"), r#"Err(Custom("Last Byte Truncated"))"#);
223 }
224
225 #[test]
226 fn test_serde_varint_shift_overflow() {
227 let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
228 let out = bincode::deserialize::<Dummy>(&buffer);
229 assert!(out.is_err());
230 assert_eq!(format!("{out:?}"), r#"Err(Custom("Left Shift Overflows"))"#);
231 }
232
233 #[test]
234 fn test_serde_varint_short_buffer() {
235 let buffer = [0x84, 0xdf, 0x96, 0xfa];
236 let out = bincode::deserialize::<Dummy>(&buffer);
237 assert!(out.is_err());
238 assert_eq!(format!("{out:?}"), r#"Err(Io(Kind(UnexpectedEof)))"#);
239 }
240
241 #[test]
242 fn test_serde_varint_fuzz() {
243 let mut rng = rand::thread_rng();
244 let mut buffer = [0u8; 36];
245 let mut num_errors = 0;
246 for _ in 0..200_000 {
247 rng.fill(&mut buffer[..]);
248 match bincode::deserialize::<Dummy>(&buffer) {
249 Err(_) => {
250 num_errors += 1;
251 }
252 Ok(dummy) => {
253 let bytes = bincode::serialize(&dummy).unwrap();
254 assert_eq!(bytes, &buffer[..bytes.len()]);
255 }
256 }
257 }
258 assert!(
259 (3_000..23_000).contains(&num_errors),
260 "num errors: {num_errors}"
261 );
262 }
263
264 #[test]
265 fn test_serde_varint_cross_fuzz() {
266 #[derive(Serialize, Deserialize)]
267 struct U16(#[serde(with = "super")] u16);
268 let mut rng = rand::thread_rng();
269 let mut buffer = [0u8; 16];
270 let mut num_errors = 0;
271 for _ in 0..200_000 {
272 rng.fill(&mut buffer[..]);
273 match bincode::deserialize::<U16>(&buffer) {
274 Err(_) => {
275 assert!(bincode::deserialize::<ShortU16>(&buffer).is_err());
276 num_errors += 1;
277 }
278 Ok(k) => {
279 let bytes = bincode::serialize(&k).unwrap();
280 assert_eq!(bytes, &buffer[..bytes.len()]);
281 assert_eq!(bytes, bincode::serialize(&ShortU16(k.0)).unwrap());
282 assert_eq!(bincode::deserialize::<ShortU16>(&buffer).unwrap().0, k.0);
283 }
284 }
285 }
286 assert!(
287 (30_000..70_000).contains(&num_errors),
288 "num errors: {num_errors}"
289 );
290 }
291}