#![forbid(unsafe_code)]
#![allow(clippy::too_many_arguments)]
#![warn(clippy::cast_possible_truncation)]
mod partial_solution;
pub use partial_solution::*;
mod solution;
pub use solution::*;
mod solution_id;
pub use solution_id::*;
mod solutions;
pub use solutions::*;
use console::{
account::Address,
algorithms::Sha3_256,
collections::kary_merkle_tree::KaryMerkleTree,
prelude::{
FromBits,
Network,
Result,
anyhow,
bail,
cfg_into_iter,
cfg_iter,
cfg_keys,
cfg_values,
ensure,
has_duplicates,
},
types::U64,
};
use aleo_std::prelude::*;
use core::num::NonZeroUsize;
use indexmap::IndexMap;
use lru::LruCache;
use parking_lot::RwLock;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use std::sync::Arc;
#[cfg(not(feature = "serial"))]
use rayon::prelude::*;
const ARITY: u8 = 8;
const CACHE_SIZE: usize = 1 << 10;
type MerkleTree = KaryMerkleTree<Sha3_256, Sha3_256, 9, { ARITY }>;
pub trait PuzzleTrait<N: Network>: Send + Sync {
fn new() -> Self
where
Self: Sized;
fn to_leaves(&self, epoch_hash: N::BlockHash, rng: &mut ChaChaRng) -> Result<Vec<Vec<bool>>>;
fn to_all_leaves(&self, epoch_hash: N::BlockHash, rngs: Vec<ChaChaRng>) -> Result<Vec<Vec<Vec<bool>>>>;
}
#[derive(Clone)]
pub struct Puzzle<N: Network> {
inner: Arc<dyn PuzzleTrait<N>>,
proof_target_cache: Arc<RwLock<LruCache<SolutionID<N>, u64>>>,
}
impl<N: Network> Puzzle<N> {
pub fn new<P: PuzzleTrait<N> + 'static>() -> Self {
Self {
inner: Arc::new(P::new()),
proof_target_cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(CACHE_SIZE).unwrap()))),
}
}
pub fn get_leaves(&self, solution: &PartialSolution<N>) -> Result<Vec<Vec<bool>>> {
let mut rng = ChaChaRng::seed_from_u64(*solution.id());
self.inner.to_leaves(solution.epoch_hash(), &mut rng)
}
pub fn get_all_leaves(&self, solutions: &PuzzleSolutions<N>) -> Result<Vec<Vec<Vec<bool>>>> {
ensure!(
cfg_values!(solutions).all(|solution| solution.epoch_hash() == solutions[0].epoch_hash()),
"The solutions are for different epochs"
);
let rngs = cfg_keys!(solutions).map(|solution_id| ChaChaRng::seed_from_u64(**solution_id)).collect::<Vec<_>>();
self.inner.to_all_leaves(solutions[0].epoch_hash(), rngs)
}
pub fn get_proof_target(&self, solution: &Solution<N>) -> Result<u64> {
let proof_target = self.get_proof_target_unchecked(solution)?;
ensure!(solution.target() == proof_target, "The proof target does not match the expected proof target");
Ok(proof_target)
}
pub fn get_proof_target_unchecked(&self, solution: &Solution<N>) -> Result<u64> {
self.get_proof_target_from_partial_solution(solution.partial_solution())
}
pub fn get_proof_target_from_partial_solution(&self, partial_solution: &PartialSolution<N>) -> Result<u64> {
if let Some(proof_target) = self.proof_target_cache.write().get(&partial_solution.id()) {
return Ok(*proof_target);
}
let leaves = self.get_leaves(partial_solution)?;
let proof_target = Self::leaves_to_proof_target(&leaves)?;
self.proof_target_cache.write().put(partial_solution.id(), proof_target);
Ok(proof_target)
}
pub fn get_proof_targets(&self, solutions: &PuzzleSolutions<N>) -> Result<Vec<u64>> {
let mut targets = vec![0u64; solutions.len()];
let mut to_compute = Vec::new();
for (i, (id, solution)) in solutions.iter().enumerate() {
match self.proof_target_cache.write().get(id) {
Some(proof_target) => {
ensure!(
solution.target() == *proof_target,
"The proof target does not match the cached proof target"
);
targets[i] = *proof_target
}
None => to_compute.push((i, id, *solution)),
}
}
if !to_compute.is_empty() {
let solutions_subset = PuzzleSolutions::new(to_compute.iter().map(|(_, _, solution)| *solution).collect())?;
let leaves = self.get_all_leaves(&solutions_subset)?;
let targets_subset = cfg_iter!(leaves)
.zip(cfg_iter!(solutions_subset))
.map(|(leaves, (solution_id, solution))| {
let proof_target = Self::leaves_to_proof_target(leaves)?;
ensure!(
solution.target() == proof_target,
"The proof target does not match the computed proof target"
);
self.proof_target_cache.write().put(*solution_id, proof_target);
Ok((solution_id, proof_target))
})
.collect::<Result<IndexMap<_, _>>>()?;
for (i, id, _) in &to_compute {
targets[*i] = targets_subset[id];
}
}
Ok(targets)
}
pub fn get_combined_proof_target(&self, solutions: &PuzzleSolutions<N>) -> Result<u128> {
self.get_proof_targets(solutions)?.into_iter().try_fold(0u128, |combined, proof_target| {
combined.checked_add(proof_target as u128).ok_or_else(|| anyhow!("Combined proof target overflowed"))
})
}
pub fn prove(
&self,
epoch_hash: N::BlockHash,
address: Address<N>,
counter: u64,
minimum_proof_target: Option<u64>,
) -> Result<Solution<N>> {
let partial_solution = PartialSolution::new(epoch_hash, address, counter)?;
let proof_target = self.get_proof_target_from_partial_solution(&partial_solution)?;
if let Some(minimum_proof_target) = minimum_proof_target {
if proof_target < minimum_proof_target {
bail!("Solution was below the minimum proof target ({proof_target} < {minimum_proof_target})")
}
}
Ok(Solution::new(partial_solution, proof_target))
}
pub fn check_solution(
&self,
solution: &Solution<N>,
expected_epoch_hash: N::BlockHash,
expected_proof_target: u64,
) -> Result<()> {
if solution.epoch_hash() != expected_epoch_hash {
bail!(
"Solution does not match the expected epoch hash (found '{}', expected '{expected_epoch_hash}')",
solution.epoch_hash()
)
}
let proof_target = self.get_proof_target(solution)?;
if proof_target < expected_proof_target {
bail!("Solution does not meet the proof target requirement ({proof_target} < {expected_proof_target})")
}
Ok(())
}
pub fn check_solution_mut(
&self,
solution: &mut Solution<N>,
expected_epoch_hash: N::BlockHash,
expected_proof_target: u64,
) -> Result<()> {
if solution.epoch_hash() != expected_epoch_hash {
bail!(
"Solution does not match the expected epoch hash (found '{}', expected '{expected_epoch_hash}')",
solution.epoch_hash()
)
}
let proof_target = self.get_proof_target_unchecked(solution)?;
solution.target = proof_target;
if proof_target < expected_proof_target {
bail!("Solution does not meet the proof target requirement ({proof_target} < {expected_proof_target})")
}
Ok(())
}
pub fn check_solutions(
&self,
solutions: &PuzzleSolutions<N>,
expected_epoch_hash: N::BlockHash,
expected_proof_target: u64,
) -> Result<()> {
let timer = timer!("Puzzle::verify");
ensure!(!solutions.is_empty(), "The solutions are empty");
if solutions.len() > N::MAX_SOLUTIONS {
bail!("Exceed the maximum number of solutions ({} > {})", solutions.len(), N::MAX_SOLUTIONS)
}
if has_duplicates(solutions.solution_ids()) {
bail!("The solutions contain duplicate solution IDs");
}
lap!(timer, "Perform initial checks");
cfg_iter!(solutions).try_for_each(|(solution_id, solution)| {
if solution.epoch_hash() != expected_epoch_hash {
bail!("Solution '{solution_id}' did not match the expected epoch hash (found '{}', expected '{expected_epoch_hash}')", solution.epoch_hash())
}
Ok(())
})?;
lap!(timer, "Verify each epoch hash matches");
cfg_into_iter!(self.get_proof_targets(solutions)?).enumerate().try_for_each(|(i, proof_target)| {
if proof_target < expected_proof_target {
bail!(
"Solution '{:?}' did not meet the proof target requirement ({proof_target} < {expected_proof_target})",
solutions.get_index(i).map(|(id, _)| id)
)
}
Ok(())
})?;
finish!(timer, "Verify each solution");
Ok(())
}
fn leaves_to_proof_target(leaves: &[Vec<bool>]) -> Result<u64> {
let merkle_tree = MerkleTree::new(&Sha3_256::default(), &Sha3_256::default(), leaves)?;
let root = merkle_tree.root();
match *U64::<N>::from_bits_be(&root[0..64])? {
0 => Ok(u64::MAX),
value => Ok(u64::MAX / value),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use console::{
account::{Address, PrivateKey},
network::Network,
prelude::{FromBytes, TestRng, ToBits as TBits, ToBytes, Uniform},
types::Field,
};
use anyhow::Result;
use core::marker::PhantomData;
use rand::{CryptoRng, Rng, RngCore, SeedableRng};
use rand_chacha::ChaChaRng;
type CurrentNetwork = console::network::MainnetV0;
const ITERATIONS: u64 = 100;
pub struct SimplePuzzle<N: Network>(PhantomData<N>);
impl<N: Network> PuzzleTrait<N> for SimplePuzzle<N> {
fn new() -> Self {
Self(PhantomData)
}
fn to_leaves(&self, epoch_hash: N::BlockHash, rng: &mut ChaChaRng) -> Result<Vec<Vec<bool>>> {
let num_leaves = self.num_leaves(epoch_hash)?;
let leaves = (0..num_leaves).map(|_| Field::<N>::rand(rng).to_bits_le()).collect::<Vec<_>>();
Ok(leaves)
}
fn to_all_leaves(&self, epoch_hash: N::BlockHash, rngs: Vec<ChaChaRng>) -> Result<Vec<Vec<Vec<bool>>>> {
let num_leaves = self.num_leaves(epoch_hash)?;
let mut leaves = Vec::with_capacity(rngs.len());
for mut rng in rngs {
leaves.push((0..num_leaves).map(|_| Field::<N>::rand(&mut rng).to_bits_le()).collect::<Vec<_>>());
}
Ok(leaves)
}
}
impl<N: Network> SimplePuzzle<N> {
pub fn num_leaves(&self, epoch_hash: N::BlockHash) -> Result<usize> {
const MIN_NUMBER_OF_LEAVES: usize = 100;
const MAX_NUMBER_OF_LEAVES: usize = 200;
let seed = u64::from_bytes_le(&epoch_hash.to_bytes_le()?[0..8])?;
let mut epoch_rng = ChaChaRng::seed_from_u64(seed);
Ok(epoch_rng.gen_range(MIN_NUMBER_OF_LEAVES..MAX_NUMBER_OF_LEAVES))
}
}
fn sample_puzzle() -> Puzzle<CurrentNetwork> {
Puzzle::<CurrentNetwork>::new::<SimplePuzzle<CurrentNetwork>>()
}
#[test]
fn test_puzzle() {
let mut rng = TestRng::default();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
for batch_size in 1..=CurrentNetwork::MAX_SOLUTIONS {
let solutions = (0..batch_size)
.map(|_| puzzle.prove(epoch_hash, rng.gen(), rng.gen(), None).unwrap())
.collect::<Vec<_>>();
let solutions = PuzzleSolutions::new(solutions).unwrap();
assert!(puzzle.check_solutions(&solutions, epoch_hash, 0u64).is_ok());
let bad_epoch_hash = rng.gen();
assert!(puzzle.check_solutions(&solutions, bad_epoch_hash, 0u64).is_err());
}
}
#[test]
fn test_prove_with_minimum_proof_target() {
let mut rng = TestRng::default();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
for _ in 0..ITERATIONS {
let private_key = PrivateKey::<CurrentNetwork>::new(&mut rng).unwrap();
let address = Address::try_from(private_key).unwrap();
let counter = u64::rand(&mut rng);
let solution = puzzle.prove(epoch_hash, address, counter, None).unwrap();
let proof_target = puzzle.get_proof_target(&solution).unwrap();
assert!(puzzle.prove(epoch_hash, address, counter, Some(proof_target)).is_ok());
assert!(puzzle.prove(epoch_hash, address, counter, Some(proof_target.saturating_add(1))).is_err());
let solutions = PuzzleSolutions::new(vec![solution]).unwrap();
assert!(puzzle.check_solutions(&solutions, epoch_hash, proof_target).is_ok());
assert!(puzzle.check_solutions(&solutions, epoch_hash, proof_target.saturating_add(1)).is_err());
}
}
#[test]
fn test_prove_with_no_minimum_proof_target() {
let mut rng = rand::thread_rng();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
let private_key = PrivateKey::<CurrentNetwork>::new(&mut rng).unwrap();
let address = Address::try_from(private_key).unwrap();
let solution = puzzle.prove(epoch_hash, address, rng.gen(), None).unwrap();
assert!(puzzle.check_solution(&solution, epoch_hash, 0u64).is_ok());
let solutions = PuzzleSolutions::new(vec![solution]).unwrap();
assert!(puzzle.check_solutions(&solutions, epoch_hash, 0u64).is_ok());
}
#[test]
fn test_check_solution_with_incorrect_target_fails() {
let mut rng = rand::thread_rng();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
let private_key = PrivateKey::<CurrentNetwork>::new(&mut rng).unwrap();
let address = Address::try_from(private_key).unwrap();
let solution = puzzle.prove(epoch_hash, address, rng.gen(), None).unwrap();
let incorrect_solution = Solution::new(*solution.partial_solution(), solution.target().saturating_add(1));
assert!(puzzle.check_solution(&incorrect_solution, epoch_hash, 0u64).is_err());
let new_puzzle = sample_puzzle();
assert!(new_puzzle.check_solution(&incorrect_solution, epoch_hash, 0u64).is_err());
let incorrect_solutions = PuzzleSolutions::new(vec![incorrect_solution]).unwrap();
assert!(puzzle.check_solutions(&incorrect_solutions, epoch_hash, 0u64).is_err());
let new_puzzle = sample_puzzle();
assert!(new_puzzle.check_solutions(&incorrect_solutions, epoch_hash, 0u64).is_err());
}
#[test]
fn test_check_solutions_with_incorrect_target_fails() {
let mut rng = TestRng::default();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
for batch_size in 1..=CurrentNetwork::MAX_SOLUTIONS {
let incorrect_solutions = (0..batch_size)
.map(|_| {
let solution = puzzle.prove(epoch_hash, rng.gen(), rng.gen(), None).unwrap();
Solution::new(*solution.partial_solution(), solution.target().saturating_add(1))
})
.collect::<Vec<_>>();
let incorrect_solutions = PuzzleSolutions::new(incorrect_solutions).unwrap();
assert!(puzzle.check_solutions(&incorrect_solutions, epoch_hash, 0u64).is_err());
let new_puzzle = sample_puzzle();
assert!(new_puzzle.check_solutions(&incorrect_solutions, epoch_hash, 0u64).is_err());
}
}
#[test]
fn test_check_solutions_with_duplicate_nonces() {
let mut rng = TestRng::default();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
let address = rng.gen();
let counter = rng.gen();
for batch_size in 1..=CurrentNetwork::MAX_SOLUTIONS {
let solutions =
(0..batch_size).map(|_| puzzle.prove(epoch_hash, address, counter, None).unwrap()).collect::<Vec<_>>();
let solutions = match batch_size {
1 => PuzzleSolutions::new(solutions).unwrap(),
_ => {
assert!(PuzzleSolutions::new(solutions).is_err());
continue;
}
};
match batch_size {
1 => assert!(puzzle.check_solutions(&solutions, epoch_hash, 0u64).is_ok()),
_ => unreachable!("There are duplicates that should not reach this point in the test"),
}
}
}
#[test]
fn test_get_proof_targets_without_cache() {
let mut rng = TestRng::default();
let epoch_hash = rng.gen();
for batch_size in 1..=CurrentNetwork::MAX_SOLUTIONS {
let puzzle = sample_puzzle();
let solutions = (0..batch_size)
.map(|_| puzzle.prove(epoch_hash, rng.gen(), rng.gen(), None).unwrap())
.collect::<Vec<_>>();
let solutions = PuzzleSolutions::new(solutions).unwrap();
let puzzle = sample_puzzle();
let proof_targets = puzzle.get_proof_targets(&solutions).unwrap();
for ((_, solution), proof_target) in solutions.iter().zip(proof_targets) {
assert_eq!(puzzle.get_proof_target(solution).unwrap(), proof_target);
}
}
}
#[test]
fn test_get_proof_targets_with_partial_cache() {
let mut rng = TestRng::default();
let epoch_hash = rng.gen();
for batch_size in 1..=CurrentNetwork::MAX_SOLUTIONS {
let puzzle = sample_puzzle();
let solutions = (0..batch_size)
.map(|_| puzzle.prove(epoch_hash, rng.gen(), rng.gen(), None).unwrap())
.collect::<Vec<_>>();
let solutions = PuzzleSolutions::new(solutions).unwrap();
let puzzle = sample_puzzle();
for solution in solutions.values() {
if rng.gen::<bool>() {
puzzle.get_proof_target(solution).unwrap();
}
}
let proof_targets = puzzle.get_proof_targets(&solutions).unwrap();
for ((_, solution), proof_target) in solutions.iter().zip(proof_targets) {
assert_eq!(puzzle.get_proof_target(solution).unwrap(), proof_target);
}
}
}
#[ignore]
#[test]
fn test_profiler() -> Result<()> {
fn sample_address_and_counter(rng: &mut (impl CryptoRng + RngCore)) -> (Address<CurrentNetwork>, u64) {
let private_key = PrivateKey::new(rng).unwrap();
let address = Address::try_from(private_key).unwrap();
let counter = rng.next_u64();
(address, counter)
}
let mut rng = rand::thread_rng();
let puzzle = sample_puzzle();
let epoch_hash = rng.gen();
for batch_size in [1, 2, <CurrentNetwork as Network>::MAX_SOLUTIONS] {
let solutions = (0..batch_size)
.map(|_| {
let (address, counter) = sample_address_and_counter(&mut rng);
puzzle.prove(epoch_hash, address, counter, None).unwrap()
})
.collect::<Vec<_>>();
let solutions = PuzzleSolutions::new(solutions).unwrap();
puzzle.check_solutions(&solutions, epoch_hash, 0u64).unwrap();
}
bail!("\n\nRemember to #[ignore] this test!\n\n")
}
}