1use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
2use crate::utils::Ctx;
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6 ext::IdentExt,
7 parenthesized,
8 parse::{Parse, ParseStream},
9 parse_quote,
10 punctuated::Punctuated,
11 spanned::Spanned,
12 Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token,
13};
14
15struct Enum<'a> {
17 enum_ident: &'a Ident,
18 variants: Vec<Container<'a>>,
19}
20
21impl<'a> Enum<'a> {
22 fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
27 ensure_spanned!(
28 !data_enum.variants.is_empty(),
29 ident.span() => "cannot derive FromPyObject for empty enum"
30 );
31 let variants = data_enum
32 .variants
33 .iter()
34 .map(|variant| {
35 let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
36 let var_ident = &variant.ident;
37 Container::new(&variant.fields, parse_quote!(#ident::#var_ident), attrs)
38 })
39 .collect::<Result<Vec<_>>>()?;
40
41 Ok(Enum {
42 enum_ident: ident,
43 variants,
44 })
45 }
46
47 fn build(&self, ctx: &Ctx) -> TokenStream {
49 let Ctx { pyo3_path, .. } = ctx;
50 let mut var_extracts = Vec::new();
51 let mut variant_names = Vec::new();
52 let mut error_names = Vec::new();
53
54 for var in &self.variants {
55 let struct_derive = var.build(ctx);
56 let ext = quote!({
57 let maybe_ret = || -> #pyo3_path::PyResult<Self> {
58 #struct_derive
59 }();
60
61 match maybe_ret {
62 ok @ ::std::result::Result::Ok(_) => return ok,
63 ::std::result::Result::Err(err) => err
64 }
65 });
66
67 var_extracts.push(ext);
68 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
69 error_names.push(&var.err_name);
70 }
71 let ty_name = self.enum_ident.to_string();
72 quote!(
73 let errors = [
74 #(#var_extracts),*
75 ];
76 ::std::result::Result::Err(
77 #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
78 obj.py(),
79 #ty_name,
80 &[#(#variant_names),*],
81 &[#(#error_names),*],
82 &errors
83 )
84 )
85 )
86 }
87}
88
89struct NamedStructField<'a> {
90 ident: &'a syn::Ident,
91 getter: Option<FieldGetter>,
92 from_py_with: Option<FromPyWithAttribute>,
93}
94
95struct TupleStructField {
96 from_py_with: Option<FromPyWithAttribute>,
97}
98
99enum ContainerType<'a> {
103 Struct(Vec<NamedStructField<'a>>),
107 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
111 Tuple(Vec<TupleStructField>),
116 TupleNewtype(Option<FromPyWithAttribute>),
120}
121
122struct Container<'a> {
126 path: syn::Path,
127 ty: ContainerType<'a>,
128 err_name: String,
129}
130
131impl<'a> Container<'a> {
132 fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> {
136 let style = match fields {
137 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
138 let mut tuple_fields = unnamed
139 .unnamed
140 .iter()
141 .map(|field| {
142 let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
143 ensure_spanned!(
144 attrs.getter.is_none(),
145 field.span() => "`getter` is not permitted on tuple struct elements."
146 );
147 Ok(TupleStructField {
148 from_py_with: attrs.from_py_with,
149 })
150 })
151 .collect::<Result<Vec<_>>>()?;
152
153 if tuple_fields.len() == 1 {
154 let field = tuple_fields.pop().unwrap();
157 ContainerType::TupleNewtype(field.from_py_with)
158 } else if options.transparent {
159 bail_spanned!(
160 fields.span() => "transparent structs and variants can only have 1 field"
161 );
162 } else {
163 ContainerType::Tuple(tuple_fields)
164 }
165 }
166 Fields::Named(named) if !named.named.is_empty() => {
167 let mut struct_fields = named
168 .named
169 .iter()
170 .map(|field| {
171 let ident = field
172 .ident
173 .as_ref()
174 .expect("Named fields should have identifiers");
175 let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
176
177 if let Some(ref from_item_all) = options.from_item_all {
178 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None))
179 {
180 match replaced {
181 FieldGetter::GetItem(Some(item_name)) => {
182 attrs.getter = Some(FieldGetter::GetItem(Some(item_name)));
183 }
184 FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
185 FieldGetter::GetAttr(_) => bail_spanned!(
186 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
187 ),
188 }
189 }
190 }
191
192 Ok(NamedStructField {
193 ident,
194 getter: attrs.getter,
195 from_py_with: attrs.from_py_with,
196 })
197 })
198 .collect::<Result<Vec<_>>>()?;
199 if options.transparent {
200 ensure_spanned!(
201 struct_fields.len() == 1,
202 fields.span() => "transparent structs and variants can only have 1 field"
203 );
204 let field = struct_fields.pop().unwrap();
205 ensure_spanned!(
206 field.getter.is_none(),
207 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
208 );
209 ContainerType::StructNewtype(field.ident, field.from_py_with)
210 } else {
211 ContainerType::Struct(struct_fields)
212 }
213 }
214 _ => bail_spanned!(
215 fields.span() => "cannot derive FromPyObject for empty structs and variants"
216 ),
217 };
218 let err_name = options.annotation.map_or_else(
219 || path.segments.last().unwrap().ident.to_string(),
220 |lit_str| lit_str.value(),
221 );
222
223 let v = Container {
224 path,
225 ty: style,
226 err_name,
227 };
228 Ok(v)
229 }
230
231 fn name(&self) -> String {
232 let mut value = String::new();
233 for segment in &self.path.segments {
234 if !value.is_empty() {
235 value.push_str("::");
236 }
237 value.push_str(&segment.ident.to_string());
238 }
239 value
240 }
241
242 fn build(&self, ctx: &Ctx) -> TokenStream {
244 match &self.ty {
245 ContainerType::StructNewtype(ident, from_py_with) => {
246 self.build_newtype_struct(Some(ident), from_py_with, ctx)
247 }
248 ContainerType::TupleNewtype(from_py_with) => {
249 self.build_newtype_struct(None, from_py_with, ctx)
250 }
251 ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
252 ContainerType::Struct(tups) => self.build_struct(tups, ctx),
253 }
254 }
255
256 fn build_newtype_struct(
257 &self,
258 field_ident: Option<&Ident>,
259 from_py_with: &Option<FromPyWithAttribute>,
260 ctx: &Ctx,
261 ) -> TokenStream {
262 let Ctx { pyo3_path, .. } = ctx;
263 let self_ty = &self.path;
264 let struct_name = self.name();
265 if let Some(ident) = field_ident {
266 let field_name = ident.to_string();
267 match from_py_with {
268 None => quote! {
269 Ok(#self_ty {
270 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
271 })
272 },
273 Some(FromPyWithAttribute {
274 value: expr_path, ..
275 }) => quote! {
276 Ok(#self_ty {
277 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
278 })
279 },
280 }
281 } else {
282 match from_py_with {
283 None => quote! {
284 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
285 },
286
287 Some(FromPyWithAttribute {
288 value: expr_path, ..
289 }) => quote! {
290 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
291 },
292 }
293 }
294 }
295
296 fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
297 let Ctx { pyo3_path, .. } = ctx;
298 let self_ty = &self.path;
299 let struct_name = &self.name();
300 let field_idents: Vec<_> = (0..struct_fields.len())
301 .map(|i| format_ident!("arg{}", i))
302 .collect();
303 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
304 match &field.from_py_with {
305 None => quote!(
306 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
307 ),
308 Some(FromPyWithAttribute {
309 value: expr_path, ..
310 }) => quote! (
311 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
312 ),
313 }
314 });
315
316 quote!(
317 match #pyo3_path::types::PyAnyMethods::extract(obj) {
318 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
319 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
320 }
321 )
322 }
323
324 fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
325 let Ctx { pyo3_path, .. } = ctx;
326 let self_ty = &self.path;
327 let struct_name = self.name();
328 let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
329 for field in struct_fields {
330 let ident = field.ident;
331 let field_name = ident.unraw().to_string();
332 let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
333 FieldGetter::GetAttr(Some(name)) => {
334 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
335 }
336 FieldGetter::GetAttr(None) => {
337 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #field_name)))
338 }
339 FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
340 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
341 }
342 FieldGetter::GetItem(Some(key)) => {
343 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
344 }
345 FieldGetter::GetItem(None) => {
346 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
347 }
348 };
349 let extractor = match &field.from_py_with {
350 None => {
351 quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
352 }
353 Some(FromPyWithAttribute {
354 value: expr_path, ..
355 }) => {
356 quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
357 }
358 };
359
360 fields.push(quote!(#ident: #extractor));
361 }
362
363 quote!(::std::result::Result::Ok(#self_ty{#fields}))
364 }
365}
366
367#[derive(Default)]
368struct ContainerOptions {
369 transparent: bool,
371 from_item_all: Option<attributes::kw::from_item_all>,
373 annotation: Option<syn::LitStr>,
375 krate: Option<CrateAttribute>,
377}
378
379enum ContainerPyO3Attribute {
381 Transparent(attributes::kw::transparent),
383 ItemAll(attributes::kw::from_item_all),
385 ErrorAnnotation(LitStr),
387 Crate(CrateAttribute),
389}
390
391impl Parse for ContainerPyO3Attribute {
392 fn parse(input: ParseStream<'_>) -> Result<Self> {
393 let lookahead = input.lookahead1();
394 if lookahead.peek(attributes::kw::transparent) {
395 let kw: attributes::kw::transparent = input.parse()?;
396 Ok(ContainerPyO3Attribute::Transparent(kw))
397 } else if lookahead.peek(attributes::kw::from_item_all) {
398 let kw: attributes::kw::from_item_all = input.parse()?;
399 Ok(ContainerPyO3Attribute::ItemAll(kw))
400 } else if lookahead.peek(attributes::kw::annotation) {
401 let _: attributes::kw::annotation = input.parse()?;
402 let _: Token![=] = input.parse()?;
403 input.parse().map(ContainerPyO3Attribute::ErrorAnnotation)
404 } else if lookahead.peek(Token![crate]) {
405 input.parse().map(ContainerPyO3Attribute::Crate)
406 } else {
407 Err(lookahead.error())
408 }
409 }
410}
411
412impl ContainerOptions {
413 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
414 let mut options = ContainerOptions::default();
415
416 for attr in attrs {
417 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
418 for pyo3_attr in pyo3_attrs {
419 match pyo3_attr {
420 ContainerPyO3Attribute::Transparent(kw) => {
421 ensure_spanned!(
422 !options.transparent,
423 kw.span() => "`transparent` may only be provided once"
424 );
425 options.transparent = true;
426 }
427 ContainerPyO3Attribute::ItemAll(kw) => {
428 ensure_spanned!(
429 options.from_item_all.is_none(),
430 kw.span() => "`from_item_all` may only be provided once"
431 );
432 options.from_item_all = Some(kw);
433 }
434 ContainerPyO3Attribute::ErrorAnnotation(lit_str) => {
435 ensure_spanned!(
436 options.annotation.is_none(),
437 lit_str.span() => "`annotation` may only be provided once"
438 );
439 options.annotation = Some(lit_str);
440 }
441 ContainerPyO3Attribute::Crate(path) => {
442 ensure_spanned!(
443 options.krate.is_none(),
444 path.span() => "`crate` may only be provided once"
445 );
446 options.krate = Some(path);
447 }
448 }
449 }
450 }
451 }
452 Ok(options)
453 }
454}
455
456#[derive(Clone, Debug)]
458struct FieldPyO3Attributes {
459 getter: Option<FieldGetter>,
460 from_py_with: Option<FromPyWithAttribute>,
461}
462
463#[derive(Clone, Debug)]
464enum FieldGetter {
465 GetItem(Option<syn::Lit>),
466 GetAttr(Option<LitStr>),
467}
468
469enum FieldPyO3Attribute {
470 Getter(FieldGetter),
471 FromPyWith(FromPyWithAttribute),
472}
473
474impl Parse for FieldPyO3Attribute {
475 fn parse(input: ParseStream<'_>) -> Result<Self> {
476 let lookahead = input.lookahead1();
477 if lookahead.peek(attributes::kw::attribute) {
478 let _: attributes::kw::attribute = input.parse()?;
479 if input.peek(syn::token::Paren) {
480 let content;
481 let _ = parenthesized!(content in input);
482 let attr_name: LitStr = content.parse()?;
483 if !content.is_empty() {
484 return Err(content.error(
485 "expected at most one argument: `attribute` or `attribute(\"name\")`",
486 ));
487 }
488 ensure_spanned!(
489 !attr_name.value().is_empty(),
490 attr_name.span() => "attribute name cannot be empty"
491 );
492 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some(
493 attr_name,
494 ))))
495 } else {
496 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None)))
497 }
498 } else if lookahead.peek(attributes::kw::item) {
499 let _: attributes::kw::item = input.parse()?;
500 if input.peek(syn::token::Paren) {
501 let content;
502 let _ = parenthesized!(content in input);
503 let key = content.parse()?;
504 if !content.is_empty() {
505 return Err(
506 content.error("expected at most one argument: `item` or `item(key)`")
507 );
508 }
509 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key))))
510 } else {
511 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None)))
512 }
513 } else if lookahead.peek(attributes::kw::from_py_with) {
514 input.parse().map(FieldPyO3Attribute::FromPyWith)
515 } else {
516 Err(lookahead.error())
517 }
518 }
519}
520
521impl FieldPyO3Attributes {
522 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
524 let mut getter = None;
525 let mut from_py_with = None;
526
527 for attr in attrs {
528 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
529 for pyo3_attr in pyo3_attrs {
530 match pyo3_attr {
531 FieldPyO3Attribute::Getter(field_getter) => {
532 ensure_spanned!(
533 getter.is_none(),
534 attr.span() => "only one of `attribute` or `item` can be provided"
535 );
536 getter = Some(field_getter);
537 }
538 FieldPyO3Attribute::FromPyWith(from_py_with_attr) => {
539 ensure_spanned!(
540 from_py_with.is_none(),
541 attr.span() => "`from_py_with` may only be provided once"
542 );
543 from_py_with = Some(from_py_with_attr);
544 }
545 }
546 }
547 }
548 }
549
550 Ok(FieldPyO3Attributes {
551 getter,
552 from_py_with,
553 })
554 }
555}
556
557fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
558 let mut lifetimes = generics.lifetimes();
559 let lifetime = lifetimes.next();
560 ensure_spanned!(
561 lifetimes.next().is_none(),
562 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
563 );
564 Ok(lifetime)
565}
566
567pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
576 let options = ContainerOptions::from_attrs(&tokens.attrs)?;
577 let ctx = &Ctx::new(&options.krate, None);
578 let Ctx { pyo3_path, .. } = &ctx;
579
580 let (_, ty_generics, _) = tokens.generics.split_for_impl();
581 let mut trait_generics = tokens.generics.clone();
582 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
583 lt.clone()
584 } else {
585 trait_generics.params.push(parse_quote!('py));
586 parse_quote!('py)
587 };
588 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
589
590 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
591 for param in trait_generics.type_params() {
592 let gen_ident = ¶m.ident;
593 where_clause
594 .predicates
595 .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
596 }
597
598 let derives = match &tokens.data {
599 syn::Data::Enum(en) => {
600 if options.transparent || options.annotation.is_some() {
601 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
602 at top level for enums");
603 }
604 let en = Enum::new(en, &tokens.ident)?;
605 en.build(ctx)
606 }
607 syn::Data::Struct(st) => {
608 if let Some(lit_str) = &options.annotation {
609 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
610 }
611 let ident = &tokens.ident;
612 let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
613 st.build(ctx)
614 }
615 syn::Data::Union(_) => bail_spanned!(
616 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
617 ),
618 };
619
620 let ident = &tokens.ident;
621 Ok(quote!(
622 #[automatically_derived]
623 impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
624 fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
625 #derives
626 }
627 }
628 ))
629}