1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
use crate::{
    bits::roundup,
    elf::{LoadingAction, ProgramMetadata},
    machine::SupportMachine,
    memory::{get_page_indices, Memory, FLAG_DIRTY},
    Error, Register, RISCV_GENERAL_REGISTER_NUMBER, RISCV_PAGESIZE,
};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::collections::HashMap;

const PAGE_SIZE: u64 = RISCV_PAGESIZE as u64;

/// DataSource represents data source that can stay stable and possibly
/// immutable for the entire lifecycle duration of a VM instance. One example
/// can be the enclosing transaction when using CKB-VM in CKB's environment,
/// no matter where and when we run the CKB smart contract, the enclosing
/// transaction will always be the same down to every last byte. As a result,
/// we can leverage DataSource for snapshot optimizations: data that is already
/// locatable in the DataSource will not need to be included in the snapshot
/// again, all we need is an id to locate it, together with a pair of
/// offset / length to cut in to the correct slices. Just like CKB's syscall design,
/// an extra u64 value is included here to return the remaining full length of data
/// starting from offset, without considering `length` parameter
pub trait DataSource<I: Clone + PartialEq> {
    fn load_data(&self, id: &I, offset: u64, length: u64) -> Option<(Bytes, u64)>;
}

#[derive(Clone, Debug)]
pub struct Snapshot2Context<I: Clone + PartialEq, D: DataSource<I>> {
    // page index -> (id, offset, flag)
    pages: HashMap<u64, (I, u64, u8)>,
    data_source: D,
}

impl<I: Clone + PartialEq, D: DataSource<I> + Default> Default for Snapshot2Context<I, D> {
    fn default() -> Self {
        Self::new(D::default())
    }
}

impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
    pub fn new(data_source: D) -> Self {
        Self {
            pages: HashMap::default(),
            data_source,
        }
    }

    /// Resume a previously suspended machine from snapshot.
    pub fn resume<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        snapshot: &Snapshot2<I>,
    ) -> Result<(), Error> {
        if machine.version() != snapshot.version {
            return Err(Error::InvalidVersion);
        }
        // A resume basically means we reside in a new context
        self.pages.clear();
        for (i, v) in snapshot.registers.iter().enumerate() {
            machine.set_register(i, M::REG::from_u64(*v));
        }
        machine.update_pc(M::REG::from_u64(snapshot.pc));
        machine.commit_pc();
        machine.set_cycles(snapshot.cycles);
        machine.set_max_cycles(snapshot.max_cycles);
        for (address, flag, id, offset, length) in &snapshot.pages_from_source {
            if address % PAGE_SIZE != 0 {
                return Err(Error::MemPageUnalignedAccess);
            }
            let (data, _) = self.load_data(id, *offset, *length)?;
            if data.len() as u64 % PAGE_SIZE != 0 {
                return Err(Error::MemPageUnalignedAccess);
            }
            machine.memory_mut().store_bytes(*address, &data)?;
            for i in 0..(data.len() as u64 / PAGE_SIZE) {
                let page = address / PAGE_SIZE + i;
                machine.memory_mut().set_flag(page, *flag)?;
            }
            self.track_pages(machine, *address, data.len() as u64, id, *offset)?;
        }
        for (address, flag, content) in &snapshot.dirty_pages {
            if address % PAGE_SIZE != 0 {
                return Err(Error::MemPageUnalignedAccess);
            }
            if content.len() as u64 % PAGE_SIZE != 0 {
                return Err(Error::MemPageUnalignedAccess);
            }
            machine.memory_mut().store_bytes(*address, content)?;
            for i in 0..(content.len() as u64 / PAGE_SIZE) {
                let page = address / PAGE_SIZE + i;
                machine.memory_mut().set_flag(page, *flag)?;
            }
        }
        machine
            .memory_mut()
            .set_lr(&M::REG::from_u64(snapshot.load_reservation_address));
        Ok(())
    }

    pub fn load_data(&mut self, id: &I, offset: u64, length: u64) -> Result<(Bytes, u64), Error> {
        self.data_source
            .load_data(id, offset, length)
            .ok_or(Error::SnapshotDataLoadError)
    }

    /// Similar to Memory::store_bytes, but this method also tracks memory
    /// pages whose entire content comes from DataSource. It returns 2 values:
    /// the actual written bytes, and the full length of data starting from offset,
    /// but ignoring `length` parameter.
    pub fn store_bytes<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        addr: u64,
        id: &I,
        offset: u64,
        length: u64,
        size_addr: u64,
    ) -> Result<(u64, u64), Error> {
        let (data, full_length) = self.load_data(id, offset, length)?;
        machine
            .memory_mut()
            .store64(&M::REG::from_u64(size_addr), &M::REG::from_u64(full_length))?;
        self.untrack_pages(machine, addr, data.len() as u64)?;
        machine.memory_mut().store_bytes(addr, &data)?;
        self.track_pages(machine, addr, data.len() as u64, id, offset)?;
        Ok((data.len() as u64, full_length))
    }

    /// Due to the design of ckb-vm right now, load_program function does not
    /// belong to SupportMachine yet. For Snapshot2Context to track memory pages
    /// from program in DataSource, we have to use the following steps now:
    ///
    /// 1. use elf::parse_elf to generate ProgramMetadata
    /// 2. use DefaultMachine::load_program_with_metadata to load the program
    /// 3. Pass ProgramMetadata to this method so we can track memory pages from
    /// program, so as to further reduce the size of the generated snapshot.
    ///
    /// One can also use the original DefaultMachine::load_program, and parse the
    /// ELF a second time to extract metadata for this method. However the above
    /// listed process saves us the time to parse the ELF again.
    pub fn mark_program<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        metadata: &ProgramMetadata,
        id: &I,
        offset: u64,
    ) -> Result<(), Error> {
        for action in &metadata.actions {
            self.init_pages(machine, action, id, offset)?;
        }
        Ok(())
    }

    /// Create a snapshot for the passed machine.
    pub fn make_snapshot<M: SupportMachine>(&self, machine: &mut M) -> Result<Snapshot2<I>, Error> {
        let mut dirty_pages: Vec<(u64, u8, Vec<u8>)> = vec![];
        for i in 0..machine.memory().memory_pages() as u64 {
            let flag = machine.memory_mut().fetch_flag(i)?;
            if flag & FLAG_DIRTY == 0 {
                continue;
            }
            let address = i * PAGE_SIZE;
            let mut data: Vec<u8> = machine.memory_mut().load_bytes(address, PAGE_SIZE)?.into();
            if let Some(last) = dirty_pages.last_mut() {
                if last.0 + last.2.len() as u64 == address && last.1 == flag {
                    last.2.append(&mut data);
                }
            }
            if !data.is_empty() {
                dirty_pages.push((address, flag, data));
            }
        }
        let mut pages_from_source: Vec<(u64, u8, I, u64, u64)> = vec![];
        let mut pages: Vec<u64> = self.pages.keys().copied().collect();
        pages.sort_unstable();
        for page in pages {
            // Some pages might be marked as cached pages from data source, but receives
            // memory writes later(and marked as dirty). We are safely skipping those pages
            // here as they will be gathered as dirty pages.
            if machine.memory_mut().fetch_flag(page)? & FLAG_DIRTY != 0 {
                continue;
            }
            let address = page * PAGE_SIZE;
            let (id, offset, flag) = &self.pages[&page];
            let mut appended_to_last = false;
            if let Some((last_address, last_flag, last_id, last_offset, last_length)) =
                pages_from_source.last_mut()
            {
                if *last_address + *last_length == address
                    && *last_flag == *flag
                    && *last_id == *id
                    && *last_offset + *last_length == *offset
                {
                    *last_length += PAGE_SIZE;
                    appended_to_last = true;
                }
            }
            if !appended_to_last {
                pages_from_source.push((address, *flag, id.clone(), *offset, PAGE_SIZE));
            }
        }
        let mut registers = [0u64; RISCV_GENERAL_REGISTER_NUMBER];
        for (i, v) in machine.registers().iter().enumerate() {
            registers[i] = v.to_u64();
        }
        Ok(Snapshot2 {
            pages_from_source,
            dirty_pages,
            version: machine.version(),
            registers,
            pc: machine.pc().to_u64(),
            cycles: machine.cycles(),
            max_cycles: machine.max_cycles(),
            load_reservation_address: machine.memory().lr().to_u64(),
        })
    }

    fn init_pages<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        action: &LoadingAction,
        id: &I,
        offset: u64,
    ) -> Result<(), Error> {
        let start = action.addr + action.offset_from_addr;
        let length = min(
            action.source.end - action.source.start,
            action.size - action.offset_from_addr,
        );
        self.track_pages(machine, start, length, id, offset + action.source.start)
    }

    /// The followings are only made public for advanced usages, but make sure to exercise more
    /// cautions when calling it!
    pub fn track_pages<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        start: u64,
        mut length: u64,
        id: &I,
        mut offset: u64,
    ) -> Result<(), Error> {
        let mut aligned_start = roundup(start, PAGE_SIZE);
        let aligned_bytes = aligned_start - start;
        if length < aligned_bytes {
            return Ok(());
        }
        offset += aligned_bytes;
        length -= aligned_bytes;
        while length >= PAGE_SIZE {
            let page = aligned_start / PAGE_SIZE;
            machine.memory_mut().clear_flag(page, FLAG_DIRTY)?;
            let flag = machine.memory_mut().fetch_flag(page)?;
            self.pages.insert(page, (id.clone(), offset, flag));
            aligned_start += PAGE_SIZE;
            length -= PAGE_SIZE;
            offset += PAGE_SIZE;
        }
        Ok(())
    }

    pub fn untrack_pages<M: SupportMachine>(
        &mut self,
        machine: &mut M,
        start: u64,
        length: u64,
    ) -> Result<(), Error> {
        if length == 0 {
            return Ok(());
        }
        let page_indices = get_page_indices(start, length)?;
        for page in page_indices.0..=page_indices.1 {
            machine.memory_mut().set_flag(page, FLAG_DIRTY)?;
            self.pages.remove(&page);
        }
        Ok(())
    }
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Snapshot2<I: Clone + PartialEq> {
    // (address, flag, id, source offset, source length)
    pub pages_from_source: Vec<(u64, u8, I, u64, u64)>,
    // (address, flag, content)
    pub dirty_pages: Vec<(u64, u8, Vec<u8>)>,
    pub version: u32,
    pub registers: [u64; RISCV_GENERAL_REGISTER_NUMBER],
    pub pc: u64,
    pub cycles: u64,
    pub max_cycles: u64,
    pub load_reservation_address: u64,
}