import sys
import json
import textwrap
import re
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Dict, Any
import asdl
TABSIZE = 4
AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n"
BUILTIN_TYPE_NAMES = {
"identifier": "Identifier",
"string": "String",
"int": "Int",
"constant": "Constant",
}
assert BUILTIN_TYPE_NAMES.keys() == asdl.builtin_types
BUILTIN_INT_NAMES = {
"simple": "bool",
"is_async": "bool",
"conversion": "ConversionFlag",
}
RENAME_MAP = {
"cmpop": "cmp_op",
"unaryop": "unary_op",
"boolop": "bool_op",
"excepthandler": "except_handler",
"withitem": "with_item",
}
RUST_KEYWORDS = {
"if",
"while",
"for",
"return",
"match",
"try",
"await",
"yield",
"in",
"mod",
"type",
}
attributes = [
asdl.Field("int", "lineno"),
asdl.Field("int", "col_offset"),
asdl.Field("int", "end_lineno"),
asdl.Field("int", "end_col_offset"),
]
ORIGINAL_NODE_WARNING = "NOTE: This type is different from original Python AST."
arg_with_default = asdl.Type(
"arg_with_default",
asdl.Product(
[
asdl.Field("arg", "def"),
asdl.Field(
"expr", "default", opt=True
), ],
),
)
arg_with_default.doc = f"""
An alternative type of AST `arg`. This is used for each function argument that might have a default value.
Used by `Arguments` original type.
{ORIGINAL_NODE_WARNING}
""".strip()
alt_arguments = asdl.Type(
"alt:arguments",
asdl.Product(
[
asdl.Field("arg_with_default", "posonlyargs", seq=True),
asdl.Field("arg_with_default", "args", seq=True),
asdl.Field("arg", "vararg", opt=True),
asdl.Field("arg_with_default", "kwonlyargs", seq=True),
asdl.Field("arg", "kwarg", opt=True),
]
),
)
alt_arguments.doc = f"""
An alternative type of AST `arguments`. This is parser-friendly and human-friendly definition of function arguments.
This form also has advantage to implement pre-order traverse.
`defaults` and `kw_defaults` fields are removed and the default values are placed under each `arg_with_default` typed argument.
`vararg` and `kwarg` are still typed as `arg` because they never can have a default value.
The matching Python style AST type is [PythonArguments]. While [PythonArguments] has ordered `kwonlyargs` fields by
default existence, [Arguments] has location-ordered kwonlyargs fields.
{ORIGINAL_NODE_WARNING}
""".strip()
CUSTOM_TYPES = [
alt_arguments,
arg_with_default,
]
CUSTOM_REPLACEMENTS = {
"arguments": alt_arguments,
}
CUSTOM_ATTACHMENTS = [
arg_with_default,
]
def maybe_custom(type):
return CUSTOM_REPLACEMENTS.get(type.name, type)
def rust_field_name(name):
name = rust_type_name(name)
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
def rust_type_name(name):
name = RENAME_MAP.get(name, name)
if name in asdl.builtin_types:
builtin = BUILTIN_TYPE_NAMES[name]
return builtin
elif name.islower():
return "".join(part.capitalize() for part in name.split("_"))
else:
return name
def is_simple(sum):
for t in sum.types:
if t.fields:
return False
return True
def asdl_of(name, obj):
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
fields = ", ".join(map(str, obj.fields))
if fields:
fields = "({})".format(fields)
return "{}{}".format(name, fields)
else:
if is_simple(obj):
types = " | ".join(type.name for type in obj.types)
else:
sep = "\n{}| ".format(" " * (len(name) + 1))
types = sep.join(asdl_of(type.name, type) for type in obj.types)
return "{} = {}".format(name, types)
class TypeInfo:
type: asdl.Type
enum_name: Optional[str]
has_user_data: Optional[bool]
has_attributes: bool
is_simple: bool
children: set
fields: Optional[Any]
boxed: bool
def __init__(self, type):
self.type = type
self.enum_name = None
self.has_user_data = None
self.has_attributes = False
self.is_simple = False
self.children = set()
self.fields = None
self.boxed = False
def __repr__(self):
return f"<TypeInfo: {self.name}>"
@property
def name(self):
return self.type.name
@property
def is_type(self):
return isinstance(self.type, asdl.Type)
@property
def is_product(self):
return self.is_type and isinstance(self.type.value, asdl.Product)
@property
def is_sum(self):
return self.is_type and isinstance(self.type.value, asdl.Sum)
@property
def has_expr(self):
return self.is_product and any(
f.type != "identifier" for f in self.type.value.fields
)
@property
def is_custom(self):
return self.type.name in [t.name for t in CUSTOM_TYPES]
@property
def is_custom_replaced(self):
return self.type.name in CUSTOM_REPLACEMENTS
@property
def custom(self):
if self.type.name in CUSTOM_REPLACEMENTS:
return CUSTOM_REPLACEMENTS[self.type.name]
return self.type
def no_cfg(self, typeinfo):
if self.is_product:
return self.has_attributes
elif self.enum_name:
return typeinfo[self.enum_name].has_attributes
else:
return self.has_attributes
@property
def rust_name(self):
return rust_type_name(self.name)
@property
def full_field_name(self):
name = self.name
if name.startswith("alt:"):
name = name[4:]
if self.enum_name is None:
return name
else:
return f"{self.enum_name}_{rust_field_name(name)}"
@property
def full_type_name(self):
name = self.name
if name.startswith("alt:"):
name = name[4:]
rust_name = rust_type_name(name)
if self.enum_name is not None:
rust_name = rust_type_name(self.enum_name) + rust_name
if self.is_custom_replaced:
rust_name = "Python" + rust_name
return rust_name
def determine_user_data(self, type_info, stack):
if self.name in stack:
return None
stack.add(self.name)
for child, child_seq in self.children:
if child in asdl.builtin_types:
continue
child_info = type_info[child]
child_has_user_data = child_info.determine_user_data(type_info, stack)
if self.has_user_data is None and child_has_user_data is True:
self.has_user_data = True
stack.remove(self.name)
return self.has_user_data
class TypeInfoMixin:
type_info: Dict[str, TypeInfo]
def customized_type_info(self, type_name):
info = self.type_info[type_name]
return self.type_info[info.custom.name]
def has_user_data(self, typ):
return self.type_info[typ].has_user_data
def apply_generics(self, typ, *generics):
needs_generics = not self.type_info[typ].is_simple
if needs_generics:
return [f"<{g}>" for g in generics]
else:
return ["" for g in generics]
class EmitVisitor(asdl.VisitorBase, TypeInfoMixin):
def __init__(self, file, type_info):
self.file = file
self.type_info = type_info
self.identifiers = set()
super(EmitVisitor, self).__init__()
def emit_identifier(self, name):
name = str(name)
if name in self.identifiers:
return
self.emit("_Py_IDENTIFIER(%s);" % name, 0)
self.identifiers.add(name)
def emit(self, line, depth):
if line:
line = (" " * TABSIZE * depth) + textwrap.dedent(line)
self.file.write(line + "\n")
class FindUserDataTypesVisitor(asdl.VisitorBase):
def __init__(self, type_info):
self.type_info = type_info
super().__init__()
def visitModule(self, mod):
for dfn in mod.dfns + CUSTOM_TYPES:
self.visit(dfn)
stack = set()
for info in self.type_info.values():
info.determine_user_data(self.type_info, stack)
def visitType(self, type):
key = type.name
info = self.type_info[key] = TypeInfo(type)
self.visit(type.value, info)
def visitSum(self, sum, info):
type = info.type
info.is_simple = is_simple(sum)
for cons in sum.types:
self.visit(cons, type, info.is_simple)
if info.is_simple:
info.has_user_data = False
return
for t in sum.types:
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
info.has_user_data = True
info.has_attributes = True
for variant in sum.types:
self.add_children(type.name, variant.fields)
def visitConstructor(self, cons, type, simple):
info = self.type_info[cons.name] = TypeInfo(cons)
info.enum_name = type.name
info.is_simple = simple
def visitProduct(self, product, info):
type = info.type
if product.attributes:
info.has_user_data = True
info.has_attributes = True
if len(product.fields) > 2:
info.boxed = True
self.add_children(type.name, product.fields)
def add_children(self, name, fields):
self.type_info[name].children.update(
(field.type, field.seq) for field in fields
)
def rust_field(field_name):
if field_name in RUST_KEYWORDS:
field_name += "_"
return field_name
class StructVisitor(EmitVisitor):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
def emit_attrs(self, depth):
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
def emit_range(self, has_attributes, depth):
if has_attributes:
self.emit("pub range: R,", depth + 1)
else:
self.emit("pub range: OptionalRange<R>,", depth + 1)
def visitModule(self, mod):
self.emit_attrs(0)
self.emit(
"""
#[derive(is_macro::Is)]
pub enum Ast<R=TextRange> {
""",
0,
)
for dfn in mod.dfns:
info = self.customized_type_info(dfn.name)
dfn = info.custom
rust_name = info.full_type_name
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
if dfn.name == "mod":
self.emit('#[is(name = "module")]', 1)
self.emit(f"{rust_name}({rust_name}{generics}),", 1)
self.emit(
"""
}
impl<R> Node for Ast<R> {
const NAME: &'static str = "AST";
const FIELD_NAMES: &'static [&'static str] = &[];
}
""",
0,
)
for dfn in mod.dfns:
info = self.customized_type_info(dfn.name)
rust_name = info.full_type_name
generics = "" if self.type_info[dfn.name].is_simple else "<R>"
self.emit(
f"""
impl<R> From<{rust_name}{generics}> for Ast<R> {{
fn from(node: {rust_name}{generics}) -> Self {{
Ast::{rust_name}(node)
}}
}}
""",
0,
)
for dfn in mod.dfns + CUSTOM_TYPES:
self.visit(dfn)
def visitType(self, type, depth=0):
if hasattr(type, "doc"):
doc = "/// " + type.doc.replace("\n", "\n/// ") + "\n"
else:
doc = f"/// See also [{type.name}](https://docs.python.org/3/library/ast.html#ast.{type.name})"
self.emit(doc, depth)
self.visit(type.value, type, depth)
def visitSum(self, sum, type, depth):
if is_simple(sum):
self.simple_sum(sum, type, depth)
else:
self.sum_with_constructors(sum, type, depth)
(generics_applied,) = self.apply_generics(type.name, "R")
self.emit(
f"""
impl{generics_applied} Node for {rust_type_name(type.name)}{generics_applied} {{
const NAME: &'static str = "{type.name}";
const FIELD_NAMES: &'static [&'static str] = &[];
}}
""",
depth,
)
def simple_sum(self, sum, type, depth):
rust_name = rust_type_name(type.name)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is, Copy, Hash, Eq)]", depth)
self.emit(f"pub enum {rust_name} {{", depth)
for cons in sum.types:
self.emit(f"{cons.name},", depth + 1)
self.emit("}", depth)
self.emit(f"impl {rust_name} {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
if needs_escape:
prefix = rust_field_name(type.name) + "_"
else:
prefix = ""
for cons in sum.types:
self.emit(
f"""
#[inline]
pub const fn {prefix}{rust_field_name(cons.name)}(&self) -> Option<{rust_name}{cons.name}> {{
match self {{
{rust_name}::{cons.name} => Some({rust_name}{cons.name}),
_ => None,
}}
}}
""",
depth,
)
self.emit("}", depth)
self.emit("", depth)
for cons in sum.types:
self.emit(
f"""
pub struct {rust_name}{cons.name};
impl From<{rust_name}{cons.name}> for {rust_name} {{
fn from(_: {rust_name}{cons.name}) -> Self {{
{rust_name}::{cons.name}
}}
}}
impl<R> From<{rust_name}{cons.name}> for Ast<R> {{
fn from(_: {rust_name}{cons.name}) -> Self {{
{rust_name}::{cons.name}.into()
}}
}}
impl Node for {rust_name}{cons.name} {{
const NAME: &'static str = "{cons.name}";
const FIELD_NAMES: &'static [&'static str] = &[];
}}
impl std::cmp::PartialEq<{rust_name}> for {rust_name}{cons.name} {{
#[inline]
fn eq(&self, other: &{rust_name}) -> bool {{
matches!(other, {rust_name}::{cons.name})
}}
}}
""",
0,
)
def sum_with_constructors(self, sum, type, depth):
type_info = self.type_info[type.name]
rust_name = rust_type_name(type.name)
self.emit_attrs(depth)
self.emit("#[derive(is_macro::Is)]", depth)
self.emit(f"pub enum {rust_name}<R = TextRange> {{", depth)
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
for t in sum.types:
if needs_escape:
self.emit(
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
depth + 1,
)
self.emit(f"{t.name}({rust_name}{t.name}<R>),", depth + 1)
self.emit("}", depth)
self.emit("", depth)
for t in sum.types:
self.sum_subtype_struct(type_info, t, rust_name, depth)
def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
self.emit(f"""/// See also [{t.name}](https://docs.python.org/3/library/ast.html#ast.{t.name})""", depth)
self.emit_attrs(depth)
payload_name = f"{rust_name}{t.name}"
self.emit(f"pub struct {payload_name}<R = TextRange> {{", depth)
self.emit_range(sum_type_info.has_attributes, depth)
for f in t.fields:
self.visit(f, sum_type_info, "pub ", depth + 1, t.name)
assert sum_type_info.has_attributes == self.type_info[t.name].no_cfg(
self.type_info
)
self.emit("}", depth)
field_names = [f'"{f.name}"' for f in t.fields]
self.emit(
f"""
impl<R> Node for {payload_name}<R> {{
const NAME: &'static str = "{t.name}";
const FIELD_NAMES: &'static [&'static str] = &[{', '.join(field_names)}];
}}
impl<R> From<{payload_name}<R>> for {rust_name}<R> {{
fn from(payload: {payload_name}<R>) -> Self {{
{rust_name}::{t.name}(payload)
}}
}}
impl<R> From<{payload_name}<R>> for Ast<R> {{
fn from(payload: {payload_name}<R>) -> Self {{
{rust_name}::from(payload).into()
}}
}}
""",
depth,
)
self.emit("", depth)
def visitConstructor(self, cons, parent, depth):
if cons.fields:
self.emit(f"{cons.name} {{", depth)
for f in cons.fields:
self.visit(f, parent, "", depth + 1, cons.name)
self.emit("},", depth)
else:
self.emit(f"{cons.name},", depth)
def visitField(self, field, parent, vis, depth, constructor=None):
try:
field_type = self.customized_type_info(field.type)
typ = field_type.full_type_name
except KeyError:
field_type = None
typ = rust_type_name(field.type)
if field_type and not field_type.is_simple:
typ = f"{typ}<R>"
if (
field_type
and field_type.boxed
and (not (parent.is_product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
typ = f"Vec<{typ}>"
if typ == "Int":
typ = BUILTIN_INT_NAMES.get(field.name, typ)
name = rust_field(field.name)
self.emit(f"{vis}{name}: {typ},", depth)
def visitProduct(self, product, type, depth):
type_info = self.type_info[type.name]
product_name = type_info.full_type_name
self.emit_attrs(depth)
self.emit(f"pub struct {product_name}<R = TextRange> {{", depth)
self.emit_range(product.attributes, depth + 1)
for f in product.fields:
self.visit(f, type_info, "pub ", depth + 1)
assert bool(product.attributes) == type_info.no_cfg(self.type_info)
self.emit("}", depth)
field_names = [f'"{f.name}"' for f in product.fields]
self.emit(
f"""
impl<R> Node for {product_name}<R> {{
const NAME: &'static str = "{type.name}";
const FIELD_NAMES: &'static [&'static str] = &[
{', '.join(field_names)}
];
}}
""",
depth,
)
class FoldTraitDefVisitor(EmitVisitor):
def visitModule(self, mod, depth):
self.emit("pub trait Fold<U> {", depth)
self.emit("type TargetU;", depth + 1)
self.emit("type Error;", depth + 1)
self.emit("type UserContext;", depth + 1)
self.emit(
"""
fn will_map_user(&mut self, user: &U) -> Self::UserContext;
#[cfg(feature = "all-nodes-with-ranges")]
fn will_map_user_cfg(&mut self, user: &U) -> Self::UserContext {
self.will_map_user(user)
}
#[cfg(not(feature = "all-nodes-with-ranges"))]
fn will_map_user_cfg(&mut self, _user: &crate::EmptyRange<U>) -> crate::EmptyRange<Self::TargetU> {
crate::EmptyRange::default()
}
fn map_user(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error>;
#[cfg(feature = "all-nodes-with-ranges")]
fn map_user_cfg(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error> {
self.map_user(user, context)
}
#[cfg(not(feature = "all-nodes-with-ranges"))]
fn map_user_cfg(
&mut self,
_user: crate::EmptyRange<U>,
_context: crate::EmptyRange<Self::TargetU>,
) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
Ok(crate::EmptyRange::default())
}
""",
depth + 1,
)
self.emit(
"""
fn fold<X: Foldable<U, Self::TargetU>>(&mut self, node: X) -> Result<X::Mapped, Self::Error> {
node.fold(self)
}""",
depth + 1,
)
for dfn in mod.dfns + [arg_with_default]:
dfn = maybe_custom(dfn)
self.visit(dfn, depth + 2)
self.emit("}", depth)
def visitType(self, type, depth):
info = self.type_info[type.name]
apply_u, apply_target_u = self.apply_generics(info.name, "U", "Self::TargetU")
enum_name = info.full_type_name
self.emit(
f"fn fold_{info.full_field_name}(&mut self, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, Self::Error> {{",
depth,
)
self.emit(f"fold_{info.full_field_name}(self, node)", depth + 1)
self.emit("}", depth)
if isinstance(type.value, asdl.Sum) and not is_simple(type.value):
for cons in type.value.types:
self.visit(cons, type, depth)
def visitConstructor(self, cons, type, depth):
info = self.type_info[type.name]
apply_u, apply_target_u = self.apply_generics(type.name, "U", "Self::TargetU")
enum_name = rust_type_name(type.name)
func_name = f"fold_{info.full_field_name}_{rust_field_name(cons.name)}"
self.emit(
f"fn {func_name}(&mut self, node: {enum_name}{cons.name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, Self::Error> {{",
depth,
)
self.emit(f"{func_name}(self, node)", depth + 1)
self.emit("}", depth)
class FoldImplVisitor(EmitVisitor):
def visitModule(self, mod, depth):
for dfn in mod.dfns + [arg_with_default]:
dfn = maybe_custom(dfn)
self.visit(dfn, depth)
def visitType(self, type, depth=0):
self.visit(type.value, type, depth)
def visitSum(self, sum, type, depth):
name = type.name
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
enum_name = rust_type_name(name)
simple = is_simple(sum)
self.emit(f"impl<T, U> Foldable<T, U> for {enum_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {enum_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {enum_name}{apply_u}) -> Result<{enum_name}{apply_target_u}, F::Error> {{",
depth,
)
if simple:
self.emit("Ok(node) }", depth + 1)
return
self.emit("let folded = match node {", depth + 1)
for cons in sum.types:
self.emit(
f"{enum_name}::{cons.name}(cons) => {enum_name}::{cons.name}(Foldable::fold(cons, folder)?),",
depth + 1,
)
self.emit("};", depth + 1)
self.emit("Ok(folded)", depth + 1)
self.emit("}", depth)
for cons in sum.types:
self.visit(cons, type, depth)
def visitConstructor(self, cons, type, depth):
apply_t, apply_u, apply_target_u = self.apply_generics(
type.name, "T", "U", "F::TargetU"
)
info = self.type_info[type.name]
enum_name = info.full_type_name
cons_type_name = f"{enum_name}{cons.name}"
self.emit(f"impl<T, U> Foldable<T, U> for {cons_type_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {cons_type_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(
f"folder.fold_{info.full_field_name}_{rust_field_name(cons.name)}(self)",
depth + 2,
)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{info.full_field_name}_{rust_field_name(cons.name)}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{",
depth,
)
fields_pattern = self.make_pattern(cons.fields)
map_user_suffix = "" if info.has_attributes else "_cfg"
self.emit(
f"""
let {cons_type_name} {{ {fields_pattern} }} = node;
let context = folder.will_map_user{map_user_suffix}(&range);
""",
depth + 3,
)
self.fold_fields(cons.fields, depth + 3)
self.emit(
f"let range = folder.map_user{map_user_suffix}(range, context)?;",
depth + 3,
)
self.composite_fields(f"{cons_type_name}", cons.fields, depth + 3)
self.emit("}", depth + 2)
def visitProduct(self, product, type, depth):
info = self.type_info[type.name]
name = type.name
apply_t, apply_u, apply_target_u = self.apply_generics(
name, "T", "U", "F::TargetU"
)
struct_name = info.full_type_name
has_attributes = bool(product.attributes)
self.emit(f"impl<T, U> Foldable<T, U> for {struct_name}{apply_t} {{", depth)
self.emit(f"type Mapped = {struct_name}{apply_u};", depth + 1)
self.emit(
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
depth + 1,
)
self.emit(f"folder.fold_{info.full_field_name}(self)", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
self.emit(
f"pub fn fold_{info.full_field_name}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {struct_name}{apply_u}) -> Result<{struct_name}{apply_target_u}, F::Error> {{",
depth,
)
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let {struct_name} {{ {fields_pattern} }} = node;", depth + 1)
map_user_suffix = "" if has_attributes else "_cfg"
self.emit(
f"let context = folder.will_map_user{map_user_suffix}(&range);", depth + 3
)
self.fold_fields(product.fields, depth + 1)
self.emit(
f"let range = folder.map_user{map_user_suffix}(range, context)?;", depth + 3
)
self.composite_fields(struct_name, product.fields, depth + 1)
self.emit("}", depth)
def make_pattern(self, fields):
body = ",".join(rust_field(f.name) for f in fields)
if body:
body += ","
body += "range"
return body
def fold_fields(self, fields, depth):
for field in fields:
name = rust_field(field.name)
self.emit(f"let {name} = Foldable::fold({name}, folder)?;", depth + 1)
def composite_fields(self, header, fields, depth):
self.emit(f"Ok({header} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name},", depth + 1)
self.emit("range,", depth + 1)
self.emit("})", depth)
class FoldModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
class VisitorModuleVisitor(StructVisitor):
def full_name(self, name):
type_info = self.type_info[name]
if type_info.enum_name:
return f"{type_info.enum_name}_{name}"
else:
return name
def node_type_name(self, name):
type_info = self.type_info[name]
if type_info.enum_name:
return f"{rust_type_name(type_info.enum_name)}{rust_type_name(name)}"
else:
return rust_type_name(name)
def visitModule(self, mod, depth=0):
self.emit("#[allow(unused_variables)]", depth)
self.emit("pub trait Visitor<R=crate::text_size::TextRange> {", depth)
for dfn in mod.dfns:
dfn = self.customized_type_info(dfn.name).type
self.visit(dfn, depth + 1)
self.emit("}", depth)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
if is_simple(sum):
self.simple_sum(sum, name, depth)
else:
self.sum_with_constructors(sum, name, depth)
def emit_visitor(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
node_type = type_info.full_type_name
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
if has_node:
self.emit(
f"self.generic_visit_{type_info.full_field_name}(node)", depth + 1
)
self.emit("}", depth)
def emit_generic_visitor_signature(self, nodename, depth, has_node=True):
type_info = self.type_info[nodename]
if has_node:
node_type = type_info.full_type_name
else:
node_type = "()"
(generic,) = self.apply_generics(nodename, "R")
self.emit(
f"fn generic_visit_{type_info.full_field_name}(&mut self, node: {node_type}{generic}) {{",
depth,
)
def emit_empty_generic_visitor(self, nodename, depth):
self.emit_generic_visitor_signature(nodename, depth)
self.emit("}", depth)
def simple_sum(self, sum, name, depth):
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
def visit_match_for_type(self, nodename, rust_name, type_, depth):
self.emit(f"{rust_name}::{type_.name}", depth)
self.emit("(data)", depth)
self.emit(
f"=> self.visit_{nodename}_{rust_field_name(type_.name)}(data),", depth
)
def visit_sum_type(self, name, type_, depth):
self.emit_visitor(type_.name, depth, has_node=type_.fields)
if not type_.fields:
return
self.emit_generic_visitor_signature(type_.name, depth, has_node=True)
for field in type_.fields:
if field.type in CUSTOM_REPLACEMENTS:
type_name = CUSTOM_REPLACEMENTS[field.type].name
else:
type_name = field.type
field_name = rust_field(field.name)
field_type = self.type_info.get(type_name)
if not (field_type and field_type.has_user_data):
continue
if field.opt:
self.emit(f"if let Some(value) = node.{field_name} {{", depth + 1)
elif field.seq:
iterable = f"node.{field_name}"
if type_.name == "Dict" and field.name == "keys":
iterable = f"{iterable}.into_iter().flatten()"
self.emit(f"for value in {iterable} {{", depth + 1)
else:
self.emit("{", depth + 1)
self.emit(f"let value = node.{field_name};", depth + 2)
variable = "value"
if field_type.boxed and (not field.seq or field.opt):
variable = "*" + variable
type_info = self.type_info[field_type.name]
self.emit(f"self.visit_{type_info.full_field_name}({variable});", depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
def sum_with_constructors(self, sum, name, depth):
if not sum.attributes:
return
enum_name = rust_type_name(name)
self.emit_visitor(name, depth)
self.emit_generic_visitor_signature(name, depth)
depth += 1
self.emit("match node {", depth)
for t in sum.types:
self.visit_match_for_type(name, enum_name, t, depth + 1)
self.emit("}", depth)
depth -= 1
self.emit("}", depth)
for t in sum.types:
self.visit_sum_type(name, t, depth)
def visitProduct(self, product, name, depth):
self.emit_visitor(name, depth)
self.emit_empty_generic_visitor(name, depth)
class RangedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns + CUSTOM_TYPES:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
for ty in sum.types:
variant_info = self.type_info[ty.name]
self.emit_type_alias(variant_info)
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_ranged_impl(variant_info)
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.emit(
f"""
impl Ranged for crate::{info.full_type_name} {{
fn range(&self) -> TextRange {{
match self {{
{sum_match_arms}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_ranged_impl(info)
def emit_type_alias(self, info):
return generics = "" if info.is_simple else "::<TextRange>"
self.emit(
f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};",
0,
)
self.emit("", 0)
def emit_ranged_impl(self, info):
if not info.no_cfg(self.type_info):
self.emit('#[cfg(feature = "all-nodes-with-ranges")]', 0)
self.file.write(
f"""
impl Ranged for crate::generic::{info.full_type_name}::<TextRange> {{
fn range(&self) -> TextRange {{
self.range
}}
}}
""".strip()
)
class LocatedDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns + CUSTOM_TYPES:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
if info.is_simple:
for ty in sum.types:
variant_info = self.type_info[ty.name]
self.emit_type_alias(variant_info)
return
sum_match_arms = ""
for ty in sum.types:
variant_info = self.type_info[ty.name]
sum_match_arms += (
f" Self::{variant_info.rust_name}(node) => node.range(),"
)
self.emit_type_alias(variant_info)
self.emit_located_impl(variant_info)
if not info.no_cfg(self.type_info):
cfg = '#[cfg(feature = "all-nodes-with-ranges")]'
else:
cfg = ''
self.emit(
f"""
{cfg}
impl Located for {info.full_type_name} {{
fn range(&self) -> SourceRange {{
match self {{
{sum_match_arms}
}}
}}
}}
{cfg}
impl LocatedMut for {info.full_type_name} {{
fn range_mut(&mut self) -> &mut SourceRange {{
match self {{
{sum_match_arms.replace('range()', 'range_mut()')}
}}
}}
}}
""".lstrip(),
0,
)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
self.emit_type_alias(info)
self.emit_located_impl(info)
def emit_type_alias(self, info):
generics = "" if info.is_simple else "::<SourceRange>"
self.emit(
f"pub type {info.full_type_name} = crate::generic::{info.full_type_name}{generics};",
0,
)
self.emit("", 0)
def emit_located_impl(self, info):
if not info.no_cfg(self.type_info):
cfg = '#[cfg(feature = "all-nodes-with-ranges")]'
else:
cfg = ''
self.emit(
f"""
{cfg}
impl Located for {info.full_type_name} {{
fn range(&self) -> SourceRange {{
self.range
}}
}}
{cfg}
impl LocatedMut for {info.full_type_name} {{
fn range_mut(&mut self) -> &mut SourceRange {{
&mut self.range
}}
}}
""",
0,
)
class ToPyo3AstVisitor(EmitVisitor):
def __init__(self, namespace, *args, **kw):
super().__init__(*args, **kw)
self.namespace = namespace
@property
def generics(self):
if self.namespace == "ranged":
return "<TextRange>"
elif self.namespace == "located":
return "<SourceRange>"
else:
assert False, self.namespace
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type):
self.visit(type.value, type)
def visitProduct(self, product, type):
info = self.type_info[type.name]
rust_name = info.full_type_name
self.emit_to_pyo3_with_fields(product, type, rust_name)
def visitSum(self, sum, type):
info = self.type_info[type.name]
rust_name = info.full_type_name
simple = is_simple(sum)
if is_simple(sum):
return
self.emit(
f"""
impl ToPyAst for ast::{rust_name}{self.generics} {{
#[inline]
fn to_py_ast<'py>(&self, {"_" if simple else ""}py: Python<'py>) -> PyResult<&'py PyAny> {{
let instance = match &self {{
""",
0,
)
for cons in sum.types:
self.emit(
f"ast::{rust_name}::{cons.name}(cons) => cons.to_py_ast(py)?,",
1,
)
self.emit(
"""
};
Ok(instance)
}
}
""",
0,
)
for cons in sum.types:
self.visit(cons, type)
def visitConstructor(self, cons, type):
parent = rust_type_name(type.name)
self.emit_to_pyo3_with_fields(cons, type, f"{parent}{cons.name}")
def emit_to_pyo3_with_fields(self, cons, type, name):
type_info = self.type_info[type.name]
self.emit(
f"""
impl ToPyAst for ast::{name}{self.generics} {{
#[inline]
fn to_py_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{
let cache = Self::py_type_cache().get().unwrap();
""",
0,
)
if cons.fields:
field_names = ", ".join(rust_field(f.name) for f in cons.fields)
if not type_info.is_simple:
field_names += ", range: _range"
self.emit(
f"let Self {{ {field_names} }} = self;",
1,
)
self.emit(
"""
let instance = Py::<PyAny>::as_ref(&cache.0, py).call1((
""",
1,
)
for field in cons.fields:
if field.type == "constant":
self.emit(
f"constant_to_object({rust_field(field.name)}, py),",
3,
)
continue
if field.type == "int":
if field.name == "level":
assert field.opt
self.emit(
f"{rust_field(field.name)}.map_or_else(|| py.None(), |level| level.to_u32().to_object(py)),",
3,
)
continue
if field.name == "lineno":
self.emit(
f"{rust_field(field.name)}.to_u32().to_object(py),",
3,
)
continue
self.emit(
f"{rust_field(field.name)}.to_py_ast(py)?,",
3,
)
self.emit(
"))?;",
0,
)
else:
self.emit(
"let Self { range: _range } = self;",
1,
)
self.emit(
"""let instance = Py::<PyAny>::as_ref(&cache.0, py).call0()?;""",
1,
)
if type.value.attributes and self.namespace == "located":
self.emit(
"""
let cache = ast_cache();
instance.setattr(cache.lineno.as_ref(py), _range.start.row.get())?;
instance.setattr(cache.col_offset.as_ref(py), _range.start.column.get())?;
if let Some(end) = _range.end {
instance.setattr(cache.end_lineno.as_ref(py), end.row.get())?;
instance.setattr(cache.end_col_offset.as_ref(py), end.column.get())?;
}
""",
0,
)
self.emit(
"""
Ok(instance)
}
}
""",
0,
)
class Pyo3StructVisitor(EmitVisitor):
def __init__(self, namespace, *args, **kw):
self.namespace = namespace
self.borrow = True
super().__init__(*args, **kw)
@property
def generics(self):
if self.namespace == "ranged":
return "<TextRange>"
elif self.namespace == "located":
return "<SourceRange>"
else:
assert False, self.namespace
@property
def module_name(self):
name = f"rustpython_ast.{self.namespace}"
return name
@property
def ref_def(self):
return "&'static " if self.borrow else ""
@property
def ref(self):
return "&" if self.borrow else ""
def emit_class(self, info, simple, base="super::Ast"):
inner_name = info.full_type_name
rust_name = self.type_info[info.custom.name].full_type_name
if simple:
generics = ""
else:
generics = self.generics
if info.is_sum:
subclass = ", subclass"
body = ""
into = f"{rust_name}"
else:
subclass = ""
body = f"(pub {self.ref_def} ast::{inner_name}{generics})"
into = f"{rust_name}(node)"
self.emit(
f"""
#[pyclass(module="{self.module_name}", name="_{info.name}", extends={base}, frozen{subclass})]
#[derive(Clone, Debug)]
pub struct {rust_name} {body};
impl From<{self.ref_def} ast::{inner_name}{generics}> for {rust_name} {{
fn from({"" if body else "_"}node: {self.ref_def} ast::{inner_name}{generics}) -> Self {{
{into}
}}
}}
""",
0,
)
if subclass:
self.emit(
f"""
#[pymethods]
impl {rust_name} {{
#[new]
fn new() -> PyClassInitializer<Self> {{
PyClassInitializer::from(Ast)
.add_subclass(Self)
}}
}}
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = Self::new();
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
""",
0,
)
else:
if base != "super::Ast":
add_subclass = f".add_subclass({base})"
else:
add_subclass = ""
self.emit(
f"""
impl ToPyObject for {rust_name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(Ast)
{add_subclass}
.add_subclass(self.clone());
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
""",
0,
)
if not subclass:
self.emit_wrapper(info)
def emit_getter(self, owner, type_name):
self.emit(
f"""
#[pymethods]
impl {type_name} {{
""",
0,
)
for field in owner.fields:
self.emit(
f"""
#[getter]
#[inline]
fn get_{field.name}(&self, py: Python) -> PyResult<PyObject> {{
self.0.{rust_field(field.name)}.to_py_wrapper(py)
}}
""",
3,
)
self.emit(
"""
}
""",
0,
)
def emit_getattr(self, owner, type_name):
self.emit(
f"""
#[pymethods]
impl {type_name} {{
fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
let object: Py<PyAny> = match key {{
""",
0,
)
for field in owner.fields:
self.emit(
f'"{field.name}" => self.0.{rust_field(field.name)}.to_py_wrapper(py)?,',
3,
)
self.emit(
"""
_ => todo!(),
};
Ok(object)
}
}
""",
0,
)
def emit_wrapper(self, info):
inner_name = info.full_type_name
rust_name = self.type_info[info.custom.name].full_type_name
self.emit(
f"""
impl ToPyWrapper for ast::{inner_name}{self.generics} {{
#[inline]
fn to_py_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
Ok({rust_name}(self).to_object(py))
}}
}}
""",
0,
)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type, depth)
def visitSum(self, sum, type, depth=0):
info = self.type_info[type.name]
rust_name = rust_type_name(type.name)
simple = is_simple(sum)
self.emit_class(info, simple)
if not simple:
self.emit(
f"""
impl ToPyWrapper for ast::{rust_name}{self.generics} {{
#[inline]
fn to_py_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
match &self {{
""",
0,
)
for cons in sum.types:
self.emit(f"Self::{cons.name}(cons) => cons.to_py_wrapper(py),", 3)
self.emit(
"""
}
}
}
""",
0,
)
for cons in sum.types:
self.visit(cons, rust_name, simple, depth + 1)
def visitProduct(self, product, type, depth=0):
info = self.type_info[type.name]
rust_name = rust_type_name(type.name)
self.emit_class(info, False)
if self.borrow:
self.emit_getter(product, rust_name)
def visitConstructor(self, cons, parent, simple, depth):
if simple:
self.emit(
f"""
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
pub struct {parent}{cons.name};
impl ToPyObject for {parent}{cons.name} {{
fn to_object(&self, py: Python) -> PyObject {{
let initializer = PyClassInitializer::from(Ast)
.add_subclass({parent})
.add_subclass(Self);
Py::new(py, initializer).unwrap().into_py(py)
}}
}}
""",
depth,
)
else:
info = self.type_info[cons.name]
self.emit_class(
info,
simple=False,
base=parent,
)
if self.borrow:
self.emit_getter(cons, f"{parent}{cons.name}")
class Pyo3PymoduleVisitor(EmitVisitor):
def __init__(self, namespace, *args, **kw):
self.namespace = namespace
super().__init__(*args, **kw)
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitProduct(self, product, name, depth=0):
info = self.type_info[name]
self.emit_fields(info, False)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
simple = is_simple(sum)
self.emit_fields(info, True)
for cons in sum.types:
self.visit(cons, name, simple, depth)
def visitConstructor(self, cons, parent, simple, depth):
info = self.type_info[cons.name]
self.emit_fields(info, simple)
def emit_fields(self, info, simple):
inner_name = info.full_type_name
rust_name = self.type_info[info.custom.name].full_type_name
self.emit(f"super::init_type::<{rust_name}, ast::{inner_name}>(py, m)?;", 1)
class StdlibClassDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
struct_name = "Node" + info.full_type_name
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "NodeAst")]',
depth,
)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {struct_name} {{}}", depth)
for cons in sum.types:
self.visit(cons, sum.attributes, struct_name, depth)
def visitConstructor(self, cons, attrs, base, depth):
self.gen_class_def(cons.name, cons.fields, attrs, depth, base)
def visitProduct(self, product, name, depth):
self.gen_class_def(name, product.fields, product.attributes, depth)
def gen_class_def(self, name, fields, attrs, depth, base=None):
info = self.type_info[self.type_info[name].custom.name]
if base is None:
base = "NodeAst"
struct_name = "Node" + info.full_type_name
else:
struct_name = "Node" + info.full_type_name
self.emit(
f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = {json.dumps(base)})]',
depth,
)
self.emit(f"struct {struct_name};", depth)
self.emit("#[pyclass(flags(HAS_DICT, BASETYPE))]", depth)
self.emit(f"impl {struct_name} {{", depth)
self.emit("#[extend_class]", depth + 1)
self.emit(
"fn extend_class_with_fields(ctx: &Context, class: &'static Py<PyType>) {",
depth + 1,
)
fields = ",".join(
f"ctx.new_str(ascii!({json.dumps(f.name)})).into()" for f in fields
)
self.emit(
f"class.set_attr(identifier!(ctx, _fields), ctx.new_tuple(vec![{fields}]).into());",
depth + 2,
)
attrs = ",".join(
f"ctx.new_str(ascii!({json.dumps(attr.name)})).into()" for attr in attrs
)
self.emit(
f"class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![{attrs}]).into());",
depth + 2,
)
self.emit("}", depth + 1)
self.emit("}", depth)
class StdlibExtendModuleVisitor(EmitVisitor):
def visitModule(self, mod):
depth = 0
self.emit(
"pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py<PyModule>) {",
depth,
)
self.emit("extend_module!(vm, module, {", depth + 1)
for dfn in mod.dfns:
self.visit(dfn, depth + 2)
self.emit("})", depth + 1)
self.emit("}", depth)
def visitType(self, type, depth):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
rust_name = rust_type_name(name)
self.emit(f"{json.dumps(name)} => Node{rust_name}::make_class(&vm.ctx),", depth)
for cons in sum.types:
self.visit(cons, depth, rust_name)
def visitConstructor(self, cons, depth, rust_name):
self.gen_extension(cons.name, depth, rust_name)
def visitProduct(self, product, name, depth):
self.gen_extension(name, depth)
def gen_extension(self, name, depth, base=""):
rust_name = rust_type_name(name)
self.emit(
f"{json.dumps(name)} => Node{base}{rust_name}::make_class(&vm.ctx),", depth
)
class StdlibTraitImplVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)
def visitSum(self, sum, name, depth):
info = self.type_info[name]
rust_name = info.full_type_name
self.emit("// sum", depth)
self.emit(f"impl Node for ast::located::{rust_name} {{", depth)
self.emit(
"fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
simple = is_simple(sum)
if simple:
self.emit("let node_type = match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name} => Node{rust_name}{cons.name}::static_type(),",
depth,
)
self.emit("};", depth + 3)
self.emit(
"NodeAst.into_ref_with_type(vm, node_type.to_owned()).unwrap().into()",
depth + 2,
)
else:
self.emit("match self {", depth + 2)
for cons in sum.types:
self.emit(
f"ast::located::{rust_name}::{cons.name}(cons) => cons.ast_to_object(vm),",
depth + 3,
)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_from_object(sum, name, rust_name, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)
if not is_simple(sum):
for cons in sum.types:
self.visit(cons, sum, rust_name, depth)
def visitConstructor(self, cons, sum, sum_rust_name, depth):
rust_name = rust_type_name(cons.name)
self.emit("// constructor", depth)
self.emit(f"impl Node for ast::located::{sum_rust_name}{rust_name} {{", depth)
fields_pattern = self.make_pattern(cons.fields)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit(
f"let ast::located::{sum_rust_name}{rust_name} {{ {fields_pattern} }} = self;",
depth,
)
self.make_node(cons.name, sum, cons.fields, depth + 2, sum_rust_name)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(
cons, cons.name, f"{sum_rust_name}{rust_name}", sum.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth + 1)
def visitProduct(self, product, name, depth):
info = self.type_info[name]
struct_name = info.full_type_name
self.emit("// product", depth)
self.emit(f"impl Node for ast::located::{struct_name} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
fields_pattern = self.make_pattern(product.fields)
self.emit(
f"let ast::located::{struct_name} {{ {fields_pattern} }} = self;",
depth + 2,
)
self.make_node(name, product, product.fields, depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_product_from_object(
product, name, struct_name, product.attributes, depth + 2
)
self.emit("}", depth + 1)
self.emit("}", depth)
def make_node(self, variant, owner, fields, depth, base=""):
rust_variant = rust_type_name(variant)
self.emit(
f"let node = NodeAst.into_ref_with_type(_vm, Node{base}{rust_variant}::static_type().to_owned()).unwrap();",
depth,
)
if fields or owner.attributes:
self.emit("let dict = node.as_object().dict().unwrap();", depth)
for f in fields:
self.emit(
f"dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();",
depth,
)
if owner.attributes:
self.emit("node_add_location(&dict, _range, _vm);", depth)
self.emit("node.into()", depth)
def make_pattern(self, fields):
return "".join(f"{rust_field(f.name)}," for f in fields) + "range: _range"
def gen_sum_from_object(self, sum, sum_name, rust_name, depth):
self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(
f"if _cls.is(Node{rust_name}{cons.name}::static_type()) {{", depth
)
self.emit(f"ast::located::{rust_name}::{cons.name}", depth + 1)
if not is_simple(sum):
self.emit(
f"(ast::located::{rust_name}{cons.name}::ast_from_object(_vm, _object)?)",
depth + 1,
)
self.emit("} else", depth)
self.emit("{", depth)
msg = f'format!("expected some sort of {sum_name}, but got {{}}",_object.repr(_vm)?)'
self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1)
self.emit("})", depth)
def gen_product_from_object(
self, product, product_name, struct_name, has_attributes, depth
):
self.emit("Ok(", depth)
self.gen_construction(
struct_name, product, product_name, has_attributes, depth + 1
)
self.emit(")", depth)
def gen_construction_fields(self, cons, name, depth):
for field in cons.fields:
self.emit(
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
depth + 1,
)
def gen_construction(self, cons_path, cons, name, attributes, depth):
self.emit(f"ast::located::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
if attributes:
self.emit(f'range: range_from_object(_vm, _object, "{name}")?,', depth + 1)
else:
self.emit("range: Default::default(),", depth + 1)
self.emit("}", depth)
def extract_location(self, typename, depth):
row = self.decode_field(asdl.Field("int", "lineno"), typename)
column = self.decode_field(asdl.Field("int", "col_offset"), typename)
self.emit(
f"""
let _location = {{
let row = {row};
let column = {column};
try_location(row, column)
}};
""",
depth,
)
def decode_field(self, field, typename):
name = json.dumps(field.name)
if field.opt and not field.seq:
return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?"
else:
return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?"
class ChainOfVisitors:
def __init__(self, *visitors):
self.visitors = visitors
def visit(self, object):
for v in self.visitors:
v.visit(object)
v.emit("", 0)
def write_ast_def(mod, type_info, f):
f.write("use crate::text_size::TextRange;")
StructVisitor(f, type_info).visit(mod)
def write_fold_def(mod, type_info, f):
FoldModuleVisitor(f, type_info).visit(mod)
def write_visitor_def(mod, type_info, f):
VisitorModuleVisitor(f, type_info).visit(mod)
def write_ranged_def(mod, type_info, f):
RangedDefVisitor(f, type_info).visit(mod)
def write_located_def(mod, type_info, f):
LocatedDefVisitor(f, type_info).visit(mod)
def write_pyo3_node(type_info, f):
def write(info: TypeInfo, rust_name: str):
if info.is_simple:
generics = ""
else:
generics = "<R>"
f.write(
f"""
impl{generics} PyNode for ast::{rust_name}{generics} {{
#[inline]
fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {{
static PY_TYPE: OnceCell<(Py<PyAny>, Py<PyAny>)> = OnceCell::new();
&PY_TYPE
}}
}}
""",
)
for type_name, info in type_info.items():
rust_name = info.full_type_name
if info.is_custom:
if type_name != info.type.name:
rust_name = "Python" + rust_name
else:
continue
write(info, rust_name)
def write_to_pyo3(mod, type_info, f):
write_pyo3_node(type_info, f)
write_to_pyo3_simple(type_info, f)
for namespace in ("ranged", "located"):
ToPyo3AstVisitor(namespace, f, type_info).visit(mod)
f.write(
"""
fn init_types(py: Python) -> PyResult<()> {
let ast_module = PyModule::import(py, "_ast")?;
"""
)
for info in type_info.values():
if info.is_custom:
continue
rust_name = info.full_type_name
f.write(f"cache_py_type::<ast::{rust_name}>(ast_module)?;\n")
f.write("Ok(())\n}")
def write_to_pyo3_simple(type_info, f):
for type_info in type_info.values():
if not type_info.is_sum:
continue
if not type_info.is_simple:
continue
rust_name = type_info.full_type_name
f.write(
f"""
impl ToPyAst for ast::{rust_name} {{
#[inline]
fn to_py_ast<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {{
let cell = match &self {{
""",
)
for cons in type_info.type.value.types:
f.write(
f"""ast::{rust_name}::{cons.name} => ast::{rust_name}{cons.name}::py_type_cache(),""",
)
f.write(
"""
};
Ok(Py::<PyAny>::as_ref(&cell.get().unwrap().1, py))
}
}
""",
)
def write_pyo3_wrapper(mod, type_info, namespace, f):
Pyo3StructVisitor(namespace, f, type_info).visit(mod)
if namespace == "located":
for info in type_info.values():
if not info.is_simple or not info.is_sum:
continue
rust_name = info.full_type_name
inner_name = type_info[info.custom.name].full_type_name
f.write(
f"""
impl ToPyWrapper for ast::{inner_name} {{
#[inline]
fn to_py_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
match &self {{
""",
)
for cons in info.type.value.types:
f.write(
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
)
f.write(
"""
}
}
}
""",
)
for cons in info.type.value.types:
f.write(
f"""
impl ToPyWrapper for ast::{rust_name}{cons.name} {{
#[inline]
fn to_py_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
Ok({rust_name}{cons.name}.to_object(py))
}}
}}
"""
)
f.write(
"""
pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
super::init_module(py, m)?;
"""
)
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
f.write("Ok(())\n}")
def write_parse_def(mod, type_info, f):
for info in type_info.values():
if info.enum_name not in ["expr", "stmt"]:
continue
type_name = rust_type_name(info.enum_name)
cons_name = rust_type_name(info.name)
f.write(f"""
impl Parse for ast::{info.full_type_name} {{
fn lex_starts_at(
source: &str,
offset: TextSize,
) -> SoftKeywordTransformer<Lexer<std::str::Chars>> {{
ast::{type_name}::lex_starts_at(source, offset)
}}
fn parse_tokens(
lxr: impl IntoIterator<Item = LexResult>,
source_path: &str,
) -> Result<Self, ParseError> {{
let node = ast::{type_name}::parse_tokens(lxr, source_path)?;
match node {{
ast::{type_name}::{cons_name}(node) => Ok(node),
node => Err(ParseError {{
error: ParseErrorType::InvalidToken,
offset: node.range().start(),
source_path: source_path.to_owned(),
}}),
}}
}}
}}
""")
def write_ast_mod(mod, type_info, f):
f.write(
"""
#![allow(clippy::all)]
use super::*;
use crate::common::ascii;
"""
)
c = ChainOfVisitors(
StdlibClassDefVisitor(f, type_info),
StdlibTraitImplVisitor(f, type_info),
StdlibExtendModuleVisitor(f, type_info),
)
c.visit(mod)
def main(
input_filename,
ast_dir,
parser_dir,
ast_pyo3_dir,
module_filename,
dump_module=False,
):
auto_gen_msg = AUTO_GEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
mod = asdl.parse(input_filename)
if dump_module:
print("Parsed Module:")
print(mod)
if not asdl.check(mod):
sys.exit(1)
type_info = {}
FindUserDataTypesVisitor(type_info).visit(mod)
from functools import partial as p
for filename, write in [
("generic", p(write_ast_def, mod, type_info)),
("fold", p(write_fold_def, mod, type_info)),
("ranged", p(write_ranged_def, mod, type_info)),
("located", p(write_located_def, mod, type_info)),
("visitor", p(write_visitor_def, mod, type_info)),
]:
with (ast_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(f)
for filename, write in [
("parse", p(write_parse_def, mod, type_info)),
]:
with (parser_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(f)
for filename, write in [
("to_py_ast", p(write_to_pyo3, mod, type_info)),
("wrapper_located", p(write_pyo3_wrapper, mod, type_info, "located")),
("wrapper_ranged", p(write_pyo3_wrapper, mod, type_info, "ranged")),
]:
with (ast_pyo3_dir / f"{filename}.rs").open("w") as f:
f.write(auto_gen_msg)
write(f)
with module_filename.open("w") as module_file:
module_file.write(auto_gen_msg)
write_ast_mod(mod, type_info, module_file)
print(f"{ast_dir}, {module_filename} regenerated.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("input_file", type=Path)
parser.add_argument("-A", "--ast-dir", type=Path, required=True)
parser.add_argument("-P", "--parser-dir", type=Path, required=True)
parser.add_argument("-O", "--ast-pyo3-dir", type=Path, required=True)
parser.add_argument("-M", "--module-file", type=Path, required=True)
parser.add_argument("-d", "--dump-module", action="store_true")
args = parser.parse_args()
main(
args.input_file,
args.ast_dir,
args.parser_dir,
args.ast_pyo3_dir,
args.module_file,
args.dump_module,
)