use crate::{
Opcode,
Operand,
traits::{RegistersLoad, RegistersLoadCircuit, StackMatches, StackProgram},
};
use console::{
network::prelude::*,
program::{Identifier, Locator, Register, RegisterType, ValueType},
};
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum CallOperator<N: Network> {
Locator(Locator<N>),
Resource(Identifier<N>),
}
impl<N: Network> Parser for CallOperator<N> {
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
alt((map(Locator::parse, CallOperator::Locator), map(Identifier::parse, CallOperator::Resource)))(string)
}
}
impl<N: Network> FromStr for CallOperator<N> {
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network> Debug for CallOperator<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network> Display for CallOperator<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
CallOperator::Locator(locator) => Display::fmt(locator, f),
CallOperator::Resource(resource) => Display::fmt(resource, f),
}
}
}
impl<N: Network> FromBytes for CallOperator<N> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let variant = u8::read_le(&mut reader)?;
match variant {
0 => Ok(CallOperator::Locator(Locator::read_le(&mut reader)?)),
1 => Ok(CallOperator::Resource(Identifier::read_le(&mut reader)?)),
_ => Err(error("Failed to read CallOperator. Invalid variant.")),
}
}
}
impl<N: Network> ToBytes for CallOperator<N> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
match self {
CallOperator::Locator(locator) => {
0u8.write_le(&mut writer)?;
locator.write_le(&mut writer)
}
CallOperator::Resource(resource) => {
1u8.write_le(&mut writer)?;
resource.write_le(&mut writer)
}
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Call<N: Network> {
operator: CallOperator<N>,
operands: Vec<Operand<N>>,
destinations: Vec<Register<N>>,
}
impl<N: Network> Call<N> {
#[inline]
pub const fn opcode() -> Opcode {
Opcode::Call
}
#[inline]
pub const fn operator(&self) -> &CallOperator<N> {
&self.operator
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
self.destinations.clone()
}
}
impl<N: Network> Call<N> {
#[inline]
pub fn is_function_call(&self, stack: &impl StackProgram<N>) -> Result<bool> {
match self.operator() {
CallOperator::Locator(locator) => {
let program = stack.get_external_program(locator.program_id())?;
Ok(program.contains_function(locator.resource()))
}
CallOperator::Resource(resource) => Ok(stack.program().contains_function(resource)),
}
}
pub fn evaluate(&self, _stack: &impl StackProgram<N>, _registers: &mut impl RegistersLoad<N>) -> Result<()> {
bail!("Forbidden operation: Evaluate cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
}
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
_stack: &impl StackProgram<N>,
_registers: &mut impl RegistersLoadCircuit<N, A>,
) -> Result<()> {
bail!("Forbidden operation: Execute cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
}
#[inline]
pub fn finalize(
&self,
_stack: &(impl StackMatches<N> + StackProgram<N>),
_registers: &mut impl RegistersLoad<N>,
) -> Result<()> {
bail!("Forbidden operation: Finalize cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
}
#[inline]
pub fn output_types(
&self,
stack: &impl StackProgram<N>,
input_types: &[RegisterType<N>],
) -> Result<Vec<RegisterType<N>>> {
let (is_external, program, resource) = match &self.operator {
CallOperator::Locator(locator) => {
(true, stack.get_external_program(locator.program_id())?, locator.resource())
}
CallOperator::Resource(resource) => {
if stack.program().contains_function(resource) {
bail!("Cannot call '{resource}'. Use a closure ('closure {resource}:') instead.")
}
(false, stack.program(), resource)
}
};
if let Ok(closure) = program.get_closure(resource) {
if closure.inputs().len() != self.operands.len() {
bail!("Expected {} inputs, found {}", closure.inputs().len(), self.operands.len())
}
if closure.inputs().len() != input_types.len() {
bail!("Expected {} input types, found {}", closure.inputs().len(), input_types.len())
}
if closure.outputs().len() != self.destinations.len() {
bail!("Expected {} outputs, found {}", closure.outputs().len(), self.destinations.len())
}
Ok(closure.outputs().iter().map(|output| output.register_type()).cloned().collect())
}
else if let Ok(function) = program.get_function(resource) {
if function.inputs().len() != self.operands.len() {
bail!("Expected {} inputs, found {}", function.inputs().len(), self.operands.len())
}
if function.inputs().len() != input_types.len() {
bail!("Expected {} input types, found {}", function.inputs().len(), input_types.len())
}
if function.outputs().len() != self.destinations.len() {
bail!("Expected {} outputs, found {}", function.outputs().len(), self.destinations.len())
}
function
.output_types()
.into_iter()
.map(|output_type| match (is_external, output_type) {
(true, ValueType::Record(record_name)) => Ok(RegisterType::ExternalRecord(Locator::from_str(
&format!("{}/{}", program.id(), record_name),
)?)),
(_, output_type) => Ok(RegisterType::from(output_type)),
})
.collect::<Result<Vec<_>>>()
}
else {
bail!("Call operator '{}' is invalid or unsupported.", self.operator)
}
}
}
impl<N: Network> Parser for Call<N> {
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
fn parse_operand<N: Network>(string: &str) -> ParserResult<Operand<N>> {
let (string, _) = Sanitizer::parse_whitespaces(string)?;
Operand::parse(string)
}
fn parse_destination<N: Network>(string: &str) -> ParserResult<Register<N>> {
let (string, _) = Sanitizer::parse_whitespaces(string)?;
Register::parse(string)
}
let (string, _) = tag(*Self::opcode())(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, operator) = CallOperator::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, operands) = map_res(many0(complete(parse_operand)), |operands: Vec<Operand<N>>| {
match operands.len() <= N::MAX_OPERANDS {
true => Ok(operands),
false => Err(error("Failed to parse 'call' opcode: too many operands")),
}
})(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, destinations) = match opt(tag("into"))(string)? {
(string, None) => (string, vec![]),
(string, Some(_)) => {
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, destinations) =
map_res(many1(complete(parse_destination)), |destinations: Vec<Register<N>>| {
match destinations.len() <= N::MAX_OPERANDS {
true => Ok(destinations),
false => Err(error("Failed to parse 'call' opcode: too many destinations")),
}
})(string)?;
(string, destinations)
}
};
Ok((string, Self { operator, operands, destinations }))
}
}
impl<N: Network> FromStr for Call<N> {
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network> Debug for Call<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network> Display for Call<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if self.operands.len() > N::MAX_OPERANDS {
return Err(fmt::Error);
}
if self.destinations.len() > N::MAX_OPERANDS {
return Err(fmt::Error);
}
write!(f, "{} {}", Self::opcode(), self.operator)?;
self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))?;
if !self.destinations.is_empty() {
write!(f, " into")?;
self.destinations.iter().try_for_each(|destination| write!(f, " {destination}"))?;
}
Ok(())
}
}
impl<N: Network> FromBytes for Call<N> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let operator = CallOperator::read_le(&mut reader)?;
let num_operands = u8::read_le(&mut reader)? as usize;
if num_operands > N::MAX_OPERANDS {
return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
}
let mut operands = Vec::with_capacity(num_operands);
for _ in 0..num_operands {
operands.push(Operand::read_le(&mut reader)?);
}
let num_destinations = u8::read_le(&mut reader)? as usize;
if num_destinations > N::MAX_OPERANDS {
return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
}
let mut destinations = Vec::with_capacity(num_destinations);
for _ in 0..num_destinations {
destinations.push(Register::read_le(&mut reader)?);
}
Ok(Self { operator, operands, destinations })
}
}
impl<N: Network> ToBytes for Call<N> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
if self.operands.len() > N::MAX_OPERANDS {
return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
}
if self.destinations.len() > N::MAX_OPERANDS {
return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
}
self.operator.write_le(&mut writer)?;
u8::try_from(self.operands.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
u8::try_from(self.destinations.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
self.destinations.iter().try_for_each(|destination| destination.write_le(&mut writer))
}
}
#[cfg(test)]
mod tests {
use super::*;
use console::{
network::MainnetV0,
program::{Access, Address, Identifier, Literal, U64},
};
type CurrentNetwork = MainnetV0;
const TEST_CASES: &[&str] = &[
"call foo",
"call foo r0",
"call foo r0.owner",
"call foo r0 r1",
"call foo into r0",
"call foo into r0 r1",
"call foo into r0 r1 r2",
"call foo r0 into r1",
"call foo r0 r1 into r2",
"call foo r0 r1 into r2 r3",
"call foo r0 r1 r2 into r3 r4",
"call foo r0 r1 r2 into r3 r4 r5",
];
fn check_parser(
string: &str,
expected_operator: CallOperator<CurrentNetwork>,
expected_operands: Vec<Operand<CurrentNetwork>>,
expected_destinations: Vec<Register<CurrentNetwork>>,
) {
let (string, call) = Call::<CurrentNetwork>::parse(string).unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(call.operator, expected_operator, "The call operator is incorrect");
assert_eq!(call.operands.len(), expected_operands.len(), "The number of operands is incorrect");
for (i, (given, expected)) in call.operands.iter().zip(expected_operands.iter()).enumerate() {
assert_eq!(given, expected, "The {i}-th operand is incorrect");
}
assert_eq!(call.destinations.len(), expected_destinations.len(), "The number of destinations is incorrect");
for (i, (given, expected)) in call.destinations.iter().zip(expected_destinations.iter()).enumerate() {
assert_eq!(given, expected, "The {i}-th destination is incorrect");
}
}
#[test]
fn test_parse() {
check_parser(
"call transfer r0.owner r0.token_amount into r1 r2 r3",
CallOperator::from_str("transfer").unwrap(),
vec![
Operand::Register(Register::Access(0, vec![Access::from(Identifier::from_str("owner").unwrap())])),
Operand::Register(Register::Access(0, vec![Access::from(
Identifier::from_str("token_amount").unwrap(),
)])),
],
vec![Register::Locator(1), Register::Locator(2), Register::Locator(3)],
);
check_parser(
"call mint_public aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw 100u64",
CallOperator::from_str("mint_public").unwrap(),
vec![
Operand::Literal(Literal::Address(
Address::from_str("aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw").unwrap(),
)),
Operand::Literal(Literal::U64(U64::from_str("100u64").unwrap())),
],
vec![],
);
check_parser(
"call get_magic_number into r0",
CallOperator::from_str("get_magic_number").unwrap(),
vec![],
vec![Register::Locator(0)],
);
check_parser("call noop", CallOperator::from_str("noop").unwrap(), vec![], vec![])
}
#[test]
fn test_display() {
for expected in TEST_CASES {
assert_eq!(Call::<CurrentNetwork>::from_str(expected).unwrap().to_string(), *expected);
}
}
#[test]
fn test_bytes() {
for case in TEST_CASES {
let expected = Call::<CurrentNetwork>::from_str(case).unwrap();
let expected_bytes = expected.to_bytes_le().unwrap();
assert_eq!(expected, Call::read_le(&expected_bytes[..]).unwrap());
}
}
}