1use crate::attributes::{self, get_pyo3_options, CrateAttribute};
2use crate::utils::Ctx;
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, quote_spanned};
5use syn::ext::IdentExt;
6use syn::parse::{Parse, ParseStream};
7use syn::spanned::Spanned as _;
8use syn::{
9 parenthesized, parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result,
10 Token,
11};
12
13enum ContainerPyO3Attribute {
15 Transparent(attributes::kw::transparent),
17 Crate(CrateAttribute),
19}
20
21impl Parse for ContainerPyO3Attribute {
22 fn parse(input: ParseStream<'_>) -> Result<Self> {
23 let lookahead = input.lookahead1();
24 if lookahead.peek(attributes::kw::transparent) {
25 let kw: attributes::kw::transparent = input.parse()?;
26 Ok(ContainerPyO3Attribute::Transparent(kw))
27 } else if lookahead.peek(Token![crate]) {
28 input.parse().map(ContainerPyO3Attribute::Crate)
29 } else {
30 Err(lookahead.error())
31 }
32 }
33}
34
35#[derive(Default)]
36struct ContainerOptions {
37 transparent: Option<attributes::kw::transparent>,
39 krate: Option<CrateAttribute>,
41}
42
43impl ContainerOptions {
44 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
45 let mut options = ContainerOptions::default();
46
47 for attr in attrs {
48 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
49 pyo3_attrs
50 .into_iter()
51 .try_for_each(|opt| options.set_option(opt))?;
52 }
53 }
54 Ok(options)
55 }
56
57 fn set_option(&mut self, option: ContainerPyO3Attribute) -> syn::Result<()> {
58 macro_rules! set_option {
59 ($key:ident) => {
60 {
61 ensure_spanned!(
62 self.$key.is_none(),
63 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
64 );
65 self.$key = Some($key);
66 }
67 };
68 }
69
70 match option {
71 ContainerPyO3Attribute::Transparent(transparent) => set_option!(transparent),
72 ContainerPyO3Attribute::Crate(krate) => set_option!(krate),
73 }
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone)]
79struct ItemOption {
80 field: Option<syn::LitStr>,
81 span: Span,
82}
83
84impl ItemOption {
85 fn span(&self) -> Span {
86 self.span
87 }
88}
89
90enum FieldAttribute {
91 Item(ItemOption),
92}
93
94impl Parse for FieldAttribute {
95 fn parse(input: ParseStream<'_>) -> Result<Self> {
96 let lookahead = input.lookahead1();
97 if lookahead.peek(attributes::kw::attribute) {
98 let attr: attributes::kw::attribute = input.parse()?;
99 bail_spanned!(attr.span => "`attribute` is not supported by `IntoPyObject`");
100 } else if lookahead.peek(attributes::kw::item) {
101 let attr: attributes::kw::item = input.parse()?;
102 if input.peek(syn::token::Paren) {
103 let content;
104 let _ = parenthesized!(content in input);
105 let key = content.parse()?;
106 if !content.is_empty() {
107 return Err(
108 content.error("expected at most one argument: `item` or `item(key)`")
109 );
110 }
111 Ok(FieldAttribute::Item(ItemOption {
112 field: Some(key),
113 span: attr.span,
114 }))
115 } else {
116 Ok(FieldAttribute::Item(ItemOption {
117 field: None,
118 span: attr.span,
119 }))
120 }
121 } else {
122 Err(lookahead.error())
123 }
124 }
125}
126
127#[derive(Clone, Debug, Default)]
128struct FieldAttributes {
129 item: Option<ItemOption>,
130}
131
132impl FieldAttributes {
133 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
135 let mut options = FieldAttributes::default();
136
137 for attr in attrs {
138 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
139 pyo3_attrs
140 .into_iter()
141 .try_for_each(|opt| options.set_option(opt))?;
142 }
143 }
144 Ok(options)
145 }
146
147 fn set_option(&mut self, option: FieldAttribute) -> syn::Result<()> {
148 macro_rules! set_option {
149 ($key:ident) => {
150 {
151 ensure_spanned!(
152 self.$key.is_none(),
153 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
154 );
155 self.$key = Some($key);
156 }
157 };
158 }
159
160 match option {
161 FieldAttribute::Item(item) => set_option!(item),
162 }
163 Ok(())
164 }
165}
166
167enum IntoPyObjectTypes {
168 Transparent(syn::Type),
169 Opaque {
170 target: TokenStream,
171 output: TokenStream,
172 error: TokenStream,
173 },
174}
175
176struct IntoPyObjectImpl {
177 types: IntoPyObjectTypes,
178 body: TokenStream,
179}
180
181struct NamedStructField<'a> {
182 ident: &'a syn::Ident,
183 field: &'a syn::Field,
184 item: Option<ItemOption>,
185}
186
187struct TupleStructField<'a> {
188 field: &'a syn::Field,
189}
190
191enum ContainerType<'a> {
195 Struct(Vec<NamedStructField<'a>>),
199 StructNewtype(&'a syn::Field),
203 Tuple(Vec<TupleStructField<'a>>),
208 TupleNewtype(&'a syn::Field),
212}
213
214struct Container<'a> {
218 path: syn::Path,
219 receiver: Option<Ident>,
220 ty: ContainerType<'a>,
221}
222
223impl<'a> Container<'a> {
225 fn new(
228 receiver: Option<Ident>,
229 fields: &'a Fields,
230 path: syn::Path,
231 options: ContainerOptions,
232 ) -> Result<Self> {
233 let style = match fields {
234 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
235 let mut tuple_fields = unnamed
236 .unnamed
237 .iter()
238 .map(|field| {
239 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
240 ensure_spanned!(
241 attrs.item.is_none(),
242 attrs.item.unwrap().span() => "`item` is not permitted on tuple struct elements."
243 );
244 Ok(TupleStructField { field })
245 })
246 .collect::<Result<Vec<_>>>()?;
247 if tuple_fields.len() == 1 {
248 let TupleStructField { field } = tuple_fields.pop().unwrap();
251 ContainerType::TupleNewtype(field)
252 } else if options.transparent.is_some() {
253 bail_spanned!(
254 fields.span() => "transparent structs and variants can only have 1 field"
255 );
256 } else {
257 ContainerType::Tuple(tuple_fields)
258 }
259 }
260 Fields::Named(named) if !named.named.is_empty() => {
261 if options.transparent.is_some() {
262 ensure_spanned!(
263 named.named.iter().count() == 1,
264 fields.span() => "transparent structs and variants can only have 1 field"
265 );
266
267 let field = named.named.iter().next().unwrap();
268 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
269 ensure_spanned!(
270 attrs.item.is_none(),
271 attrs.item.unwrap().span() => "`transparent` structs may not have `item` for the inner field"
272 );
273 ContainerType::StructNewtype(field)
274 } else {
275 let struct_fields = named
276 .named
277 .iter()
278 .map(|field| {
279 let ident = field
280 .ident
281 .as_ref()
282 .expect("Named fields should have identifiers");
283
284 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
285
286 Ok(NamedStructField {
287 ident,
288 field,
289 item: attrs.item,
290 })
291 })
292 .collect::<Result<Vec<_>>>()?;
293 ContainerType::Struct(struct_fields)
294 }
295 }
296 _ => bail_spanned!(
297 fields.span() => "cannot derive `IntoPyObject` for empty structs"
298 ),
299 };
300
301 let v = Container {
302 path,
303 receiver,
304 ty: style,
305 };
306 Ok(v)
307 }
308
309 fn match_pattern(&self) -> TokenStream {
310 let path = &self.path;
311 let pattern = match &self.ty {
312 ContainerType::Struct(fields) => fields
313 .iter()
314 .enumerate()
315 .map(|(i, f)| {
316 let ident = f.ident;
317 let new_ident = format_ident!("arg{i}");
318 quote! {#ident: #new_ident,}
319 })
320 .collect::<TokenStream>(),
321 ContainerType::StructNewtype(field) => {
322 let ident = field.ident.as_ref().unwrap();
323 quote!(#ident: arg0)
324 }
325 ContainerType::Tuple(fields) => {
326 let i = (0..fields.len()).map(Index::from);
327 let idents = (0..fields.len()).map(|i| format_ident!("arg{i}"));
328 quote! { #(#i: #idents,)* }
329 }
330 ContainerType::TupleNewtype(_) => quote!(0: arg0),
331 };
332
333 quote! { #path{ #pattern } }
334 }
335
336 fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
338 match &self.ty {
339 ContainerType::StructNewtype(field) | ContainerType::TupleNewtype(field) => {
340 self.build_newtype_struct(field, ctx)
341 }
342 ContainerType::Tuple(fields) => self.build_tuple_struct(fields, ctx),
343 ContainerType::Struct(fields) => self.build_struct(fields, ctx),
344 }
345 }
346
347 fn build_newtype_struct(&self, field: &syn::Field, ctx: &Ctx) -> IntoPyObjectImpl {
348 let Ctx { pyo3_path, .. } = ctx;
349 let ty = &field.ty;
350
351 let unpack = self
352 .receiver
353 .as_ref()
354 .map(|i| {
355 let pattern = self.match_pattern();
356 quote! { let #pattern = #i;}
357 })
358 .unwrap_or_default();
359
360 IntoPyObjectImpl {
361 types: IntoPyObjectTypes::Transparent(ty.clone()),
362 body: quote_spanned! { ty.span() =>
363 #unpack
364 #pyo3_path::conversion::IntoPyObject::into_pyobject(arg0, py)
365 },
366 }
367 }
368
369 fn build_struct(&self, fields: &[NamedStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
370 let Ctx { pyo3_path, .. } = ctx;
371
372 let unpack = self
373 .receiver
374 .as_ref()
375 .map(|i| {
376 let pattern = self.match_pattern();
377 quote! { let #pattern = #i;}
378 })
379 .unwrap_or_default();
380
381 let setter = fields
382 .iter()
383 .enumerate()
384 .map(|(i, f)| {
385 let key = f
386 .item
387 .as_ref()
388 .and_then(|item| item.field.as_ref())
389 .map(|item| item.value())
390 .unwrap_or_else(|| f.ident.unraw().to_string());
391 let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
392 quote! {
393 #pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?;
394 }
395 })
396 .collect::<TokenStream>();
397
398 IntoPyObjectImpl {
399 types: IntoPyObjectTypes::Opaque {
400 target: quote!(#pyo3_path::types::PyDict),
401 output: quote!(#pyo3_path::Bound<'py, Self::Target>),
402 error: quote!(#pyo3_path::PyErr),
403 },
404 body: quote! {
405 #unpack
406 let dict = #pyo3_path::types::PyDict::new(py);
407 #setter
408 ::std::result::Result::Ok::<_, Self::Error>(dict)
409 },
410 }
411 }
412
413 fn build_tuple_struct(&self, fields: &[TupleStructField<'_>], ctx: &Ctx) -> IntoPyObjectImpl {
414 let Ctx { pyo3_path, .. } = ctx;
415
416 let unpack = self
417 .receiver
418 .as_ref()
419 .map(|i| {
420 let pattern = self.match_pattern();
421 quote! { let #pattern = #i;}
422 })
423 .unwrap_or_default();
424
425 let setter = fields
426 .iter()
427 .enumerate()
428 .map(|(i, f)| {
429 let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
430 quote_spanned! { f.field.ty.span() =>
431 #pyo3_path::conversion::IntoPyObject::into_pyobject(#value, py)
432 .map(#pyo3_path::BoundObject::into_any)
433 .map(#pyo3_path::BoundObject::into_bound)?,
434 }
435 })
436 .collect::<TokenStream>();
437
438 IntoPyObjectImpl {
439 types: IntoPyObjectTypes::Opaque {
440 target: quote!(#pyo3_path::types::PyTuple),
441 output: quote!(#pyo3_path::Bound<'py, Self::Target>),
442 error: quote!(#pyo3_path::PyErr),
443 },
444 body: quote! {
445 #unpack
446 #pyo3_path::types::PyTuple::new(py, [#setter])
447 },
448 }
449 }
450}
451
452struct Enum<'a> {
454 variants: Vec<Container<'a>>,
455}
456
457impl<'a> Enum<'a> {
458 fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
463 ensure_spanned!(
464 !data_enum.variants.is_empty(),
465 ident.span() => "cannot derive `IntoPyObject` for empty enum"
466 );
467 let variants = data_enum
468 .variants
469 .iter()
470 .map(|variant| {
471 let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
472 let var_ident = &variant.ident;
473
474 ensure_spanned!(
475 !variant.fields.is_empty(),
476 variant.ident.span() => "cannot derive `IntoPyObject` for empty variants"
477 );
478
479 Container::new(
480 None,
481 &variant.fields,
482 parse_quote!(#ident::#var_ident),
483 attrs,
484 )
485 })
486 .collect::<Result<Vec<_>>>()?;
487
488 Ok(Enum { variants })
489 }
490
491 fn build(&self, ctx: &Ctx) -> IntoPyObjectImpl {
493 let Ctx { pyo3_path, .. } = ctx;
494
495 let variants = self
496 .variants
497 .iter()
498 .map(|v| {
499 let IntoPyObjectImpl { body, .. } = v.build(ctx);
500 let pattern = v.match_pattern();
501 quote! {
502 #pattern => {
503 {#body}
504 .map(#pyo3_path::BoundObject::into_any)
505 .map(#pyo3_path::BoundObject::into_bound)
506 .map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
507 }
508 }
509 })
510 .collect::<TokenStream>();
511
512 IntoPyObjectImpl {
513 types: IntoPyObjectTypes::Opaque {
514 target: quote!(#pyo3_path::types::PyAny),
515 output: quote!(#pyo3_path::Bound<'py, <Self as #pyo3_path::conversion::IntoPyObject<'py>>::Target>),
516 error: quote!(#pyo3_path::PyErr),
517 },
518 body: quote! {
519 match self {
520 #variants
521 }
522 },
523 }
524 }
525}
526
527fn verify_and_get_lifetime(generics: &syn::Generics) -> Option<&syn::LifetimeParam> {
529 let mut lifetimes = generics.lifetimes();
530 lifetimes.find(|l| l.lifetime.ident == "py")
531}
532
533pub fn build_derive_into_pyobject<const REF: bool>(tokens: &DeriveInput) -> Result<TokenStream> {
534 let options = ContainerOptions::from_attrs(&tokens.attrs)?;
535 let ctx = &Ctx::new(&options.krate, None);
536 let Ctx { pyo3_path, .. } = &ctx;
537
538 let (_, ty_generics, _) = tokens.generics.split_for_impl();
539 let mut trait_generics = tokens.generics.clone();
540 if REF {
541 trait_generics.params.push(parse_quote!('_a));
542 }
543 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics) {
544 lt.clone()
545 } else {
546 trait_generics.params.push(parse_quote!('py));
547 parse_quote!('py)
548 };
549 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
550
551 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
552 for param in trait_generics.type_params() {
553 let gen_ident = ¶m.ident;
554 where_clause.predicates.push(if REF {
555 parse_quote!(&'_a #gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
556 } else {
557 parse_quote!(#gen_ident: #pyo3_path::conversion::IntoPyObject<'py>)
558 })
559 }
560
561 let IntoPyObjectImpl { types, body } = match &tokens.data {
562 syn::Data::Enum(en) => {
563 if options.transparent.is_some() {
564 bail_spanned!(tokens.span() => "`transparent` is not supported at top level for enums");
565 }
566 let en = Enum::new(en, &tokens.ident)?;
567 en.build(ctx)
568 }
569 syn::Data::Struct(st) => {
570 let ident = &tokens.ident;
571 let st = Container::new(
572 Some(Ident::new("self", Span::call_site())),
573 &st.fields,
574 parse_quote!(#ident),
575 options,
576 )?;
577 st.build(ctx)
578 }
579 syn::Data::Union(_) => bail_spanned!(
580 tokens.span() => "#[derive(`IntoPyObject`)] is not supported for unions"
581 ),
582 };
583
584 let (target, output, error) = match types {
585 IntoPyObjectTypes::Transparent(ty) => {
586 if REF {
587 (
588 quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Target },
589 quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Output },
590 quote! { <&'_a #ty as #pyo3_path::IntoPyObject<'py>>::Error },
591 )
592 } else {
593 (
594 quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Target },
595 quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Output },
596 quote! { <#ty as #pyo3_path::IntoPyObject<'py>>::Error },
597 )
598 }
599 }
600 IntoPyObjectTypes::Opaque {
601 target,
602 output,
603 error,
604 } => (target, output, error),
605 };
606
607 let ident = &tokens.ident;
608 let ident = if REF {
609 quote! { &'_a #ident}
610 } else {
611 quote! { #ident }
612 };
613 Ok(quote!(
614 #[automatically_derived]
615 impl #impl_generics #pyo3_path::conversion::IntoPyObject<#lt_param> for #ident #ty_generics #where_clause {
616 type Target = #target;
617 type Output = #output;
618 type Error = #error;
619
620 fn into_pyobject(self, py: #pyo3_path::Python<#lt_param>) -> ::std::result::Result<
621 <Self as #pyo3_path::conversion::IntoPyObject>::Output,
622 <Self as #pyo3_path::conversion::IntoPyObject>::Error,
623 > {
624 #body
625 }
626 }
627 ))
628}