use std::collections::HashMap;
use arbitrary::Arbitrary;
use itertools::Itertools;
use num_traits::ConstOne;
use triton_vm::memory_layout::MemoryRegion;
use triton_vm::prelude::*;
use crate::prelude::*;
const STATIC_MEMORY_FIRST_ADDRESS_AS_U64: u64 = BFieldElement::MAX - 1;
pub const STATIC_MEMORY_FIRST_ADDRESS: BFieldElement =
BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64);
pub const STATIC_MEMORY_LAST_ADDRESS: BFieldElement =
BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64 - u32::MAX as u64);
#[derive(Clone, Debug)]
pub struct Library {
seen_snippets: HashMap<String, Vec<LabelledInstruction>>,
num_allocated_words: u32,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
pub struct StaticAllocation {
write_address: BFieldElement,
num_words: u32,
}
impl StaticAllocation {
pub fn read_address(&self) -> BFieldElement {
let offset = bfe!(self.num_words) - BFieldElement::ONE;
self.write_address() + offset
}
pub fn write_address(&self) -> BFieldElement {
self.write_address
}
pub fn num_words(&self) -> u32 {
self.num_words
}
}
impl Default for Library {
fn default() -> Self {
Self::new()
}
}
impl Library {
pub fn kmalloc_memory_region() -> MemoryRegion {
MemoryRegion::new(STATIC_MEMORY_LAST_ADDRESS, 1usize << 32)
}
pub fn new() -> Self {
Self {
seen_snippets: HashMap::default(),
num_allocated_words: 0,
}
}
pub fn empty() -> Self {
Self::new()
}
#[cfg(test)]
pub fn with_preallocated_memory(words_statically_allocated: u32) -> Self {
Library {
num_allocated_words: words_statically_allocated,
..Self::new()
}
}
pub fn import(&mut self, snippet: Box<dyn BasicSnippet>) -> String {
let dep_entrypoint = snippet.entrypoint();
let is_new_dependency = !self.seen_snippets.contains_key(&dep_entrypoint);
if is_new_dependency {
let dep_body = snippet.annotated_code(self);
self.seen_snippets.insert(dep_entrypoint.clone(), dep_body);
}
dep_entrypoint
}
pub fn explicit_import(&mut self, name: &str, body: &[LabelledInstruction]) -> String {
if !self.seen_snippets.contains_key(name) {
self.seen_snippets.insert(name.to_owned(), body.to_vec());
}
name.to_string()
}
pub fn all_external_dependencies(&self) -> Vec<Vec<LabelledInstruction>> {
self.seen_snippets
.iter()
.sorted_by_key(|(k, _)| *k)
.map(|(_, code)| code.clone())
.collect()
}
pub fn get_all_snippet_names(&self) -> Vec<String> {
let mut ret = self.seen_snippets.keys().cloned().collect_vec();
ret.sort_unstable();
ret
}
pub fn all_imports(&self) -> Vec<LabelledInstruction> {
self.all_external_dependencies().concat()
}
pub fn kmalloc(&mut self, num_words: u32) -> StaticAllocation {
assert!(num_words > 0, "must allocate a positive number of words");
let write_address =
STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.num_allocated_words) - bfe!(num_words - 1);
self.num_allocated_words = self
.num_allocated_words
.checked_add(num_words)
.expect("Cannot allocate more that u32::MAX words through `kmalloc`.");
StaticAllocation {
write_address,
num_words,
}
}
}
#[cfg(test)]
mod tests {
use triton_vm::prelude::triton_asm;
use triton_vm::prelude::Program;
use super::*;
use crate::mmr::calculate_new_peaks_from_leaf_mutation::MmrCalculateNewPeaksFromLeafMutationMtIndices;
use crate::test_prelude::*;
#[derive(Debug, Copy, Clone, BFieldCodec)]
struct ZeroSizedType;
#[derive(Debug)]
struct DummyTestSnippetA;
#[derive(Debug)]
struct DummyTestSnippetB;
#[derive(Debug)]
struct DummyTestSnippetC;
impl BasicSnippet for DummyTestSnippetA {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Xfe, "dummy".to_string())]
}
fn entrypoint(&self) -> String {
"tasmlib_a_dummy_test_value".to_string()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let b = library.import(Box::new(DummyTestSnippetB));
let c = library.import(Box::new(DummyTestSnippetC));
triton_asm!(
{self.entrypoint()}:
call {b}
call {c}
return
)
}
}
impl BasicSnippet for DummyTestSnippetB {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![]
}
fn outputs(&self) -> Vec<(DataType, String)> {
["1"; 2]
.map(|name| (DataType::Bfe, name.to_string()))
.to_vec()
}
fn entrypoint(&self) -> String {
"tasmlib_b_dummy_test_value".to_string()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let c = library.import(Box::new(DummyTestSnippetC));
triton_asm!(
{self.entrypoint()}:
call {c}
call {c}
return
)
}
}
impl BasicSnippet for DummyTestSnippetC {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Bfe, "1".to_string())]
}
fn entrypoint(&self) -> String {
"tasmlib_c_dummy_test_value".to_string()
}
fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
triton_asm!({self.entrypoint()}: push 1 return)
}
}
impl Closure for DummyTestSnippetA {
type Args = ZeroSizedType;
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
push_encodable(stack, &xfe![[1, 1, 1]]);
}
fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
ZeroSizedType
}
}
impl Closure for DummyTestSnippetB {
type Args = ZeroSizedType;
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
stack.push(bfe!(1));
stack.push(bfe!(1));
}
fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
ZeroSizedType
}
}
impl Closure for DummyTestSnippetC {
type Args = ZeroSizedType;
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
stack.push(bfe!(1));
}
fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
ZeroSizedType
}
}
#[test]
fn library_includes() {
ShadowedClosure::new(DummyTestSnippetA).test();
ShadowedClosure::new(DummyTestSnippetB).test();
ShadowedClosure::new(DummyTestSnippetC).test();
}
#[test]
fn get_all_snippet_names_test_a() {
let mut lib = Library::new();
lib.import(Box::new(DummyTestSnippetA));
assert_eq!(
vec![
"tasmlib_a_dummy_test_value",
"tasmlib_b_dummy_test_value",
"tasmlib_c_dummy_test_value",
],
lib.get_all_snippet_names()
);
}
#[test]
fn get_all_snippet_names_test_b() {
let mut lib = Library::new();
lib.import(Box::new(DummyTestSnippetB));
assert_eq!(
vec!["tasmlib_b_dummy_test_value", "tasmlib_c_dummy_test_value"],
lib.get_all_snippet_names()
);
}
#[test]
fn all_imports_as_instruction_lists() {
let mut lib = Library::new();
lib.import(Box::new(DummyTestSnippetA));
lib.import(Box::new(DummyTestSnippetA));
lib.import(Box::new(DummyTestSnippetC));
let _ret = lib.all_imports();
}
#[test]
fn program_is_deterministic() {
fn smaller_program() -> Program {
let mut library = Library::new();
let memcpy = library.import(Box::new(MemCpy));
let calculate_new_peaks_from_leaf_mutation =
library.import(Box::new(MmrCalculateNewPeaksFromLeafMutationMtIndices));
let code = triton_asm!(
lala_entrypoint:
push 1 call {memcpy}
call {calculate_new_peaks_from_leaf_mutation}
return
);
let mut src = code;
let mut imports = library.all_imports();
let all_ext_deps = library.all_external_dependencies();
let imports_repeated = all_ext_deps.concat();
assert_eq!(imports, imports_repeated);
src.append(&mut imports);
Program::new(&src)
}
for _ in 0..100 {
let program = smaller_program();
let same_program = smaller_program();
assert_eq!(program, same_program);
}
}
#[test]
fn kmalloc_test() {
const MINUS_TWO: BFieldElement = BFieldElement::new(BFieldElement::MAX - 1);
let mut lib = Library::new();
let first_chunk = lib.kmalloc(1);
assert_eq!(MINUS_TWO, first_chunk.write_address());
let second_chunk = lib.kmalloc(7);
assert_eq!(-bfe!(9), second_chunk.write_address());
let third_chunk = lib.kmalloc(1000);
assert_eq!(-bfe!(1009), third_chunk.write_address());
}
}