pyo3_macros_backend/
frompyobject.rs

1use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
2use crate::utils::Ctx;
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    ext::IdentExt,
7    parenthesized,
8    parse::{Parse, ParseStream},
9    parse_quote,
10    punctuated::Punctuated,
11    spanned::Spanned,
12    Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token,
13};
14
15/// Describes derivation input of an enum.
16struct Enum<'a> {
17    enum_ident: &'a Ident,
18    variants: Vec<Container<'a>>,
19}
20
21impl<'a> Enum<'a> {
22    /// Construct a new enum representation.
23    ///
24    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
25    /// `Identifier` of the enum.
26    fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
27        ensure_spanned!(
28            !data_enum.variants.is_empty(),
29            ident.span() => "cannot derive FromPyObject for empty enum"
30        );
31        let variants = data_enum
32            .variants
33            .iter()
34            .map(|variant| {
35                let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
36                let var_ident = &variant.ident;
37                Container::new(&variant.fields, parse_quote!(#ident::#var_ident), attrs)
38            })
39            .collect::<Result<Vec<_>>>()?;
40
41        Ok(Enum {
42            enum_ident: ident,
43            variants,
44        })
45    }
46
47    /// Build derivation body for enums.
48    fn build(&self, ctx: &Ctx) -> TokenStream {
49        let Ctx { pyo3_path, .. } = ctx;
50        let mut var_extracts = Vec::new();
51        let mut variant_names = Vec::new();
52        let mut error_names = Vec::new();
53
54        for var in &self.variants {
55            let struct_derive = var.build(ctx);
56            let ext = quote!({
57                let maybe_ret = || -> #pyo3_path::PyResult<Self> {
58                    #struct_derive
59                }();
60
61                match maybe_ret {
62                    ok @ ::std::result::Result::Ok(_) => return ok,
63                    ::std::result::Result::Err(err) => err
64                }
65            });
66
67            var_extracts.push(ext);
68            variant_names.push(var.path.segments.last().unwrap().ident.to_string());
69            error_names.push(&var.err_name);
70        }
71        let ty_name = self.enum_ident.to_string();
72        quote!(
73            let errors = [
74                #(#var_extracts),*
75            ];
76            ::std::result::Result::Err(
77                #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
78                    obj.py(),
79                    #ty_name,
80                    &[#(#variant_names),*],
81                    &[#(#error_names),*],
82                    &errors
83                )
84            )
85        )
86    }
87}
88
89struct NamedStructField<'a> {
90    ident: &'a syn::Ident,
91    getter: Option<FieldGetter>,
92    from_py_with: Option<FromPyWithAttribute>,
93}
94
95struct TupleStructField {
96    from_py_with: Option<FromPyWithAttribute>,
97}
98
99/// Container Style
100///
101/// Covers Structs, Tuplestructs and corresponding Newtypes.
102enum ContainerType<'a> {
103    /// Struct Container, e.g. `struct Foo { a: String }`
104    ///
105    /// Variant contains the list of field identifiers and the corresponding extraction call.
106    Struct(Vec<NamedStructField<'a>>),
107    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
108    ///
109    /// The field specified by the identifier is extracted directly from the object.
110    StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
111    /// Tuple struct, e.g. `struct Foo(String)`.
112    ///
113    /// Variant contains a list of conversion methods for each of the fields that are directly
114    ///  extracted from the tuple.
115    Tuple(Vec<TupleStructField>),
116    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
117    ///
118    /// The wrapped field is directly extracted from the object.
119    TupleNewtype(Option<FromPyWithAttribute>),
120}
121
122/// Data container
123///
124/// Either describes a struct or an enum variant.
125struct Container<'a> {
126    path: syn::Path,
127    ty: ContainerType<'a>,
128    err_name: String,
129}
130
131impl<'a> Container<'a> {
132    /// Construct a container based on fields, identifier and attributes.
133    ///
134    /// Fails if the variant has no fields or incompatible attributes.
135    fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> {
136        let style = match fields {
137            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
138                let mut tuple_fields = unnamed
139                    .unnamed
140                    .iter()
141                    .map(|field| {
142                        let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
143                        ensure_spanned!(
144                            attrs.getter.is_none(),
145                            field.span() => "`getter` is not permitted on tuple struct elements."
146                        );
147                        Ok(TupleStructField {
148                            from_py_with: attrs.from_py_with,
149                        })
150                    })
151                    .collect::<Result<Vec<_>>>()?;
152
153                if tuple_fields.len() == 1 {
154                    // Always treat a 1-length tuple struct as "transparent", even without the
155                    // explicit annotation.
156                    let field = tuple_fields.pop().unwrap();
157                    ContainerType::TupleNewtype(field.from_py_with)
158                } else if options.transparent {
159                    bail_spanned!(
160                        fields.span() => "transparent structs and variants can only have 1 field"
161                    );
162                } else {
163                    ContainerType::Tuple(tuple_fields)
164                }
165            }
166            Fields::Named(named) if !named.named.is_empty() => {
167                let mut struct_fields = named
168                    .named
169                    .iter()
170                    .map(|field| {
171                        let ident = field
172                            .ident
173                            .as_ref()
174                            .expect("Named fields should have identifiers");
175                        let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
176
177                        if let Some(ref from_item_all) = options.from_item_all {
178                            if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None))
179                            {
180                                match replaced {
181                                    FieldGetter::GetItem(Some(item_name)) => {
182                                        attrs.getter = Some(FieldGetter::GetItem(Some(item_name)));
183                                    }
184                                    FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
185                                    FieldGetter::GetAttr(_) => bail_spanned!(
186                                        from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
187                                    ),
188                                }
189                            }
190                        }
191
192                        Ok(NamedStructField {
193                            ident,
194                            getter: attrs.getter,
195                            from_py_with: attrs.from_py_with,
196                        })
197                    })
198                    .collect::<Result<Vec<_>>>()?;
199                if options.transparent {
200                    ensure_spanned!(
201                        struct_fields.len() == 1,
202                        fields.span() => "transparent structs and variants can only have 1 field"
203                    );
204                    let field = struct_fields.pop().unwrap();
205                    ensure_spanned!(
206                        field.getter.is_none(),
207                        field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
208                    );
209                    ContainerType::StructNewtype(field.ident, field.from_py_with)
210                } else {
211                    ContainerType::Struct(struct_fields)
212                }
213            }
214            _ => bail_spanned!(
215                fields.span() => "cannot derive FromPyObject for empty structs and variants"
216            ),
217        };
218        let err_name = options.annotation.map_or_else(
219            || path.segments.last().unwrap().ident.to_string(),
220            |lit_str| lit_str.value(),
221        );
222
223        let v = Container {
224            path,
225            ty: style,
226            err_name,
227        };
228        Ok(v)
229    }
230
231    fn name(&self) -> String {
232        let mut value = String::new();
233        for segment in &self.path.segments {
234            if !value.is_empty() {
235                value.push_str("::");
236            }
237            value.push_str(&segment.ident.to_string());
238        }
239        value
240    }
241
242    /// Build derivation body for a struct.
243    fn build(&self, ctx: &Ctx) -> TokenStream {
244        match &self.ty {
245            ContainerType::StructNewtype(ident, from_py_with) => {
246                self.build_newtype_struct(Some(ident), from_py_with, ctx)
247            }
248            ContainerType::TupleNewtype(from_py_with) => {
249                self.build_newtype_struct(None, from_py_with, ctx)
250            }
251            ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
252            ContainerType::Struct(tups) => self.build_struct(tups, ctx),
253        }
254    }
255
256    fn build_newtype_struct(
257        &self,
258        field_ident: Option<&Ident>,
259        from_py_with: &Option<FromPyWithAttribute>,
260        ctx: &Ctx,
261    ) -> TokenStream {
262        let Ctx { pyo3_path, .. } = ctx;
263        let self_ty = &self.path;
264        let struct_name = self.name();
265        if let Some(ident) = field_ident {
266            let field_name = ident.to_string();
267            match from_py_with {
268                None => quote! {
269                    Ok(#self_ty {
270                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
271                    })
272                },
273                Some(FromPyWithAttribute {
274                    value: expr_path, ..
275                }) => quote! {
276                    Ok(#self_ty {
277                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
278                    })
279                },
280            }
281        } else {
282            match from_py_with {
283                None => quote! {
284                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
285                },
286
287                Some(FromPyWithAttribute {
288                    value: expr_path, ..
289                }) => quote! {
290                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
291                },
292            }
293        }
294    }
295
296    fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
297        let Ctx { pyo3_path, .. } = ctx;
298        let self_ty = &self.path;
299        let struct_name = &self.name();
300        let field_idents: Vec<_> = (0..struct_fields.len())
301            .map(|i| format_ident!("arg{}", i))
302            .collect();
303        let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
304            match &field.from_py_with {
305                None => quote!(
306                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
307                ),
308                Some(FromPyWithAttribute {
309                    value: expr_path, ..
310                }) => quote! (
311                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
312                ),
313            }
314        });
315
316        quote!(
317            match #pyo3_path::types::PyAnyMethods::extract(obj) {
318                ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
319                ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
320            }
321        )
322    }
323
324    fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
325        let Ctx { pyo3_path, .. } = ctx;
326        let self_ty = &self.path;
327        let struct_name = self.name();
328        let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
329        for field in struct_fields {
330            let ident = field.ident;
331            let field_name = ident.unraw().to_string();
332            let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
333                FieldGetter::GetAttr(Some(name)) => {
334                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
335                }
336                FieldGetter::GetAttr(None) => {
337                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #field_name)))
338                }
339                FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
340                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
341                }
342                FieldGetter::GetItem(Some(key)) => {
343                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
344                }
345                FieldGetter::GetItem(None) => {
346                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
347                }
348            };
349            let extractor = match &field.from_py_with {
350                None => {
351                    quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
352                }
353                Some(FromPyWithAttribute {
354                    value: expr_path, ..
355                }) => {
356                    quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
357                }
358            };
359
360            fields.push(quote!(#ident: #extractor));
361        }
362
363        quote!(::std::result::Result::Ok(#self_ty{#fields}))
364    }
365}
366
367#[derive(Default)]
368struct ContainerOptions {
369    /// Treat the Container as a Wrapper, directly extract its fields from the input object.
370    transparent: bool,
371    /// Force every field to be extracted from item of source Python object.
372    from_item_all: Option<attributes::kw::from_item_all>,
373    /// Change the name of an enum variant in the generated error message.
374    annotation: Option<syn::LitStr>,
375    /// Change the path for the pyo3 crate
376    krate: Option<CrateAttribute>,
377}
378
379/// Attributes for deriving FromPyObject scoped on containers.
380enum ContainerPyO3Attribute {
381    /// Treat the Container as a Wrapper, directly extract its fields from the input object.
382    Transparent(attributes::kw::transparent),
383    /// Force every field to be extracted from item of source Python object.
384    ItemAll(attributes::kw::from_item_all),
385    /// Change the name of an enum variant in the generated error message.
386    ErrorAnnotation(LitStr),
387    /// Change the path for the pyo3 crate
388    Crate(CrateAttribute),
389}
390
391impl Parse for ContainerPyO3Attribute {
392    fn parse(input: ParseStream<'_>) -> Result<Self> {
393        let lookahead = input.lookahead1();
394        if lookahead.peek(attributes::kw::transparent) {
395            let kw: attributes::kw::transparent = input.parse()?;
396            Ok(ContainerPyO3Attribute::Transparent(kw))
397        } else if lookahead.peek(attributes::kw::from_item_all) {
398            let kw: attributes::kw::from_item_all = input.parse()?;
399            Ok(ContainerPyO3Attribute::ItemAll(kw))
400        } else if lookahead.peek(attributes::kw::annotation) {
401            let _: attributes::kw::annotation = input.parse()?;
402            let _: Token![=] = input.parse()?;
403            input.parse().map(ContainerPyO3Attribute::ErrorAnnotation)
404        } else if lookahead.peek(Token![crate]) {
405            input.parse().map(ContainerPyO3Attribute::Crate)
406        } else {
407            Err(lookahead.error())
408        }
409    }
410}
411
412impl ContainerOptions {
413    fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
414        let mut options = ContainerOptions::default();
415
416        for attr in attrs {
417            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
418                for pyo3_attr in pyo3_attrs {
419                    match pyo3_attr {
420                        ContainerPyO3Attribute::Transparent(kw) => {
421                            ensure_spanned!(
422                                !options.transparent,
423                                kw.span() => "`transparent` may only be provided once"
424                            );
425                            options.transparent = true;
426                        }
427                        ContainerPyO3Attribute::ItemAll(kw) => {
428                            ensure_spanned!(
429                                options.from_item_all.is_none(),
430                                kw.span() => "`from_item_all` may only be provided once"
431                            );
432                            options.from_item_all = Some(kw);
433                        }
434                        ContainerPyO3Attribute::ErrorAnnotation(lit_str) => {
435                            ensure_spanned!(
436                                options.annotation.is_none(),
437                                lit_str.span() => "`annotation` may only be provided once"
438                            );
439                            options.annotation = Some(lit_str);
440                        }
441                        ContainerPyO3Attribute::Crate(path) => {
442                            ensure_spanned!(
443                                options.krate.is_none(),
444                                path.span() => "`crate` may only be provided once"
445                            );
446                            options.krate = Some(path);
447                        }
448                    }
449                }
450            }
451        }
452        Ok(options)
453    }
454}
455
456/// Attributes for deriving FromPyObject scoped on fields.
457#[derive(Clone, Debug)]
458struct FieldPyO3Attributes {
459    getter: Option<FieldGetter>,
460    from_py_with: Option<FromPyWithAttribute>,
461}
462
463#[derive(Clone, Debug)]
464enum FieldGetter {
465    GetItem(Option<syn::Lit>),
466    GetAttr(Option<LitStr>),
467}
468
469enum FieldPyO3Attribute {
470    Getter(FieldGetter),
471    FromPyWith(FromPyWithAttribute),
472}
473
474impl Parse for FieldPyO3Attribute {
475    fn parse(input: ParseStream<'_>) -> Result<Self> {
476        let lookahead = input.lookahead1();
477        if lookahead.peek(attributes::kw::attribute) {
478            let _: attributes::kw::attribute = input.parse()?;
479            if input.peek(syn::token::Paren) {
480                let content;
481                let _ = parenthesized!(content in input);
482                let attr_name: LitStr = content.parse()?;
483                if !content.is_empty() {
484                    return Err(content.error(
485                        "expected at most one argument: `attribute` or `attribute(\"name\")`",
486                    ));
487                }
488                ensure_spanned!(
489                    !attr_name.value().is_empty(),
490                    attr_name.span() => "attribute name cannot be empty"
491                );
492                Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some(
493                    attr_name,
494                ))))
495            } else {
496                Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None)))
497            }
498        } else if lookahead.peek(attributes::kw::item) {
499            let _: attributes::kw::item = input.parse()?;
500            if input.peek(syn::token::Paren) {
501                let content;
502                let _ = parenthesized!(content in input);
503                let key = content.parse()?;
504                if !content.is_empty() {
505                    return Err(
506                        content.error("expected at most one argument: `item` or `item(key)`")
507                    );
508                }
509                Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key))))
510            } else {
511                Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None)))
512            }
513        } else if lookahead.peek(attributes::kw::from_py_with) {
514            input.parse().map(FieldPyO3Attribute::FromPyWith)
515        } else {
516            Err(lookahead.error())
517        }
518    }
519}
520
521impl FieldPyO3Attributes {
522    /// Extract the field attributes.
523    fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
524        let mut getter = None;
525        let mut from_py_with = None;
526
527        for attr in attrs {
528            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
529                for pyo3_attr in pyo3_attrs {
530                    match pyo3_attr {
531                        FieldPyO3Attribute::Getter(field_getter) => {
532                            ensure_spanned!(
533                                getter.is_none(),
534                                attr.span() => "only one of `attribute` or `item` can be provided"
535                            );
536                            getter = Some(field_getter);
537                        }
538                        FieldPyO3Attribute::FromPyWith(from_py_with_attr) => {
539                            ensure_spanned!(
540                                from_py_with.is_none(),
541                                attr.span() => "`from_py_with` may only be provided once"
542                            );
543                            from_py_with = Some(from_py_with_attr);
544                        }
545                    }
546                }
547            }
548        }
549
550        Ok(FieldPyO3Attributes {
551            getter,
552            from_py_with,
553        })
554    }
555}
556
557fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
558    let mut lifetimes = generics.lifetimes();
559    let lifetime = lifetimes.next();
560    ensure_spanned!(
561        lifetimes.next().is_none(),
562        generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
563    );
564    Ok(lifetime)
565}
566
567/// Derive FromPyObject for enums and structs.
568///
569///   * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
570///   * At least one field, in case of `#[transparent]`, exactly one field
571///   * At least one variant for enums.
572///   * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
573///   * Derivation for structs with generic fields like `struct<T> Foo(T)`
574///     adds `T: FromPyObject` on the derived implementation.
575pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
576    let options = ContainerOptions::from_attrs(&tokens.attrs)?;
577    let ctx = &Ctx::new(&options.krate, None);
578    let Ctx { pyo3_path, .. } = &ctx;
579
580    let (_, ty_generics, _) = tokens.generics.split_for_impl();
581    let mut trait_generics = tokens.generics.clone();
582    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
583        lt.clone()
584    } else {
585        trait_generics.params.push(parse_quote!('py));
586        parse_quote!('py)
587    };
588    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
589
590    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
591    for param in trait_generics.type_params() {
592        let gen_ident = &param.ident;
593        where_clause
594            .predicates
595            .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
596    }
597
598    let derives = match &tokens.data {
599        syn::Data::Enum(en) => {
600            if options.transparent || options.annotation.is_some() {
601                bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
602                                                at top level for enums");
603            }
604            let en = Enum::new(en, &tokens.ident)?;
605            en.build(ctx)
606        }
607        syn::Data::Struct(st) => {
608            if let Some(lit_str) = &options.annotation {
609                bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
610            }
611            let ident = &tokens.ident;
612            let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
613            st.build(ctx)
614        }
615        syn::Data::Union(_) => bail_spanned!(
616            tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
617        ),
618    };
619
620    let ident = &tokens.ident;
621    Ok(quote!(
622        #[automatically_derived]
623        impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
624            fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self>  {
625                #derives
626            }
627        }
628    ))
629}