1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
//! All binary files generated by measureme have a simple file header that
//! consists of a 4 byte file magic string and a 4 byte little-endian version
//! number.
use std::convert::TryInto;
use std::error::Error;
use std::path::Path;

pub const CURRENT_FILE_FORMAT_VERSION: u32 = 8;

pub const FILE_MAGIC_TOP_LEVEL: &[u8; 4] = b"MMPD";
pub const FILE_MAGIC_EVENT_STREAM: &[u8; 4] = b"MMES";
pub const FILE_MAGIC_STRINGTABLE_DATA: &[u8; 4] = b"MMSD";
pub const FILE_MAGIC_STRINGTABLE_INDEX: &[u8; 4] = b"MMSI";

pub const FILE_EXTENSION: &str = "mm_profdata";

/// The size of the file header in bytes. Note that functions in this module
/// rely on this size to be `8`.
pub const FILE_HEADER_SIZE: usize = 8;

pub fn write_file_header(
    s: &mut dyn std::io::Write,
    file_magic: &[u8; 4],
) -> Result<(), Box<dyn Error + Send + Sync>> {
    // The implementation here relies on FILE_HEADER_SIZE to have the value 8.
    // Let's make sure this assumption cannot be violated without being noticed.
    assert_eq!(FILE_HEADER_SIZE, 8);

    s.write_all(file_magic).map_err(Box::new)?;
    s.write_all(&CURRENT_FILE_FORMAT_VERSION.to_le_bytes())
        .map_err(Box::new)?;

    Ok(())
}

#[must_use]
pub fn verify_file_header(
    bytes: &[u8],
    expected_magic: &[u8; 4],
    diagnostic_file_path: Option<&Path>,
    stream_tag: &str,
) -> Result<(), Box<dyn Error + Send + Sync>> {
    // The implementation here relies on FILE_HEADER_SIZE to have the value 8.
    // Let's make sure this assumption cannot be violated without being noticed.
    assert_eq!(FILE_HEADER_SIZE, 8);

    let diagnostic_file_path = diagnostic_file_path.unwrap_or(Path::new("<in-memory>"));

    if bytes.len() < FILE_HEADER_SIZE {
        let msg = format!(
            "Error reading {} stream in file `{}`: Expected file to contain at least `{:?}` bytes but found `{:?}` bytes",
            stream_tag,
            diagnostic_file_path.display(),
            FILE_HEADER_SIZE,
            bytes.len()
        );

        return Err(From::from(msg));
    }

    let actual_magic = &bytes[0..4];

    if actual_magic != expected_magic {
        let msg = format!(
            "Error reading {} stream in file `{}`: Expected file magic `{:?}` but found `{:?}`",
            stream_tag,
            diagnostic_file_path.display(),
            expected_magic,
            actual_magic
        );

        return Err(From::from(msg));
    }

    let file_format_version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());

    if file_format_version != CURRENT_FILE_FORMAT_VERSION {
        let msg = format!(
            "Error reading {} stream in file `{}`: Expected file format version {} but found `{}`",
            stream_tag,
            diagnostic_file_path.display(),
            CURRENT_FILE_FORMAT_VERSION,
            file_format_version
        );

        return Err(From::from(msg));
    }

    Ok(())
}

pub fn strip_file_header(data: &[u8]) -> &[u8] {
    &data[FILE_HEADER_SIZE..]
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{PageTag, SerializationSinkBuilder};

    #[test]
    fn roundtrip() {
        let data_sink = SerializationSinkBuilder::new_in_memory().new_sink(PageTag::Events);

        write_file_header(&mut data_sink.as_std_write(), FILE_MAGIC_EVENT_STREAM).unwrap();

        let data = data_sink.into_bytes();

        verify_file_header(&data, FILE_MAGIC_EVENT_STREAM, None, "test").unwrap();
    }

    #[test]
    fn invalid_magic() {
        let data_sink = SerializationSinkBuilder::new_in_memory().new_sink(PageTag::Events);
        write_file_header(&mut data_sink.as_std_write(), FILE_MAGIC_STRINGTABLE_DATA).unwrap();
        let mut data = data_sink.into_bytes();

        // Invalidate the filemagic
        data[2] = 0;
        assert!(verify_file_header(&data, FILE_MAGIC_STRINGTABLE_DATA, None, "test").is_err());
    }

    #[test]
    fn other_version() {
        let data_sink = SerializationSinkBuilder::new_in_memory().new_sink(PageTag::Events);

        write_file_header(&mut data_sink.as_std_write(), FILE_MAGIC_STRINGTABLE_INDEX).unwrap();

        let mut data = data_sink.into_bytes();

        // Change version
        data[4] = 0xFF;
        data[5] = 0xFF;
        data[6] = 0xFF;
        data[7] = 0xFF;
        assert!(verify_file_header(&data, FILE_MAGIC_STRINGTABLE_INDEX, None, "test").is_err());
    }

    #[test]
    fn empty_file() {
        let data: [u8; 0] = [];

        assert!(verify_file_header(&data, FILE_MAGIC_STRINGTABLE_DATA, None, "test").is_err());
    }
}