solana_short_vec/
lib.rs

1//! Compact serde-encoding of vectors with small length.
2#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![allow(clippy::arithmetic_side_effects)]
4#[cfg(feature = "frozen-abi")]
5use solana_frozen_abi_macro::AbiExample;
6use {
7    serde::{
8        de::{self, Deserializer, SeqAccess, Visitor},
9        ser::{self, SerializeTuple, Serializer},
10        Deserialize, Serialize,
11    },
12    std::{convert::TryFrom, fmt, marker::PhantomData},
13};
14
15/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
16/// 0x7f, the top bit is set and the remaining value is stored in the next
17/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
18/// byte may only have the 2 least-significant bits set, otherwise the encoded
19/// value will overflow the u16.
20#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
21pub struct ShortU16(pub u16);
22
23impl Serialize for ShortU16 {
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: Serializer,
27    {
28        // Pass a non-zero value to serialize_tuple() so that serde_json will
29        // generate an open bracket.
30        let mut seq = serializer.serialize_tuple(1)?;
31
32        let mut rem_val = self.0;
33        loop {
34            let mut elem = (rem_val & 0x7f) as u8;
35            rem_val >>= 7;
36            if rem_val == 0 {
37                seq.serialize_element(&elem)?;
38                break;
39            } else {
40                elem |= 0x80;
41                seq.serialize_element(&elem)?;
42            }
43        }
44        seq.end()
45    }
46}
47
48enum VisitStatus {
49    Done(u16),
50    More(u16),
51}
52
53#[derive(Debug)]
54enum VisitError {
55    TooLong(usize),
56    TooShort(usize),
57    Overflow(u32),
58    Alias,
59    ByteThreeContinues,
60}
61
62impl VisitError {
63    fn into_de_error<'de, A>(self) -> A::Error
64    where
65        A: SeqAccess<'de>,
66    {
67        match self {
68            VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
69            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
70            VisitError::Overflow(val) => de::Error::invalid_value(
71                de::Unexpected::Unsigned(val as u64),
72                &"a value in the range [0, 65535]",
73            ),
74            VisitError::Alias => de::Error::invalid_value(
75                de::Unexpected::Other("alias encoding"),
76                &"strict form encoding",
77            ),
78            VisitError::ByteThreeContinues => de::Error::invalid_value(
79                de::Unexpected::Other("continue signal on byte-three"),
80                &"a terminal signal on or before byte-three",
81            ),
82        }
83    }
84}
85
86type VisitResult = Result<VisitStatus, VisitError>;
87
88const MAX_ENCODING_LENGTH: usize = 3;
89fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
90    if elem == 0 && nth_byte != 0 {
91        return Err(VisitError::Alias);
92    }
93
94    let val = u32::from(val);
95    let elem = u32::from(elem);
96    let elem_val = elem & 0x7f;
97    let elem_done = (elem & 0x80) == 0;
98
99    if nth_byte >= MAX_ENCODING_LENGTH {
100        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
101    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
102        return Err(VisitError::ByteThreeContinues);
103    }
104
105    let shift = u32::try_from(nth_byte)
106        .unwrap_or(u32::MAX)
107        .saturating_mul(7);
108    let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
109
110    let new_val = val | elem_val;
111    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
112
113    if elem_done {
114        Ok(VisitStatus::Done(val))
115    } else {
116        Ok(VisitStatus::More(val))
117    }
118}
119
120struct ShortU16Visitor;
121
122impl<'de> Visitor<'de> for ShortU16Visitor {
123    type Value = ShortU16;
124
125    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
126        formatter.write_str("a ShortU16")
127    }
128
129    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
130    where
131        A: SeqAccess<'de>,
132    {
133        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
134        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
135        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
136        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
137        let mut val: u16 = 0;
138        for nth_byte in 0..MAX_ENCODING_LENGTH {
139            let elem: u8 = seq.next_element()?.ok_or_else(|| {
140                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
141            })?;
142            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
143                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
144                VisitStatus::More(new_val) => val = new_val,
145            }
146        }
147
148        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
149    }
150}
151
152impl<'de> Deserialize<'de> for ShortU16 {
153    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
154    where
155        D: Deserializer<'de>,
156    {
157        deserializer.deserialize_tuple(3, ShortU16Visitor)
158    }
159}
160
161/// If you don't want to use the ShortVec newtype, you can do ShortVec
162/// serialization on an ordinary vector with the following field annotation:
163///
164/// #[serde(with = "short_vec")]
165///
166pub fn serialize<S: Serializer, T: Serialize>(
167    elements: &[T],
168    serializer: S,
169) -> Result<S::Ok, S::Error> {
170    // Pass a non-zero value to serialize_tuple() so that serde_json will
171    // generate an open bracket.
172    let mut seq = serializer.serialize_tuple(1)?;
173
174    let len = elements.len();
175    if len > u16::MAX as usize {
176        return Err(ser::Error::custom("length larger than u16"));
177    }
178    let short_len = ShortU16(len as u16);
179    seq.serialize_element(&short_len)?;
180
181    for element in elements {
182        seq.serialize_element(element)?;
183    }
184    seq.end()
185}
186
187struct ShortVecVisitor<T> {
188    _t: PhantomData<T>,
189}
190
191impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
192where
193    T: Deserialize<'de>,
194{
195    type Value = Vec<T>;
196
197    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
198        formatter.write_str("a Vec with a multi-byte length")
199    }
200
201    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
202    where
203        A: SeqAccess<'de>,
204    {
205        let short_len: ShortU16 = seq
206            .next_element()?
207            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
208        let len = short_len.0 as usize;
209
210        let mut result = Vec::with_capacity(len);
211        for i in 0..len {
212            let elem = seq
213                .next_element()?
214                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
215            result.push(elem);
216        }
217        Ok(result)
218    }
219}
220
221/// If you don't want to use the ShortVec newtype, you can do ShortVec
222/// deserialization on an ordinary vector with the following field annotation:
223///
224/// #[serde(with = "short_vec")]
225///
226pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
227where
228    D: Deserializer<'de>,
229    T: Deserialize<'de>,
230{
231    let visitor = ShortVecVisitor { _t: PhantomData };
232    deserializer.deserialize_tuple(usize::MAX, visitor)
233}
234
235pub struct ShortVec<T>(pub Vec<T>);
236
237impl<T: Serialize> Serialize for ShortVec<T> {
238    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
239    where
240        S: Serializer,
241    {
242        serialize(&self.0, serializer)
243    }
244}
245
246impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
247    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
248    where
249        D: Deserializer<'de>,
250    {
251        deserialize(deserializer).map(ShortVec)
252    }
253}
254
255/// Return the decoded value and how many bytes it consumed.
256#[allow(clippy::result_unit_err)]
257pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
258    let mut val = 0;
259    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
260        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
261            VisitStatus::More(new_val) => val = new_val,
262            VisitStatus::Done(new_val) => {
263                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
264            }
265        }
266    }
267    Err(())
268}
269
270#[cfg(test)]
271mod tests {
272    use {
273        super::*,
274        assert_matches::assert_matches,
275        bincode::{deserialize, serialize},
276    };
277
278    /// Return the serialized length.
279    fn encode_len(len: u16) -> Vec<u8> {
280        bincode::serialize(&ShortU16(len)).unwrap()
281    }
282
283    fn assert_len_encoding(len: u16, bytes: &[u8]) {
284        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
285        assert_eq!(
286            decode_shortu16_len(bytes).unwrap(),
287            (usize::from(len), bytes.len()),
288            "unexpected usize decoding"
289        );
290    }
291
292    #[test]
293    fn test_short_vec_encode_len() {
294        assert_len_encoding(0x0, &[0x0]);
295        assert_len_encoding(0x7f, &[0x7f]);
296        assert_len_encoding(0x80, &[0x80, 0x01]);
297        assert_len_encoding(0xff, &[0xff, 0x01]);
298        assert_len_encoding(0x100, &[0x80, 0x02]);
299        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
300        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
301    }
302
303    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
304        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
305    }
306
307    fn assert_bad_deserialized_value(bytes: &[u8]) {
308        assert!(deserialize::<ShortU16>(bytes).is_err());
309    }
310
311    #[test]
312    fn test_deserialize() {
313        assert_good_deserialized_value(0x0000, &[0x00]);
314        assert_good_deserialized_value(0x007f, &[0x7f]);
315        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
316        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
317        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
318        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
319        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
320        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
321        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
322
323        // aliases
324        // 0x0000
325        assert_bad_deserialized_value(&[0x80, 0x00]);
326        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
327        // 0x007f
328        assert_bad_deserialized_value(&[0xff, 0x00]);
329        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
330        // 0x0080
331        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
332        // 0x00ff
333        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
334        // 0x0100
335        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
336        // 0x07ff
337        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
338        // 0x3fff
339        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
340
341        // too short
342        assert_bad_deserialized_value(&[]);
343        assert_bad_deserialized_value(&[0x80]);
344
345        // too long
346        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
347
348        // too large
349        // 0x0001_0000
350        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
351        // 0x0001_8000
352        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
353    }
354
355    #[test]
356    fn test_short_vec_u8() {
357        let vec = ShortVec(vec![4u8; 32]);
358        let bytes = serialize(&vec).unwrap();
359        assert_eq!(bytes.len(), vec.0.len() + 1);
360
361        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
362        assert_eq!(vec.0, vec1.0);
363    }
364
365    #[test]
366    fn test_short_vec_u8_too_long() {
367        let vec = ShortVec(vec![4u8; u16::MAX as usize]);
368        assert_matches!(serialize(&vec), Ok(_));
369
370        let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
371        assert_matches!(serialize(&vec), Err(_));
372    }
373
374    #[test]
375    fn test_short_vec_json() {
376        let vec = ShortVec(vec![0, 1, 2]);
377        let s = serde_json::to_string(&vec).unwrap();
378        assert_eq!(s, "[[3],0,1,2]");
379    }
380
381    #[test]
382    fn test_short_vec_aliased_length() {
383        let bytes = [
384            0x81, 0x80, 0x00, // 3-byte alias of 1
385            0x00,
386        ];
387        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
388    }
389}