use pliron::{
arg_err_noloc,
basic_block::BasicBlock,
builtin::{
attr_interfaces::TypedAttrInterface,
attributes::TypeAttr,
op_interfaces::{
BranchOpInterface, IsTerminatorInterface, OneOpdInterface, OneResultInterface,
SameOperandsAndResultType, SameOperandsType, SameResultsType, ZeroResultInterface,
ZeroResultVerifyErr,
},
types::{IntegerType, Signedness},
},
common_traits::Verify,
context::{Context, Ptr},
dialect::Dialect,
error::{Error, ErrorKind, Result},
identifier::Identifier,
impl_canonical_syntax, impl_op_interface, impl_verify_succ, input_err,
irfmt::parsers::ssa_opd_parser,
location::{Located, Location},
op::{Op, OpObj},
operation::Operation,
parsable::{self, Parsable, ParseResult},
printable::{self, Printable},
r#type::{TypeObj, TypePtr},
use_def_lists::Value,
vec_exns::VecExtns,
verify_err,
};
use crate::{
op_interfaces::{BinArithOp, IntBinArithOp, IntBinArithOpWithOverflowFlag, PointerTypeResult},
types::{ArrayType, StructType},
};
use combine::parser::Parser;
use pliron_derive::def_op;
use thiserror::Error;
use super::{
attributes::{GepIndexAttr, GepIndicesAttr, ICmpPredicateAttr},
types::PointerType,
};
#[def_op("llvm.return")]
pub struct ReturnOp {}
impl ReturnOp {
pub fn new(ctx: &mut Context, value: Value) -> Self {
let op = Operation::new(ctx, Self::get_opid_static(), vec![], vec![value], vec![], 0);
ReturnOp { op }
}
}
impl Printable for ReturnOp {
fn fmt(
&self,
ctx: &Context,
_state: &printable::State,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
write!(
f,
"{} {}",
self.get_opid().disp(ctx),
self.get_operation()
.deref(ctx)
.get_operand(0)
.unwrap()
.disp(ctx)
)
}
}
impl_verify_succ!(ReturnOp);
impl Parsable for ReturnOp {
type Arg = Vec<(Identifier, Location)>;
type Parsed = OpObj;
fn parse<'a>(
state_stream: &mut parsable::StateStream<'a>,
results: Self::Arg,
) -> ParseResult<'a, Self::Parsed> {
if !results.is_empty() {
input_err!(
state_stream.loc(),
ZeroResultVerifyErr(Self::get_opid_static().to_string())
)?
}
ssa_opd_parser()
.parse_stream(state_stream)
.map(|opd| -> OpObj { Box::new(Self::new(state_stream.state.ctx, opd)) })
.into()
}
}
impl_op_interface!(IsTerminatorInterface for ReturnOp {});
macro_rules! new_int_bin_op {
( $(#[$outer:meta])*
$op_name:ident, $op_id:literal
) => {
#[def_op($op_id)]
$(#[$outer])*
pub struct $op_name {}
impl_verify_succ!($op_name);
impl_canonical_syntax!($op_name);
impl_op_interface!(OneResultInterface for $op_name {});
impl_op_interface!(SameOperandsType for $op_name {});
impl_op_interface!(SameResultsType for $op_name {});
impl_op_interface!(SameOperandsAndResultType for $op_name {});
impl_op_interface!(BinArithOp for $op_name {});
impl_op_interface!(IntBinArithOp for $op_name {});
}
}
macro_rules! new_int_bin_op_with_overflow {
( $(#[$outer:meta])*
$op_name:ident, $op_id:literal
) => {
new_int_bin_op!(
$(#[$outer])*
$op_name,
$op_id
);
impl_op_interface!(IntBinArithOpWithOverflowFlag for $op_name {});
}
}
new_int_bin_op_with_overflow!(
AddOp,
"llvm.add"
);
new_int_bin_op_with_overflow!(
SubOp,
"llvm.sub"
);
new_int_bin_op_with_overflow!(
MulOp,
"llvm.mul"
);
new_int_bin_op_with_overflow!(
ShlOp,
"llvm.shl"
);
new_int_bin_op!(
UDivOp,
"llvm.udiv"
);
new_int_bin_op!(
SDivOp,
"llvm.sdiv"
);
new_int_bin_op!(
URemOp,
"llvm.urem"
);
new_int_bin_op!(
SRemOp,
"llvm.srem"
);
new_int_bin_op!(
AndOp,
"llvm.and"
);
new_int_bin_op!(
OrOp,
"llvm.or"
);
new_int_bin_op!(
XorOp,
"llvm.xor"
);
new_int_bin_op!(
LShrOp,
"llvm.lshr"
);
new_int_bin_op!(
AShrOp,
"llvm.ashr"
);
#[derive(Error, Debug)]
pub enum ICmpOpVerifyErr {
#[error("Result must be 1-bit integer (bool)")]
ResultNotBool,
#[error("Operand must be integer or pointer types")]
IncorrectOperandsType,
#[error("Missing or incorrect predicate attribute")]
PredAttrErr,
}
#[def_op("llvm.icmp")]
pub struct ICmpOp {}
impl ICmpOp {
pub const ATTR_KEY_PREDICATE: &'static str = "llvm.icmp_predicate";
pub fn new(ctx: &mut Context, pred: ICmpPredicateAttr, lhs: Value, rhs: Value) -> Self {
let bool_ty = IntegerType::get(ctx, 1, Signedness::Signless);
let op = Operation::new(
ctx,
Self::get_opid_static(),
vec![bool_ty.into()],
vec![lhs, rhs],
vec![],
0,
);
op.deref_mut(ctx)
.attributes
.set(Self::ATTR_KEY_PREDICATE, pred);
ICmpOp { op }
}
}
impl Verify for ICmpOp {
fn verify(&self, ctx: &Context) -> Result<()> {
let loc = self.get_operation().deref(ctx).loc();
let op = &*self.op.deref(ctx);
if op
.attributes
.get::<ICmpPredicateAttr>(Self::ATTR_KEY_PREDICATE)
.is_none()
{
verify_err!(op.loc(), ICmpOpVerifyErr::PredAttrErr)?
}
let res_ty: TypePtr<IntegerType> =
TypePtr::from_ptr(self.result_type(ctx), ctx).map_err(|mut err| {
err.set_loc(loc.clone());
err
})?;
if res_ty.deref(ctx).get_width() != 1 {
return verify_err!(loc, ICmpOpVerifyErr::ResultNotBool);
}
let opd_ty = self.operand_type(ctx).deref(ctx);
if !(opd_ty.is::<IntegerType>() || opd_ty.is::<PointerType>()) {
return verify_err!(loc, ICmpOpVerifyErr::IncorrectOperandsType);
}
Ok(())
}
}
impl_canonical_syntax!(ICmpOp);
impl_op_interface!(SameOperandsType for ICmpOp {});
impl_op_interface!(OneResultInterface for ICmpOp {});
#[derive(Error, Debug)]
pub enum AllocaOpVerifyErr {
#[error("Operand must be a signless integer")]
OperandType,
#[error("Missing or incorrect type of attribute for element type")]
ElemTypeAttr,
}
#[def_op("llvm.alloca")]
pub struct AllocaOp {}
impl_canonical_syntax!(AllocaOp);
impl Verify for AllocaOp {
fn verify(&self, ctx: &Context) -> Result<()> {
let loc = self.get_operation().deref(ctx).loc();
if !self.operand_type(ctx).deref(ctx).is::<IntegerType>() {
return verify_err!(loc, AllocaOpVerifyErr::OperandType);
}
let op = &*self.op.deref(ctx);
if op
.attributes
.get::<TypeAttr>(Self::ATTR_KEY_ELEM_TYPE)
.is_none()
{
verify_err!(op.loc(), AllocaOpVerifyErr::ElemTypeAttr)?
}
Ok(())
}
}
impl_op_interface!(OneResultInterface for AllocaOp {});
impl_op_interface!(OneOpdInterface for AllocaOp {});
impl_op_interface!(PointerTypeResult for AllocaOp {
fn result_pointee_type(&self,ctx: &Context) -> Ptr<TypeObj> {
self.op
.deref(ctx)
.attributes
.get::<TypeAttr>(Self::ATTR_KEY_ELEM_TYPE)
.expect("AllocaOp missing or incorrect type for elem_type attribute")
.get_type()
}
});
impl AllocaOp {
pub const ATTR_KEY_ELEM_TYPE: &'static str = "llvm.element_type";
pub fn new(ctx: &mut Context, elem_type: Ptr<TypeObj>, size: Value) -> Self {
let ptr_ty = PointerType::get(ctx).into();
let op = Operation::new(
ctx,
Self::get_opid_static(),
vec![ptr_ty],
vec![size],
vec![],
0,
);
op.deref_mut(ctx)
.attributes
.set(Self::ATTR_KEY_ELEM_TYPE, TypeAttr::new(elem_type));
AllocaOp { op }
}
}
#[def_op("llvm.bitcast")]
pub struct BitcastOp {}
impl_canonical_syntax!(BitcastOp);
impl_verify_succ!(BitcastOp);
impl_op_interface!(OneResultInterface for BitcastOp {});
impl_op_interface!(OneOpdInterface for BitcastOp {});
#[def_op("llvm.br")]
pub struct BrOp {}
impl_canonical_syntax!(BrOp);
impl_verify_succ!(BrOp);
impl_op_interface!(IsTerminatorInterface for BrOp {});
impl_op_interface!(BranchOpInterface for BrOp {
fn successor_operands(&self, ctx: &Context, succ_idx: usize) -> Vec<Value> {
assert!(succ_idx == 0, "BrOp has exactly one successor");
self.get_operation().deref(ctx).operands().collect()
}
});
impl BrOp {
pub fn new(ctx: &mut Context, dest: Ptr<BasicBlock>, dest_opds: Vec<Value>) -> Self {
BrOp {
op: Operation::new(
ctx,
Self::get_opid_static(),
vec![],
dest_opds,
vec![dest],
0,
),
}
}
}
#[def_op("llvm.cond_br")]
pub struct CondBrOp {}
impl CondBrOp {
pub fn new(
ctx: &mut Context,
condition: Value,
true_dest: Ptr<BasicBlock>,
mut true_dest_opds: Vec<Value>,
false_dest: Ptr<BasicBlock>,
mut false_dest_opds: Vec<Value>,
) -> Self {
let mut operands = vec![condition];
operands.append(&mut true_dest_opds);
operands.append(&mut false_dest_opds);
CondBrOp {
op: Operation::new(
ctx,
Self::get_opid_static(),
vec![],
operands,
vec![true_dest, false_dest],
0,
),
}
}
}
impl_canonical_syntax!(CondBrOp);
impl_verify_succ!(CondBrOp);
impl_op_interface!(IsTerminatorInterface for CondBrOp {});
impl_op_interface!(BranchOpInterface for CondBrOp {
fn successor_operands(&self, ctx: &Context, succ_idx: usize) -> Vec<Value> {
assert!(succ_idx == 0 || succ_idx == 1, "CondBrOp has exactly two successors");
let num_opds_succ0 = self.get_operation().deref(ctx).get_successor(0).unwrap().deref(ctx).get_num_arguments();
if succ_idx == 0 {
self.get_operation().deref(ctx).operands().skip(1).take(num_opds_succ0).collect()
} else {
self.get_operation().deref(ctx).operands().skip(1 + num_opds_succ0).collect()
}
}
});
#[derive(Clone)]
pub enum GepIndex {
Constant(u32),
Value(Value),
}
#[derive(Error, Debug)]
pub enum GetElementPtrOpErr {
#[error("GetElementPtrOp has no or incorrect indices attribute")]
IndicesAttrErr,
#[error("The indices on this GEP are invalid for its source element type")]
IndicesErr,
}
#[def_op("llvm.gep")]
pub struct GetElementPtrOp {}
impl_canonical_syntax!(GetElementPtrOp);
impl_op_interface!(OneResultInterface for GetElementPtrOp {});
impl_op_interface!(PointerTypeResult for GetElementPtrOp {
fn result_pointee_type(&self, ctx: &Context) -> Ptr<TypeObj> {
Self::indexed_type(ctx, self.src_elem_type(ctx), &self.indices(ctx)).expect("Invalid indices for GEP")
}
});
impl Verify for GetElementPtrOp {
fn verify(&self, ctx: &Context) -> Result<()> {
let op = &*self.op.deref(ctx);
if op
.attributes
.get::<GepIndicesAttr>(Self::ATTR_KEY_INDICES)
.is_none()
{
verify_err!(op.loc(), GetElementPtrOpErr::IndicesAttrErr)?
}
if let Err(Error { kind: _, err, loc }) =
Self::indexed_type(ctx, self.src_elem_type(ctx), &self.indices(ctx))
{
return Err(Error {
kind: ErrorKind::VerificationFailed,
err,
loc,
});
}
Ok(())
}
}
impl GetElementPtrOp {
pub const ATTR_KEY_INDICES: &'static str = "llvm.gep_indices";
pub const ATTR_KEY_SRC_ELEM_TYPE: &'static str = "llvm.gep_src_elem_type";
pub fn new(
ctx: &mut Context,
base: Value,
indices: Vec<GepIndex>,
elem_type: TypeAttr,
) -> Self {
let mut attr: Vec<GepIndexAttr> = Vec::new();
let mut opds: Vec<Value> = vec![base];
for idx in indices {
match idx {
GepIndex::Constant(c) => {
attr.push(GepIndexAttr::Constant(c));
}
GepIndex::Value(v) => {
attr.push(GepIndexAttr::OperandIdx(opds.push_back(v)));
}
}
}
let op = Operation::new(ctx, Self::get_opid_static(), vec![], opds, vec![], 0);
op.deref_mut(ctx)
.attributes
.set(Self::ATTR_KEY_INDICES, GepIndicesAttr(attr));
op.deref_mut(ctx)
.attributes
.set(Self::ATTR_KEY_SRC_ELEM_TYPE, elem_type);
GetElementPtrOp { op }
}
pub fn src_elem_type(&self, ctx: &Context) -> Ptr<TypeObj> {
self.op
.deref(ctx)
.attributes
.get::<TypeAttr>(Self::ATTR_KEY_SRC_ELEM_TYPE)
.expect("GetElementPtrOp missing or has incorrect src_elem_type attribute type")
.get_type()
}
pub fn src_ptr(&self, ctx: &Context) -> Value {
self.get_operation().deref(ctx).get_operand(0).unwrap()
}
pub fn indices(&self, ctx: &Context) -> Vec<GepIndex> {
let op = &*self.op.deref(ctx);
op.attributes
.get::<GepIndicesAttr>(Self::ATTR_KEY_INDICES)
.unwrap()
.0
.iter()
.map(|index| match index {
GepIndexAttr::Constant(c) => GepIndex::Constant(*c),
GepIndexAttr::OperandIdx(i) => GepIndex::Value(op.get_operand(*i).unwrap()),
})
.collect()
}
pub fn indexed_type(
ctx: &Context,
src_elem_type: Ptr<TypeObj>,
indices: &[GepIndex],
) -> Result<Ptr<TypeObj>> {
fn indexed_type_inner(
ctx: &Context,
src_elem_type: Ptr<TypeObj>,
mut idx_itr: impl Iterator<Item = GepIndex>,
) -> Result<Ptr<TypeObj>> {
let Some(idx) = idx_itr.next() else {
return Ok(src_elem_type);
};
let src_elem_type = &*src_elem_type.deref(ctx);
if let Some(st) = src_elem_type.downcast_ref::<StructType>() {
let GepIndex::Constant(i) = idx else {
return arg_err_noloc!(GetElementPtrOpErr::IndicesErr);
};
if i as usize >= st.num_fields() {
return arg_err_noloc!(GetElementPtrOpErr::IndicesErr);
}
indexed_type_inner(ctx, st.field_type(i as usize), idx_itr)
} else if let Some(at) = src_elem_type.downcast_ref::<ArrayType>() {
indexed_type_inner(ctx, at.elem_type(), idx_itr)
} else {
arg_err_noloc!(GetElementPtrOpErr::IndicesErr)
}
}
indexed_type_inner(ctx, src_elem_type, indices.iter().skip(1).cloned())
}
}
#[derive(Error, Debug)]
pub enum LoadOpVerifyErr {
#[error("Load operand must be a pointer")]
OperandTypeErr,
}
#[def_op("llvm.load")]
pub struct LoadOp {}
impl LoadOp {
pub fn new(ctx: &mut Context, ptr: Value, res_ty: Ptr<TypeObj>) -> Self {
LoadOp {
op: Operation::new(
ctx,
Self::get_opid_static(),
vec![res_ty],
vec![ptr],
vec![],
0,
),
}
}
}
impl_canonical_syntax!(LoadOp);
impl Verify for LoadOp {
fn verify(&self, ctx: &Context) -> Result<()> {
let loc = self.get_operation().deref(ctx).loc();
if !self.operand_type(ctx).deref(ctx).is::<PointerType>() {
return verify_err!(loc, LoadOpVerifyErr::OperandTypeErr);
}
Ok(())
}
}
impl_op_interface!(OneResultInterface for LoadOp {});
impl_op_interface!(OneOpdInterface for LoadOp {});
#[derive(Error, Debug)]
pub enum StoreOpVerifyErr {
#[error("Store operand must have two operands")]
NumOpdsErr,
#[error("Store operand must have a pointer as its second argument")]
AddrOpdTypeErr,
}
#[def_op("llvm.store")]
pub struct StoreOp {}
impl StoreOp {
pub fn new(ctx: &mut Context, value: Value, ptr: Value) -> Self {
StoreOp {
op: Operation::new(
ctx,
Self::get_opid_static(),
vec![],
vec![value, ptr],
vec![],
0,
),
}
}
pub fn value_opd(&self, ctx: &Context) -> Value {
self.op.deref(ctx).get_operand(0).unwrap()
}
pub fn address_opd(&self, ctx: &Context) -> Value {
self.op.deref(ctx).get_operand(1).unwrap()
}
}
impl_canonical_syntax!(StoreOp);
impl Verify for StoreOp {
fn verify(&self, ctx: &Context) -> Result<()> {
let loc = self.get_operation().deref(ctx).loc();
let op = &*self.op.deref(ctx);
if op.get_num_operands() != 2 {
return verify_err!(loc, StoreOpVerifyErr::NumOpdsErr);
}
use pliron::r#type::Typed;
if !op
.get_operand(1)
.unwrap()
.get_type(ctx)
.deref(ctx)
.is::<PointerType>()
{
return verify_err!(loc, StoreOpVerifyErr::AddrOpdTypeErr);
}
Ok(())
}
}
impl_op_interface!(ZeroResultInterface for LoadOp {});
pub fn register(ctx: &mut Context, dialect: &mut Dialect) {
AddOp::register(ctx, dialect, AddOp::parser_fn);
SubOp::register(ctx, dialect, SubOp::parser_fn);
MulOp::register(ctx, dialect, MulOp::parser_fn);
ShlOp::register(ctx, dialect, ShlOp::parser_fn);
UDivOp::register(ctx, dialect, UDivOp::parser_fn);
SDivOp::register(ctx, dialect, SDivOp::parser_fn);
URemOp::register(ctx, dialect, URemOp::parser_fn);
SRemOp::register(ctx, dialect, SRemOp::parser_fn);
AndOp::register(ctx, dialect, AndOp::parser_fn);
OrOp::register(ctx, dialect, OrOp::parser_fn);
XorOp::register(ctx, dialect, XorOp::parser_fn);
LShrOp::register(ctx, dialect, LShrOp::parser_fn);
AShrOp::register(ctx, dialect, AShrOp::parser_fn);
ICmpOp::register(ctx, dialect, ICmpOp::parser_fn);
AllocaOp::register(ctx, dialect, AllocaOp::parser_fn);
BitcastOp::register(ctx, dialect, BitcastOp::parser_fn);
BrOp::register(ctx, dialect, BrOp::parser_fn);
CondBrOp::register(ctx, dialect, CondBrOp::parser_fn);
GetElementPtrOp::register(ctx, dialect, GetElementPtrOp::parser_fn);
LoadOp::register(ctx, dialect, LoadOp::parser_fn);
StoreOp::register(ctx, dialect, StoreOp::parser_fn);
ReturnOp::register(ctx, dialect, ReturnOp::parser_fn);
}