use crate::{
bits::roundup,
elf::{LoadingAction, ProgramMetadata},
machine::SupportMachine,
memory::{get_page_indices, Memory, FLAG_DIRTY},
Error, Register, RISCV_GENERAL_REGISTER_NUMBER, RISCV_PAGESIZE,
};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::collections::HashMap;
const PAGE_SIZE: u64 = RISCV_PAGESIZE as u64;
pub trait DataSource<I: Clone + PartialEq> {
fn load_data(&self, id: &I, offset: u64, length: u64) -> Option<(Bytes, u64)>;
}
#[derive(Clone, Debug)]
pub struct Snapshot2Context<I: Clone + PartialEq, D: DataSource<I>> {
pages: HashMap<u64, (I, u64, u8)>,
data_source: D,
}
impl<I: Clone + PartialEq, D: DataSource<I> + Default> Default for Snapshot2Context<I, D> {
fn default() -> Self {
Self::new(D::default())
}
}
impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
pub fn new(data_source: D) -> Self {
Self {
pages: HashMap::default(),
data_source,
}
}
pub fn resume<M: SupportMachine>(
&mut self,
machine: &mut M,
snapshot: &Snapshot2<I>,
) -> Result<(), Error> {
if machine.version() != snapshot.version {
return Err(Error::InvalidVersion);
}
self.pages.clear();
for (i, v) in snapshot.registers.iter().enumerate() {
machine.set_register(i, M::REG::from_u64(*v));
}
machine.update_pc(M::REG::from_u64(snapshot.pc));
machine.commit_pc();
machine.set_cycles(snapshot.cycles);
machine.set_max_cycles(snapshot.max_cycles);
for (address, flag, id, offset, length) in &snapshot.pages_from_source {
if address % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
let (data, _) = self.load_data(id, *offset, *length)?;
if data.len() as u64 % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
machine.memory_mut().store_bytes(*address, &data)?;
for i in 0..(data.len() as u64 / PAGE_SIZE) {
let page = address / PAGE_SIZE + i;
machine.memory_mut().set_flag(page, *flag)?;
}
self.track_pages(machine, *address, data.len() as u64, id, *offset)?;
}
for (address, flag, content) in &snapshot.dirty_pages {
if address % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
if content.len() as u64 % PAGE_SIZE != 0 {
return Err(Error::MemPageUnalignedAccess);
}
machine.memory_mut().store_bytes(*address, content)?;
for i in 0..(content.len() as u64 / PAGE_SIZE) {
let page = address / PAGE_SIZE + i;
machine.memory_mut().set_flag(page, *flag)?;
}
}
machine
.memory_mut()
.set_lr(&M::REG::from_u64(snapshot.load_reservation_address));
Ok(())
}
pub fn load_data(&mut self, id: &I, offset: u64, length: u64) -> Result<(Bytes, u64), Error> {
self.data_source
.load_data(id, offset, length)
.ok_or(Error::SnapshotDataLoadError)
}
pub fn store_bytes<M: SupportMachine>(
&mut self,
machine: &mut M,
addr: u64,
id: &I,
offset: u64,
length: u64,
size_addr: u64,
) -> Result<(u64, u64), Error> {
let (data, full_length) = self.load_data(id, offset, length)?;
machine
.memory_mut()
.store64(&M::REG::from_u64(size_addr), &M::REG::from_u64(full_length))?;
self.untrack_pages(machine, addr, data.len() as u64)?;
machine.memory_mut().store_bytes(addr, &data)?;
self.track_pages(machine, addr, data.len() as u64, id, offset)?;
Ok((data.len() as u64, full_length))
}
pub fn mark_program<M: SupportMachine>(
&mut self,
machine: &mut M,
metadata: &ProgramMetadata,
id: &I,
offset: u64,
) -> Result<(), Error> {
for action in &metadata.actions {
self.init_pages(machine, action, id, offset)?;
}
Ok(())
}
pub fn make_snapshot<M: SupportMachine>(&self, machine: &mut M) -> Result<Snapshot2<I>, Error> {
let mut dirty_pages: Vec<(u64, u8, Vec<u8>)> = vec![];
for i in 0..machine.memory().memory_pages() as u64 {
let flag = machine.memory_mut().fetch_flag(i)?;
if flag & FLAG_DIRTY == 0 {
continue;
}
let address = i * PAGE_SIZE;
let mut data: Vec<u8> = machine.memory_mut().load_bytes(address, PAGE_SIZE)?.into();
if let Some(last) = dirty_pages.last_mut() {
if last.0 + last.2.len() as u64 == address && last.1 == flag {
last.2.append(&mut data);
}
}
if !data.is_empty() {
dirty_pages.push((address, flag, data));
}
}
let mut pages_from_source: Vec<(u64, u8, I, u64, u64)> = vec![];
let mut pages: Vec<u64> = self.pages.keys().copied().collect();
pages.sort_unstable();
for page in pages {
if machine.memory_mut().fetch_flag(page)? & FLAG_DIRTY != 0 {
continue;
}
let address = page * PAGE_SIZE;
let (id, offset, flag) = &self.pages[&page];
let mut appended_to_last = false;
if let Some((last_address, last_flag, last_id, last_offset, last_length)) =
pages_from_source.last_mut()
{
if *last_address + *last_length == address
&& *last_flag == *flag
&& *last_id == *id
&& *last_offset + *last_length == *offset
{
*last_length += PAGE_SIZE;
appended_to_last = true;
}
}
if !appended_to_last {
pages_from_source.push((address, *flag, id.clone(), *offset, PAGE_SIZE));
}
}
let mut registers = [0u64; RISCV_GENERAL_REGISTER_NUMBER];
for (i, v) in machine.registers().iter().enumerate() {
registers[i] = v.to_u64();
}
Ok(Snapshot2 {
pages_from_source,
dirty_pages,
version: machine.version(),
registers,
pc: machine.pc().to_u64(),
cycles: machine.cycles(),
max_cycles: machine.max_cycles(),
load_reservation_address: machine.memory().lr().to_u64(),
})
}
fn init_pages<M: SupportMachine>(
&mut self,
machine: &mut M,
action: &LoadingAction,
id: &I,
offset: u64,
) -> Result<(), Error> {
let start = action.addr + action.offset_from_addr;
let length = min(
action.source.end - action.source.start,
action.size - action.offset_from_addr,
);
self.track_pages(machine, start, length, id, offset + action.source.start)
}
pub fn track_pages<M: SupportMachine>(
&mut self,
machine: &mut M,
start: u64,
mut length: u64,
id: &I,
mut offset: u64,
) -> Result<(), Error> {
let mut aligned_start = roundup(start, PAGE_SIZE);
let aligned_bytes = aligned_start - start;
if length < aligned_bytes {
return Ok(());
}
offset += aligned_bytes;
length -= aligned_bytes;
while length >= PAGE_SIZE {
let page = aligned_start / PAGE_SIZE;
machine.memory_mut().clear_flag(page, FLAG_DIRTY)?;
let flag = machine.memory_mut().fetch_flag(page)?;
self.pages.insert(page, (id.clone(), offset, flag));
aligned_start += PAGE_SIZE;
length -= PAGE_SIZE;
offset += PAGE_SIZE;
}
Ok(())
}
pub fn untrack_pages<M: SupportMachine>(
&mut self,
machine: &mut M,
start: u64,
length: u64,
) -> Result<(), Error> {
if length == 0 {
return Ok(());
}
let page_indices = get_page_indices(start, length)?;
for page in page_indices.0..=page_indices.1 {
machine.memory_mut().set_flag(page, FLAG_DIRTY)?;
self.pages.remove(&page);
}
Ok(())
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Snapshot2<I: Clone + PartialEq> {
pub pages_from_source: Vec<(u64, u8, I, u64, u64)>,
pub dirty_pages: Vec<(u64, u8, Vec<u8>)>,
pub version: u32,
pub registers: [u64; RISCV_GENERAL_REGISTER_NUMBER],
pub pc: u64,
pub cycles: u64,
pub max_cycles: u64,
pub load_reservation_address: u64,
}