use crate::{
Opcode,
Operand,
Operation,
traits::{RegistersLoad, RegistersLoadCircuit, RegistersStore, RegistersStoreCircuit, StackMatches, StackProgram},
};
use console::{
network::prelude::*,
program::{Literal, LiteralType, PlaintextType, Register, RegisterType},
};
use core::marker::PhantomData;
pub type UnaryLiteral<N, O> = Literals<N, O, 1>;
pub type BinaryLiteral<N, O> = Literals<N, O, 2>;
pub type TernaryLiteral<N, O> = Literals<N, O, 3>;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Literals<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> {
operands: Vec<Operand<N>>,
destination: Register<N>,
_phantom: PhantomData<O>,
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize>
Literals<N, O, NUM_OPERANDS>
{
#[inline]
pub const fn opcode() -> Opcode {
O::OPCODE
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
vec![self.destination.clone()]
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize>
Literals<N, O, NUM_OPERANDS>
{
#[inline]
pub fn evaluate(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersLoad<N> + RegistersStore<N>),
) -> Result<()> {
if self.operands.len() != NUM_OPERANDS {
bail!("Instruction '{}' expects {NUM_OPERANDS} operands, found {} operands", O::OPCODE, self.operands.len())
}
let inputs: Vec<_> =
self.operands.iter().map(|operand| registers.load_literal(stack, operand)).try_collect()?;
let input_types: Vec<_> =
inputs.iter().map(|input| RegisterType::Plaintext(PlaintextType::from(input.to_type()))).collect();
let inputs: [Literal<N>; NUM_OPERANDS] =
inputs.try_into().map_err(|_| anyhow!("Failed to prepare operands in evaluate"))?;
let output = O::evaluate(&inputs)?;
let output_type = RegisterType::Plaintext(PlaintextType::from(output.to_type()));
let expected_types = self.output_types(stack, &input_types)?;
ensure!(expected_types.len() == 1, "Expected 1 output type, found {}", expected_types.len());
ensure!(expected_types[0] == output_type, "Expected output type '{}', found {output_type}", expected_types[0]);
registers.store_literal(stack, &self.destination, output)
}
#[inline]
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersLoadCircuit<N, A> + RegistersStoreCircuit<N, A>),
) -> Result<()> {
if self.operands.len() != NUM_OPERANDS {
bail!("Instruction '{}' expects {NUM_OPERANDS} operands, found {} operands", O::OPCODE, self.operands.len())
}
let inputs: Vec<_> =
self.operands.iter().map(|operand| registers.load_literal_circuit(stack, operand)).try_collect()?;
let input_types: Vec<_> =
inputs.iter().map(|input| RegisterType::Plaintext(PlaintextType::from(input.to_type()))).collect();
let output = O::execute(&inputs.try_into().map_err(|_| anyhow!("Failed to prepare operands in evaluate"))?)?;
let output_type = RegisterType::Plaintext(PlaintextType::from(output.to_type()));
let expected_types = self.output_types(stack, &input_types)?;
ensure!(expected_types.len() == 1, "Expected 1 output type, found {}", expected_types.len());
ensure!(expected_types[0] == output_type, "Expected output type '{}', found {output_type}", expected_types[0]);
registers.store_literal_circuit(stack, &self.destination, output)
}
#[inline]
pub fn finalize(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersLoad<N> + RegistersStore<N>),
) -> Result<()> {
self.evaluate(stack, registers)
}
#[inline]
pub fn output_types(
&self,
_stack: &impl StackProgram<N>,
input_types: &[RegisterType<N>],
) -> Result<Vec<RegisterType<N>>> {
if input_types.len() != NUM_OPERANDS {
bail!("Instruction '{}' expects {NUM_OPERANDS} inputs, found {} inputs", O::OPCODE, input_types.len())
}
if self.operands.len() != NUM_OPERANDS {
bail!("Instruction '{}' expects {NUM_OPERANDS} operands, found {} operands", O::OPCODE, self.operands.len())
}
let input_types = input_types
.iter()
.map(|input_type| match input_type {
RegisterType::Plaintext(PlaintextType::Literal(literal_type)) => Ok(*literal_type),
RegisterType::Plaintext(PlaintextType::Struct(..))
| RegisterType::Plaintext(PlaintextType::Array(..))
| RegisterType::Record(..)
| RegisterType::ExternalRecord(..)
| RegisterType::Future(..) => bail!("Expected literal type, found '{input_type}'"),
})
.collect::<Result<Vec<_>>>()?;
let output = O::output_type(&input_types.try_into().map_err(|_| anyhow!("Failed to prepare operand types"))?)?;
Ok(vec![RegisterType::Plaintext(PlaintextType::Literal(output))])
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> Parser
for Literals<N, O, NUM_OPERANDS>
{
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
let (string, _) = tag(*O::OPCODE)(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
if NUM_OPERANDS > N::MAX_OPERANDS {
return map_res(fail, |_: ParserResult<Self>| {
Err(format!("The number of operands must be <= {}", N::MAX_OPERANDS))
})(string);
}
let mut operands = Vec::with_capacity(NUM_OPERANDS);
let mut string_tracker = string;
for _ in 0..NUM_OPERANDS {
let (string, operand) = Operand::parse(string_tracker)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
operands.push(operand);
string_tracker = string;
}
let string = string_tracker;
let (string, _) = tag("into")(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, destination) = Register::parse(string)?;
Ok((string, Self { operands, destination, _phantom: PhantomData }))
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> FromStr
for Literals<N, O, NUM_OPERANDS>
{
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> Debug
for Literals<N, O, NUM_OPERANDS>
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> Display
for Literals<N, O, NUM_OPERANDS>
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if NUM_OPERANDS > N::MAX_OPERANDS {
return Err(fmt::Error);
}
if self.operands.len() > NUM_OPERANDS {
return Err(fmt::Error);
}
write!(f, "{} ", O::OPCODE)?;
self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))?;
write!(f, "into {}", self.destination)
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> FromBytes
for Literals<N, O, NUM_OPERANDS>
{
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
if NUM_OPERANDS > N::MAX_OPERANDS {
return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
}
let mut operands = Vec::with_capacity(NUM_OPERANDS);
for _ in 0..NUM_OPERANDS {
operands.push(Operand::read_le(&mut reader)?);
}
let destination = Register::read_le(&mut reader)?;
Ok(Self { operands, destination, _phantom: PhantomData })
}
}
impl<N: Network, O: Operation<N, Literal<N>, LiteralType, NUM_OPERANDS>, const NUM_OPERANDS: usize> ToBytes
for Literals<N, O, NUM_OPERANDS>
{
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
if NUM_OPERANDS > N::MAX_OPERANDS {
return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
}
if self.operands.len() > NUM_OPERANDS {
return Err(error(format!("The number of operands must be {NUM_OPERANDS}")));
}
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
self.destination.write_le(&mut writer)
}
}