1#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![allow(clippy::arithmetic_side_effects)]
4#[cfg(feature = "frozen-abi")]
5use solana_frozen_abi_macro::AbiExample;
6use {
7 serde::{
8 de::{self, Deserializer, SeqAccess, Visitor},
9 ser::{self, SerializeTuple, Serializer},
10 Deserialize, Serialize,
11 },
12 std::{convert::TryFrom, fmt, marker::PhantomData},
13};
14
15#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
21pub struct ShortU16(pub u16);
22
23impl Serialize for ShortU16 {
24 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25 where
26 S: Serializer,
27 {
28 let mut seq = serializer.serialize_tuple(1)?;
31
32 let mut rem_val = self.0;
33 loop {
34 let mut elem = (rem_val & 0x7f) as u8;
35 rem_val >>= 7;
36 if rem_val == 0 {
37 seq.serialize_element(&elem)?;
38 break;
39 } else {
40 elem |= 0x80;
41 seq.serialize_element(&elem)?;
42 }
43 }
44 seq.end()
45 }
46}
47
48enum VisitStatus {
49 Done(u16),
50 More(u16),
51}
52
53#[derive(Debug)]
54enum VisitError {
55 TooLong(usize),
56 TooShort(usize),
57 Overflow(u32),
58 Alias,
59 ByteThreeContinues,
60}
61
62impl VisitError {
63 fn into_de_error<'de, A>(self) -> A::Error
64 where
65 A: SeqAccess<'de>,
66 {
67 match self {
68 VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
69 VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
70 VisitError::Overflow(val) => de::Error::invalid_value(
71 de::Unexpected::Unsigned(val as u64),
72 &"a value in the range [0, 65535]",
73 ),
74 VisitError::Alias => de::Error::invalid_value(
75 de::Unexpected::Other("alias encoding"),
76 &"strict form encoding",
77 ),
78 VisitError::ByteThreeContinues => de::Error::invalid_value(
79 de::Unexpected::Other("continue signal on byte-three"),
80 &"a terminal signal on or before byte-three",
81 ),
82 }
83 }
84}
85
86type VisitResult = Result<VisitStatus, VisitError>;
87
88const MAX_ENCODING_LENGTH: usize = 3;
89fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
90 if elem == 0 && nth_byte != 0 {
91 return Err(VisitError::Alias);
92 }
93
94 let val = u32::from(val);
95 let elem = u32::from(elem);
96 let elem_val = elem & 0x7f;
97 let elem_done = (elem & 0x80) == 0;
98
99 if nth_byte >= MAX_ENCODING_LENGTH {
100 return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
101 } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
102 return Err(VisitError::ByteThreeContinues);
103 }
104
105 let shift = u32::try_from(nth_byte)
106 .unwrap_or(u32::MAX)
107 .saturating_mul(7);
108 let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
109
110 let new_val = val | elem_val;
111 let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
112
113 if elem_done {
114 Ok(VisitStatus::Done(val))
115 } else {
116 Ok(VisitStatus::More(val))
117 }
118}
119
120struct ShortU16Visitor;
121
122impl<'de> Visitor<'de> for ShortU16Visitor {
123 type Value = ShortU16;
124
125 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
126 formatter.write_str("a ShortU16")
127 }
128
129 fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
130 where
131 A: SeqAccess<'de>,
132 {
133 let mut val: u16 = 0;
138 for nth_byte in 0..MAX_ENCODING_LENGTH {
139 let elem: u8 = seq.next_element()?.ok_or_else(|| {
140 VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
141 })?;
142 match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
143 VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
144 VisitStatus::More(new_val) => val = new_val,
145 }
146 }
147
148 Err(VisitError::ByteThreeContinues.into_de_error::<A>())
149 }
150}
151
152impl<'de> Deserialize<'de> for ShortU16 {
153 fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
154 where
155 D: Deserializer<'de>,
156 {
157 deserializer.deserialize_tuple(3, ShortU16Visitor)
158 }
159}
160
161pub fn serialize<S: Serializer, T: Serialize>(
167 elements: &[T],
168 serializer: S,
169) -> Result<S::Ok, S::Error> {
170 let mut seq = serializer.serialize_tuple(1)?;
173
174 let len = elements.len();
175 if len > u16::MAX as usize {
176 return Err(ser::Error::custom("length larger than u16"));
177 }
178 let short_len = ShortU16(len as u16);
179 seq.serialize_element(&short_len)?;
180
181 for element in elements {
182 seq.serialize_element(element)?;
183 }
184 seq.end()
185}
186
187struct ShortVecVisitor<T> {
188 _t: PhantomData<T>,
189}
190
191impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
192where
193 T: Deserialize<'de>,
194{
195 type Value = Vec<T>;
196
197 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
198 formatter.write_str("a Vec with a multi-byte length")
199 }
200
201 fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
202 where
203 A: SeqAccess<'de>,
204 {
205 let short_len: ShortU16 = seq
206 .next_element()?
207 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
208 let len = short_len.0 as usize;
209
210 let mut result = Vec::with_capacity(len);
211 for i in 0..len {
212 let elem = seq
213 .next_element()?
214 .ok_or_else(|| de::Error::invalid_length(i, &self))?;
215 result.push(elem);
216 }
217 Ok(result)
218 }
219}
220
221pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
227where
228 D: Deserializer<'de>,
229 T: Deserialize<'de>,
230{
231 let visitor = ShortVecVisitor { _t: PhantomData };
232 deserializer.deserialize_tuple(usize::MAX, visitor)
233}
234
235pub struct ShortVec<T>(pub Vec<T>);
236
237impl<T: Serialize> Serialize for ShortVec<T> {
238 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
239 where
240 S: Serializer,
241 {
242 serialize(&self.0, serializer)
243 }
244}
245
246impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
247 fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
248 where
249 D: Deserializer<'de>,
250 {
251 deserialize(deserializer).map(ShortVec)
252 }
253}
254
255#[allow(clippy::result_unit_err)]
257pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
258 let mut val = 0;
259 for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
260 match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
261 VisitStatus::More(new_val) => val = new_val,
262 VisitStatus::Done(new_val) => {
263 return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
264 }
265 }
266 }
267 Err(())
268}
269
270#[cfg(test)]
271mod tests {
272 use {
273 super::*,
274 assert_matches::assert_matches,
275 bincode::{deserialize, serialize},
276 };
277
278 fn encode_len(len: u16) -> Vec<u8> {
280 bincode::serialize(&ShortU16(len)).unwrap()
281 }
282
283 fn assert_len_encoding(len: u16, bytes: &[u8]) {
284 assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
285 assert_eq!(
286 decode_shortu16_len(bytes).unwrap(),
287 (usize::from(len), bytes.len()),
288 "unexpected usize decoding"
289 );
290 }
291
292 #[test]
293 fn test_short_vec_encode_len() {
294 assert_len_encoding(0x0, &[0x0]);
295 assert_len_encoding(0x7f, &[0x7f]);
296 assert_len_encoding(0x80, &[0x80, 0x01]);
297 assert_len_encoding(0xff, &[0xff, 0x01]);
298 assert_len_encoding(0x100, &[0x80, 0x02]);
299 assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
300 assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
301 }
302
303 fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
304 assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
305 }
306
307 fn assert_bad_deserialized_value(bytes: &[u8]) {
308 assert!(deserialize::<ShortU16>(bytes).is_err());
309 }
310
311 #[test]
312 fn test_deserialize() {
313 assert_good_deserialized_value(0x0000, &[0x00]);
314 assert_good_deserialized_value(0x007f, &[0x7f]);
315 assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
316 assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
317 assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
318 assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
319 assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
320 assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
321 assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
322
323 assert_bad_deserialized_value(&[0x80, 0x00]);
326 assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
327 assert_bad_deserialized_value(&[0xff, 0x00]);
329 assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
330 assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
332 assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
334 assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
336 assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
338 assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
340
341 assert_bad_deserialized_value(&[]);
343 assert_bad_deserialized_value(&[0x80]);
344
345 assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
347
348 assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
351 assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
353 }
354
355 #[test]
356 fn test_short_vec_u8() {
357 let vec = ShortVec(vec![4u8; 32]);
358 let bytes = serialize(&vec).unwrap();
359 assert_eq!(bytes.len(), vec.0.len() + 1);
360
361 let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
362 assert_eq!(vec.0, vec1.0);
363 }
364
365 #[test]
366 fn test_short_vec_u8_too_long() {
367 let vec = ShortVec(vec![4u8; u16::MAX as usize]);
368 assert_matches!(serialize(&vec), Ok(_));
369
370 let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
371 assert_matches!(serialize(&vec), Err(_));
372 }
373
374 #[test]
375 fn test_short_vec_json() {
376 let vec = ShortVec(vec![0, 1, 2]);
377 let s = serde_json::to_string(&vec).unwrap();
378 assert_eq!(s, "[[3],0,1,2]");
379 }
380
381 #[test]
382 fn test_short_vec_aliased_length() {
383 let bytes = [
384 0x81, 0x80, 0x00, 0x00,
386 ];
387 assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
388 }
389}