solana_accounts_db/tiered_storage/
file.rs

1use {
2    super::{error::TieredStorageError, TieredStorageResult},
3    bytemuck::{AnyBitPattern, NoUninit, Zeroable},
4    std::{
5        fs::{File, OpenOptions},
6        io::{BufWriter, Read, Result as IoResult, Seek, SeekFrom, Write},
7        mem,
8        path::Path,
9        ptr,
10    },
11};
12
13/// The ending 8 bytes of a valid tiered account storage file.
14pub const FILE_MAGIC_NUMBER: u64 = u64::from_le_bytes(*b"AnzaTech");
15
16#[derive(Debug, PartialEq, Eq, Clone, Copy, bytemuck_derive::Pod, bytemuck_derive::Zeroable)]
17#[repr(C)]
18pub struct TieredStorageMagicNumber(pub u64);
19
20// Ensure there are no implicit padding bytes
21const _: () = assert!(std::mem::size_of::<TieredStorageMagicNumber>() == 8);
22
23impl Default for TieredStorageMagicNumber {
24    fn default() -> Self {
25        Self(FILE_MAGIC_NUMBER)
26    }
27}
28
29#[derive(Debug)]
30pub struct TieredReadableFile(pub File);
31
32impl TieredReadableFile {
33    pub fn new(file_path: impl AsRef<Path>) -> TieredStorageResult<Self> {
34        let file = Self(
35            OpenOptions::new()
36                .read(true)
37                .create(false)
38                .open(&file_path)?,
39        );
40
41        file.check_magic_number()?;
42
43        Ok(file)
44    }
45
46    pub fn new_writable(file_path: impl AsRef<Path>) -> IoResult<Self> {
47        Ok(Self(
48            OpenOptions::new()
49                .create_new(true)
50                .write(true)
51                .open(file_path)?,
52        ))
53    }
54
55    fn check_magic_number(&self) -> TieredStorageResult<()> {
56        self.seek_from_end(-(std::mem::size_of::<TieredStorageMagicNumber>() as i64))?;
57        let mut magic_number = TieredStorageMagicNumber::zeroed();
58        self.read_pod(&mut magic_number)?;
59        if magic_number != TieredStorageMagicNumber::default() {
60            return Err(TieredStorageError::MagicNumberMismatch(
61                TieredStorageMagicNumber::default().0,
62                magic_number.0,
63            ));
64        }
65        Ok(())
66    }
67
68    /// Reads a value of type `T` from the file.
69    ///
70    /// Type T must be plain ol' data.
71    pub fn read_pod<T: NoUninit + AnyBitPattern>(&self, value: &mut T) -> IoResult<()> {
72        // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T.
73        unsafe { self.read_type(value) }
74    }
75
76    /// Reads a value of type `T` from the file.
77    ///
78    /// Prefer `read_pod()` when possible, because `read_type()` may cause
79    /// undefined behavior.
80    ///
81    /// # Safety
82    ///
83    /// Caller must ensure casting bytes to T is safe.
84    /// Refer to the Safety sections in std::slice::from_raw_parts()
85    /// and bytemuck's Pod and AnyBitPattern for more information.
86    pub unsafe fn read_type<T>(&self, value: &mut T) -> IoResult<()> {
87        let ptr = ptr::from_mut(value).cast();
88        // SAFETY: The caller ensures it is safe to cast bytes to T,
89        // we ensure the size is safe by querying T directly,
90        // and Rust ensures ptr is aligned.
91        let bytes = unsafe { std::slice::from_raw_parts_mut(ptr, mem::size_of::<T>()) };
92        self.read_bytes(bytes)
93    }
94
95    pub fn seek(&self, offset: u64) -> IoResult<u64> {
96        (&self.0).seek(SeekFrom::Start(offset))
97    }
98
99    pub fn seek_from_end(&self, offset: i64) -> IoResult<u64> {
100        (&self.0).seek(SeekFrom::End(offset))
101    }
102
103    pub fn read_bytes(&self, buffer: &mut [u8]) -> IoResult<()> {
104        (&self.0).read_exact(buffer)
105    }
106}
107
108#[derive(Debug)]
109pub struct TieredWritableFile(pub BufWriter<File>);
110
111impl TieredWritableFile {
112    pub fn new(file_path: impl AsRef<Path>) -> IoResult<Self> {
113        Ok(Self(BufWriter::new(
114            OpenOptions::new()
115                .create_new(true)
116                .write(true)
117                .open(file_path)?,
118        )))
119    }
120
121    /// Writes `value` to the file.
122    ///
123    /// `value` must be plain ol' data.
124    pub fn write_pod<T: NoUninit>(&mut self, value: &T) -> IoResult<usize> {
125        // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes.
126        unsafe { self.write_type(value) }
127    }
128
129    /// Writes `value` to the file.
130    ///
131    /// Prefer `write_pod` when possible, because `write_value` may cause
132    /// undefined behavior if `value` contains uninitialized bytes.
133    ///
134    /// # Safety
135    ///
136    /// Caller must ensure casting T to bytes is safe.
137    /// Refer to the Safety sections in std::slice::from_raw_parts()
138    /// and bytemuck's Pod and NoUninit for more information.
139    pub unsafe fn write_type<T>(&mut self, value: &T) -> IoResult<usize> {
140        let ptr = ptr::from_ref(value).cast();
141        let bytes = unsafe { std::slice::from_raw_parts(ptr, mem::size_of::<T>()) };
142        self.write_bytes(bytes)
143    }
144
145    pub fn seek(&mut self, offset: u64) -> IoResult<u64> {
146        self.0.seek(SeekFrom::Start(offset))
147    }
148
149    pub fn seek_from_end(&mut self, offset: i64) -> IoResult<u64> {
150        self.0.seek(SeekFrom::End(offset))
151    }
152
153    pub fn write_bytes(&mut self, bytes: &[u8]) -> IoResult<usize> {
154        self.0.write_all(bytes)?;
155
156        Ok(bytes.len())
157    }
158}
159
160impl Drop for TieredWritableFile {
161    fn drop(&mut self) {
162        // BufWriter flushes on Drop, but swallows any errors.
163        // Users should always flush explicitly, so errors can be handled.
164        // However, if flush wasn't called, do it here and panic on error.
165        // This is a programmer bug; it means we have forgotten to call flush somewhere.
166        let result = self.0.flush();
167        if let Err(err) = result {
168            panic!("failed to flush TieredWritableFile on drop: {err}");
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use {
176        crate::tiered_storage::{
177            error::TieredStorageError,
178            file::{TieredReadableFile, TieredWritableFile, FILE_MAGIC_NUMBER},
179        },
180        std::path::Path,
181        tempfile::TempDir,
182    };
183
184    fn generate_test_file_with_number(path: impl AsRef<Path>, number: u64) {
185        let mut file = TieredWritableFile::new(path).unwrap();
186        file.write_pod(&number).unwrap();
187    }
188
189    #[test]
190    fn test_new() {
191        let temp_dir = TempDir::new().unwrap();
192        let path = temp_dir.path().join("test_new");
193        generate_test_file_with_number(&path, FILE_MAGIC_NUMBER);
194        assert!(TieredReadableFile::new(&path).is_ok());
195    }
196
197    #[test]
198    fn test_magic_number_mismatch() {
199        let temp_dir = TempDir::new().unwrap();
200        let path = temp_dir.path().join("test_magic_number_mismatch");
201        generate_test_file_with_number(&path, !FILE_MAGIC_NUMBER);
202        assert!(matches!(
203            TieredReadableFile::new(&path),
204            Err(TieredStorageError::MagicNumberMismatch(_, _))
205        ));
206    }
207}