use crate::{context::Context, pretty::DebugWithContext, Constant, ConstantValue, Value};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, DebugWithContext)]
pub struct Type(pub generational_arena::Index);
#[derive(Debug, Clone, DebugWithContext, Hash, PartialEq, Eq)]
pub enum TypeContent {
Unit,
Bool,
Uint(u16),
B256,
StringSlice,
StringArray(u64),
Array(Type, u64),
Union(Vec<Type>),
Struct(Vec<Type>),
Slice,
Pointer(Type),
}
impl Type {
fn get_or_create_unique_type(context: &mut Context, t: TypeContent) -> Type {
#[allow(clippy::map_entry)]
if !context.type_map.contains_key(&t) {
let new_type = Type(context.types.insert(t.clone()));
context.type_map.insert(t, new_type);
new_type
} else {
context.type_map.get(&t).copied().unwrap()
}
}
pub fn get_type(context: &Context, t: &TypeContent) -> Option<Type> {
context.type_map.get(t).copied()
}
pub fn create_basic_types(context: &mut Context) {
Self::get_or_create_unique_type(context, TypeContent::Unit);
Self::get_or_create_unique_type(context, TypeContent::Bool);
Self::get_or_create_unique_type(context, TypeContent::Uint(8));
Self::get_or_create_unique_type(context, TypeContent::Uint(64));
Self::get_or_create_unique_type(context, TypeContent::Uint(256));
Self::get_or_create_unique_type(context, TypeContent::B256);
Self::get_or_create_unique_type(context, TypeContent::Slice);
}
pub fn get_content<'a>(&self, context: &'a Context) -> &'a TypeContent {
&context.types[self.0]
}
pub fn get_unit(context: &Context) -> Type {
Self::get_type(context, &TypeContent::Unit).expect("create_basic_types not called")
}
pub fn get_bool(context: &Context) -> Type {
Self::get_type(context, &TypeContent::Bool).expect("create_basic_types not called")
}
pub fn new_uint(context: &mut Context, width: u16) -> Type {
Self::get_or_create_unique_type(context, TypeContent::Uint(width))
}
pub fn get_uint8(context: &Context) -> Type {
Self::get_type(context, &TypeContent::Uint(8)).expect("create_basic_types not called")
}
pub fn get_uint64(context: &Context) -> Type {
Self::get_type(context, &TypeContent::Uint(64)).expect("create_basic_types not called")
}
pub fn get_uint256(context: &Context) -> Type {
Self::get_type(context, &TypeContent::Uint(256)).expect("create_basic_types not called")
}
pub fn get_uint(context: &Context, width: u16) -> Option<Type> {
Self::get_type(context, &TypeContent::Uint(width))
}
pub fn get_b256(context: &Context) -> Type {
Self::get_type(context, &TypeContent::B256).expect("create_basic_types not called")
}
pub fn new_string_array(context: &mut Context, len: u64) -> Type {
Self::get_or_create_unique_type(context, TypeContent::StringArray(len))
}
pub fn new_array(context: &mut Context, elm_ty: Type, len: u64) -> Type {
Self::get_or_create_unique_type(context, TypeContent::Array(elm_ty, len))
}
pub fn new_union(context: &mut Context, fields: Vec<Type>) -> Type {
Self::get_or_create_unique_type(context, TypeContent::Union(fields))
}
pub fn new_struct(context: &mut Context, fields: Vec<Type>) -> Type {
Self::get_or_create_unique_type(context, TypeContent::Struct(fields))
}
pub fn new_ptr(context: &mut Context, to_ty: Type) -> Type {
Self::get_or_create_unique_type(context, TypeContent::Pointer(to_ty))
}
pub fn get_slice(context: &mut Context) -> Type {
Self::get_type(context, &TypeContent::Slice).expect("create_basic_types not called")
}
pub fn as_string(&self, context: &Context) -> String {
let sep_types_str = |agg_content: &Vec<Type>, sep: &str| {
agg_content
.iter()
.map(|ty| ty.as_string(context))
.collect::<Vec<_>>()
.join(sep)
};
match self.get_content(context) {
TypeContent::Unit => "()".into(),
TypeContent::Bool => "bool".into(),
TypeContent::Uint(nbits) => format!("u{nbits}"),
TypeContent::B256 => "b256".into(),
TypeContent::StringSlice => "str".into(),
TypeContent::StringArray(n) => format!("string<{n}>"),
TypeContent::Array(ty, cnt) => {
format!("[{}; {}]", ty.as_string(context), cnt)
}
TypeContent::Union(agg) => {
format!("( {} )", sep_types_str(agg, " | "))
}
TypeContent::Struct(agg) => {
format!("{{ {} }}", sep_types_str(agg, ", "))
}
TypeContent::Slice => "slice".into(),
TypeContent::Pointer(ty) => format!("ptr {}", ty.as_string(context)),
}
}
pub fn eq(&self, context: &Context, other: &Type) -> bool {
match (self.get_content(context), other.get_content(context)) {
(TypeContent::Unit, TypeContent::Unit) => true,
(TypeContent::Bool, TypeContent::Bool) => true,
(TypeContent::Uint(l), TypeContent::Uint(r)) => l == r,
(TypeContent::B256, TypeContent::B256) => true,
(TypeContent::StringSlice, TypeContent::StringSlice) => true,
(TypeContent::StringArray(l), TypeContent::StringArray(r)) => l == r,
(TypeContent::Array(l, llen), TypeContent::Array(r, rlen)) => {
llen == rlen && l.eq(context, r)
}
(TypeContent::Struct(l), TypeContent::Struct(r))
| (TypeContent::Union(l), TypeContent::Union(r)) => {
l.len() == r.len() && l.iter().zip(r.iter()).all(|(l, r)| l.eq(context, r))
}
(_, TypeContent::Union(_)) => other.eq(context, self),
(TypeContent::Union(l), _) => l.iter().any(|field_ty| other.eq(context, field_ty)),
(TypeContent::Slice, TypeContent::Slice) => true,
(TypeContent::Pointer(l), TypeContent::Pointer(r)) => l.eq(context, r),
_ => false,
}
}
pub fn is_bool(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Bool)
}
pub fn is_unit(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Unit)
}
pub fn is_uint(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Uint(_))
}
pub fn is_uint8(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Uint(8))
}
pub fn is_uint32(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Uint(32))
}
pub fn is_uint64(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Uint(64))
}
pub fn is_uint_of(&self, context: &Context, width: u16) -> bool {
matches!(*self.get_content(context), TypeContent::Uint(width_) if width == width_)
}
pub fn is_b256(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::B256)
}
pub fn is_string_slice(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::StringSlice)
}
pub fn is_string_array(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::StringArray(_))
}
pub fn is_array(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Array(..))
}
pub fn is_union(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Union(_))
}
pub fn is_struct(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Struct(_))
}
pub fn is_aggregate(&self, context: &Context) -> bool {
self.is_struct(context) || self.is_union(context) || self.is_array(context)
}
pub fn is_slice(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Slice)
}
pub fn is_ptr(&self, context: &Context) -> bool {
matches!(*self.get_content(context), TypeContent::Pointer(_))
}
pub fn get_pointee_type(&self, context: &Context) -> Option<Type> {
if let TypeContent::Pointer(to_ty) = self.get_content(context) {
Some(*to_ty)
} else {
None
}
}
pub fn get_uint_width(&self, context: &Context) -> Option<u16> {
if let TypeContent::Uint(width) = self.get_content(context) {
Some(*width)
} else {
None
}
}
pub fn get_indexed_type(&self, context: &Context, indices: &[u64]) -> Option<Type> {
if indices.is_empty() {
return None;
}
indices.iter().try_fold(*self, |ty, idx| {
ty.get_field_type(context, *idx)
.or_else(|| match ty.get_content(context) {
TypeContent::Array(ty, len) if idx < len => Some(*ty),
_ => None,
})
})
}
pub fn get_indexed_offset(&self, context: &Context, indices: &[u64]) -> Option<u64> {
indices
.iter()
.try_fold((*self, 0), |(ty, accum_offset), idx| {
if ty.is_struct(context) {
let prev_idxs_offset = (0..(*idx)).try_fold(0, |accum, pre_idx| {
ty.get_field_type(context, pre_idx)
.map(|field_ty| field_ty.size_in_bytes(context) + accum)
})?;
ty.get_field_type(context, *idx)
.map(|field_ty| (field_ty, accum_offset + prev_idxs_offset))
} else if ty.is_union(context) {
ty.get_field_type(context, *idx)
.map(|field_ty| (field_ty, accum_offset))
} else {
assert!(
ty.is_array(context),
"Expected aggregate type when indexing using GEP. Got {}",
ty.as_string(context)
);
ty.get_array_elem_type(context).map(|elm_ty| {
let prev_idxs_offset = ty
.get_array_elem_type(context)
.unwrap()
.size_in_bytes(context)
* idx;
(elm_ty, accum_offset + prev_idxs_offset)
})
}
})
.map(|pair| pair.1)
}
pub fn get_value_indexed_offset(&self, context: &Context, indices: &[Value]) -> Option<u64> {
let const_indices: Vec<_> = indices
.iter()
.map_while(|idx| {
if let Some(Constant {
value: ConstantValue::Uint(idx),
ty: _,
}) = idx.get_constant(context)
{
Some(*idx)
} else {
None
}
})
.collect();
(const_indices.len() == indices.len())
.then(|| self.get_indexed_offset(context, &const_indices))
.flatten()
}
pub fn get_field_type(&self, context: &Context, idx: u64) -> Option<Type> {
if let TypeContent::Struct(fields) | TypeContent::Union(fields) = self.get_content(context)
{
fields.get(idx as usize).cloned()
} else {
None
}
}
pub fn get_array_elem_type(&self, context: &Context) -> Option<Type> {
if let TypeContent::Array(ty, _) = *self.get_content(context) {
Some(ty)
} else {
None
}
}
pub fn get_array_len(&self, context: &Context) -> Option<u64> {
if let TypeContent::Array(_, n) = *self.get_content(context) {
Some(n)
} else {
None
}
}
pub fn get_string_len(&self, context: &Context) -> Option<u64> {
if let TypeContent::StringArray(n) = *self.get_content(context) {
Some(n)
} else {
None
}
}
pub fn get_field_types(&self, context: &Context) -> Vec<Type> {
match self.get_content(context) {
TypeContent::Struct(fields) | TypeContent::Union(fields) => fields.clone(),
_ => vec![],
}
}
pub fn size_in_bytes(&self, context: &Context) -> u64 {
match self.get_content(context) {
TypeContent::Unit | TypeContent::Bool | TypeContent::Pointer(_) => 8,
TypeContent::Uint(bits) => (*bits as u64) / 8,
TypeContent::Slice => 16,
TypeContent::B256 => 32,
TypeContent::StringSlice => 16,
TypeContent::StringArray(n) => super::size_bytes_round_up_to_word_alignment!(*n),
TypeContent::Array(el_ty, cnt) => cnt * el_ty.size_in_bytes(context),
TypeContent::Struct(field_tys) => {
field_tys
.iter()
.map(|field_ty| field_ty.size_in_bytes(context))
.sum()
}
TypeContent::Union(field_tys) => {
field_tys
.iter()
.map(|field_ty| field_ty.size_in_bytes(context))
.max()
.unwrap_or(0)
}
}
}
}
#[macro_export]
macro_rules! size_bytes_round_up_to_word_alignment {
($bytes_expr: expr) => {
($bytes_expr + 7) - (($bytes_expr + 7) % 8)
};
}
pub trait TypeOption {
fn is(&self, pred: fn(&Type, &Context) -> bool, context: &Context) -> bool;
}
impl TypeOption for Option<Type> {
fn is(&self, pred: fn(&Type, &Context) -> bool, context: &Context) -> bool {
self.filter(|ty| pred(ty, context)).is_some()
}
}