pyo3_macros_backend/
intopyobject.rs

1use crate::attributes::{self, get_pyo3_options, CrateAttribute};
2use crate::utils::Ctx;
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::ext::IdentExt;
6use syn::parse::{Parse, ParseStream};
7use syn::spanned::Spanned as _;
8use syn::{
9    parenthesized, parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result,
10    Token,
11};
12
13/// Attributes for deriving `IntoPyObject` scoped on containers.
14enum ContainerPyO3Attribute {
15    /// Treat the Container as a Wrapper, directly convert its field into the output object.
16    Transparent(attributes::kw::transparent),
17    /// Change the path for the pyo3 crate
18    Crate(CrateAttribute),
19}
20
21impl Parse for ContainerPyO3Attribute {
22    fn parse(input: ParseStream<'_>) -> Result<Self> {
23        let lookahead = input.lookahead1();
24        if lookahead.peek(attributes::kw::transparent) {
25            let kw: attributes::kw::transparent = input.parse()?;
26            Ok(ContainerPyO3Attribute::Transparent(kw))
27        } else if lookahead.peek(Token![crate]) {
28            input.parse().map(ContainerPyO3Attribute::Crate)
29        } else {
30            Err(lookahead.error())
31        }
32    }
33}
34
35#[derive(Default)]
36struct ContainerOptions {
37    /// Treat the Container as a Wrapper, directly convert its field into the output object.
38    transparent: Option<attributes::kw::transparent>,
39    /// Change the path for the pyo3 crate
40    krate: Option<CrateAttribute>,
41}
42
43impl ContainerOptions {
44    fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
45        let mut options = ContainerOptions::default();
46
47        for attr in attrs {
48            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
49                pyo3_attrs
50                    .into_iter()
51                    .try_for_each(|opt| options.set_option(opt))?;
52            }
53        }
54        Ok(options)
55    }
56
57    fn set_option(&mut self, option: ContainerPyO3Attribute) -> syn::Result<()> {
58        macro_rules! set_option {
59            ($key:ident) => {
60                {
61                    ensure_spanned!(
62                        self.$key.is_none(),
63                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
64                    );
65                    self.$key = Some($key);
66                }
67            };
68        }
69
70        match option {
71            ContainerPyO3Attribute::Transparent(transparent) => set_option!(transparent),
72            ContainerPyO3Attribute::Crate(krate) => set_option!(krate),
73        }
74        Ok(())
75    }
76}
77
78#[derive(Debug, Clone)]
79struct ItemOption {
80    field: Option<syn::LitStr>,
81    span: Span,
82}
83
84impl ItemOption {
85    fn span(&self) -> Span {
86        self.span
87    }
88}
89
90enum FieldAttribute {
91    Item(ItemOption),
92}
93
94impl Parse for FieldAttribute {
95    fn parse(input: ParseStream<'_>) -> Result<Self> {
96        let lookahead = input.lookahead1();
97        if lookahead.peek(attributes::kw::attribute) {
98            let attr: attributes::kw::attribute = input.parse()?;
99            bail_spanned!(attr.span => "`attribute` is not supported by `IntoPyObject`");
100        } else if lookahead.peek(attributes::kw::item) {
101            let attr: attributes::kw::item = input.parse()?;
102            if input.peek(syn::token::Paren) {
103                let content;
104                let _ = parenthesized!(content in input);
105                let key = content.parse()?;
106                if !content.is_empty() {
107                    return Err(
108                        content.error("expected at most one argument: `item` or `item(key)`")
109                    );
110                }
111                Ok(FieldAttribute::Item(ItemOption {
112                    field: Some(key),
113                    span: attr.span,
114                }))
115            } else {
116                Ok(FieldAttribute::Item(ItemOption {
117                    field: None,
118                    span: attr.span,
119                }))
120            }
121        } else {
122            Err(lookahead.error())
123        }
124    }
125}
126
127#[derive(Clone, Debug, Default)]
128struct FieldAttributes {
129    item: Option<ItemOption>,
130}
131
132impl FieldAttributes {
133    /// Extract the field attributes.
134    fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
135        let mut options = FieldAttributes::default();
136
137        for attr in attrs {
138            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
139                pyo3_attrs
140                    .into_iter()
141                    .try_for_each(|opt| options.set_option(opt))?;
142            }
143        }
144        Ok(options)
145    }
146
147    fn set_option(&mut self, option: FieldAttribute) -> syn::Result<()> {
148        macro_rules! set_option {
149            ($key:ident) => {
150                {
151                    ensure_spanned!(
152                        self.$key.is_none(),
153                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
154                    );
155                    self.$key = Some($key);
156                }
157            };
158        }
159
160        match option {
161            FieldAttribute::Item(item) => set_option!(item),
162        }
163        Ok(())
164    }
165}
166
167enum IntoPyObjectTypes {
168    Transparent(syn::Type),
169    Opaque {
170        target: TokenStream,
171        output: TokenStream,
172        error: TokenStream,
173    },
174}
175
176struct IntoPyObjectImpl {
177    types: IntoPyObjectTypes,
178    body: TokenStream,
179}
180
181struct NamedStructField<'a> {
182    ident: &'a syn::Ident,
183    field: &'a syn::Field,
184    item: Option<ItemOption>,
185}
186
187struct TupleStructField<'a> {
188    field: &'a syn::Field,
189}
190
191/// Container Style
192///
193/// Covers Structs, Tuplestructs and corresponding Newtypes.
194enum ContainerType<'a> {
195    /// Struct Container, e.g. `struct Foo { a: String }`
196    ///
197    /// Variant contains the list of field identifiers and the corresponding extraction call.
198    Struct(Vec<NamedStructField<'a>>),
199    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
200    ///
201    /// The field specified by the identifier is extracted directly from the object.
202    StructNewtype(&'a syn::Field),
203    /// Tuple struct, e.g. `struct Foo(String)`.
204    ///
205    /// Variant contains a list of conversion methods for each of the fields that are directly
206    ///  extracted from the tuple.
207    Tuple(Vec<TupleStructField<'a>>),
208    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
209    ///
210    /// The wrapped field is directly extracted from the object.
211    TupleNewtype(&'a syn::Field),
212}
213
214/// Data container
215///
216/// Either describes a struct or an enum variant.
217struct Container<'a> {
218    path: syn::Path,
219    receiver: Option<Ident>,
220    ty: ContainerType<'a>,
221}
222
223/// Construct a container based on fields, identifier and attributes.
224impl<'a> Container<'a> {
225    ///
226    /// Fails if the variant has no fields or incompatible attributes.
227    fn new(
228        receiver: Option<Ident>,
229        fields: &'a Fields,
230        path: syn::Path,
231        options: ContainerOptions,
232    ) -> Result<Self> {
233        let style = match fields {
234            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
235                let mut tuple_fields = unnamed
236                    .unnamed
237                    .iter()
238                    .map(|field| {
239                        let attrs = FieldAttributes::from_attrs(&field.attrs)?;
240                        ensure_spanned!(
241                            attrs.item.is_none(),
242                            attrs.item.unwrap().span() => "`item` is not permitted on tuple struct elements."
243                        );
244                        Ok(TupleStructField { field })
245                    })
246                    .collect::<Result<Vec<_>>>()?;
247                if tuple_fields.len() == 1 {
248                    // Always treat a 1-length tuple struct as "transparent", even without the
249                    // explicit annotation.
250                    let TupleStructField { field } = tuple_fields.pop().unwrap();
251                    ContainerType::TupleNewtype(field)
252                } else if options.transparent.is_some() {
253                    bail_spanned!(
254                        fields.span() => "transparent structs and variants can only have 1 field"
255                    );
256                } else {
257                    ContainerType::Tuple(tuple_fields)
258                }
259            }
260            Fields::Named(named) if !named.named.is_empty() => {
261                if options.transparent.is_some() {
262                    ensure_spanned!(
263                        named.named.iter().count() == 1,
264                        fields.span() => "transparent structs and variants can only have 1 field"
265                    );
266
267                    let field = named.named.iter().next().unwrap();
268                    let attrs = FieldAttributes::from_attrs(&field.attrs)?;
269                    ensure_spanned!(
270                        attrs.item.is_none(),
271                        attrs.item.unwrap().span() => "`transparent` structs may not have `item` for the inner field"
272                    );
273                    ContainerType::StructNewtype(field)
274                } else {
275                    let struct_fields = named
276                        .named
277                        .iter()
278                        .map(|field| {
279                            let ident = field
280                                .ident
281                                .as_ref()
282                                .expect("Named fields should have identifiers");
283
284                            let attrs = FieldAttributes::from_attrs(&field.attrs)?;
285
286                            Ok(NamedStructField {
287                                ident,
288                                field,
289                                item: attrs.item,
290                            })
291                        })
292                        .collect::<Result<Vec<_>>>()?;
293                    ContainerType::Struct(struct_fields)
294                }
295            }
296            _ => bail_spanned!(
297                fields.span() => "cannot derive `IntoPyObject` for empty structs"
298            ),
299        };
300
301        let v = Container {
302            path,
303            receiver,
304            ty: style,
305        };
306        Ok(v)
307    }
308
309    fn match_pattern(&self) -> TokenStream {
310        let path = &self.path;
311        let pattern = match &self.ty {
312            ContainerType::Struct(fields) => fields
313                .iter()
314                .enumerate()
315                .map(|(i, f)| {
316                    let ident = f.ident;
317                    let new_ident = format_ident!("arg{i}");
318                    quote! {#ident: #new_ident,}
319                })
320                .collect::<TokenStream>(),
321            ContainerType::StructNewtype(field) => {
322                let ident = field.ident.as_ref().unwrap();
323                quote!(#ident: arg0)
324            }
325            ContainerType::Tuple(fields) => {
326                let i = (0..fields.len()).map(Index::from);
327                let idents = (0..fields.len()).map(|i| format_ident!("arg{i}"));
328                quote! { #(#i: #idents,)* }
329            }
330            ContainerType::TupleNewtype(_) => quote!(0: arg0),
331        };
332
333        quote! { #path{ #pattern } }
334    }
335
336    /// Build derivation body for a struct.
337    fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
338        match &self.ty {
339            ContainerType::StructNewtype(field) | ContainerType::TupleNewtype(field) => {
340                self.build_newtype_struct(field, ctx)
341            }
342            ContainerType::Tuple(fields) => self.build_tuple_struct(fields, ctx),
343            ContainerType::Struct(fields) => self.build_struct(fields, ctx),
344        }
345    }
346
347    fn build_newtype_struct(&self, field: &syn::Field, ctx: &Ctx) -> IntoPyObjectImpl {
348        let Ctx { pyo3_path, .. } = ctx;
349        let ty = &field.ty;
350
351        let unpack = self
352            .receiver
353            .as_ref()
354            .map(|i| {
355                let pattern = self.match_pattern();
356                quote! { let #pattern = #i;}
357            })
358            .unwrap_or_default();
359
360        IntoPyObjectImpl {
361            types: IntoPyObjectTypes::Transparent(ty.clone()),
362            body: quote_spanned! { ty.span() =>
363                #unpack
364                #pyo3_path::conversion::IntoPyObject::into_pyobject(arg0, py)
365            },
366        }
367    }
368
369    fn build_struct(&self, fields: &[NamedStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
370        let Ctx { pyo3_path, .. } = ctx;
371
372        let unpack = self
373            .receiver
374            .as_ref()
375            .map(|i| {
376                let pattern = self.match_pattern();
377                quote! { let #pattern = #i;}
378            })
379            .unwrap_or_default();
380
381        let setter = fields
382            .iter()
383            .enumerate()
384            .map(|(i, f)| {
385                let key = f
386                    .item
387                    .as_ref()
388                    .and_then(|item| item.field.as_ref())
389                    .map(|item| item.value())
390                    .unwrap_or_else(|| f.ident.unraw().to_string());
391                let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
392                quote! {
393                    #pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?;
394                }
395            })
396            .collect::<TokenStream>();
397
398        IntoPyObjectImpl {
399            types: IntoPyObjectTypes::Opaque {
400                target: quote!(#pyo3_path::types::PyDict),
401                output: quote!(#pyo3_path::Bound<'py, Self::Target>),
402                error: quote!(#pyo3_path::PyErr),
403            },
404            body: quote! {
405                #unpack
406                let dict = #pyo3_path::types::PyDict::new(py);
407                #setter
408                ::std::result::Result::Ok::<_, Self::Error>(dict)
409            },
410        }
411    }
412
413    fn build_tuple_struct(&self, fields: &[TupleStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
414        let Ctx { pyo3_path, .. } = ctx;
415
416        let unpack = self
417            .receiver
418            .as_ref()
419            .map(|i| {
420                let pattern = self.match_pattern();
421                quote! { let #pattern = #i;}
422            })
423            .unwrap_or_default();
424
425        let setter = fields
426            .iter()
427            .enumerate()
428            .map(|(i, f)| {
429                let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
430                quote_spanned! { f.field.ty.span() =>
431                    #pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py)
432                        .map(#pyo3_path::BoundObject::into_any)
433                        .map(#pyo3_path::BoundObject::into_bound)?,
434                }
435            })
436            .collect::<TokenStream>();
437
438        IntoPyObjectImpl {
439            types: IntoPyObjectTypes::Opaque {
440                target: quote!(#pyo3_path::types::PyTuple),
441                output: quote!(#pyo3_path::Bound<'py, Self::Target>),
442                error: quote!(#pyo3_path::PyErr),
443            },
444            body: quote! {
445                #unpack
446                #pyo3_path::types::PyTuple::new(py, [#setter])
447            },
448        }
449    }
450}
451
452/// Describes derivation input of an enum.
453struct Enum<'a> {
454    variants: Vec<Container<'a>>,
455}
456
457impl<'a> Enum<'a> {
458    /// Construct a new enum representation.
459    ///
460    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
461    /// `Identifier` of the enum.
462    fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
463        ensure_spanned!(
464            !data_enum.variants.is_empty(),
465            ident.span() => "cannot derive `IntoPyObject` for empty enum"
466        );
467        let variants = data_enum
468            .variants
469            .iter()
470            .map(|variant| {
471                let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
472                let var_ident = &variant.ident;
473
474                ensure_spanned!(
475                    !variant.fields.is_empty(),
476                    variant.ident.span() => "cannot derive `IntoPyObject` for empty variants"
477                );
478
479                Container::new(
480                    None,
481                    &variant.fields,
482                    parse_quote!(#ident::#var_ident),
483                    attrs,
484                )
485            })
486            .collect::<Result<Vec<_>>>()?;
487
488        Ok(Enum { variants })
489    }
490
491    /// Build derivation body for enums.
492    fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
493        let Ctx { pyo3_path, .. } = ctx;
494
495        let variants = self
496            .variants
497            .iter()
498            .map(|v| {
499                let IntoPyObjectImpl { body, .. } = v.build(ctx);
500                let pattern = v.match_pattern();
501                quote! {
502                    #pattern => {
503                        {#body}
504                            .map(#pyo3_path::BoundObject::into_any)
505                            .map(#pyo3_path::BoundObject::into_bound)
506                            .map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
507                    }
508                }
509            })
510            .collect::<TokenStream>();
511
512        IntoPyObjectImpl {
513            types: IntoPyObjectTypes::Opaque {
514                target: quote!(#pyo3_path::types::PyAny),
515                output: quote!(#pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>),
516                error: quote!(#pyo3_path::PyErr),
517            },
518            body: quote! {
519                match self {
520                    #variants
521                }
522            },
523        }
524    }
525}
526
527// if there is a `'py` lifetime, we treat it as the `Python<'py>` lifetime
528fn verify_and_get_lifetime(generics: &syn::Generics) -> Option<&syn::LifetimeParam> {
529    let mut lifetimes = generics.lifetimes();
530    lifetimes.find(|l| l.lifetime.ident == "py")
531}
532
533pub fn build_derive_into_pyobject<const REF: bool>(tokens: &DeriveInput) -> Result<TokenStream> {
534    let options = ContainerOptions::from_attrs(&tokens.attrs)?;
535    let ctx = &Ctx::new(&options.krate, None);
536    let Ctx { pyo3_path, .. } = &ctx;
537
538    let (_, ty_generics, _) = tokens.generics.split_for_impl();
539    let mut trait_generics = tokens.generics.clone();
540    if REF {
541        trait_generics.params.push(parse_quote!('_a));
542    }
543    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics) {
544        lt.clone()
545    } else {
546        trait_generics.params.push(parse_quote!('py));
547        parse_quote!('py)
548    };
549    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
550
551    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
552    for param in trait_generics.type_params() {
553        let gen_ident = &param.ident;
554        where_clause.predicates.push(if REF {
555            parse_quote!(&'_a #gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
556        } else {
557            parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
558        })
559    }
560
561    let IntoPyObjectImpl { types, body } = match &tokens.data {
562        syn::Data::Enum(en) => {
563            if options.transparent.is_some() {
564                bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums");
565            }
566            let en = Enum::new(en, &tokens.ident)?;
567            en.build(ctx)
568        }
569        syn::Data::Struct(st) => {
570            let ident = &tokens.ident;
571            let st = Container::new(
572                Some(Ident::new("self", Span::call_site())),
573                &st.fields,
574                parse_quote!(#ident),
575                options,
576            )?;
577            st.build(ctx)
578        }
579        syn::Data::Union(_) => bail_spanned!(
580            tokens.span() => "#[derive(`IntoPyObject`)] is not supported for unions"
581        ),
582    };
583
584    let (target, output, error) = match types {
585        IntoPyObjectTypes::Transparent(ty) => {
586            if REF {
587                (
588                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Target },
589                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Output },
590                    quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Error },
591                )
592            } else {
593                (
594                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Target },
595                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Output },
596                    quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Error },
597                )
598            }
599        }
600        IntoPyObjectTypes::Opaque {
601            target,
602            output,
603            error,
604        } => (target, output, error),
605    };
606
607    let ident = &tokens.ident;
608    let ident = if REF {
609        quote! { &'_a #ident}
610    } else {
611        quote! { #ident }
612    };
613    Ok(quote!(
614        #[automatically_derived]
615        impl #impl_generics #pyo3_path::conversion::IntoPyObject<#lt_param> for #ident #ty_generics #where_clause {
616            type Target = #target;
617            type Output = #output;
618            type Error = #error;
619
620            fn into_pyobject(self, py: #pyo3_path::Python<#lt_param>) -> ::std::result::Result<
621                <Self as #pyo3_path::conversion::IntoPyObject>::Output,
622                <Self as #pyo3_path::conversion::IntoPyObject>::Error,
623            > {
624                #body
625            }
626        }
627    ))
628}