use std::collections::{HashMap, HashSet};
use std::mem;
use rustc::hir::def_id::DefId;
use syntax::ast::*;
use syntax::ptr::P;
use syntax::symbol::Symbol;
use crate::ast_manip::{FlatMapNodes, MutVisitNodes, fold_modules};
use crate::ast_manip::fn_edit::mut_visit_fns;
use crate::command::{CommandState, Registry};
use crate::driver::{parse_expr};
use crate::matcher::{Bindings, BindingType, MatchCtxt, Subst, mut_visit_match_with};
use crate::path_edit::fold_resolved_paths;
use crate::transform::Transform;
use c2rust_ast_builder::{mk, IntoSymbol};
use crate::util::dataflow;
use crate::RefactorCtxt;
pub struct CollectToStruct {
pub struct_name: String,
pub instance_name: String,
}
impl Transform for CollectToStruct {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
let mut old_statics = HashMap::new();
fold_modules(krate, |curs| {
let mut matches = Vec::new();
let mut insert_point = None;
while let Some((ident, ty, init)) = curs.advance_until_match(
|i| match_or!([i.node] ItemKind::Static(ref ty, _, ref init) =>
Some((i.ident, ty.clone(), init.clone())); None)) {
if !st.marked(curs.next().id, "target") {
curs.advance();
continue;
}
info!("found {:?}: {:?}", ident, ty);
old_statics.insert(ident.name,
cx.node_def_id(curs.next().id));
if insert_point.is_none() {
insert_point = Some(curs.mark());
}
curs.remove();
let mut bnd = Bindings::new();
bnd.add("__x", ident);
bnd.add("__t", ty);
bnd.add("__init", init);
matches.push(bnd);
}
info!("collected {} matching statics", matches.len());
if let Some(insert_point) = insert_point {
curs.seek(insert_point);
curs.insert(build_collected_struct(&self.struct_name, &matches));
curs.insert(build_struct_instance(&self.struct_name,
&self.instance_name,
&matches));
}
});
let ident_pat = parse_expr(cx.session(), "__x");
let ident_repl = parse_expr(cx.session(), "__s.__x");
let mut init_mcx = MatchCtxt::new(st, cx);
init_mcx.set_type("__x", BindingType::Ident);
init_mcx.bindings.add(
"__s", Ident::with_empty_ctxt((&self.instance_name as &str).into_symbol()));
mut_visit_match_with(init_mcx, ident_pat, krate, |orig, mcx| {
let static_id = match old_statics.get(&mcx.bindings.get::<_, Ident>("__x").unwrap().name) {
Some(&x) => x,
None => return,
};
if cx.resolve_expr(&orig) != static_id {
return;
}
*orig = ident_repl.clone().subst(st, cx, &mcx.bindings)
});
}
}
fn build_collected_struct(name: &str, matches: &[Bindings]) -> P<Item> {
let fields = matches.iter().map(
|bnd| mk().struct_field(bnd.get::<_, Ident>("__x").unwrap(), bnd.get::<_, P<Ty>>("__t").unwrap())).collect::<Vec<_>>();
mk().struct_item(name, fields)
}
fn build_struct_instance(struct_name: &str,
instance_name: &str,
matches: &[Bindings]) -> P<Item> {
let fields = matches.iter().map(
|bnd| mk().field(bnd.get::<_, Ident>("__x").unwrap(), bnd.get::<_, P<Expr>>("__init").unwrap())).collect::<Vec<_>>();
mk().mutbl()
.static_item(instance_name,
mk().path_ty(vec![struct_name]),
mk().struct_expr(vec![struct_name], fields))
}
pub struct Localize;
impl Transform for Localize {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
struct StaticInfo {
name: Ident,
arg_name: Symbol,
ty: P<Ty>,
mutbl: Mutability,
}
let mut statics = HashMap::new();
FlatMapNodes::visit(krate, |i: P<Item>| {
if !st.marked(i.id, "target") {
return smallvec![i];
}
match i.node {
ItemKind::Static(ref ty, mutbl, _) => {
let def_id = cx.node_def_id(i.id);
let arg_name_str = format!("{}_", i.ident.name.as_str());
let arg_name = (&arg_name_str as &str).into_symbol();
statics.insert(def_id, StaticInfo {
name: i.ident.clone(),
arg_name: arg_name,
ty: ty.clone(),
mutbl: mutbl,
});
},
_ => {},
}
smallvec![i]
});
let mut fn_refs = HashMap::new();
mut_visit_fns(krate, |fl| {
if !st.marked(fl.id, "user") {
return;
}
let fn_def_id = cx.node_def_id(fl.id);
let mut refs = HashSet::new();
fold_resolved_paths(&mut fl.block, cx, |qself, path, def| {
if let Some(def_id) = def.opt_def_id() {
refs.insert(def_id);
}
(qself, path)
});
fn_refs.insert(fn_def_id, refs);
});
struct FnInfo {
fn_refs: HashSet<DefId>,
static_refs: HashSet<DefId>,
}
let fn_ids = fn_refs.keys().map(|&x| x).collect::<HashSet<_>>();
let mut fns = fn_refs.into_iter().map(|(k, v)| {
let fn_refs = v.iter().filter(|id| fn_ids.contains(id))
.map(|&x| x).collect();
let static_refs = v.iter().filter(|id| statics.contains_key(id))
.map(|&x| x).collect();
(k, FnInfo { fn_refs, static_refs })
}).collect::<HashMap<_, _>>();
dataflow::iterate(&mut fns, |cur_id, cur, data| {
let mut changed = false;
for &other_id in &cur.fn_refs {
if other_id == cur_id {
continue;
}
for &static_id in &data[other_id].static_refs {
if !cur.static_refs.contains(&static_id) {
cur.static_refs.insert(static_id);
changed = true;
}
}
}
changed
});
let fn_statics = fns.into_iter().map(|(k, v)| {
let mut statics = v.static_refs.into_iter().collect::<Vec<_>>();
statics.sort();
(k, statics)
}).collect::<HashMap<_, _>>();
mut_visit_fns(krate, |fl| {
let fn_def_id = cx.node_def_id(fl.id);
if let Some(static_ids) = fn_statics.get(&fn_def_id) {
for &static_id in static_ids {
let info = &statics[&static_id];
fl.decl.inputs.push(mk().arg(
mk().set_mutbl(info.mutbl).ref_ty(&info.ty),
mk().ident_pat(info.arg_name)));
}
MutVisitNodes::visit(&mut fl.block, |e: &mut P<Expr>| {
if let Some(def_id) = cx.try_resolve_expr(&e) {
if let Some(info) = statics.get(&def_id) {
*e = mk().unary_expr("*", mk().ident_expr(info.arg_name));
return;
}
}
});
MutVisitNodes::visit(&mut fl.block, |e: &mut P<Expr>| {
if let ExprKind::Call(func, args) = &mut e.node {
if let Some(func_id) = cx.try_resolve_expr(&func) {
if let Some(func_static_ids) = fn_statics.get(&func_id) {
for &static_id in func_static_ids {
args.push(mk().ident_expr(statics[&static_id].arg_name));
}
}
}
}
});
} else {
MutVisitNodes::visit(&mut fl.block, |e: &mut P<Expr>| {
if let ExprKind::Call(func, args) = &mut e.node {
if let Some(func_id) = cx.try_resolve_expr(&func) {
if let Some(func_static_ids) = fn_statics.get(&func_id) {
for &static_id in func_static_ids {
let info = &statics[&static_id];
args.push(mk().set_mutbl(info.mutbl).addr_of_expr(
mk().ident_expr(info.name)));
}
}
}
}
});
}
});
}
}
struct StaticToLocal;
impl Transform for StaticToLocal {
fn transform(&self, krate: &mut Crate, st: &CommandState, cx: &RefactorCtxt) {
struct StaticInfo {
name: Ident,
ty: P<Ty>,
mutbl: Mutability,
expr: P<Expr>,
}
let mut statics = HashMap::new();
FlatMapNodes::visit(krate, |i: P<Item>| {
if !st.marked(i.id, "target") {
return smallvec![i];
}
match i.node {
ItemKind::Static(ref ty, mutbl, ref expr) => {
let def_id = cx.node_def_id(i.id);
statics.insert(def_id, StaticInfo {
name: i.ident.clone(),
ty: ty.clone(),
mutbl: mutbl,
expr: expr.clone(),
});
return smallvec![];
},
_ => {},
}
smallvec![i]
});
mut_visit_fns(krate, |fl| {
let mut ref_ids = HashSet::new();
let mut refs = Vec::new();
fold_resolved_paths(&mut fl.block, cx, |qself, path, def| {
if let Some(def_id) = def.opt_def_id() {
if ref_ids.insert(def_id) {
if let Some(info) = statics.get(&def_id) {
refs.push(info);
}
}
}
(qself, path)
});
if refs.len() == 0 {
return;
}
refs.sort_by_key(|info| info.name.name);
if let Some(block) = &mut fl.block {
let new_stmts = Vec::with_capacity(refs.len() + block.stmts.len());
let old_stmts = mem::replace(&mut block.stmts, new_stmts);
for &info in &refs {
let pat = mk().set_mutbl(info.mutbl).ident_pat(info.name);
let local = mk().local(pat, Some(info.ty.clone()), Some(info.expr.clone()));
let stmt = mk().local_stmt(P(local));
block.stmts.push(stmt);
}
block.stmts.extend(old_stmts.into_iter());
}
});
}
}
pub fn register_commands(reg: &mut Registry) {
use super::mk;
reg.register("static_collect_to_struct", |args| mk(CollectToStruct {
struct_name: args[0].clone(),
instance_name: args[1].clone(),
}));
reg.register("static_to_local_ref", |_args| mk(Localize));
reg.register("static_to_local", |_args| mk(StaticToLocal));
}