polkavm_common/
varint.rs

1#[inline]
2fn get_varint_length(leading_zeros: u32) -> u32 {
3    let bits_required = 32 - leading_zeros;
4    let x = bits_required >> 3;
5    ((x + bits_required) ^ x) >> 3
6}
7
8pub const MAX_VARINT_LENGTH: usize = 5;
9
10#[inline]
11pub(crate) fn read_varint(input: &[u8], first_byte: u8) -> Option<(usize, u32)> {
12    let length = (!first_byte).leading_zeros();
13    let upper_mask = 0b11111111_u32 >> length;
14    let upper_bits = (upper_mask & u32::from(first_byte)).wrapping_shl(length * 8);
15    let input = input.get(..length as usize)?;
16    let value = match input.len() {
17        0 => upper_bits,
18        1 => upper_bits | u32::from(input[0]),
19        2 => upper_bits | u32::from(u16::from_le_bytes([input[0], input[1]])),
20        3 => upper_bits | u32::from_le_bytes([input[0], input[1], input[2], 0]),
21        4 => upper_bits | u32::from_le_bytes([input[0], input[1], input[2], input[3]]),
22        _ => return None,
23    };
24
25    Some((length as usize, value))
26}
27
28#[inline]
29pub fn write_varint(value: u32, buffer: &mut [u8]) -> usize {
30    let varint_length = get_varint_length(value.leading_zeros());
31    match varint_length {
32        0 => buffer[0] = value as u8,
33        1 => {
34            buffer[0] = 0b10000000 | (value >> 8) as u8;
35            let bytes = value.to_le_bytes();
36            buffer[1] = bytes[0];
37        }
38        2 => {
39            buffer[0] = 0b11000000 | (value >> 16) as u8;
40            let bytes = value.to_le_bytes();
41            buffer[1] = bytes[0];
42            buffer[2] = bytes[1];
43        }
44        3 => {
45            buffer[0] = 0b11100000 | (value >> 24) as u8;
46            let bytes = value.to_le_bytes();
47            buffer[1] = bytes[0];
48            buffer[2] = bytes[1];
49            buffer[3] = bytes[2];
50        }
51        4 => {
52            buffer[0] = 0b11110000;
53            let bytes = value.to_le_bytes();
54            buffer[1] = bytes[0];
55            buffer[2] = bytes[1];
56            buffer[3] = bytes[2];
57            buffer[4] = bytes[3];
58        }
59        _ => unreachable!(),
60    }
61
62    varint_length as usize + 1
63}
64
65#[cfg(test)]
66proptest::proptest! {
67    #[allow(clippy::ignored_unit_patterns)]
68    #[test]
69    fn varint_serialization(value in 0u32..=0xffffffff) {
70        let mut buffer = [0; MAX_VARINT_LENGTH];
71        let length = write_varint(value, &mut buffer);
72        let (parsed_length, parsed_value) = read_varint(&buffer[1..], buffer[0]).unwrap();
73        assert_eq!(parsed_value, value, "value mismatch");
74        assert_eq!(parsed_length + 1, length, "length mismatch")
75    }
76}
77
78static LENGTH_TO_SHIFT: [u32; 256] = {
79    let mut output = [0; 256];
80    let mut length = 0_u32;
81    while length < 256 {
82        let shift = match length {
83            0 => 32,
84            1 => 24,
85            2 => 16,
86            3 => 8,
87            _ => 0,
88        };
89
90        output[length as usize] = shift;
91        length += 1;
92    }
93    output
94};
95
96#[inline(always)]
97pub(crate) fn read_simple_varint(chunk: u32, length: u32) -> u32 {
98    let shift = LENGTH_TO_SHIFT[length as usize];
99    (((u64::from(chunk) << shift) as u32 as i32).wrapping_shr(shift)) as u32
100}
101
102#[inline]
103fn get_bytes_required(value: u32) -> u32 {
104    let zeros = value.leading_zeros();
105    if zeros == 32 {
106        0
107    } else if zeros > 24 {
108        1
109    } else if zeros > 16 {
110        2
111    } else if zeros > 8 {
112        3
113    } else if zeros != 0 {
114        4
115    } else {
116        let ones = value.leading_ones();
117        if ones > 24 {
118            1
119        } else if ones > 16 {
120            2
121        } else if ones > 8 {
122            3
123        } else {
124            4
125        }
126    }
127}
128
129#[inline]
130pub(crate) fn write_simple_varint(value: u32, buffer: &mut [u8]) -> usize {
131    let varint_length = get_bytes_required(value);
132    match varint_length {
133        0 => {}
134        1 => {
135            buffer[0] = value as u8;
136        }
137        2 => {
138            let bytes = value.to_le_bytes();
139            buffer[0] = bytes[0];
140            buffer[1] = bytes[1];
141        }
142        3 => {
143            let bytes = value.to_le_bytes();
144            buffer[0] = bytes[0];
145            buffer[1] = bytes[1];
146            buffer[2] = bytes[2];
147        }
148        4 => {
149            let bytes = value.to_le_bytes();
150            buffer[0] = bytes[0];
151            buffer[1] = bytes[1];
152            buffer[2] = bytes[2];
153            buffer[3] = bytes[3];
154        }
155        _ => unreachable!(),
156    }
157
158    varint_length as usize
159}
160
161#[test]
162fn test_simple_varint() {
163    assert_eq!(get_bytes_required(0b00000000_00000000_00000000_00000000), 0);
164    assert_eq!(get_bytes_required(0b00000000_00000000_00000000_00000001), 1);
165    assert_eq!(get_bytes_required(0b00000000_00000000_00000000_01000001), 1);
166    assert_eq!(get_bytes_required(0b00000000_00000000_00000000_10000000), 2);
167    assert_eq!(get_bytes_required(0b00000000_00000000_00000000_11111111), 2);
168    assert_eq!(get_bytes_required(0b00000000_00000000_00000001_00000000), 2);
169    assert_eq!(get_bytes_required(0b00000000_00000000_01000000_00000000), 2);
170    assert_eq!(get_bytes_required(0b00000000_00000000_10000000_00000000), 3);
171    assert_eq!(get_bytes_required(0b00000000_00000001_00000000_00000000), 3);
172    assert_eq!(get_bytes_required(0b00000000_01000000_00000000_00000000), 3);
173    assert_eq!(get_bytes_required(0b00000000_10000000_00000000_00000000), 4);
174    assert_eq!(get_bytes_required(0b00000001_00000000_00000000_00000000), 4);
175    assert_eq!(get_bytes_required(0b10000000_00000000_00000000_00000000), 4);
176    assert_eq!(get_bytes_required(0b11111111_11111111_11111111_11111111), 1);
177    assert_eq!(get_bytes_required(0b10111111_11111111_11111111_11111111), 4);
178    assert_eq!(get_bytes_required(0b11111110_11111111_11111111_11111111), 4);
179    assert_eq!(get_bytes_required(0b11111111_01111111_11111111_11111111), 4);
180    assert_eq!(get_bytes_required(0b11111111_10111111_11111111_11111111), 3);
181    assert_eq!(get_bytes_required(0b11111111_11111110_11111111_11111111), 3);
182    assert_eq!(get_bytes_required(0b11111111_11111111_01111111_11111111), 3);
183    assert_eq!(get_bytes_required(0b11111111_11111111_10111111_11111111), 2);
184    assert_eq!(get_bytes_required(0b11111111_11111111_11111110_11111111), 2);
185    assert_eq!(get_bytes_required(0b11111111_11111111_11111111_01111111), 2);
186    assert_eq!(get_bytes_required(0b11111111_11111111_11111111_10111111), 1);
187
188    assert_eq!(read_simple_varint(0x000000ff, 1), 0xffffffff);
189    assert_eq!(read_simple_varint(0x555555ff, 1), 0xffffffff);
190    assert_eq!(read_simple_varint(0xaaaaaaff, 1), 0xffffffff);
191    assert_eq!(read_simple_varint(0xffffffff, 1), 0xffffffff);
192
193    assert_eq!(read_simple_varint(0x000000ff, 0), 0);
194    assert_eq!(read_simple_varint(0x555555ff, 0), 0);
195    assert_eq!(read_simple_varint(0xaaaaaaff, 0), 0);
196    assert_eq!(read_simple_varint(0xffffffff, 0), 0);
197}
198
199#[cfg(test)]
200proptest::proptest! {
201    #[allow(clippy::ignored_unit_patterns)]
202    #[test]
203    fn proptest_simple_varint(value in 0u32..=0xffffffff) {
204        fn read_simple_varint_from_slice(input: [u8; 4], length: usize) -> u32 {
205            let chunk = u32::from_le_bytes(input);
206            read_simple_varint(chunk, length as u32)
207        }
208
209        for fill_byte in [0x00, 0x55, 0xaa, 0xff] {
210            let mut t = [fill_byte; 4];
211            let length = write_simple_varint(value, &mut t);
212            assert_eq!(read_simple_varint_from_slice(t, length), value, "value mismatch");
213        }
214    }
215}