use crate::sema::{ExternalSig, ReturnKind, Sym, Term, TermEnv, TermId, Type, TypeEnv, TypeId};
use crate::serialize::{Block, ControlFlow, EvalStep, MatchArm};
use crate::trie_again::{Binding, BindingId, Constraint, RuleSet};
use crate::StableSet;
use std::fmt::Write;
#[derive(Clone, Debug, Default)]
pub struct CodegenOptions {
pub exclude_global_allow_pragmas: bool,
}
pub fn codegen(
typeenv: &TypeEnv,
termenv: &TermEnv,
terms: &[(TermId, RuleSet)],
options: &CodegenOptions,
) -> String {
Codegen::compile(typeenv, termenv, terms).generate_rust(options)
}
#[derive(Clone, Debug)]
struct Codegen<'a> {
typeenv: &'a TypeEnv,
termenv: &'a TermEnv,
terms: &'a [(TermId, RuleSet)],
}
struct BodyContext<'a, W> {
out: &'a mut W,
ruleset: &'a RuleSet,
indent: String,
is_ref: StableSet<BindingId>,
is_bound: StableSet<BindingId>,
}
impl<'a, W: Write> BodyContext<'a, W> {
fn new(out: &'a mut W, ruleset: &'a RuleSet) -> Self {
Self {
out,
ruleset,
indent: Default::default(),
is_ref: Default::default(),
is_bound: Default::default(),
}
}
fn enter_scope(&mut self) -> StableSet<BindingId> {
let new = self.is_bound.clone();
std::mem::replace(&mut self.is_bound, new)
}
fn begin_block(&mut self) -> std::fmt::Result {
self.indent.push_str(" ");
writeln!(self.out, " {{")
}
fn end_block(&mut self, scope: StableSet<BindingId>) -> std::fmt::Result {
self.is_bound = scope;
self.end_block_without_newline()?;
writeln!(self.out)
}
fn end_block_without_newline(&mut self) -> std::fmt::Result {
self.indent.truncate(self.indent.len() - 4);
write!(self.out, "{}}}", &self.indent)
}
fn set_ref(&mut self, binding: BindingId, is_ref: bool) {
if is_ref {
self.is_ref.insert(binding);
} else {
debug_assert!(!self.is_ref.contains(&binding));
}
}
}
impl<'a> Codegen<'a> {
fn compile(
typeenv: &'a TypeEnv,
termenv: &'a TermEnv,
terms: &'a [(TermId, RuleSet)],
) -> Codegen<'a> {
Codegen {
typeenv,
termenv,
terms,
}
}
fn generate_rust(&self, options: &CodegenOptions) -> String {
let mut code = String::new();
self.generate_header(&mut code, options);
self.generate_ctx_trait(&mut code);
self.generate_internal_types(&mut code);
self.generate_internal_term_constructors(&mut code).unwrap();
code
}
fn generate_header(&self, code: &mut String, options: &CodegenOptions) {
writeln!(code, "// GENERATED BY ISLE. DO NOT EDIT!").unwrap();
writeln!(code, "//").unwrap();
writeln!(
code,
"// Generated automatically from the instruction-selection DSL code in:",
)
.unwrap();
for file in &self.typeenv.filenames {
writeln!(code, "// - {}", file).unwrap();
}
if !options.exclude_global_allow_pragmas {
writeln!(
code,
"\n#![allow(dead_code, unreachable_code, unreachable_patterns)]"
)
.unwrap();
writeln!(
code,
"#![allow(unused_imports, unused_variables, non_snake_case, unused_mut)]"
)
.unwrap();
writeln!(
code,
"#![allow(irrefutable_let_patterns, unused_assignments, non_camel_case_types)]"
)
.unwrap();
}
writeln!(code, "\nuse super::*; // Pulls in all external types.").unwrap();
writeln!(code, "use std::marker::PhantomData;").unwrap();
}
fn generate_trait_sig(&self, code: &mut String, indent: &str, sig: &ExternalSig) {
let ret_tuple = format!(
"{open_paren}{rets}{close_paren}",
open_paren = if sig.ret_tys.len() != 1 { "(" } else { "" },
rets = sig
.ret_tys
.iter()
.map(|&ty| self.type_name(ty, false))
.collect::<Vec<_>>()
.join(", "),
close_paren = if sig.ret_tys.len() != 1 { ")" } else { "" },
);
if sig.ret_kind == ReturnKind::Iterator {
writeln!(
code,
"{indent}type {name}_iter: ContextIter<Context = Self, Output = {output}>;",
indent = indent,
name = sig.func_name,
output = ret_tuple,
)
.unwrap();
}
let ret_ty = match sig.ret_kind {
ReturnKind::Plain => ret_tuple,
ReturnKind::Option => format!("Option<{}>", ret_tuple),
ReturnKind::Iterator => format!("Self::{}_iter", sig.func_name),
};
writeln!(
code,
"{indent}fn {name}(&mut self, {params}) -> {ret_ty};",
indent = indent,
name = sig.func_name,
params = sig
.param_tys
.iter()
.enumerate()
.map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, true)))
.collect::<Vec<_>>()
.join(", "),
ret_ty = ret_ty,
)
.unwrap();
}
fn generate_ctx_trait(&self, code: &mut String) {
writeln!(code, "").unwrap();
writeln!(
code,
"/// Context during lowering: an implementation of this trait"
)
.unwrap();
writeln!(
code,
"/// must be provided with all external constructors and extractors."
)
.unwrap();
writeln!(
code,
"/// A mutable borrow is passed along through all lowering logic."
)
.unwrap();
writeln!(code, "pub trait Context {{").unwrap();
for term in &self.termenv.terms {
if term.has_external_extractor() {
let ext_sig = term.extractor_sig(self.typeenv).unwrap();
self.generate_trait_sig(code, " ", &ext_sig);
}
if term.has_external_constructor() {
let ext_sig = term.constructor_sig(self.typeenv).unwrap();
self.generate_trait_sig(code, " ", &ext_sig);
}
}
writeln!(code, "}}").unwrap();
writeln!(
code,
r#"
pub trait ContextIter {{
type Context;
type Output;
fn next(&mut self, ctx: &mut Self::Context) -> Option<Self::Output>;
}}
pub struct ContextIterWrapper<Item, I: Iterator < Item = Item>, C: Context> {{
iter: I,
_ctx: PhantomData<C>,
}}
impl<Item, I: Iterator<Item = Item>, C: Context> From<I> for ContextIterWrapper<Item, I, C> {{
fn from(iter: I) -> Self {{
Self {{ iter, _ctx: PhantomData }}
}}
}}
impl<Item, I: Iterator<Item = Item>, C: Context> ContextIter for ContextIterWrapper<Item, I, C> {{
type Context = C;
type Output = Item;
fn next(&mut self, _ctx: &mut Self::Context) -> Option<Self::Output> {{
self.iter.next()
}}
}}
"#,
)
.unwrap();
}
fn generate_internal_types(&self, code: &mut String) {
for ty in &self.typeenv.types {
match ty {
&Type::Enum {
name,
is_extern,
is_nodebug,
ref variants,
pos,
..
} if !is_extern => {
let name = &self.typeenv.syms[name.index()];
writeln!(
code,
"\n/// Internal type {}: defined at {}.",
name,
pos.pretty_print_line(&self.typeenv.filenames[..])
)
.unwrap();
let debug_derive = if is_nodebug { "" } else { ", Debug" };
if variants.iter().all(|v| v.fields.is_empty()) {
writeln!(
code,
"#[derive(Copy, Clone, PartialEq, Eq{})]",
debug_derive
)
.unwrap();
} else {
writeln!(code, "#[derive(Clone{})]", debug_derive).unwrap();
}
writeln!(code, "pub enum {} {{", name).unwrap();
for variant in variants {
let name = &self.typeenv.syms[variant.name.index()];
if variant.fields.is_empty() {
writeln!(code, " {},", name).unwrap();
} else {
writeln!(code, " {} {{", name).unwrap();
for field in &variant.fields {
let name = &self.typeenv.syms[field.name.index()];
let ty_name =
self.typeenv.types[field.ty.index()].name(&self.typeenv);
writeln!(code, " {}: {},", name, ty_name).unwrap();
}
writeln!(code, " }},").unwrap();
}
}
writeln!(code, "}}").unwrap();
}
_ => {}
}
}
}
fn type_name(&self, typeid: TypeId, by_ref: bool) -> String {
match &self.typeenv.types[typeid.index()] {
&Type::Primitive(_, sym, _) => self.typeenv.syms[sym.index()].clone(),
&Type::Enum { name, .. } => {
let r = if by_ref { "&" } else { "" };
format!("{}{}", r, self.typeenv.syms[name.index()])
}
}
}
fn generate_internal_term_constructors(&self, code: &mut String) -> std::fmt::Result {
for &(termid, ref ruleset) in self.terms.iter() {
let root = crate::serialize::serialize(ruleset);
let mut ctx = BodyContext::new(code, ruleset);
let termdata = &self.termenv.terms[termid.index()];
let term_name = &self.typeenv.syms[termdata.name.index()];
writeln!(ctx.out)?;
writeln!(
ctx.out,
"{}// Generated as internal constructor for term {}.",
&ctx.indent, term_name,
)?;
let sig = termdata.constructor_sig(self.typeenv).unwrap();
writeln!(
ctx.out,
"{}pub fn {}<C: Context>(",
&ctx.indent, sig.func_name
)?;
writeln!(ctx.out, "{} ctx: &mut C,", &ctx.indent)?;
for (i, &ty) in sig.param_tys.iter().enumerate() {
let (is_ref, sym) = self.ty(ty);
write!(ctx.out, "{} arg{}: ", &ctx.indent, i)?;
write!(
ctx.out,
"{}{}",
if is_ref { "&" } else { "" },
&self.typeenv.syms[sym.index()]
)?;
if let Some(binding) = ctx.ruleset.find_binding(&Binding::Argument {
index: i.try_into().unwrap(),
}) {
ctx.set_ref(binding, is_ref);
}
writeln!(ctx.out, ",")?;
}
write!(ctx.out, "{}) -> ", &ctx.indent)?;
let (_, ret) = self.ty(sig.ret_tys[0]);
let ret = &self.typeenv.syms[ret.index()];
match sig.ret_kind {
ReturnKind::Iterator => {
write!(ctx.out, "impl ContextIter<Context = C, Output = {}>", ret)?
}
ReturnKind::Option => write!(ctx.out, "Option<{}>", ret)?,
ReturnKind::Plain => write!(ctx.out, "{}", ret)?,
};
let scope = ctx.enter_scope();
ctx.begin_block()?;
if sig.ret_kind == ReturnKind::Iterator {
writeln!(
ctx.out,
"{}let mut returns = ConstructorVec::new();",
&ctx.indent
)?;
}
self.emit_block(&mut ctx, &root, sig.ret_kind)?;
match (sig.ret_kind, root.steps.last()) {
(ReturnKind::Iterator, _) => {
writeln!(
ctx.out,
"{}return ContextIterWrapper::from(returns.into_iter());",
&ctx.indent
)?;
}
(_, Some(EvalStep { check: ControlFlow::Return { .. }, .. })) => {
}
(ReturnKind::Option, _) => {
writeln!(ctx.out, "{}None", &ctx.indent)?
}
(ReturnKind::Plain, _) => {
writeln!(ctx.out,
"unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})",
term_name,
termdata
.decl_pos
.pretty_print_line(&self.typeenv.filenames[..])
)?
}
}
ctx.end_block(scope)?;
}
Ok(())
}
fn ty(&self, typeid: TypeId) -> (bool, Sym) {
match &self.typeenv.types[typeid.index()] {
&Type::Primitive(_, sym, _) => (false, sym),
&Type::Enum { name, .. } => (true, name),
}
}
fn emit_block<W: Write>(
&self,
ctx: &mut BodyContext<W>,
block: &Block,
ret_kind: ReturnKind,
) -> std::fmt::Result {
if !matches!(ret_kind, ReturnKind::Iterator) {
assert!(!block
.steps
.iter()
.any(|c| matches!(c.check, ControlFlow::Loop { .. })));
if let Some(result_pos) = block
.steps
.iter()
.position(|c| matches!(c.check, ControlFlow::Return { .. }))
{
assert_eq!(block.steps.len() - 1, result_pos);
}
}
for case in block.steps.iter() {
for &expr in case.bind_order.iter() {
write!(ctx.out, "{}let v{} = ", &ctx.indent, expr.index())?;
self.emit_expr(ctx, expr)?;
writeln!(ctx.out, ";")?;
ctx.is_bound.insert(expr);
}
match &case.check {
ControlFlow::Match { source, arms } if arms.len() == 1 => {
let arm = &arms[0];
let scope = ctx.enter_scope();
match arm.constraint {
Constraint::ConstInt { .. } | Constraint::ConstPrim { .. } => {
write!(ctx.out, "{}if ", &ctx.indent)?;
self.emit_expr(ctx, *source)?;
write!(ctx.out, " == ")?;
self.emit_constraint(ctx, *source, arm)?;
}
Constraint::Variant { .. } | Constraint::Some => {
write!(ctx.out, "{}if let ", &ctx.indent)?;
self.emit_constraint(ctx, *source, arm)?;
write!(ctx.out, " = ")?;
self.emit_source(ctx, *source, arm.constraint)?;
}
}
ctx.begin_block()?;
self.emit_block(ctx, &arm.body, ret_kind)?;
ctx.end_block(scope)?;
}
ControlFlow::Match { source, arms } => {
let scope = ctx.enter_scope();
write!(ctx.out, "{}match ", &ctx.indent)?;
self.emit_source(ctx, *source, arms[0].constraint)?;
ctx.begin_block()?;
for arm in arms.iter() {
let scope = ctx.enter_scope();
write!(ctx.out, "{}", &ctx.indent)?;
self.emit_constraint(ctx, *source, arm)?;
write!(ctx.out, " =>")?;
ctx.begin_block()?;
self.emit_block(ctx, &arm.body, ret_kind)?;
ctx.end_block(scope)?;
}
writeln!(ctx.out, "{}_ => {{}}", &ctx.indent)?;
ctx.end_block(scope)?;
}
ControlFlow::Equal { a, b, body } => {
let scope = ctx.enter_scope();
write!(ctx.out, "{}if ", &ctx.indent)?;
self.emit_expr(ctx, *a)?;
write!(ctx.out, " == ")?;
self.emit_expr(ctx, *b)?;
ctx.begin_block()?;
self.emit_block(ctx, body, ret_kind)?;
ctx.end_block(scope)?;
}
ControlFlow::Loop { result, body } => {
let source = match &ctx.ruleset.bindings[result.index()] {
Binding::Iterator { source } => source,
_ => unreachable!("Loop from a non-Iterator"),
};
let scope = ctx.enter_scope();
write!(ctx.out, "{}let mut v{} = ", &ctx.indent, source.index())?;
self.emit_expr(ctx, *source)?;
writeln!(ctx.out, ";")?;
write!(
ctx.out,
"{}while let Some(v{}) = v{}.next(ctx)",
&ctx.indent,
result.index(),
source.index()
)?;
ctx.is_bound.insert(*result);
ctx.begin_block()?;
self.emit_block(ctx, body, ret_kind)?;
ctx.end_block(scope)?;
}
&ControlFlow::Return { pos, result } => {
writeln!(
ctx.out,
"{}// Rule at {}.",
&ctx.indent,
pos.pretty_print_line(&self.typeenv.filenames)
)?;
write!(ctx.out, "{}", &ctx.indent)?;
match ret_kind {
ReturnKind::Plain => write!(ctx.out, "return ")?,
ReturnKind::Option => write!(ctx.out, "return Some(")?,
ReturnKind::Iterator => write!(ctx.out, "returns.push(")?,
}
self.emit_expr(ctx, result)?;
if ctx.is_ref.contains(&result) {
write!(ctx.out, ".clone()")?;
}
match ret_kind {
ReturnKind::Plain => writeln!(ctx.out, ";")?,
ReturnKind::Option | ReturnKind::Iterator => writeln!(ctx.out, ");")?,
}
}
}
}
Ok(())
}
fn emit_expr<W: Write>(&self, ctx: &mut BodyContext<W>, result: BindingId) -> std::fmt::Result {
if ctx.is_bound.contains(&result) {
return write!(ctx.out, "v{}", result.index());
}
let binding = &ctx.ruleset.bindings[result.index()];
let mut call =
|term: TermId,
parameters: &[BindingId],
get_sig: fn(&Term, &TypeEnv) -> Option<ExternalSig>| {
let termdata = &self.termenv.terms[term.index()];
let sig = get_sig(termdata, self.typeenv).unwrap();
if let &[ret_ty] = &sig.ret_tys[..] {
let (is_ref, _) = self.ty(ret_ty);
if is_ref {
ctx.set_ref(result, true);
write!(ctx.out, "&")?;
}
}
write!(ctx.out, "{}(ctx", sig.full_name)?;
debug_assert_eq!(parameters.len(), sig.param_tys.len());
for (¶meter, &arg_ty) in parameters.iter().zip(sig.param_tys.iter()) {
let (is_ref, _) = self.ty(arg_ty);
write!(ctx.out, ", ")?;
let (before, after) = match (is_ref, ctx.is_ref.contains(¶meter)) {
(false, true) => ("", ".clone()"),
(true, false) => ("&", ""),
_ => ("", ""),
};
write!(ctx.out, "{}", before)?;
self.emit_expr(ctx, parameter)?;
write!(ctx.out, "{}", after)?;
}
write!(ctx.out, ")")
};
match binding {
&Binding::ConstInt { val, ty } => self.emit_int(ctx, val, ty),
Binding::ConstPrim { val } => write!(ctx.out, "{}", &self.typeenv.syms[val.index()]),
Binding::Argument { index } => write!(ctx.out, "arg{}", index.index()),
Binding::Extractor { term, parameter } => {
call(*term, std::slice::from_ref(parameter), Term::extractor_sig)
}
Binding::Constructor {
term, parameters, ..
} => call(*term, ¶meters[..], Term::constructor_sig),
Binding::MakeVariant {
ty,
variant,
fields,
} => {
let (name, variants) = match &self.typeenv.types[ty.index()] {
Type::Enum { name, variants, .. } => (name, variants),
_ => unreachable!("MakeVariant with primitive type"),
};
let variant = &variants[variant.index()];
write!(
ctx.out,
"{}::{}",
&self.typeenv.syms[name.index()],
&self.typeenv.syms[variant.name.index()]
)?;
if !fields.is_empty() {
ctx.begin_block()?;
for (field, value) in variant.fields.iter().zip(fields.iter()) {
write!(
ctx.out,
"{}{}: ",
&ctx.indent,
&self.typeenv.syms[field.name.index()],
)?;
self.emit_expr(ctx, *value)?;
if ctx.is_ref.contains(&value) {
write!(ctx.out, ".clone()")?;
}
writeln!(ctx.out, ",")?;
}
ctx.end_block_without_newline()?;
}
Ok(())
}
&Binding::MatchSome { source } => {
self.emit_expr(ctx, source)?;
write!(ctx.out, "?")
}
&Binding::MatchTuple { source, field } => {
self.emit_expr(ctx, source)?;
write!(ctx.out, ".{}", field.index())
}
&Binding::MatchVariant { source, field, .. } => {
self.emit_expr(ctx, source)?;
write!(ctx.out, ".{} /*FIXME*/", field.index())
}
&Binding::Iterator { source } => {
self.emit_expr(ctx, source)?;
write!(ctx.out, ".next() /*FIXME*/")
}
}
}
fn emit_source<W: Write>(
&self,
ctx: &mut BodyContext<W>,
source: BindingId,
constraint: Constraint,
) -> std::fmt::Result {
if let Constraint::Variant { .. } = constraint {
if !ctx.is_ref.contains(&source) {
write!(ctx.out, "&")?;
}
}
self.emit_expr(ctx, source)
}
fn emit_constraint<W: Write>(
&self,
ctx: &mut BodyContext<W>,
source: BindingId,
arm: &MatchArm,
) -> std::fmt::Result {
let MatchArm {
constraint,
bindings,
..
} = arm;
for binding in bindings.iter() {
if let &Some(binding) = binding {
ctx.is_bound.insert(binding);
}
}
match *constraint {
Constraint::ConstInt { val, ty } => self.emit_int(ctx, val, ty),
Constraint::ConstPrim { val } => {
write!(ctx.out, "{}", &self.typeenv.syms[val.index()])
}
Constraint::Variant { ty, variant, .. } => {
let (name, variants) = match &self.typeenv.types[ty.index()] {
Type::Enum { name, variants, .. } => (name, variants),
_ => unreachable!("Variant constraint on primitive type"),
};
let variant = &variants[variant.index()];
write!(
ctx.out,
"&{}::{}",
&self.typeenv.syms[name.index()],
&self.typeenv.syms[variant.name.index()]
)?;
if !bindings.is_empty() {
ctx.begin_block()?;
let mut skipped_some = false;
for (&binding, field) in bindings.iter().zip(variant.fields.iter()) {
if let Some(binding) = binding {
write!(
ctx.out,
"{}{}: ",
&ctx.indent,
&self.typeenv.syms[field.name.index()]
)?;
let (is_ref, _) = self.ty(field.ty);
if is_ref {
ctx.set_ref(binding, true);
write!(ctx.out, "ref ")?;
}
writeln!(ctx.out, "v{},", binding.index())?;
} else {
skipped_some = true;
}
}
if skipped_some {
writeln!(ctx.out, "{}..", &ctx.indent)?;
}
ctx.end_block_without_newline()?;
}
Ok(())
}
Constraint::Some => {
write!(ctx.out, "Some(")?;
if let Some(binding) = bindings[0] {
ctx.set_ref(binding, ctx.is_ref.contains(&source));
write!(ctx.out, "v{}", binding.index())?;
} else {
write!(ctx.out, "_")?;
}
write!(ctx.out, ")")
}
}
}
fn emit_int<W: Write>(
&self,
ctx: &mut BodyContext<W>,
val: i128,
ty: TypeId,
) -> Result<(), std::fmt::Error> {
if val < 0
&& self.typeenv.types[ty.index()]
.name(self.typeenv)
.starts_with('i')
{
write!(ctx.out, "-{:#X}", -val)
} else {
write!(ctx.out, "{:#X}", val)
}
}
}