pyo3_macros_backend/
pyclass.rs

1use std::borrow::Cow;
2use std::fmt::Debug;
3
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::ext::IdentExt;
7use syn::parse::{Parse, ParseStream};
8use syn::punctuated::Punctuated;
9use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, ImplItemFn, Result, Token};
10
11use crate::attributes::kw::frozen;
12use crate::attributes::{
13    self, kw, take_pyo3_options, CrateAttribute, ErrorCombiner, ExtendsAttribute,
14    FreelistAttribute, ModuleAttribute, NameAttribute, NameLitStr, RenameAllAttribute,
15    StrFormatterAttribute,
16};
17use crate::konst::{ConstAttributes, ConstSpec};
18use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
19use crate::pyfunction::ConstructorAttribute;
20use crate::pyimpl::{gen_py_const, get_cfg_attributes, PyClassMethodsType};
21use crate::pymethod::{
22    impl_py_class_attribute, impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef,
23    MethodAndSlotDef, PropertyType, SlotDef, __GETITEM__, __HASH__, __INT__, __LEN__, __REPR__,
24    __RICHCMP__, __STR__,
25};
26use crate::pyversions::is_abi3_before;
27use crate::utils::{self, apply_renaming_rule, Ctx, LitCStr, PythonDoc};
28use crate::PyFunctionOptions;
29
30/// If the class is derived from a Rust `struct` or `enum`.
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum PyClassKind {
33    Struct,
34    Enum,
35}
36
37/// The parsed arguments of the pyclass macro
38#[derive(Clone)]
39pub struct PyClassArgs {
40    pub class_kind: PyClassKind,
41    pub options: PyClassPyO3Options,
42}
43
44impl PyClassArgs {
45    fn parse(input: ParseStream<'_>, kind: PyClassKind) -> Result<Self> {
46        Ok(PyClassArgs {
47            class_kind: kind,
48            options: PyClassPyO3Options::parse(input)?,
49        })
50    }
51
52    pub fn parse_struct_args(input: ParseStream<'_>) -> syn::Result<Self> {
53        Self::parse(input, PyClassKind::Struct)
54    }
55
56    pub fn parse_enum_args(input: ParseStream<'_>) -> syn::Result<Self> {
57        Self::parse(input, PyClassKind::Enum)
58    }
59}
60
61#[derive(Clone, Default)]
62pub struct PyClassPyO3Options {
63    pub krate: Option<CrateAttribute>,
64    pub dict: Option<kw::dict>,
65    pub eq: Option<kw::eq>,
66    pub eq_int: Option<kw::eq_int>,
67    pub extends: Option<ExtendsAttribute>,
68    pub get_all: Option<kw::get_all>,
69    pub freelist: Option<FreelistAttribute>,
70    pub frozen: Option<kw::frozen>,
71    pub hash: Option<kw::hash>,
72    pub mapping: Option<kw::mapping>,
73    pub module: Option<ModuleAttribute>,
74    pub name: Option<NameAttribute>,
75    pub ord: Option<kw::ord>,
76    pub rename_all: Option<RenameAllAttribute>,
77    pub sequence: Option<kw::sequence>,
78    pub set_all: Option<kw::set_all>,
79    pub str: Option<StrFormatterAttribute>,
80    pub subclass: Option<kw::subclass>,
81    pub unsendable: Option<kw::unsendable>,
82    pub weakref: Option<kw::weakref>,
83}
84
85pub enum PyClassPyO3Option {
86    Crate(CrateAttribute),
87    Dict(kw::dict),
88    Eq(kw::eq),
89    EqInt(kw::eq_int),
90    Extends(ExtendsAttribute),
91    Freelist(FreelistAttribute),
92    Frozen(kw::frozen),
93    GetAll(kw::get_all),
94    Hash(kw::hash),
95    Mapping(kw::mapping),
96    Module(ModuleAttribute),
97    Name(NameAttribute),
98    Ord(kw::ord),
99    RenameAll(RenameAllAttribute),
100    Sequence(kw::sequence),
101    SetAll(kw::set_all),
102    Str(StrFormatterAttribute),
103    Subclass(kw::subclass),
104    Unsendable(kw::unsendable),
105    Weakref(kw::weakref),
106}
107
108impl Parse for PyClassPyO3Option {
109    fn parse(input: ParseStream<'_>) -> Result<Self> {
110        let lookahead = input.lookahead1();
111        if lookahead.peek(Token![crate]) {
112            input.parse().map(PyClassPyO3Option::Crate)
113        } else if lookahead.peek(kw::dict) {
114            input.parse().map(PyClassPyO3Option::Dict)
115        } else if lookahead.peek(kw::eq) {
116            input.parse().map(PyClassPyO3Option::Eq)
117        } else if lookahead.peek(kw::eq_int) {
118            input.parse().map(PyClassPyO3Option::EqInt)
119        } else if lookahead.peek(kw::extends) {
120            input.parse().map(PyClassPyO3Option::Extends)
121        } else if lookahead.peek(attributes::kw::freelist) {
122            input.parse().map(PyClassPyO3Option::Freelist)
123        } else if lookahead.peek(attributes::kw::frozen) {
124            input.parse().map(PyClassPyO3Option::Frozen)
125        } else if lookahead.peek(attributes::kw::get_all) {
126            input.parse().map(PyClassPyO3Option::GetAll)
127        } else if lookahead.peek(attributes::kw::hash) {
128            input.parse().map(PyClassPyO3Option::Hash)
129        } else if lookahead.peek(attributes::kw::mapping) {
130            input.parse().map(PyClassPyO3Option::Mapping)
131        } else if lookahead.peek(attributes::kw::module) {
132            input.parse().map(PyClassPyO3Option::Module)
133        } else if lookahead.peek(kw::name) {
134            input.parse().map(PyClassPyO3Option::Name)
135        } else if lookahead.peek(attributes::kw::ord) {
136            input.parse().map(PyClassPyO3Option::Ord)
137        } else if lookahead.peek(kw::rename_all) {
138            input.parse().map(PyClassPyO3Option::RenameAll)
139        } else if lookahead.peek(attributes::kw::sequence) {
140            input.parse().map(PyClassPyO3Option::Sequence)
141        } else if lookahead.peek(attributes::kw::set_all) {
142            input.parse().map(PyClassPyO3Option::SetAll)
143        } else if lookahead.peek(attributes::kw::str) {
144            input.parse().map(PyClassPyO3Option::Str)
145        } else if lookahead.peek(attributes::kw::subclass) {
146            input.parse().map(PyClassPyO3Option::Subclass)
147        } else if lookahead.peek(attributes::kw::unsendable) {
148            input.parse().map(PyClassPyO3Option::Unsendable)
149        } else if lookahead.peek(attributes::kw::weakref) {
150            input.parse().map(PyClassPyO3Option::Weakref)
151        } else {
152            Err(lookahead.error())
153        }
154    }
155}
156
157impl Parse for PyClassPyO3Options {
158    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
159        let mut options: PyClassPyO3Options = Default::default();
160
161        for option in Punctuated::<PyClassPyO3Option, syn::Token![,]>::parse_terminated(input)? {
162            options.set_option(option)?;
163        }
164
165        Ok(options)
166    }
167}
168
169impl PyClassPyO3Options {
170    pub fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> syn::Result<()> {
171        take_pyo3_options(attrs)?
172            .into_iter()
173            .try_for_each(|option| self.set_option(option))
174    }
175
176    fn set_option(&mut self, option: PyClassPyO3Option) -> syn::Result<()> {
177        macro_rules! set_option {
178            ($key:ident) => {
179                {
180                    ensure_spanned!(
181                        self.$key.is_none(),
182                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
183                    );
184                    self.$key = Some($key);
185                }
186            };
187        }
188
189        match option {
190            PyClassPyO3Option::Crate(krate) => set_option!(krate),
191            PyClassPyO3Option::Dict(dict) => {
192                ensure_spanned!(
193                    !is_abi3_before(3, 9),
194                    dict.span() => "`dict` requires Python >= 3.9 when using the `abi3` feature"
195                );
196                set_option!(dict);
197            }
198            PyClassPyO3Option::Eq(eq) => set_option!(eq),
199            PyClassPyO3Option::EqInt(eq_int) => set_option!(eq_int),
200            PyClassPyO3Option::Extends(extends) => set_option!(extends),
201            PyClassPyO3Option::Freelist(freelist) => set_option!(freelist),
202            PyClassPyO3Option::Frozen(frozen) => set_option!(frozen),
203            PyClassPyO3Option::GetAll(get_all) => set_option!(get_all),
204            PyClassPyO3Option::Hash(hash) => set_option!(hash),
205            PyClassPyO3Option::Mapping(mapping) => set_option!(mapping),
206            PyClassPyO3Option::Module(module) => set_option!(module),
207            PyClassPyO3Option::Name(name) => set_option!(name),
208            PyClassPyO3Option::Ord(ord) => set_option!(ord),
209            PyClassPyO3Option::RenameAll(rename_all) => set_option!(rename_all),
210            PyClassPyO3Option::Sequence(sequence) => set_option!(sequence),
211            PyClassPyO3Option::SetAll(set_all) => set_option!(set_all),
212            PyClassPyO3Option::Str(str) => set_option!(str),
213            PyClassPyO3Option::Subclass(subclass) => set_option!(subclass),
214            PyClassPyO3Option::Unsendable(unsendable) => set_option!(unsendable),
215            PyClassPyO3Option::Weakref(weakref) => {
216                ensure_spanned!(
217                    !is_abi3_before(3, 9),
218                    weakref.span() => "`weakref` requires Python >= 3.9 when using the `abi3` feature"
219                );
220                set_option!(weakref);
221            }
222        }
223        Ok(())
224    }
225}
226
227pub fn build_py_class(
228    class: &mut syn::ItemStruct,
229    mut args: PyClassArgs,
230    methods_type: PyClassMethodsType,
231) -> syn::Result<TokenStream> {
232    args.options.take_pyo3_options(&mut class.attrs)?;
233
234    let ctx = &Ctx::new(&args.options.krate, None);
235    let doc = utils::get_doc(&class.attrs, None, ctx);
236
237    if let Some(lt) = class.generics.lifetimes().next() {
238        bail_spanned!(
239            lt.span() => concat!(
240                "#[pyclass] cannot have lifetime parameters. For an explanation, see \
241                https://pyo3.rs/v", env!("CARGO_PKG_VERSION"), "/class.html#no-lifetime-parameters"
242            )
243        );
244    }
245
246    ensure_spanned!(
247        class.generics.params.is_empty(),
248        class.generics.span() => concat!(
249            "#[pyclass] cannot have generic parameters. For an explanation, see \
250            https://pyo3.rs/v", env!("CARGO_PKG_VERSION"), "/class.html#no-generic-parameters"
251        )
252    );
253
254    let mut all_errors = ErrorCombiner(None);
255
256    let mut field_options: Vec<(&syn::Field, FieldPyO3Options)> = match &mut class.fields {
257        syn::Fields::Named(fields) => fields
258            .named
259            .iter_mut()
260            .filter_map(
261                |field| match FieldPyO3Options::take_pyo3_options(&mut field.attrs) {
262                    Ok(options) => Some((&*field, options)),
263                    Err(e) => {
264                        all_errors.combine(e);
265                        None
266                    }
267                },
268            )
269            .collect::<Vec<_>>(),
270        syn::Fields::Unnamed(fields) => fields
271            .unnamed
272            .iter_mut()
273            .filter_map(
274                |field| match FieldPyO3Options::take_pyo3_options(&mut field.attrs) {
275                    Ok(options) => Some((&*field, options)),
276                    Err(e) => {
277                        all_errors.combine(e);
278                        None
279                    }
280                },
281            )
282            .collect::<Vec<_>>(),
283        syn::Fields::Unit => {
284            if let Some(attr) = args.options.set_all {
285                return Err(syn::Error::new_spanned(attr, UNIT_SET));
286            };
287            if let Some(attr) = args.options.get_all {
288                return Err(syn::Error::new_spanned(attr, UNIT_GET));
289            };
290            // No fields for unit struct
291            Vec::new()
292        }
293    };
294
295    all_errors.ensure_empty()?;
296
297    if let Some(attr) = args.options.get_all {
298        for (_, FieldPyO3Options { get, .. }) in &mut field_options {
299            if let Some(old_get) = get.replace(Annotated::Struct(attr)) {
300                return Err(syn::Error::new(old_get.span(), DUPE_GET));
301            }
302        }
303    }
304
305    if let Some(attr) = args.options.set_all {
306        for (_, FieldPyO3Options { set, .. }) in &mut field_options {
307            if let Some(old_set) = set.replace(Annotated::Struct(attr)) {
308                return Err(syn::Error::new(old_set.span(), DUPE_SET));
309            }
310        }
311    }
312
313    impl_class(&class.ident, &args, doc, field_options, methods_type, ctx)
314}
315
316enum Annotated<X, Y> {
317    Field(X),
318    Struct(Y),
319}
320
321impl<X: Spanned, Y: Spanned> Annotated<X, Y> {
322    fn span(&self) -> Span {
323        match self {
324            Self::Field(x) => x.span(),
325            Self::Struct(y) => y.span(),
326        }
327    }
328}
329
330/// `#[pyo3()]` options for pyclass fields
331struct FieldPyO3Options {
332    get: Option<Annotated<kw::get, kw::get_all>>,
333    set: Option<Annotated<kw::set, kw::set_all>>,
334    name: Option<NameAttribute>,
335}
336
337enum FieldPyO3Option {
338    Get(attributes::kw::get),
339    Set(attributes::kw::set),
340    Name(NameAttribute),
341}
342
343impl Parse for FieldPyO3Option {
344    fn parse(input: ParseStream<'_>) -> Result<Self> {
345        let lookahead = input.lookahead1();
346        if lookahead.peek(attributes::kw::get) {
347            input.parse().map(FieldPyO3Option::Get)
348        } else if lookahead.peek(attributes::kw::set) {
349            input.parse().map(FieldPyO3Option::Set)
350        } else if lookahead.peek(attributes::kw::name) {
351            input.parse().map(FieldPyO3Option::Name)
352        } else {
353            Err(lookahead.error())
354        }
355    }
356}
357
358impl FieldPyO3Options {
359    fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
360        let mut options = FieldPyO3Options {
361            get: None,
362            set: None,
363            name: None,
364        };
365
366        for option in take_pyo3_options(attrs)? {
367            match option {
368                FieldPyO3Option::Get(kw) => {
369                    if options.get.replace(Annotated::Field(kw)).is_some() {
370                        return Err(syn::Error::new(kw.span(), UNIQUE_GET));
371                    }
372                }
373                FieldPyO3Option::Set(kw) => {
374                    if options.set.replace(Annotated::Field(kw)).is_some() {
375                        return Err(syn::Error::new(kw.span(), UNIQUE_SET));
376                    }
377                }
378                FieldPyO3Option::Name(name) => {
379                    if options.name.replace(name).is_some() {
380                        return Err(syn::Error::new(options.name.span(), UNIQUE_NAME));
381                    }
382                }
383            }
384        }
385
386        Ok(options)
387    }
388}
389
390fn get_class_python_name<'a>(cls: &'a syn::Ident, args: &'a PyClassArgs) -> Cow<'a, syn::Ident> {
391    args.options
392        .name
393        .as_ref()
394        .map(|name_attr| Cow::Borrowed(&name_attr.value.0))
395        .unwrap_or_else(|| Cow::Owned(cls.unraw()))
396}
397
398fn impl_class(
399    cls: &syn::Ident,
400    args: &PyClassArgs,
401    doc: PythonDoc,
402    field_options: Vec<(&syn::Field, FieldPyO3Options)>,
403    methods_type: PyClassMethodsType,
404    ctx: &Ctx,
405) -> syn::Result<TokenStream> {
406    let Ctx { pyo3_path, .. } = ctx;
407    let pytypeinfo_impl = impl_pytypeinfo(cls, args, ctx);
408
409    if let Some(str) = &args.options.str {
410        if str.value.is_some() {
411            // check if any renaming is present
412            let no_naming_conflict = field_options.iter().all(|x| x.1.name.is_none())
413                & args.options.name.is_none()
414                & args.options.rename_all.is_none();
415            ensure_spanned!(no_naming_conflict, str.value.span() => "The format string syntax is incompatible with any renaming via `name` or `rename_all`");
416        }
417    }
418
419    let (default_str, default_str_slot) =
420        implement_pyclass_str(&args.options, &syn::parse_quote!(#cls), ctx);
421
422    let (default_richcmp, default_richcmp_slot) =
423        pyclass_richcmp(&args.options, &syn::parse_quote!(#cls), ctx)?;
424
425    let (default_hash, default_hash_slot) =
426        pyclass_hash(&args.options, &syn::parse_quote!(#cls), ctx)?;
427
428    let mut slots = Vec::new();
429    slots.extend(default_richcmp_slot);
430    slots.extend(default_hash_slot);
431    slots.extend(default_str_slot);
432
433    let py_class_impl = PyClassImplsBuilder::new(
434        cls,
435        args,
436        methods_type,
437        descriptors_to_items(
438            cls,
439            args.options.rename_all.as_ref(),
440            args.options.frozen,
441            field_options,
442            ctx,
443        )?,
444        slots,
445    )
446    .doc(doc)
447    .impl_all(ctx)?;
448
449    Ok(quote! {
450        impl #pyo3_path::types::DerefToPyAny for #cls {}
451
452        #pytypeinfo_impl
453
454        #py_class_impl
455
456        #[doc(hidden)]
457        #[allow(non_snake_case)]
458        impl #cls {
459            #default_richcmp
460            #default_hash
461            #default_str
462        }
463    })
464}
465
466enum PyClassEnum<'a> {
467    Simple(PyClassSimpleEnum<'a>),
468    Complex(PyClassComplexEnum<'a>),
469}
470
471impl<'a> PyClassEnum<'a> {
472    fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
473        let has_only_unit_variants = enum_
474            .variants
475            .iter()
476            .all(|variant| matches!(variant.fields, syn::Fields::Unit));
477
478        Ok(if has_only_unit_variants {
479            let simple_enum = PyClassSimpleEnum::new(enum_)?;
480            Self::Simple(simple_enum)
481        } else {
482            let complex_enum = PyClassComplexEnum::new(enum_)?;
483            Self::Complex(complex_enum)
484        })
485    }
486}
487
488pub fn build_py_enum(
489    enum_: &mut syn::ItemEnum,
490    mut args: PyClassArgs,
491    method_type: PyClassMethodsType,
492) -> syn::Result<TokenStream> {
493    args.options.take_pyo3_options(&mut enum_.attrs)?;
494
495    let ctx = &Ctx::new(&args.options.krate, None);
496    if let Some(extends) = &args.options.extends {
497        bail_spanned!(extends.span() => "enums can't extend from other classes");
498    } else if let Some(subclass) = &args.options.subclass {
499        bail_spanned!(subclass.span() => "enums can't be inherited by other classes");
500    } else if enum_.variants.is_empty() {
501        bail_spanned!(enum_.brace_token.span.join() => "#[pyclass] can't be used on enums without any variants");
502    }
503
504    let doc = utils::get_doc(&enum_.attrs, None, ctx);
505    let enum_ = PyClassEnum::new(enum_)?;
506    impl_enum(enum_, &args, doc, method_type, ctx)
507}
508
509struct PyClassSimpleEnum<'a> {
510    ident: &'a syn::Ident,
511    // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
512    // This matters when the underlying representation may not fit in `isize`.
513    repr_type: syn::Ident,
514    variants: Vec<PyClassEnumUnitVariant<'a>>,
515}
516
517impl<'a> PyClassSimpleEnum<'a> {
518    fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
519        fn is_numeric_type(t: &syn::Ident) -> bool {
520            [
521                "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
522                "isize",
523            ]
524            .iter()
525            .any(|&s| t == s)
526        }
527
528        fn extract_unit_variant_data(
529            variant: &mut syn::Variant,
530        ) -> syn::Result<PyClassEnumUnitVariant<'_>> {
531            use syn::Fields;
532            let ident = match &variant.fields {
533                Fields::Unit => &variant.ident,
534                _ => bail_spanned!(variant.span() => "Must be a unit variant."),
535            };
536            let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
537            let cfg_attrs = get_cfg_attributes(&variant.attrs);
538            Ok(PyClassEnumUnitVariant {
539                ident,
540                options,
541                cfg_attrs,
542            })
543        }
544
545        let ident = &enum_.ident;
546
547        // According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
548        // "Under the default representation, the specified discriminant is interpreted as an isize
549        // value", so `isize` should be enough by default.
550        let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
551        if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) {
552            let args =
553                attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
554            if let Some(ident) = args
555                .into_iter()
556                .filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
557                .find(is_numeric_type)
558            {
559                repr_type = ident;
560            }
561        }
562
563        let variants: Vec<_> = enum_
564            .variants
565            .iter_mut()
566            .map(extract_unit_variant_data)
567            .collect::<syn::Result<_>>()?;
568        Ok(Self {
569            ident,
570            repr_type,
571            variants,
572        })
573    }
574}
575
576struct PyClassComplexEnum<'a> {
577    ident: &'a syn::Ident,
578    variants: Vec<PyClassEnumVariant<'a>>,
579}
580
581impl<'a> PyClassComplexEnum<'a> {
582    fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
583        let witness = enum_
584            .variants
585            .iter()
586            .find(|variant| !matches!(variant.fields, syn::Fields::Unit))
587            .expect("complex enum has a non-unit variant")
588            .ident
589            .to_owned();
590
591        let extract_variant_data =
592            |variant: &'a mut syn::Variant| -> syn::Result<PyClassEnumVariant<'a>> {
593                use syn::Fields;
594                let ident = &variant.ident;
595                let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
596
597                let variant = match &variant.fields {
598                    Fields::Unit => {
599                        bail_spanned!(variant.span() => format!(
600                            "Unit variant `{ident}` is not yet supported in a complex enum\n\
601                            = help: change to an empty tuple variant instead: `{ident}()`\n\
602                            = note: the enum is complex because of non-unit variant `{witness}`",
603                            ident=ident, witness=witness))
604                    }
605                    Fields::Named(fields) => {
606                        let fields = fields
607                            .named
608                            .iter()
609                            .map(|field| PyClassEnumVariantNamedField {
610                                ident: field.ident.as_ref().expect("named field has an identifier"),
611                                ty: &field.ty,
612                                span: field.span(),
613                            })
614                            .collect();
615
616                        PyClassEnumVariant::Struct(PyClassEnumStructVariant {
617                            ident,
618                            fields,
619                            options,
620                        })
621                    }
622                    Fields::Unnamed(types) => {
623                        let fields = types
624                            .unnamed
625                            .iter()
626                            .map(|field| PyClassEnumVariantUnnamedField {
627                                ty: &field.ty,
628                                span: field.span(),
629                            })
630                            .collect();
631
632                        PyClassEnumVariant::Tuple(PyClassEnumTupleVariant {
633                            ident,
634                            fields,
635                            options,
636                        })
637                    }
638                };
639
640                Ok(variant)
641            };
642
643        let ident = &enum_.ident;
644
645        let variants: Vec<_> = enum_
646            .variants
647            .iter_mut()
648            .map(extract_variant_data)
649            .collect::<syn::Result<_>>()?;
650
651        Ok(Self { ident, variants })
652    }
653}
654
655enum PyClassEnumVariant<'a> {
656    // TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>),
657    Struct(PyClassEnumStructVariant<'a>),
658    Tuple(PyClassEnumTupleVariant<'a>),
659}
660
661trait EnumVariant {
662    fn get_ident(&self) -> &syn::Ident;
663    fn get_options(&self) -> &EnumVariantPyO3Options;
664
665    fn get_python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> {
666        self.get_options()
667            .name
668            .as_ref()
669            .map(|name_attr| Cow::Borrowed(&name_attr.value.0))
670            .unwrap_or_else(|| {
671                let name = self.get_ident().unraw();
672                if let Some(attr) = &args.options.rename_all {
673                    let new_name = apply_renaming_rule(attr.value.rule, &name.to_string());
674                    Cow::Owned(Ident::new(&new_name, Span::call_site()))
675                } else {
676                    Cow::Owned(name)
677                }
678            })
679    }
680}
681
682impl EnumVariant for PyClassEnumVariant<'_> {
683    fn get_ident(&self) -> &syn::Ident {
684        match self {
685            PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident,
686            PyClassEnumVariant::Tuple(tuple_variant) => tuple_variant.ident,
687        }
688    }
689
690    fn get_options(&self) -> &EnumVariantPyO3Options {
691        match self {
692            PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options,
693            PyClassEnumVariant::Tuple(tuple_variant) => &tuple_variant.options,
694        }
695    }
696}
697
698/// A unit variant has no fields
699struct PyClassEnumUnitVariant<'a> {
700    ident: &'a syn::Ident,
701    options: EnumVariantPyO3Options,
702    cfg_attrs: Vec<&'a syn::Attribute>,
703}
704
705impl EnumVariant for PyClassEnumUnitVariant<'_> {
706    fn get_ident(&self) -> &syn::Ident {
707        self.ident
708    }
709
710    fn get_options(&self) -> &EnumVariantPyO3Options {
711        &self.options
712    }
713}
714
715/// A struct variant has named fields
716struct PyClassEnumStructVariant<'a> {
717    ident: &'a syn::Ident,
718    fields: Vec<PyClassEnumVariantNamedField<'a>>,
719    options: EnumVariantPyO3Options,
720}
721
722struct PyClassEnumTupleVariant<'a> {
723    ident: &'a syn::Ident,
724    fields: Vec<PyClassEnumVariantUnnamedField<'a>>,
725    options: EnumVariantPyO3Options,
726}
727
728struct PyClassEnumVariantNamedField<'a> {
729    ident: &'a syn::Ident,
730    ty: &'a syn::Type,
731    span: Span,
732}
733
734struct PyClassEnumVariantUnnamedField<'a> {
735    ty: &'a syn::Type,
736    span: Span,
737}
738
739/// `#[pyo3()]` options for pyclass enum variants
740#[derive(Clone, Default)]
741struct EnumVariantPyO3Options {
742    name: Option<NameAttribute>,
743    constructor: Option<ConstructorAttribute>,
744}
745
746enum EnumVariantPyO3Option {
747    Name(NameAttribute),
748    Constructor(ConstructorAttribute),
749}
750
751impl Parse for EnumVariantPyO3Option {
752    fn parse(input: ParseStream<'_>) -> Result<Self> {
753        let lookahead = input.lookahead1();
754        if lookahead.peek(attributes::kw::name) {
755            input.parse().map(EnumVariantPyO3Option::Name)
756        } else if lookahead.peek(attributes::kw::constructor) {
757            input.parse().map(EnumVariantPyO3Option::Constructor)
758        } else {
759            Err(lookahead.error())
760        }
761    }
762}
763
764impl EnumVariantPyO3Options {
765    fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
766        let mut options = EnumVariantPyO3Options::default();
767
768        take_pyo3_options(attrs)?
769            .into_iter()
770            .try_for_each(|option| options.set_option(option))?;
771
772        Ok(options)
773    }
774
775    fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> {
776        macro_rules! set_option {
777            ($key:ident) => {
778                {
779                    ensure_spanned!(
780                        self.$key.is_none(),
781                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
782                    );
783                    self.$key = Some($key);
784                }
785            };
786        }
787
788        match option {
789            EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor),
790            EnumVariantPyO3Option::Name(name) => set_option!(name),
791        }
792        Ok(())
793    }
794}
795
796// todo(remove this dead code allowance once __repr__ is implemented
797#[allow(dead_code)]
798pub enum PyFmtName {
799    Str,
800    Repr,
801}
802
803fn implement_py_formatting(
804    ty: &syn::Type,
805    ctx: &Ctx,
806    option: &StrFormatterAttribute,
807) -> (ImplItemFn, MethodAndSlotDef) {
808    let mut fmt_impl = match &option.value {
809        Some(opt) => {
810            let fmt = &opt.fmt;
811            let args = &opt
812                .args
813                .iter()
814                .map(|member| quote! {self.#member})
815                .collect::<Vec<TokenStream>>();
816            let fmt_impl: ImplItemFn = syn::parse_quote! {
817                fn __pyo3__generated____str__(&self) -> ::std::string::String {
818                    ::std::format!(#fmt, #(#args, )*)
819                }
820            };
821            fmt_impl
822        }
823        None => {
824            let fmt_impl: syn::ImplItemFn = syn::parse_quote! {
825                fn __pyo3__generated____str__(&self) -> ::std::string::String {
826                    ::std::format!("{}", &self)
827                }
828            };
829            fmt_impl
830        }
831    };
832    let fmt_slot = generate_protocol_slot(ty, &mut fmt_impl, &__STR__, "__str__", ctx).unwrap();
833    (fmt_impl, fmt_slot)
834}
835
836fn implement_pyclass_str(
837    options: &PyClassPyO3Options,
838    ty: &syn::Type,
839    ctx: &Ctx,
840) -> (Option<ImplItemFn>, Option<MethodAndSlotDef>) {
841    match &options.str {
842        Some(option) => {
843            let (default_str, default_str_slot) = implement_py_formatting(ty, ctx, option);
844            (Some(default_str), Some(default_str_slot))
845        }
846        _ => (None, None),
847    }
848}
849
850fn impl_enum(
851    enum_: PyClassEnum<'_>,
852    args: &PyClassArgs,
853    doc: PythonDoc,
854    methods_type: PyClassMethodsType,
855    ctx: &Ctx,
856) -> Result<TokenStream> {
857    if let Some(str_fmt) = &args.options.str {
858        ensure_spanned!(str_fmt.value.is_none(), str_fmt.value.span() => "The format string syntax cannot be used with enums")
859    }
860
861    match enum_ {
862        PyClassEnum::Simple(simple_enum) => {
863            impl_simple_enum(simple_enum, args, doc, methods_type, ctx)
864        }
865        PyClassEnum::Complex(complex_enum) => {
866            impl_complex_enum(complex_enum, args, doc, methods_type, ctx)
867        }
868    }
869}
870
871fn impl_simple_enum(
872    simple_enum: PyClassSimpleEnum<'_>,
873    args: &PyClassArgs,
874    doc: PythonDoc,
875    methods_type: PyClassMethodsType,
876    ctx: &Ctx,
877) -> Result<TokenStream> {
878    let cls = simple_enum.ident;
879    let ty: syn::Type = syn::parse_quote!(#cls);
880    let variants = simple_enum.variants;
881    let pytypeinfo = impl_pytypeinfo(cls, args, ctx);
882
883    for variant in &variants {
884        ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant");
885    }
886
887    let variant_cfg_check = generate_cfg_check(&variants, cls);
888
889    let (default_repr, default_repr_slot) = {
890        let variants_repr = variants.iter().map(|variant| {
891            let variant_name = variant.ident;
892            let cfg_attrs = &variant.cfg_attrs;
893            // Assuming all variants are unit variants because they are the only type we support.
894            let repr = format!(
895                "{}.{}",
896                get_class_python_name(cls, args),
897                variant.get_python_name(args),
898            );
899            quote! { #(#cfg_attrs)* #cls::#variant_name => #repr, }
900        });
901        let mut repr_impl: syn::ImplItemFn = syn::parse_quote! {
902            fn __pyo3__repr__(&self) -> &'static str {
903                match *self {
904                    #(#variants_repr)*
905                }
906            }
907        };
908        let repr_slot =
909            generate_default_protocol_slot(&ty, &mut repr_impl, &__REPR__, ctx).unwrap();
910        (repr_impl, repr_slot)
911    };
912
913    let (default_str, default_str_slot) = implement_pyclass_str(&args.options, &ty, ctx);
914
915    let repr_type = &simple_enum.repr_type;
916
917    let (default_int, default_int_slot) = {
918        // This implementation allows us to convert &T to #repr_type without implementing `Copy`
919        let variants_to_int = variants.iter().map(|variant| {
920            let variant_name = variant.ident;
921            let cfg_attrs = &variant.cfg_attrs;
922            quote! { #(#cfg_attrs)* #cls::#variant_name => #cls::#variant_name as #repr_type, }
923        });
924        let mut int_impl: syn::ImplItemFn = syn::parse_quote! {
925            fn __pyo3__int__(&self) -> #repr_type {
926                match *self {
927                    #(#variants_to_int)*
928                }
929            }
930        };
931        let int_slot = generate_default_protocol_slot(&ty, &mut int_impl, &__INT__, ctx).unwrap();
932        (int_impl, int_slot)
933    };
934
935    let (default_richcmp, default_richcmp_slot) =
936        pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx)?;
937    let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;
938
939    let mut default_slots = vec![default_repr_slot, default_int_slot];
940    default_slots.extend(default_richcmp_slot);
941    default_slots.extend(default_hash_slot);
942    default_slots.extend(default_str_slot);
943
944    let pyclass_impls = PyClassImplsBuilder::new(
945        cls,
946        args,
947        methods_type,
948        simple_enum_default_methods(
949            cls,
950            variants
951                .iter()
952                .map(|v| (v.ident, v.get_python_name(args), &v.cfg_attrs)),
953            ctx,
954        ),
955        default_slots,
956    )
957    .doc(doc)
958    .impl_all(ctx)?;
959
960    Ok(quote! {
961        #variant_cfg_check
962
963        #pytypeinfo
964
965        #pyclass_impls
966
967        #[doc(hidden)]
968        #[allow(non_snake_case)]
969        impl #cls {
970            #default_repr
971            #default_int
972            #default_richcmp
973            #default_hash
974            #default_str
975        }
976    })
977}
978
979fn impl_complex_enum(
980    complex_enum: PyClassComplexEnum<'_>,
981    args: &PyClassArgs,
982    doc: PythonDoc,
983    methods_type: PyClassMethodsType,
984    ctx: &Ctx,
985) -> Result<TokenStream> {
986    let Ctx { pyo3_path, .. } = ctx;
987    let cls = complex_enum.ident;
988    let ty: syn::Type = syn::parse_quote!(#cls);
989
990    // Need to rig the enum PyClass options
991    let args = {
992        let mut rigged_args = args.clone();
993        // Needs to be frozen to disallow `&mut self` methods, which could break a runtime invariant
994        rigged_args.options.frozen = parse_quote!(frozen);
995        // Needs to be subclassable by the variant PyClasses
996        rigged_args.options.subclass = parse_quote!(subclass);
997        rigged_args
998    };
999
1000    let ctx = &Ctx::new(&args.options.krate, None);
1001    let cls = complex_enum.ident;
1002    let variants = complex_enum.variants;
1003    let pytypeinfo = impl_pytypeinfo(cls, &args, ctx);
1004
1005    let (default_richcmp, default_richcmp_slot) = pyclass_richcmp(&args.options, &ty, ctx)?;
1006    let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;
1007
1008    let (default_str, default_str_slot) = implement_pyclass_str(&args.options, &ty, ctx);
1009
1010    let mut default_slots = vec![];
1011    default_slots.extend(default_richcmp_slot);
1012    default_slots.extend(default_hash_slot);
1013    default_slots.extend(default_str_slot);
1014
1015    let impl_builder = PyClassImplsBuilder::new(
1016        cls,
1017        &args,
1018        methods_type,
1019        complex_enum_default_methods(
1020            cls,
1021            variants
1022                .iter()
1023                .map(|v| (v.get_ident(), v.get_python_name(&args))),
1024            ctx,
1025        ),
1026        default_slots,
1027    )
1028    .doc(doc);
1029
1030    // Need to customize the into_py impl so that it returns the variant PyClass
1031    let enum_into_py_impl = {
1032        let match_arms: Vec<TokenStream> = variants
1033            .iter()
1034            .map(|variant| {
1035                let variant_ident = variant.get_ident();
1036                let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
1037                quote! {
1038                    #cls::#variant_ident { .. } => {
1039                        let pyclass_init = <#pyo3_path::PyClassInitializer<Self> as ::std::convert::From<Self>>::from(self).add_subclass(#variant_cls);
1040                        let variant_value = #pyo3_path::Py::new(py, pyclass_init).unwrap();
1041                        #pyo3_path::IntoPy::into_py(variant_value, py)
1042                    }
1043                }
1044            })
1045            .collect();
1046
1047        quote! {
1048            #[allow(deprecated)]
1049            impl #pyo3_path::IntoPy<#pyo3_path::PyObject> for #cls {
1050                fn into_py(self, py: #pyo3_path::Python) -> #pyo3_path::PyObject {
1051                    match self {
1052                        #(#match_arms)*
1053                    }
1054                }
1055            }
1056        }
1057    };
1058
1059    let enum_into_pyobject_impl = {
1060        let match_arms = variants
1061            .iter()
1062            .map(|variant| {
1063                let variant_ident = variant.get_ident();
1064                let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
1065                quote! {
1066                    #cls::#variant_ident { .. } => {
1067                        let pyclass_init = <#pyo3_path::PyClassInitializer<Self> as ::std::convert::From<Self>>::from(self).add_subclass(#variant_cls);
1068                        unsafe { #pyo3_path::Bound::new(py, pyclass_init).map(|b| #pyo3_path::types::PyAnyMethods::downcast_into_unchecked(b.into_any())) }
1069                    }
1070                }
1071            });
1072
1073        quote! {
1074            impl<'py> #pyo3_path::conversion::IntoPyObject<'py> for #cls {
1075                type Target = Self;
1076                type Output = #pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>;
1077                type Error = #pyo3_path::PyErr;
1078
1079                fn into_pyobject(self, py: #pyo3_path::Python<'py>) -> ::std::result::Result<
1080                    <Self as #pyo3_path::conversion::IntoPyObject>::Output,
1081                    <Self as #pyo3_path::conversion::IntoPyObject>::Error,
1082                > {
1083                    match self {
1084                        #(#match_arms)*
1085                    }
1086                }
1087            }
1088        }
1089    };
1090
1091    let pyclass_impls: TokenStream = [
1092        impl_builder.impl_pyclass(ctx),
1093        impl_builder.impl_extractext(ctx),
1094        enum_into_py_impl,
1095        enum_into_pyobject_impl,
1096        impl_builder.impl_pyclassimpl(ctx)?,
1097        impl_builder.impl_add_to_module(ctx),
1098        impl_builder.impl_freelist(ctx),
1099    ]
1100    .into_iter()
1101    .collect();
1102
1103    let mut variant_cls_zsts = vec![];
1104    let mut variant_cls_pytypeinfos = vec![];
1105    let mut variant_cls_pyclass_impls = vec![];
1106    let mut variant_cls_impls = vec![];
1107    for variant in variants {
1108        let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
1109
1110        let variant_cls_zst = quote! {
1111            #[doc(hidden)]
1112            #[allow(non_camel_case_types)]
1113            struct #variant_cls;
1114        };
1115        variant_cls_zsts.push(variant_cls_zst);
1116
1117        let variant_args = PyClassArgs {
1118            class_kind: PyClassKind::Struct,
1119            // TODO(mkovaxx): propagate variant.options
1120            options: {
1121                let mut rigged_options: PyClassPyO3Options = parse_quote!(extends = #cls, frozen);
1122                // If a specific module was given to the base class, use it for all variants.
1123                rigged_options.module.clone_from(&args.options.module);
1124                rigged_options
1125            },
1126        };
1127
1128        let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, ctx);
1129        variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
1130
1131        let (variant_cls_impl, field_getters, mut slots) =
1132            impl_complex_enum_variant_cls(cls, &variant, ctx)?;
1133        variant_cls_impls.push(variant_cls_impl);
1134
1135        let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
1136        slots.push(variant_new);
1137
1138        let pyclass_impl = PyClassImplsBuilder::new(
1139            &variant_cls,
1140            &variant_args,
1141            methods_type,
1142            field_getters,
1143            slots,
1144        )
1145        .impl_all(ctx)?;
1146
1147        variant_cls_pyclass_impls.push(pyclass_impl);
1148    }
1149
1150    Ok(quote! {
1151        #pytypeinfo
1152
1153        #pyclass_impls
1154
1155        #[doc(hidden)]
1156        #[allow(non_snake_case)]
1157        impl #cls {
1158            #default_richcmp
1159            #default_hash
1160            #default_str
1161        }
1162
1163        #(#variant_cls_zsts)*
1164
1165        #(#variant_cls_pytypeinfos)*
1166
1167        #(#variant_cls_pyclass_impls)*
1168
1169        #(#variant_cls_impls)*
1170    })
1171}
1172
1173fn impl_complex_enum_variant_cls(
1174    enum_name: &syn::Ident,
1175    variant: &PyClassEnumVariant<'_>,
1176    ctx: &Ctx,
1177) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
1178    match variant {
1179        PyClassEnumVariant::Struct(struct_variant) => {
1180            impl_complex_enum_struct_variant_cls(enum_name, struct_variant, ctx)
1181        }
1182        PyClassEnumVariant::Tuple(tuple_variant) => {
1183            impl_complex_enum_tuple_variant_cls(enum_name, tuple_variant, ctx)
1184        }
1185    }
1186}
1187
1188fn impl_complex_enum_variant_match_args(
1189    ctx @ Ctx { pyo3_path, .. }: &Ctx,
1190    variant_cls_type: &syn::Type,
1191    field_names: &[Ident],
1192) -> syn::Result<(MethodAndMethodDef, syn::ImplItemFn)> {
1193    let ident = format_ident!("__match_args__");
1194    let field_names_unraw = field_names.iter().map(|name| name.unraw());
1195    let mut match_args_impl: syn::ImplItemFn = {
1196        parse_quote! {
1197            #[classattr]
1198            fn #ident(py: #pyo3_path::Python<'_>) -> #pyo3_path::PyResult<#pyo3_path::Bound<'_, #pyo3_path::types::PyTuple>> {
1199                #pyo3_path::types::PyTuple::new::<&str, _>(py, [
1200                    #(stringify!(#field_names_unraw),)*
1201                ])
1202            }
1203        }
1204    };
1205
1206    let spec = FnSpec::parse(
1207        &mut match_args_impl.sig,
1208        &mut match_args_impl.attrs,
1209        Default::default(),
1210    )?;
1211    let variant_match_args = impl_py_class_attribute(variant_cls_type, &spec, ctx)?;
1212
1213    Ok((variant_match_args, match_args_impl))
1214}
1215
1216fn impl_complex_enum_struct_variant_cls(
1217    enum_name: &syn::Ident,
1218    variant: &PyClassEnumStructVariant<'_>,
1219    ctx: &Ctx,
1220) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
1221    let Ctx { pyo3_path, .. } = ctx;
1222    let variant_ident = &variant.ident;
1223    let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
1224    let variant_cls_type = parse_quote!(#variant_cls);
1225
1226    let mut field_names: Vec<Ident> = vec![];
1227    let mut fields_with_types: Vec<TokenStream> = vec![];
1228    let mut field_getters = vec![];
1229    let mut field_getter_impls: Vec<TokenStream> = vec![];
1230    for field in &variant.fields {
1231        let field_name = field.ident;
1232        let field_type = field.ty;
1233        let field_with_type = quote! { #field_name: #field_type };
1234
1235        let field_getter =
1236            complex_enum_variant_field_getter(&variant_cls_type, field_name, field.span, ctx)?;
1237
1238        let field_getter_impl = quote! {
1239            fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1240                #[allow(unused_imports)]
1241                use #pyo3_path::impl_::pyclass::Probe;
1242                let py = slf.py();
1243                match &*slf.into_super() {
1244                    #enum_name::#variant_ident { #field_name, .. } =>
1245                        #pyo3_path::impl_::pyclass::ConvertField::<
1246                            { #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
1247                            { #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
1248                        >::convert_field::<#field_type>(#field_name, py),
1249                    _ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
1250                }
1251            }
1252        };
1253
1254        field_names.push(field_name.clone());
1255        fields_with_types.push(field_with_type);
1256        field_getters.push(field_getter);
1257        field_getter_impls.push(field_getter_impl);
1258    }
1259
1260    let (variant_match_args, match_args_const_impl) =
1261        impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &field_names)?;
1262
1263    field_getters.push(variant_match_args);
1264
1265    let cls_impl = quote! {
1266        #[doc(hidden)]
1267        #[allow(non_snake_case)]
1268        impl #variant_cls {
1269            #[allow(clippy::too_many_arguments)]
1270            fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#fields_with_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> {
1271                let base_value = #enum_name::#variant_ident { #(#field_names,)* };
1272                <#pyo3_path::PyClassInitializer<#enum_name> as ::std::convert::From<#enum_name>>::from(base_value).add_subclass(#variant_cls)
1273            }
1274
1275            #match_args_const_impl
1276
1277            #(#field_getter_impls)*
1278        }
1279    };
1280
1281    Ok((cls_impl, field_getters, Vec::new()))
1282}
1283
1284fn impl_complex_enum_tuple_variant_field_getters(
1285    ctx: &Ctx,
1286    variant: &PyClassEnumTupleVariant<'_>,
1287    enum_name: &syn::Ident,
1288    variant_cls_type: &syn::Type,
1289    variant_ident: &&Ident,
1290    field_names: &mut Vec<Ident>,
1291    fields_types: &mut Vec<syn::Type>,
1292) -> Result<(Vec<MethodAndMethodDef>, Vec<syn::ImplItemFn>)> {
1293    let Ctx { pyo3_path, .. } = ctx;
1294
1295    let mut field_getters = vec![];
1296    let mut field_getter_impls = vec![];
1297
1298    for (index, field) in variant.fields.iter().enumerate() {
1299        let field_name = format_ident!("_{}", index);
1300        let field_type = field.ty;
1301
1302        let field_getter =
1303            complex_enum_variant_field_getter(variant_cls_type, &field_name, field.span, ctx)?;
1304
1305        // Generate the match arms needed to destructure the tuple and access the specific field
1306        let field_access_tokens: Vec<_> = (0..variant.fields.len())
1307            .map(|i| {
1308                if i == index {
1309                    quote! { val }
1310                } else {
1311                    quote! { _ }
1312                }
1313            })
1314            .collect();
1315        let field_getter_impl: syn::ImplItemFn = parse_quote! {
1316            fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1317                #[allow(unused_imports)]
1318                use #pyo3_path::impl_::pyclass::Probe;
1319                let py = slf.py();
1320                match &*slf.into_super() {
1321                    #enum_name::#variant_ident ( #(#field_access_tokens), *) =>
1322                        #pyo3_path::impl_::pyclass::ConvertField::<
1323                            { #pyo3_path::impl_::pyclass::IsIntoPyObjectRef::<#field_type>::VALUE },
1324                            { #pyo3_path::impl_::pyclass::IsIntoPyObject::<#field_type>::VALUE },
1325                        >::convert_field::<#field_type>(val, py),
1326                    _ => ::core::unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
1327                }
1328            }
1329        };
1330
1331        field_names.push(field_name);
1332        fields_types.push(field_type.clone());
1333        field_getters.push(field_getter);
1334        field_getter_impls.push(field_getter_impl);
1335    }
1336
1337    Ok((field_getters, field_getter_impls))
1338}
1339
1340fn impl_complex_enum_tuple_variant_len(
1341    ctx: &Ctx,
1342
1343    variant_cls_type: &syn::Type,
1344    num_fields: usize,
1345) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> {
1346    let Ctx { pyo3_path, .. } = ctx;
1347
1348    let mut len_method_impl: syn::ImplItemFn = parse_quote! {
1349        fn __len__(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<usize> {
1350            ::std::result::Result::Ok(#num_fields)
1351        }
1352    };
1353
1354    let variant_len =
1355        generate_default_protocol_slot(variant_cls_type, &mut len_method_impl, &__LEN__, ctx)?;
1356
1357    Ok((variant_len, len_method_impl))
1358}
1359
1360fn impl_complex_enum_tuple_variant_getitem(
1361    ctx: &Ctx,
1362    variant_cls: &syn::Ident,
1363    variant_cls_type: &syn::Type,
1364    num_fields: usize,
1365) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> {
1366    let Ctx { pyo3_path, .. } = ctx;
1367
1368    let match_arms: Vec<_> = (0..num_fields)
1369        .map(|i| {
1370            let field_access = format_ident!("_{}", i);
1371            quote! { #i =>
1372                #pyo3_path::IntoPyObjectExt::into_py_any(#variant_cls::#field_access(slf)?, py)
1373            }
1374        })
1375        .collect();
1376
1377    let mut get_item_method_impl: syn::ImplItemFn = parse_quote! {
1378        fn __getitem__(slf: #pyo3_path::PyRef<Self>, idx: usize) -> #pyo3_path::PyResult< #pyo3_path::PyObject> {
1379            let py = slf.py();
1380            match idx {
1381                #( #match_arms, )*
1382                _ => ::std::result::Result::Err(#pyo3_path::exceptions::PyIndexError::new_err("tuple index out of range")),
1383            }
1384        }
1385    };
1386
1387    let variant_getitem = generate_default_protocol_slot(
1388        variant_cls_type,
1389        &mut get_item_method_impl,
1390        &__GETITEM__,
1391        ctx,
1392    )?;
1393
1394    Ok((variant_getitem, get_item_method_impl))
1395}
1396
1397fn impl_complex_enum_tuple_variant_cls(
1398    enum_name: &syn::Ident,
1399    variant: &PyClassEnumTupleVariant<'_>,
1400    ctx: &Ctx,
1401) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
1402    let Ctx { pyo3_path, .. } = ctx;
1403    let variant_ident = &variant.ident;
1404    let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
1405    let variant_cls_type = parse_quote!(#variant_cls);
1406
1407    let mut slots = vec![];
1408
1409    // represents the index of the field
1410    let mut field_names: Vec<Ident> = vec![];
1411    let mut field_types: Vec<syn::Type> = vec![];
1412
1413    let (mut field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters(
1414        ctx,
1415        variant,
1416        enum_name,
1417        &variant_cls_type,
1418        variant_ident,
1419        &mut field_names,
1420        &mut field_types,
1421    )?;
1422
1423    let num_fields = variant.fields.len();
1424
1425    let (variant_len, len_method_impl) =
1426        impl_complex_enum_tuple_variant_len(ctx, &variant_cls_type, num_fields)?;
1427
1428    slots.push(variant_len);
1429
1430    let (variant_getitem, getitem_method_impl) =
1431        impl_complex_enum_tuple_variant_getitem(ctx, &variant_cls, &variant_cls_type, num_fields)?;
1432
1433    slots.push(variant_getitem);
1434
1435    let (variant_match_args, match_args_method_impl) =
1436        impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &field_names)?;
1437
1438    field_getters.push(variant_match_args);
1439
1440    let cls_impl = quote! {
1441        #[doc(hidden)]
1442        #[allow(non_snake_case)]
1443        impl #variant_cls {
1444            #[allow(clippy::too_many_arguments)]
1445            fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#field_names : #field_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> {
1446                let base_value = #enum_name::#variant_ident ( #(#field_names,)* );
1447                <#pyo3_path::PyClassInitializer<#enum_name> as ::std::convert::From<#enum_name>>::from(base_value).add_subclass(#variant_cls)
1448            }
1449
1450            #len_method_impl
1451
1452            #getitem_method_impl
1453
1454            #match_args_method_impl
1455
1456            #(#field_getter_impls)*
1457        }
1458    };
1459
1460    Ok((cls_impl, field_getters, slots))
1461}
1462
1463fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident {
1464    format_ident!("{}_{}", enum_, variant)
1465}
1466
1467fn generate_protocol_slot(
1468    cls: &syn::Type,
1469    method: &mut syn::ImplItemFn,
1470    slot: &SlotDef,
1471    name: &str,
1472    ctx: &Ctx,
1473) -> syn::Result<MethodAndSlotDef> {
1474    let spec = FnSpec::parse(
1475        &mut method.sig,
1476        &mut Vec::new(),
1477        PyFunctionOptions::default(),
1478    )
1479    .unwrap();
1480    slot.generate_type_slot(&syn::parse_quote!(#cls), &spec, name, ctx)
1481}
1482
1483fn generate_default_protocol_slot(
1484    cls: &syn::Type,
1485    method: &mut syn::ImplItemFn,
1486    slot: &SlotDef,
1487    ctx: &Ctx,
1488) -> syn::Result<MethodAndSlotDef> {
1489    let spec = FnSpec::parse(
1490        &mut method.sig,
1491        &mut Vec::new(),
1492        PyFunctionOptions::default(),
1493    )
1494    .unwrap();
1495    let name = spec.name.to_string();
1496    slot.generate_type_slot(
1497        &syn::parse_quote!(#cls),
1498        &spec,
1499        &format!("__default_{}__", name),
1500        ctx,
1501    )
1502}
1503
1504fn simple_enum_default_methods<'a>(
1505    cls: &'a syn::Ident,
1506    unit_variant_names: impl IntoIterator<
1507        Item = (
1508            &'a syn::Ident,
1509            Cow<'a, syn::Ident>,
1510            &'a Vec<&'a syn::Attribute>,
1511        ),
1512    >,
1513    ctx: &Ctx,
1514) -> Vec<MethodAndMethodDef> {
1515    let cls_type = syn::parse_quote!(#cls);
1516    let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
1517        rust_ident: var_ident.clone(),
1518        attributes: ConstAttributes {
1519            is_class_attr: true,
1520            name: Some(NameAttribute {
1521                kw: syn::parse_quote! { name },
1522                value: NameLitStr(py_ident.clone()),
1523            }),
1524        },
1525    };
1526    unit_variant_names
1527        .into_iter()
1528        .map(|(var, py_name, attrs)| {
1529            let method = gen_py_const(&cls_type, &variant_to_attribute(var, &py_name), ctx);
1530            let associated_method_tokens = method.associated_method;
1531            let method_def_tokens = method.method_def;
1532
1533            let associated_method = quote! {
1534                #(#attrs)*
1535                #associated_method_tokens
1536            };
1537            let method_def = quote! {
1538                #(#attrs)*
1539                #method_def_tokens
1540            };
1541
1542            MethodAndMethodDef {
1543                associated_method,
1544                method_def,
1545            }
1546        })
1547        .collect()
1548}
1549
1550fn complex_enum_default_methods<'a>(
1551    cls: &'a syn::Ident,
1552    variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
1553    ctx: &Ctx,
1554) -> Vec<MethodAndMethodDef> {
1555    let cls_type = syn::parse_quote!(#cls);
1556    let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
1557        rust_ident: var_ident.clone(),
1558        attributes: ConstAttributes {
1559            is_class_attr: true,
1560            name: Some(NameAttribute {
1561                kw: syn::parse_quote! { name },
1562                value: NameLitStr(py_ident.clone()),
1563            }),
1564        },
1565    };
1566    variant_names
1567        .into_iter()
1568        .map(|(var, py_name)| {
1569            gen_complex_enum_variant_attr(cls, &cls_type, &variant_to_attribute(var, &py_name), ctx)
1570        })
1571        .collect()
1572}
1573
1574pub fn gen_complex_enum_variant_attr(
1575    cls: &syn::Ident,
1576    cls_type: &syn::Type,
1577    spec: &ConstSpec,
1578    ctx: &Ctx,
1579) -> MethodAndMethodDef {
1580    let Ctx { pyo3_path, .. } = ctx;
1581    let member = &spec.rust_ident;
1582    let wrapper_ident = format_ident!("__pymethod_variant_cls_{}__", member);
1583    let python_name = spec.null_terminated_python_name(ctx);
1584
1585    let variant_cls = format_ident!("{}_{}", cls, member);
1586    let associated_method = quote! {
1587        fn #wrapper_ident(py: #pyo3_path::Python<'_>) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1588            ::std::result::Result::Ok(py.get_type::<#variant_cls>().into_any().unbind())
1589        }
1590    };
1591
1592    let method_def = quote! {
1593        #pyo3_path::impl_::pyclass::MaybeRuntimePyMethodDef::Static(
1594            #pyo3_path::impl_::pymethods::PyMethodDefType::ClassAttribute({
1595                #pyo3_path::impl_::pymethods::PyClassAttributeDef::new(
1596                    #python_name,
1597                    #cls_type::#wrapper_ident
1598                )
1599            })
1600        )
1601    };
1602
1603    MethodAndMethodDef {
1604        associated_method,
1605        method_def,
1606    }
1607}
1608
1609fn complex_enum_variant_new<'a>(
1610    cls: &'a syn::Ident,
1611    variant: PyClassEnumVariant<'a>,
1612    ctx: &Ctx,
1613) -> Result<MethodAndSlotDef> {
1614    match variant {
1615        PyClassEnumVariant::Struct(struct_variant) => {
1616            complex_enum_struct_variant_new(cls, struct_variant, ctx)
1617        }
1618        PyClassEnumVariant::Tuple(tuple_variant) => {
1619            complex_enum_tuple_variant_new(cls, tuple_variant, ctx)
1620        }
1621    }
1622}
1623
1624fn complex_enum_struct_variant_new<'a>(
1625    cls: &'a syn::Ident,
1626    variant: PyClassEnumStructVariant<'a>,
1627    ctx: &Ctx,
1628) -> Result<MethodAndSlotDef> {
1629    let Ctx { pyo3_path, .. } = ctx;
1630    let variant_cls = format_ident!("{}_{}", cls, variant.ident);
1631    let variant_cls_type: syn::Type = parse_quote!(#variant_cls);
1632
1633    let arg_py_ident: syn::Ident = parse_quote!(py);
1634    let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>);
1635
1636    let args = {
1637        let mut args = vec![
1638            // py: Python<'_>
1639            FnArg::Py(PyArg {
1640                name: &arg_py_ident,
1641                ty: &arg_py_type,
1642            }),
1643        ];
1644
1645        for field in &variant.fields {
1646            args.push(FnArg::Regular(RegularArg {
1647                name: Cow::Borrowed(field.ident),
1648                ty: field.ty,
1649                from_py_with: None,
1650                default_value: None,
1651                option_wrapped_type: None,
1652            }));
1653        }
1654        args
1655    };
1656
1657    let signature = if let Some(constructor) = variant.options.constructor {
1658        crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
1659            args,
1660            constructor.into_signature(),
1661        )?
1662    } else {
1663        crate::pyfunction::FunctionSignature::from_arguments(args)
1664    };
1665
1666    let spec = FnSpec {
1667        tp: crate::method::FnType::FnNew,
1668        name: &format_ident!("__pymethod_constructor__"),
1669        python_name: format_ident!("__new__"),
1670        signature,
1671        convention: crate::method::CallingConvention::TpNew,
1672        text_signature: None,
1673        asyncness: None,
1674        unsafety: None,
1675    };
1676
1677    crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
1678}
1679
1680fn complex_enum_tuple_variant_new<'a>(
1681    cls: &'a syn::Ident,
1682    variant: PyClassEnumTupleVariant<'a>,
1683    ctx: &Ctx,
1684) -> Result<MethodAndSlotDef> {
1685    let Ctx { pyo3_path, .. } = ctx;
1686
1687    let variant_cls: Ident = format_ident!("{}_{}", cls, variant.ident);
1688    let variant_cls_type: syn::Type = parse_quote!(#variant_cls);
1689
1690    let arg_py_ident: syn::Ident = parse_quote!(py);
1691    let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>);
1692
1693    let args = {
1694        let mut args = vec![FnArg::Py(PyArg {
1695            name: &arg_py_ident,
1696            ty: &arg_py_type,
1697        })];
1698
1699        for (i, field) in variant.fields.iter().enumerate() {
1700            args.push(FnArg::Regular(RegularArg {
1701                name: std::borrow::Cow::Owned(format_ident!("_{}", i)),
1702                ty: field.ty,
1703                from_py_with: None,
1704                default_value: None,
1705                option_wrapped_type: None,
1706            }));
1707        }
1708        args
1709    };
1710
1711    let signature = if let Some(constructor) = variant.options.constructor {
1712        crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
1713            args,
1714            constructor.into_signature(),
1715        )?
1716    } else {
1717        crate::pyfunction::FunctionSignature::from_arguments(args)
1718    };
1719
1720    let spec = FnSpec {
1721        tp: crate::method::FnType::FnNew,
1722        name: &format_ident!("__pymethod_constructor__"),
1723        python_name: format_ident!("__new__"),
1724        signature,
1725        convention: crate::method::CallingConvention::TpNew,
1726        text_signature: None,
1727        asyncness: None,
1728        unsafety: None,
1729    };
1730
1731    crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
1732}
1733
1734fn complex_enum_variant_field_getter<'a>(
1735    variant_cls_type: &'a syn::Type,
1736    field_name: &'a syn::Ident,
1737    field_span: Span,
1738    ctx: &Ctx,
1739) -> Result<MethodAndMethodDef> {
1740    let signature = crate::pyfunction::FunctionSignature::from_arguments(vec![]);
1741
1742    let self_type = crate::method::SelfType::TryFromBoundRef(field_span);
1743
1744    let spec = FnSpec {
1745        tp: crate::method::FnType::Getter(self_type.clone()),
1746        name: field_name,
1747        python_name: field_name.unraw(),
1748        signature,
1749        convention: crate::method::CallingConvention::Noargs,
1750        text_signature: None,
1751        asyncness: None,
1752        unsafety: None,
1753    };
1754
1755    let property_type = crate::pymethod::PropertyType::Function {
1756        self_type: &self_type,
1757        spec: &spec,
1758        doc: crate::get_doc(&[], None, ctx),
1759    };
1760
1761    let getter = crate::pymethod::impl_py_getter_def(variant_cls_type, property_type, ctx)?;
1762    Ok(getter)
1763}
1764
1765fn descriptors_to_items(
1766    cls: &syn::Ident,
1767    rename_all: Option<&RenameAllAttribute>,
1768    frozen: Option<frozen>,
1769    field_options: Vec<(&syn::Field, FieldPyO3Options)>,
1770    ctx: &Ctx,
1771) -> syn::Result<Vec<MethodAndMethodDef>> {
1772    let ty = syn::parse_quote!(#cls);
1773    let mut items = Vec::new();
1774    for (field_index, (field, options)) in field_options.into_iter().enumerate() {
1775        if let FieldPyO3Options {
1776            name: Some(name),
1777            get: None,
1778            set: None,
1779        } = options
1780        {
1781            return Err(syn::Error::new_spanned(name, USELESS_NAME));
1782        }
1783
1784        if options.get.is_some() {
1785            let getter = impl_py_getter_def(
1786                &ty,
1787                PropertyType::Descriptor {
1788                    field_index,
1789                    field,
1790                    python_name: options.name.as_ref(),
1791                    renaming_rule: rename_all.map(|rename_all| rename_all.value.rule),
1792                },
1793                ctx,
1794            )?;
1795            items.push(getter);
1796        }
1797
1798        if let Some(set) = options.set {
1799            ensure_spanned!(frozen.is_none(), set.span() => "cannot use `#[pyo3(set)]` on a `frozen` class");
1800            let setter = impl_py_setter_def(
1801                &ty,
1802                PropertyType::Descriptor {
1803                    field_index,
1804                    field,
1805                    python_name: options.name.as_ref(),
1806                    renaming_rule: rename_all.map(|rename_all| rename_all.value.rule),
1807                },
1808                ctx,
1809            )?;
1810            items.push(setter);
1811        };
1812    }
1813    Ok(items)
1814}
1815
1816fn impl_pytypeinfo(cls: &syn::Ident, attr: &PyClassArgs, ctx: &Ctx) -> TokenStream {
1817    let Ctx { pyo3_path, .. } = ctx;
1818    let cls_name = get_class_python_name(cls, attr).to_string();
1819
1820    let module = if let Some(ModuleAttribute { value, .. }) = &attr.options.module {
1821        quote! { ::core::option::Option::Some(#value) }
1822    } else {
1823        quote! { ::core::option::Option::None }
1824    };
1825
1826    quote! {
1827        unsafe impl #pyo3_path::type_object::PyTypeInfo for #cls {
1828            const NAME: &'static str = #cls_name;
1829            const MODULE: ::std::option::Option<&'static str> = #module;
1830
1831            #[inline]
1832            fn type_object_raw(py: #pyo3_path::Python<'_>) -> *mut #pyo3_path::ffi::PyTypeObject {
1833                use #pyo3_path::prelude::PyTypeMethods;
1834                <#cls as #pyo3_path::impl_::pyclass::PyClassImpl>::lazy_type_object()
1835                    .get_or_init(py)
1836                    .as_type_ptr()
1837            }
1838        }
1839    }
1840}
1841
1842fn pyclass_richcmp_arms(
1843    options: &PyClassPyO3Options,
1844    ctx: &Ctx,
1845) -> std::result::Result<TokenStream, syn::Error> {
1846    let Ctx { pyo3_path, .. } = ctx;
1847
1848    let eq_arms = options
1849        .eq
1850        .map(|eq| eq.span)
1851        .or(options.eq_int.map(|eq_int| eq_int.span))
1852        .map(|span| {
1853            quote_spanned! { span =>
1854                #pyo3_path::pyclass::CompareOp::Eq => {
1855                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val == other, py)
1856                },
1857                #pyo3_path::pyclass::CompareOp::Ne => {
1858                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val != other, py)
1859                },
1860            }
1861        })
1862        .unwrap_or_default();
1863
1864    if let Some(ord) = options.ord {
1865        ensure_spanned!(options.eq.is_some(), ord.span() => "The `ord` option requires the `eq` option.");
1866    }
1867
1868    let ord_arms = options
1869        .ord
1870        .map(|ord| {
1871            quote_spanned! { ord.span() =>
1872                #pyo3_path::pyclass::CompareOp::Gt => {
1873                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val > other, py)
1874                },
1875                #pyo3_path::pyclass::CompareOp::Lt => {
1876                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val < other, py)
1877                 },
1878                #pyo3_path::pyclass::CompareOp::Le => {
1879                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val <= other, py)
1880                 },
1881                #pyo3_path::pyclass::CompareOp::Ge => {
1882                    #pyo3_path::IntoPyObjectExt::into_py_any(self_val >= other, py)
1883                 },
1884            }
1885        })
1886        .unwrap_or_else(|| quote! { _ => ::std::result::Result::Ok(py.NotImplemented()) });
1887
1888    Ok(quote! {
1889        #eq_arms
1890        #ord_arms
1891    })
1892}
1893
1894fn pyclass_richcmp_simple_enum(
1895    options: &PyClassPyO3Options,
1896    cls: &syn::Type,
1897    repr_type: &syn::Ident,
1898    ctx: &Ctx,
1899) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
1900    let Ctx { pyo3_path, .. } = ctx;
1901
1902    if let Some(eq_int) = options.eq_int {
1903        ensure_spanned!(options.eq.is_some(), eq_int.span() => "The `eq_int` option requires the `eq` option.");
1904    }
1905
1906    if options.eq.is_none() && options.eq_int.is_none() {
1907        return Ok((None, None));
1908    }
1909
1910    let arms = pyclass_richcmp_arms(options, ctx)?;
1911
1912    let eq = options.eq.map(|eq| {
1913        quote_spanned! { eq.span() =>
1914            let self_val = self;
1915            if let ::std::result::Result::Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::<Self>(other) {
1916                let other = &*other.borrow();
1917                return match op {
1918                    #arms
1919                }
1920            }
1921        }
1922    });
1923
1924    let eq_int = options.eq_int.map(|eq_int| {
1925        quote_spanned! { eq_int.span() =>
1926            let self_val = self.__pyo3__int__();
1927            if let ::std::result::Result::Ok(other) = #pyo3_path::types::PyAnyMethods::extract::<#repr_type>(other).or_else(|_| {
1928                #pyo3_path::types::PyAnyMethods::downcast::<Self>(other).map(|o| o.borrow().__pyo3__int__())
1929            }) {
1930                return match op {
1931                    #arms
1932                }
1933            }
1934        }
1935    });
1936
1937    let mut richcmp_impl = parse_quote! {
1938        fn __pyo3__generated____richcmp__(
1939            &self,
1940            py: #pyo3_path::Python,
1941            other: &#pyo3_path::Bound<'_, #pyo3_path::PyAny>,
1942            op: #pyo3_path::pyclass::CompareOp
1943        ) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1944            #eq
1945
1946            #eq_int
1947
1948            ::std::result::Result::Ok(py.NotImplemented())
1949        }
1950    };
1951    let richcmp_slot = if options.eq.is_some() {
1952        generate_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, "__richcmp__", ctx).unwrap()
1953    } else {
1954        generate_default_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, ctx).unwrap()
1955    };
1956    Ok((Some(richcmp_impl), Some(richcmp_slot)))
1957}
1958
1959fn pyclass_richcmp(
1960    options: &PyClassPyO3Options,
1961    cls: &syn::Type,
1962    ctx: &Ctx,
1963) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
1964    let Ctx { pyo3_path, .. } = ctx;
1965    if let Some(eq_int) = options.eq_int {
1966        bail_spanned!(eq_int.span() => "`eq_int` can only be used on simple enums.")
1967    }
1968
1969    let arms = pyclass_richcmp_arms(options, ctx)?;
1970    if options.eq.is_some() {
1971        let mut richcmp_impl = parse_quote! {
1972            fn __pyo3__generated____richcmp__(
1973                &self,
1974                py: #pyo3_path::Python,
1975                other: &#pyo3_path::Bound<'_, #pyo3_path::PyAny>,
1976                op: #pyo3_path::pyclass::CompareOp
1977            ) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
1978                let self_val = self;
1979                if let ::std::result::Result::Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::<Self>(other) {
1980                    let other = &*other.borrow();
1981                    match op {
1982                        #arms
1983                    }
1984                } else {
1985                    ::std::result::Result::Ok(py.NotImplemented())
1986                }
1987            }
1988        };
1989        let richcmp_slot =
1990            generate_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, "__richcmp__", ctx)
1991                .unwrap();
1992        Ok((Some(richcmp_impl), Some(richcmp_slot)))
1993    } else {
1994        Ok((None, None))
1995    }
1996}
1997
1998fn pyclass_hash(
1999    options: &PyClassPyO3Options,
2000    cls: &syn::Type,
2001    ctx: &Ctx,
2002) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
2003    if options.hash.is_some() {
2004        ensure_spanned!(
2005            options.frozen.is_some(), options.hash.span() => "The `hash` option requires the `frozen` option.";
2006            options.eq.is_some(), options.hash.span() => "The `hash` option requires the `eq` option.";
2007        );
2008    }
2009    // FIXME: Use hash.map(...).unzip() on MSRV >= 1.66
2010    match options.hash {
2011        Some(opt) => {
2012            let mut hash_impl = parse_quote_spanned! { opt.span() =>
2013                fn __pyo3__generated____hash__(&self) -> u64 {
2014                    let mut s = ::std::collections::hash_map::DefaultHasher::new();
2015                    ::std::hash::Hash::hash(self, &mut s);
2016                    ::std::hash::Hasher::finish(&s)
2017                }
2018            };
2019            let hash_slot =
2020                generate_protocol_slot(cls, &mut hash_impl, &__HASH__, "__hash__", ctx).unwrap();
2021            Ok((Some(hash_impl), Some(hash_slot)))
2022        }
2023        None => Ok((None, None)),
2024    }
2025}
2026
2027/// Implements most traits used by `#[pyclass]`.
2028///
2029/// Specifically, it implements traits that only depend on class name,
2030/// and attributes of `#[pyclass]`, and docstrings.
2031/// Therefore it doesn't implement traits that depends on struct fields and enum variants.
2032struct PyClassImplsBuilder<'a> {
2033    cls: &'a syn::Ident,
2034    attr: &'a PyClassArgs,
2035    methods_type: PyClassMethodsType,
2036    default_methods: Vec<MethodAndMethodDef>,
2037    default_slots: Vec<MethodAndSlotDef>,
2038    doc: Option<PythonDoc>,
2039}
2040
2041impl<'a> PyClassImplsBuilder<'a> {
2042    fn new(
2043        cls: &'a syn::Ident,
2044        attr: &'a PyClassArgs,
2045        methods_type: PyClassMethodsType,
2046        default_methods: Vec<MethodAndMethodDef>,
2047        default_slots: Vec<MethodAndSlotDef>,
2048    ) -> Self {
2049        Self {
2050            cls,
2051            attr,
2052            methods_type,
2053            default_methods,
2054            default_slots,
2055            doc: None,
2056        }
2057    }
2058
2059    fn doc(self, doc: PythonDoc) -> Self {
2060        Self {
2061            doc: Some(doc),
2062            ..self
2063        }
2064    }
2065
2066    fn impl_all(&self, ctx: &Ctx) -> Result<TokenStream> {
2067        let tokens = [
2068            self.impl_pyclass(ctx),
2069            self.impl_extractext(ctx),
2070            self.impl_into_py(ctx),
2071            self.impl_pyclassimpl(ctx)?,
2072            self.impl_add_to_module(ctx),
2073            self.impl_freelist(ctx),
2074        ]
2075        .into_iter()
2076        .collect();
2077        Ok(tokens)
2078    }
2079
2080    fn impl_pyclass(&self, ctx: &Ctx) -> TokenStream {
2081        let Ctx { pyo3_path, .. } = ctx;
2082        let cls = self.cls;
2083
2084        let frozen = if self.attr.options.frozen.is_some() {
2085            quote! { #pyo3_path::pyclass::boolean_struct::True }
2086        } else {
2087            quote! { #pyo3_path::pyclass::boolean_struct::False }
2088        };
2089
2090        quote! {
2091            impl #pyo3_path::PyClass for #cls {
2092                type Frozen = #frozen;
2093            }
2094        }
2095    }
2096    fn impl_extractext(&self, ctx: &Ctx) -> TokenStream {
2097        let Ctx { pyo3_path, .. } = ctx;
2098        let cls = self.cls;
2099        if self.attr.options.frozen.is_some() {
2100            quote! {
2101                impl<'a, 'py> #pyo3_path::impl_::extract_argument::PyFunctionArgument<'a, 'py, false> for &'a #cls
2102                {
2103                    type Holder = ::std::option::Option<#pyo3_path::PyRef<'py, #cls>>;
2104
2105                    #[inline]
2106                    fn extract(obj: &'a #pyo3_path::Bound<'py, #pyo3_path::PyAny>, holder: &'a mut Self::Holder) -> #pyo3_path::PyResult<Self> {
2107                        #pyo3_path::impl_::extract_argument::extract_pyclass_ref(obj, holder)
2108                    }
2109                }
2110            }
2111        } else {
2112            quote! {
2113                impl<'a, 'py> #pyo3_path::impl_::extract_argument::PyFunctionArgument<'a, 'py, false> for &'a #cls
2114                {
2115                    type Holder = ::std::option::Option<#pyo3_path::PyRef<'py, #cls>>;
2116
2117                    #[inline]
2118                    fn extract(obj: &'a #pyo3_path::Bound<'py, #pyo3_path::PyAny>, holder: &'a mut Self::Holder) -> #pyo3_path::PyResult<Self> {
2119                        #pyo3_path::impl_::extract_argument::extract_pyclass_ref(obj, holder)
2120                    }
2121                }
2122
2123                impl<'a, 'py> #pyo3_path::impl_::extract_argument::PyFunctionArgument<'a, 'py, false> for &'a mut #cls
2124                {
2125                    type Holder = ::std::option::Option<#pyo3_path::PyRefMut<'py, #cls>>;
2126
2127                    #[inline]
2128                    fn extract(obj: &'a #pyo3_path::Bound<'py, #pyo3_path::PyAny>, holder: &'a mut Self::Holder) -> #pyo3_path::PyResult<Self> {
2129                        #pyo3_path::impl_::extract_argument::extract_pyclass_ref_mut(obj, holder)
2130                    }
2131                }
2132            }
2133        }
2134    }
2135
2136    fn impl_into_py(&self, ctx: &Ctx) -> TokenStream {
2137        let Ctx { pyo3_path, .. } = ctx;
2138        let cls = self.cls;
2139        let attr = self.attr;
2140        // If #cls is not extended type, we allow Self->PyObject conversion
2141        if attr.options.extends.is_none() {
2142            quote! {
2143                #[allow(deprecated)]
2144                impl #pyo3_path::IntoPy<#pyo3_path::PyObject> for #cls {
2145                    fn into_py(self, py: #pyo3_path::Python<'_>) -> #pyo3_path::PyObject {
2146                        #pyo3_path::IntoPy::into_py(#pyo3_path::Py::new(py, self).unwrap(), py)
2147                    }
2148                }
2149
2150                impl<'py> #pyo3_path::conversion::IntoPyObject<'py> for #cls {
2151                    type Target = Self;
2152                    type Output = #pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>;
2153                    type Error = #pyo3_path::PyErr;
2154
2155                    fn into_pyobject(self, py: #pyo3_path::Python<'py>) -> ::std::result::Result<
2156                        <Self as #pyo3_path::conversion::IntoPyObject>::Output,
2157                        <Self as #pyo3_path::conversion::IntoPyObject>::Error,
2158                    > {
2159                        #pyo3_path::Bound::new(py, self)
2160                    }
2161                }
2162            }
2163        } else {
2164            quote! {}
2165        }
2166    }
2167    fn impl_pyclassimpl(&self, ctx: &Ctx) -> Result<TokenStream> {
2168        let Ctx { pyo3_path, .. } = ctx;
2169        let cls = self.cls;
2170        let doc = self.doc.as_ref().map_or(
2171            LitCStr::empty(ctx).to_token_stream(),
2172            PythonDoc::to_token_stream,
2173        );
2174        let is_basetype = self.attr.options.subclass.is_some();
2175        let base = match &self.attr.options.extends {
2176            Some(extends_attr) => extends_attr.value.clone(),
2177            None => parse_quote! { #pyo3_path::PyAny },
2178        };
2179        let is_subclass = self.attr.options.extends.is_some();
2180        let is_mapping: bool = self.attr.options.mapping.is_some();
2181        let is_sequence: bool = self.attr.options.sequence.is_some();
2182
2183        ensure_spanned!(
2184            !(is_mapping && is_sequence),
2185            self.cls.span() => "a `#[pyclass]` cannot be both a `mapping` and a `sequence`"
2186        );
2187
2188        let dict_offset = if self.attr.options.dict.is_some() {
2189            quote! {
2190                fn dict_offset() -> ::std::option::Option<#pyo3_path::ffi::Py_ssize_t> {
2191                    ::std::option::Option::Some(#pyo3_path::impl_::pyclass::dict_offset::<Self>())
2192                }
2193            }
2194        } else {
2195            TokenStream::new()
2196        };
2197
2198        // insert space for weak ref
2199        let weaklist_offset = if self.attr.options.weakref.is_some() {
2200            quote! {
2201                fn weaklist_offset() -> ::std::option::Option<#pyo3_path::ffi::Py_ssize_t> {
2202                    ::std::option::Option::Some(#pyo3_path::impl_::pyclass::weaklist_offset::<Self>())
2203                }
2204            }
2205        } else {
2206            TokenStream::new()
2207        };
2208
2209        let thread_checker = if self.attr.options.unsendable.is_some() {
2210            quote! { #pyo3_path::impl_::pyclass::ThreadCheckerImpl }
2211        } else {
2212            quote! { #pyo3_path::impl_::pyclass::SendablePyClass<#cls> }
2213        };
2214
2215        let (pymethods_items, inventory, inventory_class) = match self.methods_type {
2216            PyClassMethodsType::Specialization => (quote! { collector.py_methods() }, None, None),
2217            PyClassMethodsType::Inventory => {
2218                // To allow multiple #[pymethods] block, we define inventory types.
2219                let inventory_class_name = syn::Ident::new(
2220                    &format!("Pyo3MethodsInventoryFor{}", cls.unraw()),
2221                    Span::call_site(),
2222                );
2223                (
2224                    quote! {
2225                        ::std::boxed::Box::new(
2226                            ::std::iter::Iterator::map(
2227                                #pyo3_path::inventory::iter::<<Self as #pyo3_path::impl_::pyclass::PyClassImpl>::Inventory>(),
2228                                #pyo3_path::impl_::pyclass::PyClassInventory::items
2229                            )
2230                        )
2231                    },
2232                    Some(quote! { type Inventory = #inventory_class_name; }),
2233                    Some(define_inventory_class(&inventory_class_name, ctx)),
2234                )
2235            }
2236        };
2237
2238        let default_methods = self
2239            .default_methods
2240            .iter()
2241            .map(|meth| &meth.associated_method)
2242            .chain(
2243                self.default_slots
2244                    .iter()
2245                    .map(|meth| &meth.associated_method),
2246            );
2247
2248        let default_method_defs = self.default_methods.iter().map(|meth| &meth.method_def);
2249        let default_slot_defs = self.default_slots.iter().map(|slot| &slot.slot_def);
2250        let freelist_slots = self.freelist_slots(ctx);
2251
2252        let class_mutability = if self.attr.options.frozen.is_some() {
2253            quote! {
2254                ImmutableChild
2255            }
2256        } else {
2257            quote! {
2258                MutableChild
2259            }
2260        };
2261
2262        let cls = self.cls;
2263        let attr = self.attr;
2264        let dict = if attr.options.dict.is_some() {
2265            quote! { #pyo3_path::impl_::pyclass::PyClassDictSlot }
2266        } else {
2267            quote! { #pyo3_path::impl_::pyclass::PyClassDummySlot }
2268        };
2269
2270        // insert space for weak ref
2271        let weakref = if attr.options.weakref.is_some() {
2272            quote! { #pyo3_path::impl_::pyclass::PyClassWeakRefSlot }
2273        } else {
2274            quote! { #pyo3_path::impl_::pyclass::PyClassDummySlot }
2275        };
2276
2277        let base_nativetype = if attr.options.extends.is_some() {
2278            quote! { <Self::BaseType as #pyo3_path::impl_::pyclass::PyClassBaseType>::BaseNativeType }
2279        } else {
2280            quote! { #pyo3_path::PyAny }
2281        };
2282
2283        let pyclass_base_type_impl = attr.options.subclass.map(|subclass| {
2284            quote_spanned! { subclass.span() =>
2285                impl #pyo3_path::impl_::pyclass::PyClassBaseType for #cls {
2286                    type LayoutAsBase = #pyo3_path::impl_::pycell::PyClassObject<Self>;
2287                    type BaseNativeType = <Self as #pyo3_path::impl_::pyclass::PyClassImpl>::BaseNativeType;
2288                    type Initializer = #pyo3_path::pyclass_init::PyClassInitializer<Self>;
2289                    type PyClassMutability = <Self as #pyo3_path::impl_::pyclass::PyClassImpl>::PyClassMutability;
2290                }
2291            }
2292        });
2293
2294        let assertions = if attr.options.unsendable.is_some() {
2295            TokenStream::new()
2296        } else {
2297            let assert = quote_spanned! { cls.span() => #pyo3_path::impl_::pyclass::assert_pyclass_sync::<#cls>(); };
2298            quote! {
2299                const _: () = {
2300                    #assert
2301                };
2302            }
2303        };
2304
2305        Ok(quote! {
2306            #assertions
2307
2308            #pyclass_base_type_impl
2309
2310            impl #pyo3_path::impl_::pyclass::PyClassImpl for #cls {
2311                const IS_BASETYPE: bool = #is_basetype;
2312                const IS_SUBCLASS: bool = #is_subclass;
2313                const IS_MAPPING: bool = #is_mapping;
2314                const IS_SEQUENCE: bool = #is_sequence;
2315
2316                type BaseType = #base;
2317                type ThreadChecker = #thread_checker;
2318                #inventory
2319                type PyClassMutability = <<#base as #pyo3_path::impl_::pyclass::PyClassBaseType>::PyClassMutability as #pyo3_path::impl_::pycell::PyClassMutability>::#class_mutability;
2320                type Dict = #dict;
2321                type WeakRef = #weakref;
2322                type BaseNativeType = #base_nativetype;
2323
2324                fn items_iter() -> #pyo3_path::impl_::pyclass::PyClassItemsIter {
2325                    use #pyo3_path::impl_::pyclass::*;
2326                    let collector = PyClassImplCollector::<Self>::new();
2327                    static INTRINSIC_ITEMS: PyClassItems = PyClassItems {
2328                        methods: &[#(#default_method_defs),*],
2329                        slots: &[#(#default_slot_defs),* #(#freelist_slots),*],
2330                    };
2331                    PyClassItemsIter::new(&INTRINSIC_ITEMS, #pymethods_items)
2332                }
2333
2334                fn doc(py: #pyo3_path::Python<'_>) -> #pyo3_path::PyResult<&'static ::std::ffi::CStr>  {
2335                    use #pyo3_path::impl_::pyclass::*;
2336                    static DOC: #pyo3_path::sync::GILOnceCell<::std::borrow::Cow<'static, ::std::ffi::CStr>> = #pyo3_path::sync::GILOnceCell::new();
2337                    DOC.get_or_try_init(py, || {
2338                        let collector = PyClassImplCollector::<Self>::new();
2339                        build_pyclass_doc(<Self as #pyo3_path::PyTypeInfo>::NAME, #doc, collector.new_text_signature())
2340                    }).map(::std::ops::Deref::deref)
2341                }
2342
2343                #dict_offset
2344
2345                #weaklist_offset
2346
2347                fn lazy_type_object() -> &'static #pyo3_path::impl_::pyclass::LazyTypeObject<Self> {
2348                    use #pyo3_path::impl_::pyclass::LazyTypeObject;
2349                    static TYPE_OBJECT: LazyTypeObject<#cls> = LazyTypeObject::new();
2350                    &TYPE_OBJECT
2351                }
2352            }
2353
2354            #[doc(hidden)]
2355            #[allow(non_snake_case)]
2356            impl #cls {
2357                #(#default_methods)*
2358            }
2359
2360            #inventory_class
2361        })
2362    }
2363
2364    fn impl_add_to_module(&self, ctx: &Ctx) -> TokenStream {
2365        let Ctx { pyo3_path, .. } = ctx;
2366        let cls = self.cls;
2367        quote! {
2368            impl #cls {
2369                #[doc(hidden)]
2370                pub const _PYO3_DEF: #pyo3_path::impl_::pymodule::AddClassToModule<Self> = #pyo3_path::impl_::pymodule::AddClassToModule::new();
2371            }
2372        }
2373    }
2374
2375    fn impl_freelist(&self, ctx: &Ctx) -> TokenStream {
2376        let cls = self.cls;
2377        let Ctx { pyo3_path, .. } = ctx;
2378
2379        self.attr.options.freelist.as_ref().map_or(quote!{}, |freelist| {
2380            let freelist = &freelist.value;
2381            quote! {
2382                impl #pyo3_path::impl_::pyclass::PyClassWithFreeList for #cls {
2383                    #[inline]
2384                    fn get_free_list(py: #pyo3_path::Python<'_>) -> &'static ::std::sync::Mutex<#pyo3_path::impl_::freelist::PyObjectFreeList> {
2385                        static FREELIST: #pyo3_path::sync::GILOnceCell<::std::sync::Mutex<#pyo3_path::impl_::freelist::PyObjectFreeList>> = #pyo3_path::sync::GILOnceCell::new();
2386                        // If there's a race to fill the cell, the object created
2387                        // by the losing thread will be deallocated via RAII
2388                        &FREELIST.get_or_init(py, || {
2389                            ::std::sync::Mutex::new(#pyo3_path::impl_::freelist::PyObjectFreeList::with_capacity(#freelist))
2390                        })
2391                    }
2392                }
2393            }
2394        })
2395    }
2396
2397    fn freelist_slots(&self, ctx: &Ctx) -> Vec<TokenStream> {
2398        let Ctx { pyo3_path, .. } = ctx;
2399        let cls = self.cls;
2400
2401        if self.attr.options.freelist.is_some() {
2402            vec![
2403                quote! {
2404                    #pyo3_path::ffi::PyType_Slot {
2405                        slot: #pyo3_path::ffi::Py_tp_alloc,
2406                        pfunc: #pyo3_path::impl_::pyclass::alloc_with_freelist::<#cls> as *mut _,
2407                    }
2408                },
2409                quote! {
2410                    #pyo3_path::ffi::PyType_Slot {
2411                        slot: #pyo3_path::ffi::Py_tp_free,
2412                        pfunc: #pyo3_path::impl_::pyclass::free_with_freelist::<#cls> as *mut _,
2413                    }
2414                },
2415            ]
2416        } else {
2417            Vec::new()
2418        }
2419    }
2420}
2421
2422fn define_inventory_class(inventory_class_name: &syn::Ident, ctx: &Ctx) -> TokenStream {
2423    let Ctx { pyo3_path, .. } = ctx;
2424    quote! {
2425        #[doc(hidden)]
2426        pub struct #inventory_class_name {
2427            items: #pyo3_path::impl_::pyclass::PyClassItems,
2428        }
2429        impl #inventory_class_name {
2430            pub const fn new(items: #pyo3_path::impl_::pyclass::PyClassItems) -> Self {
2431                Self { items }
2432            }
2433        }
2434
2435        impl #pyo3_path::impl_::pyclass::PyClassInventory for #inventory_class_name {
2436            fn items(&self) -> &#pyo3_path::impl_::pyclass::PyClassItems {
2437                &self.items
2438            }
2439        }
2440
2441        #pyo3_path::inventory::collect!(#inventory_class_name);
2442    }
2443}
2444
2445fn generate_cfg_check(variants: &[PyClassEnumUnitVariant<'_>], cls: &syn::Ident) -> TokenStream {
2446    if variants.is_empty() {
2447        return quote! {};
2448    }
2449
2450    let mut conditions = Vec::new();
2451
2452    for variant in variants {
2453        let cfg_attrs = &variant.cfg_attrs;
2454
2455        if cfg_attrs.is_empty() {
2456            // There's at least one variant of the enum without cfg attributes,
2457            // so the check is not necessary
2458            return quote! {};
2459        }
2460
2461        for attr in cfg_attrs {
2462            if let syn::Meta::List(meta) = &attr.meta {
2463                let cfg_tokens = &meta.tokens;
2464                conditions.push(quote! { not(#cfg_tokens) });
2465            }
2466        }
2467    }
2468
2469    quote_spanned! {
2470        cls.span() =>
2471        #[cfg(all(#(#conditions),*))]
2472        ::core::compile_error!(concat!("#[pyclass] can't be used on enums without any variants - all variants of enum `", stringify!(#cls), "` have been configured out by cfg attributes"));
2473    }
2474}
2475
2476const UNIQUE_GET: &str = "`get` may only be specified once";
2477const UNIQUE_SET: &str = "`set` may only be specified once";
2478const UNIQUE_NAME: &str = "`name` may only be specified once";
2479
2480const DUPE_SET: &str = "useless `set` - the struct is already annotated with `set_all`";
2481const DUPE_GET: &str = "useless `get` - the struct is already annotated with `get_all`";
2482const UNIT_GET: &str =
2483    "`get_all` on an unit struct does nothing, because unit structs have no fields";
2484const UNIT_SET: &str =
2485    "`set_all` on an unit struct does nothing, because unit structs have no fields";
2486
2487const USELESS_NAME: &str = "`name` is useless without `get` or `set`";