use crate::method::{FnType, SelfType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType,
};
use crate::utils;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{parse_quote, Expr, Token};
pub struct PyClassArgs {
pub freelist: Option<syn::Expr>,
pub name: Option<syn::Ident>,
pub flags: Vec<syn::Expr>,
pub base: syn::TypePath,
pub has_extends: bool,
pub has_unsendable: bool,
pub module: Option<syn::LitStr>,
}
impl Parse for PyClassArgs {
fn parse(input: ParseStream) -> syn::parse::Result<Self> {
let mut slf = PyClassArgs::default();
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
for expr in vars {
slf.add_expr(&expr)?;
}
Ok(slf)
}
}
impl Default for PyClassArgs {
fn default() -> Self {
PyClassArgs {
freelist: None,
name: None,
module: None,
flags: vec![parse_quote! { 0 }],
base: parse_quote! { pyo3::PyAny },
has_extends: false,
has_unsendable: false,
}
}
}
impl PyClassArgs {
fn add_expr(&mut self, expr: &Expr) -> syn::parse::Result<()> {
match expr {
syn::Expr::Path(exp) if exp.path.segments.len() == 1 => self.add_path(exp),
syn::Expr::Assign(assign) => self.add_assign(assign),
_ => Err(syn::Error::new_spanned(expr, "Failed to parse arguments")),
}
}
fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> {
let syn::ExprAssign { left, right, .. } = assign;
let key = match &**left {
syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
exp.path.segments.first().unwrap().ident.to_string()
}
_ => {
return Err(syn::Error::new_spanned(assign, "Failed to parse arguments"));
}
};
macro_rules! expected {
($expected: literal) => {
expected!($expected, right)
};
($expected: literal, $span: ident) => {
return Err(syn::Error::new_spanned(
$span,
concat!("Expected ", $expected),
));
};
}
match key.as_str() {
"freelist" => {
self.freelist = Some(syn::Expr::clone(right));
}
"name" => match &**right {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit),
..
}) => {
self.name = Some(lit.parse().map_err(|_| {
syn::Error::new_spanned(
lit,
"expected a single identifier in double-quotes",
)
})?);
}
syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
return Err(syn::Error::new_spanned(
exp,
format!(
concat!(
"since PyO3 0.13 a pyclass name should be in double-quotes, ",
"e.g. \"{}\""
),
exp.path.get_ident().expect("path has 1 segment")
),
));
}
_ => expected!("type name (e.g. \"Name\")"),
},
"extends" => match &**right {
syn::Expr::Path(exp) => {
self.base = syn::TypePath {
path: exp.path.clone(),
qself: None,
};
self.has_extends = true;
}
_ => expected!("type path (e.g., my_mod::BaseClass)"),
},
"module" => match &**right {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit),
..
}) => {
self.module = Some(lit.clone());
}
_ => expected!(r#"string literal (e.g., "my_mod")"#),
},
_ => expected!("one of freelist/name/extends/module", left),
};
Ok(())
}
fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> {
let flag = exp.path.segments.first().unwrap().ident.to_string();
let mut push_flag = |flag| {
self.flags.push(syn::Expr::Path(flag));
};
match flag.as_str() {
"gc" => push_flag(parse_quote! {pyo3::type_flags::GC}),
"weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}),
"subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}),
"dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}),
"unsendable" => {
self.has_unsendable = true;
}
_ => {
return Err(syn::Error::new_spanned(
&exp.path,
"Expected one of gc/weakref/subclass/dict/unsendable",
))
}
};
Ok(())
}
}
pub fn build_py_class(class: &mut syn::ItemStruct, attr: &PyClassArgs) -> syn::Result<TokenStream> {
let text_signature = utils::parse_text_signature_attrs(
&mut class.attrs,
&get_class_python_name(&class.ident, attr),
)?;
let doc = utils::get_doc(&class.attrs, text_signature, true)?;
let mut descriptors = Vec::new();
check_generics(class)?;
if let syn::Fields::Named(fields) = &mut class.fields {
for field in fields.named.iter_mut() {
let field_descs = parse_descriptors(field)?;
if !field_descs.is_empty() {
descriptors.push((field.clone(), field_descs));
}
}
} else {
return Err(syn::Error::new_spanned(
&class.fields,
"#[pyclass] can only be used with C-style structs",
));
}
impl_class(&class.ident, &attr, doc, descriptors)
}
fn parse_descriptors(item: &mut syn::Field) -> syn::Result<Vec<FnType>> {
let mut descs = Vec::new();
let mut new_attrs = Vec::new();
for attr in item.attrs.iter() {
if let Ok(syn::Meta::List(list)) = attr.parse_meta() {
if list.path.is_ident("pyo3") {
for meta in list.nested.iter() {
if let syn::NestedMeta::Meta(metaitem) = meta {
if metaitem.path().is_ident("get") {
descs.push(FnType::Getter(SelfType::Receiver { mutable: false }));
} else if metaitem.path().is_ident("set") {
descs.push(FnType::Setter(SelfType::Receiver { mutable: true }));
} else {
return Err(syn::Error::new_spanned(
metaitem,
"Only get and set are supported",
));
}
}
}
} else {
new_attrs.push(attr.clone())
}
} else {
new_attrs.push(attr.clone());
}
}
item.attrs.clear();
item.attrs.extend(new_attrs);
Ok(descs)
}
fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
let name = format!("Pyo3MethodsInventoryFor{}", cls);
let inventory_cls = syn::Ident::new(&name, Span::call_site());
quote! {
#[doc(hidden)]
pub struct #inventory_cls {
methods: Vec<pyo3::class::PyMethodDefType>,
}
impl pyo3::class::methods::PyMethodsInventory for #inventory_cls {
fn new(methods: Vec<pyo3::class::PyMethodDefType>) -> Self {
Self { methods }
}
fn get(&'static self) -> &'static [pyo3::class::PyMethodDefType] {
&self.methods
}
}
impl pyo3::class::methods::HasMethodsInventory for #cls {
type Methods = #inventory_cls;
}
pyo3::inventory::collect!(#inventory_cls);
}
}
fn impl_proto_inventory(cls: &syn::Ident) -> TokenStream {
let name = format!("Pyo3ProtoInventoryFor{}", cls);
let inventory_cls = syn::Ident::new(&name, Span::call_site());
quote! {
#[doc(hidden)]
pub struct #inventory_cls {
def: pyo3::class::proto_methods::PyProtoMethodDef,
}
impl pyo3::class::proto_methods::PyProtoInventory for #inventory_cls {
fn new(def: pyo3::class::proto_methods::PyProtoMethodDef) -> Self {
Self { def }
}
fn get(&'static self) -> &'static pyo3::class::proto_methods::PyProtoMethodDef {
&self.def
}
}
impl pyo3::class::proto_methods::HasProtoInventory for #cls {
type ProtoMethods = #inventory_cls;
}
pyo3::inventory::collect!(#inventory_cls);
}
}
fn get_class_python_name<'a>(cls: &'a syn::Ident, attr: &'a PyClassArgs) -> &'a syn::Ident {
attr.name.as_ref().unwrap_or(cls)
}
fn impl_class(
cls: &syn::Ident,
attr: &PyClassArgs,
doc: syn::LitStr,
descriptors: Vec<(syn::Field, Vec<FnType>)>,
) -> syn::Result<TokenStream> {
let cls_name = get_class_python_name(cls, attr).to_string();
let extra = {
if let Some(freelist) = &attr.freelist {
quote! {
impl pyo3::freelist::PyClassWithFreeList for #cls {
#[inline]
fn get_free_list(_py: pyo3::Python) -> &mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> {
static mut FREELIST: *mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> = 0 as *mut _;
unsafe {
if FREELIST.is_null() {
FREELIST = Box::into_raw(Box::new(
pyo3::freelist::FreeList::with_capacity(#freelist)));
}
&mut *FREELIST
}
}
}
}
} else {
quote! {
impl pyo3::pyclass::PyClassAlloc for #cls {}
}
}
};
let extra = if !descriptors.is_empty() {
let path = syn::Path::from(syn::PathSegment::from(cls.clone()));
let ty = syn::Type::from(syn::TypePath { path, qself: None });
let desc_impls = impl_descriptors(&ty, descriptors)?;
quote! {
#desc_impls
#extra
}
} else {
extra
};
let mut has_weakref = false;
let mut has_dict = false;
let mut has_gc = false;
for f in attr.flags.iter() {
if let syn::Expr::Path(epath) = f {
if epath.path == parse_quote! { pyo3::type_flags::WEAKREF } {
has_weakref = true;
} else if epath.path == parse_quote! { pyo3::type_flags::DICT } {
has_dict = true;
} else if epath.path == parse_quote! { pyo3::type_flags::GC } {
has_gc = true;
}
}
}
let weakref = if has_weakref {
quote! { pyo3::pyclass_slots::PyClassWeakRefSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::WeakRef }
} else {
quote! { pyo3::pyclass_slots::PyClassDummySlot }
};
let dict = if has_dict {
quote! { pyo3::pyclass_slots::PyClassDictSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::Dict }
} else {
quote! { pyo3::pyclass_slots::PyClassDummySlot }
};
let module = if let Some(m) = &attr.module {
quote! { Some(#m) }
} else {
quote! { None }
};
let gc_impl = if has_gc {
let closure_name = format!("__assertion_closure_{}", cls);
let closure_token = syn::Ident::new(&closure_name, Span::call_site());
quote! {
fn #closure_token() {
use pyo3::class;
fn _assert_implements_protocol<'p, T: pyo3::class::PyGCProtocol<'p>>() {}
_assert_implements_protocol::<#cls>();
}
}
} else {
quote! {}
};
let impl_inventory = impl_methods_inventory(&cls);
let impl_proto_inventory = impl_proto_inventory(&cls);
let base = &attr.base;
let flags = &attr.flags;
let extended = if attr.has_extends {
quote! { pyo3::type_flags::EXTENDED }
} else {
quote! { 0 }
};
let base_layout = if attr.has_extends {
quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::LayoutAsBase }
} else {
quote! { pyo3::pycell::PyCellBase<pyo3::PyAny> }
};
let base_nativetype = if attr.has_extends {
quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::BaseNativeType }
} else {
quote! { pyo3::PyAny }
};
let into_pyobject = if !attr.has_extends {
quote! {
impl pyo3::IntoPy<pyo3::PyObject> for #cls {
fn into_py(self, py: pyo3::Python) -> pyo3::PyObject {
pyo3::IntoPy::into_py(pyo3::Py::new(py, self).unwrap(), py)
}
}
}
} else {
quote! {}
};
let thread_checker = if attr.has_unsendable {
quote! { pyo3::pyclass::ThreadCheckerImpl<#cls> }
} else if attr.has_extends {
quote! {
pyo3::pyclass::ThreadCheckerInherited<#cls, <#cls as pyo3::type_object::PyTypeInfo>::BaseType>
}
} else {
quote! { pyo3::pyclass::ThreadCheckerStub<#cls> }
};
Ok(quote! {
unsafe impl pyo3::type_object::PyTypeInfo for #cls {
type Type = #cls;
type BaseType = #base;
type Layout = pyo3::PyCell<Self>;
type BaseLayout = #base_layout;
type Initializer = pyo3::pyclass_init::PyClassInitializer<Self>;
type AsRefTarget = pyo3::PyCell<Self>;
const NAME: &'static str = #cls_name;
const MODULE: Option<&'static str> = #module;
const DESCRIPTION: &'static str = #doc;
const FLAGS: usize = #(#flags)|* | #extended;
#[inline]
fn type_object_raw(py: pyo3::Python) -> *mut pyo3::ffi::PyTypeObject {
use pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>(py)
}
}
impl pyo3::PyClass for #cls {
type Dict = #dict;
type WeakRef = #weakref;
type BaseNativeType = #base_nativetype;
}
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = pyo3::PyRef<'a, #cls>;
}
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = pyo3::PyRefMut<'a, #cls>;
}
impl pyo3::pyclass::PyClassSend for #cls {
type ThreadChecker = #thread_checker;
}
#into_pyobject
#impl_inventory
#impl_proto_inventory
#extra
#gc_impl
})
}
fn impl_descriptors(
cls: &syn::Type,
descriptors: Vec<(syn::Field, Vec<FnType>)>,
) -> syn::Result<TokenStream> {
let py_methods: Vec<TokenStream> = descriptors
.iter()
.flat_map(|(field, fns)| {
fns.iter()
.map(|desc| {
let name = field.ident.as_ref().unwrap().unraw();
let doc = utils::get_doc(&field.attrs, None, true)
.unwrap_or_else(|_| syn::LitStr::new(&name.to_string(), name.span()));
match desc {
FnType::Getter(self_ty) => Ok(impl_py_getter_def(
&name,
&doc,
&impl_wrap_getter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
)),
FnType::Setter(self_ty) => Ok(impl_py_setter_def(
&name,
&doc,
&impl_wrap_setter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
)),
_ => unreachable!(),
}
})
.collect::<Vec<syn::Result<TokenStream>>>()
})
.collect::<syn::Result<_>>()?;
Ok(quote! {
pyo3::inventory::submit! {
#![crate = pyo3] {
type Inventory = <#cls as pyo3::class::methods::HasMethodsInventory>::Methods;
<Inventory as pyo3::class::methods::PyMethodsInventory>::new(vec![#(#py_methods),*])
}
}
})
}
fn check_generics(class: &mut syn::ItemStruct) -> syn::Result<()> {
if class.generics.params.is_empty() {
Ok(())
} else {
Err(syn::Error::new_spanned(
&class.generics,
"#[pyclass] cannot have generic parameters",
))
}
}