solana_accounts_db/tiered_storage/
byte_block.rs

1//! The utility structs and functions for writing byte blocks for the
2//! accounts db tiered storage.
3
4use {
5    crate::tiered_storage::{footer::AccountBlockFormat, meta::AccountMetaOptionalFields},
6    std::{
7        io::{Cursor, Read, Result as IoResult, Write},
8        mem, ptr,
9    },
10};
11
12/// The encoder for the byte-block.
13#[derive(Debug)]
14pub enum ByteBlockEncoder {
15    Raw(Cursor<Vec<u8>>),
16    Lz4(lz4::Encoder<Vec<u8>>),
17}
18
19/// The byte block writer.
20///
21/// All writes (`write_type` and `write`) will be buffered in the internal
22/// buffer of the ByteBlockWriter using the specified encoding.
23///
24/// To finalize all the writes, invoke `finish` to obtain the encoded byte
25/// block.
26#[derive(Debug)]
27pub struct ByteBlockWriter {
28    /// the encoder for the byte-block
29    encoder: ByteBlockEncoder,
30    /// the length of the raw data
31    len: usize,
32}
33
34impl ByteBlockWriter {
35    /// Create a ByteBlockWriter from the specified AccountBlockFormat.
36    pub fn new(encoding: AccountBlockFormat) -> Self {
37        Self {
38            encoder: match encoding {
39                AccountBlockFormat::AlignedRaw => ByteBlockEncoder::Raw(Cursor::new(Vec::new())),
40                AccountBlockFormat::Lz4 => ByteBlockEncoder::Lz4(
41                    lz4::EncoderBuilder::new()
42                        .level(0)
43                        .build(Vec::new())
44                        .unwrap(),
45                ),
46            },
47            len: 0,
48        }
49    }
50
51    /// Return the length of the raw data (i.e. after decoding).
52    pub fn raw_len(&self) -> usize {
53        self.len
54    }
55
56    /// Write plain ol' data to the internal buffer of the ByteBlockWriter instance
57    ///
58    /// Prefer this over `write_type()`, as it prevents some undefined behavior.
59    pub fn write_pod<T: bytemuck::NoUninit>(&mut self, value: &T) -> IoResult<usize> {
60        // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes.
61        unsafe { self.write_type(value) }
62    }
63
64    /// Write the specified typed instance to the internal buffer of
65    /// the ByteBlockWriter instance.
66    ///
67    /// Prefer `write_pod()` when possible, because `write_type()` may cause
68    /// undefined behavior if `value` contains uninitialized bytes.
69    ///
70    /// # Safety
71    ///
72    /// Caller must ensure casting T to bytes is safe.
73    /// Refer to the Safety sections in std::slice::from_raw_parts()
74    /// and bytemuck's Pod and NoUninit for more information.
75    pub unsafe fn write_type<T>(&mut self, value: &T) -> IoResult<usize> {
76        let size = mem::size_of::<T>();
77        let ptr = ptr::from_ref(value).cast();
78        // SAFETY: The caller ensures that `value` contains no uninitialized bytes,
79        // we ensure the size is safe by querying T directly,
80        // and Rust ensures all values are at least byte-aligned.
81        let slice = unsafe { std::slice::from_raw_parts(ptr, size) };
82        self.write(slice)?;
83        Ok(size)
84    }
85
86    /// Write all the Some fields of the specified AccountMetaOptionalFields.
87    ///
88    /// Note that the existence of each optional field is stored separately in
89    /// AccountMetaFlags.
90    pub fn write_optional_fields(
91        &mut self,
92        opt_fields: &AccountMetaOptionalFields,
93    ) -> IoResult<usize> {
94        let mut size = 0;
95        if let Some(rent_epoch) = opt_fields.rent_epoch {
96            size += self.write_pod(&rent_epoch)?;
97        }
98
99        debug_assert_eq!(size, opt_fields.size());
100
101        Ok(size)
102    }
103
104    /// Write the specified typed bytes to the internal buffer of the
105    /// ByteBlockWriter instance.
106    pub fn write(&mut self, buf: &[u8]) -> IoResult<()> {
107        match &mut self.encoder {
108            ByteBlockEncoder::Raw(cursor) => cursor.write_all(buf)?,
109            ByteBlockEncoder::Lz4(lz4_encoder) => lz4_encoder.write_all(buf)?,
110        };
111        self.len += buf.len();
112        Ok(())
113    }
114
115    /// Flush the internal byte buffer that collects all the previous writes
116    /// into an encoded byte array.
117    pub fn finish(self) -> IoResult<Vec<u8>> {
118        match self.encoder {
119            ByteBlockEncoder::Raw(cursor) => Ok(cursor.into_inner()),
120            ByteBlockEncoder::Lz4(lz4_encoder) => {
121                let (compressed_block, result) = lz4_encoder.finish();
122                result?;
123                Ok(compressed_block)
124            }
125        }
126    }
127}
128
129/// The util struct for reading byte blocks.
130pub struct ByteBlockReader;
131
132/// Reads the raw part of the input byte_block, at the specified offset, as type T.
133///
134/// Returns None if `offset` + size_of::<T>() exceeds the size of the input byte_block.
135///
136/// Type T must be plain ol' data to ensure no undefined behavior.
137pub fn read_pod<T: bytemuck::AnyBitPattern>(byte_block: &[u8], offset: usize) -> Option<&T> {
138    // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T.
139    unsafe { read_type(byte_block, offset) }
140}
141
142/// Reads the raw part of the input byte_block at the specified offset
143/// as type T.
144///
145/// If `offset` + size_of::<T>() exceeds the size of the input byte_block,
146/// then None will be returned.
147///
148/// Prefer `read_pod()` when possible, because `read_type()` may cause
149/// undefined behavior.
150///
151/// # Safety
152///
153/// Caller must ensure casting bytes to T is safe.
154/// Refer to the Safety sections in std::slice::from_raw_parts()
155/// and bytemuck's Pod and AnyBitPattern for more information.
156pub unsafe fn read_type<T>(byte_block: &[u8], offset: usize) -> Option<&T> {
157    let (next, overflow) = offset.overflowing_add(std::mem::size_of::<T>());
158    if overflow || next > byte_block.len() {
159        return None;
160    }
161    let ptr = byte_block[offset..].as_ptr().cast();
162    debug_assert!(ptr as usize % std::mem::align_of::<T>() == 0);
163    // SAFETY: The caller ensures it is safe to cast bytes to T,
164    // we ensure the size is safe by querying T directly,
165    // and we just checked above to ensure the ptr is aligned for T.
166    Some(unsafe { &*ptr })
167}
168
169impl ByteBlockReader {
170    /// Decode the input byte array using the specified format.
171    ///
172    /// Typically, the input byte array is the output of ByteBlockWriter::finish().
173    ///
174    /// Note that calling this function with AccountBlockFormat::AlignedRaw encoding
175    /// will result in panic as the input is already decoded.
176    pub fn decode(encoding: AccountBlockFormat, input: &[u8]) -> IoResult<Vec<u8>> {
177        match encoding {
178            AccountBlockFormat::Lz4 => {
179                let mut decoder = lz4::Decoder::new(input).unwrap();
180                let mut output = vec![];
181                decoder.read_to_end(&mut output)?;
182                Ok(output)
183            }
184            AccountBlockFormat::AlignedRaw => panic!("the input buffer is already decoded"),
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use {super::*, solana_sdk::stake_history::Epoch};
192
193    fn read_type_unaligned<T>(buffer: &[u8], offset: usize) -> (T, usize) {
194        let size = std::mem::size_of::<T>();
195        let (next, overflow) = offset.overflowing_add(size);
196        assert!(!overflow && next <= buffer.len());
197        let data = &buffer[offset..next];
198        let ptr = data.as_ptr().cast();
199
200        (unsafe { std::ptr::read_unaligned(ptr) }, next)
201    }
202
203    fn write_single(format: AccountBlockFormat) {
204        let mut writer = ByteBlockWriter::new(format);
205        let value: u32 = 42;
206
207        writer.write_pod(&value).unwrap();
208        assert_eq!(writer.raw_len(), mem::size_of::<u32>());
209
210        let buffer = writer.finish().unwrap();
211
212        let decoded_buffer = if format == AccountBlockFormat::AlignedRaw {
213            buffer
214        } else {
215            ByteBlockReader::decode(format, &buffer).unwrap()
216        };
217
218        assert_eq!(decoded_buffer.len(), mem::size_of::<u32>());
219
220        let (value_from_buffer, next) = read_type_unaligned::<u32>(&decoded_buffer, 0);
221        assert_eq!(value, value_from_buffer);
222
223        if format != AccountBlockFormat::AlignedRaw {
224            assert_eq!(next, mem::size_of::<u32>());
225        }
226    }
227
228    #[test]
229    fn test_write_single_raw_format() {
230        write_single(AccountBlockFormat::AlignedRaw);
231    }
232
233    #[test]
234    fn test_write_single_encoded_format() {
235        write_single(AccountBlockFormat::Lz4);
236    }
237
238    #[derive(Debug, PartialEq)]
239    struct TestMetaStruct {
240        lamports: u64,
241        owner_index: u32,
242        data_len: usize,
243    }
244
245    fn write_multiple(format: AccountBlockFormat) {
246        let mut writer = ByteBlockWriter::new(format);
247        let test_metas: Vec<TestMetaStruct> = vec![
248            TestMetaStruct {
249                lamports: 10,
250                owner_index: 0,
251                data_len: 100,
252            },
253            TestMetaStruct {
254                lamports: 20,
255                owner_index: 1,
256                data_len: 200,
257            },
258            TestMetaStruct {
259                lamports: 30,
260                owner_index: 2,
261                data_len: 300,
262            },
263        ];
264        let test_data1 = [11u8; 100];
265        let test_data2 = [22u8; 200];
266        let test_data3 = [33u8; 300];
267
268        // Write the above meta and data in an interleaving way.
269        unsafe {
270            writer.write_type(&test_metas[0]).unwrap();
271            writer.write_type(&test_data1).unwrap();
272            writer.write_type(&test_metas[1]).unwrap();
273            writer.write_type(&test_data2).unwrap();
274            writer.write_type(&test_metas[2]).unwrap();
275            writer.write_type(&test_data3).unwrap();
276        }
277        assert_eq!(
278            writer.raw_len(),
279            mem::size_of::<TestMetaStruct>() * 3
280                + mem::size_of_val(&test_data1)
281                + mem::size_of_val(&test_data2)
282                + mem::size_of_val(&test_data3)
283        );
284
285        let buffer = writer.finish().unwrap();
286
287        let decoded_buffer = if format == AccountBlockFormat::AlignedRaw {
288            buffer
289        } else {
290            ByteBlockReader::decode(format, &buffer).unwrap()
291        };
292
293        assert_eq!(
294            decoded_buffer.len(),
295            mem::size_of::<TestMetaStruct>() * 3
296                + mem::size_of_val(&test_data1)
297                + mem::size_of_val(&test_data2)
298                + mem::size_of_val(&test_data3)
299        );
300
301        // verify meta1 and its data
302        let (meta1_from_buffer, next1) = read_type_unaligned::<TestMetaStruct>(&decoded_buffer, 0);
303        assert_eq!(test_metas[0], meta1_from_buffer);
304        assert_eq!(
305            test_data1,
306            decoded_buffer[next1..][..meta1_from_buffer.data_len]
307        );
308
309        // verify meta2 and its data
310        let (meta2_from_buffer, next2) = read_type_unaligned::<TestMetaStruct>(
311            &decoded_buffer,
312            next1 + meta1_from_buffer.data_len,
313        );
314        assert_eq!(test_metas[1], meta2_from_buffer);
315        assert_eq!(
316            test_data2,
317            decoded_buffer[next2..][..meta2_from_buffer.data_len]
318        );
319
320        // verify meta3 and its data
321        let (meta3_from_buffer, next3) = read_type_unaligned::<TestMetaStruct>(
322            &decoded_buffer,
323            next2 + meta2_from_buffer.data_len,
324        );
325        assert_eq!(test_metas[2], meta3_from_buffer);
326        assert_eq!(
327            test_data3,
328            decoded_buffer[next3..][..meta3_from_buffer.data_len]
329        );
330    }
331
332    #[test]
333    fn test_write_multiple_raw_format() {
334        write_multiple(AccountBlockFormat::AlignedRaw);
335    }
336
337    #[test]
338    fn test_write_multiple_lz4_format() {
339        write_multiple(AccountBlockFormat::Lz4);
340    }
341
342    fn write_optional_fields(format: AccountBlockFormat) {
343        let mut test_epoch = 5432312;
344
345        let mut writer = ByteBlockWriter::new(format);
346        let mut opt_fields_vec = vec![];
347        let mut some_count = 0;
348
349        // prepare a vector of optional fields that contains all combinations
350        // of Some and None.
351        for rent_epoch in [None, Some(test_epoch)] {
352            some_count += rent_epoch.iter().count();
353
354            opt_fields_vec.push(AccountMetaOptionalFields { rent_epoch });
355            test_epoch += 1;
356        }
357
358        // write all the combinations of the optional fields
359        let mut expected_size = 0;
360        for opt_fields in &opt_fields_vec {
361            writer.write_optional_fields(opt_fields).unwrap();
362            expected_size += opt_fields.size();
363        }
364
365        let buffer = writer.finish().unwrap();
366        let decoded_buffer = if format == AccountBlockFormat::AlignedRaw {
367            buffer
368        } else {
369            ByteBlockReader::decode(format, &buffer).unwrap()
370        };
371
372        // first, verify whether the size of the decoded data matches the
373        // expected size.
374        assert_eq!(decoded_buffer.len(), expected_size);
375
376        // verify the correctness of the written optional fields
377        let mut verified_count = 0;
378        let mut offset = 0;
379        for opt_fields in &opt_fields_vec {
380            if let Some(expected_rent_epoch) = opt_fields.rent_epoch {
381                let rent_epoch = read_pod::<Epoch>(&decoded_buffer, offset).unwrap();
382                assert_eq!(*rent_epoch, expected_rent_epoch);
383                verified_count += 1;
384                offset += std::mem::size_of::<Epoch>();
385            }
386        }
387
388        // make sure the number of Some fields matches the number of fields we
389        // have verified.
390        assert_eq!(some_count, verified_count);
391    }
392
393    #[test]
394    fn test_write_optionl_fields_raw_format() {
395        write_optional_fields(AccountBlockFormat::AlignedRaw);
396    }
397
398    #[test]
399    fn test_write_optional_fields_lz4_format() {
400        write_optional_fields(AccountBlockFormat::Lz4);
401    }
402}