use std::collections::HashSet;
use cairo_lang_defs::ids::{LanguageElementId, TraitFunctionId, TraitId};
use cairo_lang_diagnostics::{DiagnosticAdded, Maybe};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::plugin::consts::{EVENT_ATTR, VIEW_ATTR};
#[cfg(test)]
#[path = "abi_test.rs"]
mod test;
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Contract {
pub items: Vec<Item>,
}
impl Contract {
pub fn json(&self) -> String {
serde_json::to_string_pretty(&self).unwrap()
}
}
pub struct AbiBuilder {
abi: Contract,
types: HashSet<TypeId>,
}
impl AbiBuilder {
pub fn from_trait(db: &dyn SemanticGroup, trait_id: TraitId) -> Result<Contract, ABIError> {
if !db.trait_generic_params(trait_id).map_err(|_| ABIError::CompilationError)?.is_empty() {
return Err(ABIError::GenericTraitsUnsupported);
}
let mut builder = Self { abi: Contract::default(), types: HashSet::new() };
for trait_function_id in db.trait_functions(trait_id).unwrap_or_default().values() {
if trait_function_has_attr(db, *trait_function_id, EVENT_ATTR)? {
builder.add_event(db, *trait_function_id)?;
} else {
builder.add_function(db, *trait_function_id)?;
}
}
Ok(builder.abi)
}
fn add_function(
&mut self,
db: &dyn SemanticGroup,
trait_function_id: TraitFunctionId,
) -> Result<(), ABIError> {
let state_mutability = if trait_function_has_attr(db, trait_function_id, VIEW_ATTR)? {
StateMutability::View
} else {
StateMutability::External
};
let defs_db = db.upcast();
let name = trait_function_id.name(defs_db).into();
let signature = db
.trait_function_signature(trait_function_id)
.map_err(|_| ABIError::CompilationError)?;
let mut inputs = vec![];
for param in signature.params.into_iter() {
self.add_type(db, param.ty)?;
inputs.push(Input { name: param.id.name(db.upcast()).into(), ty: param.ty.format(db) });
}
let outputs = if signature.return_type.is_unit(db) {
vec![]
} else {
self.add_type(db, signature.return_type)?;
vec![Output { ty: signature.return_type.format(db) }]
};
self.abi.items.push(Item::Function(Function { name, inputs, outputs, state_mutability }));
Ok(())
}
fn add_event(
&mut self,
db: &dyn SemanticGroup,
trait_function_id: TraitFunctionId,
) -> Result<(), ABIError> {
let defs_db = db.upcast();
let name = trait_function_id.name(defs_db).into();
let signature = db
.trait_function_signature(trait_function_id)
.map_err(|_| ABIError::CompilationError)?;
self.abi.items.push(Item::Event(Event {
name,
inputs: signature
.params
.into_iter()
.map(|param| Input {
name: param.id.name(db.upcast()).into(),
ty: param.ty.format(db),
})
.collect(),
}));
Ok(())
}
fn add_type(&mut self, db: &dyn SemanticGroup, type_id: TypeId) -> Result<(), ABIError> {
if !self.types.insert(type_id) {
return Ok(());
}
match db.lookup_intern_type(type_id) {
TypeLongId::Concrete(concrete) => self.add_concrete_type(db, concrete),
TypeLongId::Tuple(inner_types) => {
for ty in inner_types {
self.add_type(db, ty)?;
}
Ok(())
}
TypeLongId::Snapshot(ty) => self.add_type(db, ty),
TypeLongId::GenericParameter(_) | TypeLongId::Var(_) | TypeLongId::Missing(_) => {
Err(ABIError::UnexpectedType)
}
}
}
fn add_concrete_type(
&mut self,
db: &dyn SemanticGroup,
concrete: ConcreteTypeId,
) -> Result<(), ABIError> {
for generic_arg in concrete.generic_args(db) {
if let GenericArgumentId::Type(type_id) = generic_arg {
self.add_type(db, type_id)?;
}
}
if is_native_type(db, &concrete) {
return Ok(());
}
match concrete {
ConcreteTypeId::Struct(id) => self.abi.items.push(Item::Struct(Struct {
name: concrete.format(db),
members: get_struct_members(db, id).map_err(|_| ABIError::UnexpectedType)?,
})),
ConcreteTypeId::Enum(id) => self.abi.items.push(Item::Enum(Enum {
name: concrete.format(db),
variants: get_enum_variants(db, id).map_err(|_| ABIError::UnexpectedType)?,
})),
ConcreteTypeId::Extern(_) => {}
}
Ok(())
}
}
fn get_struct_members(
db: &dyn SemanticGroup,
id: cairo_lang_semantic::ConcreteStructId,
) -> Maybe<Vec<StructMember>> {
Ok(db
.concrete_struct_members(id)?
.iter()
.map(|(name, member)| StructMember { name: name.to_string(), ty: member.ty.format(db) })
.collect())
}
fn get_enum_variants(
db: &dyn SemanticGroup,
id: cairo_lang_semantic::ConcreteEnumId,
) -> Maybe<Vec<EnumVariant>> {
let generic_id = id.enum_id(db);
db.enum_variants(generic_id)?
.iter()
.map(|(name, variant_id)| {
Ok(EnumVariant {
name: name.to_string(),
ty: db
.concrete_enum_variant(id, &db.variant_semantic(generic_id, *variant_id)?)?
.ty
.format(db),
})
})
.collect::<Result<Vec<_>, DiagnosticAdded>>()
}
fn is_native_type(db: &dyn SemanticGroup, concrete: &ConcreteTypeId) -> bool {
let def_db = db.upcast();
concrete.generic_type(db).parent_module(def_db).owning_crate(def_db) == db.core_crate()
}
fn trait_function_has_attr(
db: &dyn SemanticGroup,
trait_function_id: TraitFunctionId,
attr: &str,
) -> Result<bool, ABIError> {
Ok(db
.trait_function_attributes(trait_function_id)
.map_err(|_| ABIError::CompilationError)?
.iter()
.any(|a| a.id.to_string() == attr))
}
#[derive(Error, Debug)]
pub enum ABIError {
#[error("Generic traits are unsupported.")]
GenericTraitsUnsupported,
#[error("Compilation error.")]
CompilationError,
#[error("Got unexpected type.")]
UnexpectedType,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Item {
#[serde(rename = "function")]
Function(Function),
#[serde(rename = "event")]
Event(Event),
#[serde(rename = "struct")]
Struct(Struct),
#[serde(rename = "enum")]
Enum(Enum),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum StateMutability {
#[serde(rename = "external")]
External,
#[serde(rename = "view")]
View,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>,
pub state_mutability: StateMutability,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Event {
pub name: String,
pub inputs: Vec<Input>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Input {
pub name: String,
#[serde(rename = "type")]
pub ty: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Output {
#[serde(rename = "type")]
pub ty: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Struct {
pub name: String,
pub members: Vec<StructMember>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct StructMember {
pub name: String,
#[serde(rename = "type")]
pub ty: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Enum {
pub name: String,
pub variants: Vec<EnumVariant>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct EnumVariant {
pub name: String,
#[serde(rename = "type")]
pub ty: String,
}