pyo3_macros_backend/
module.rs

1//! Code generation for the function that initializes a python module and adds classes and function.
2
3use crate::{
4    attributes::{
5        self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
6        ModuleAttribute, NameAttribute, SubmoduleAttribute,
7    },
8    get_doc,
9    pyclass::PyClassPyO3Option,
10    pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
11    utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
12};
13use proc_macro2::{Span, TokenStream};
14use quote::quote;
15use std::ffi::CString;
16use syn::{
17    ext::IdentExt,
18    parse::{Parse, ParseStream},
19    parse_quote, parse_quote_spanned,
20    punctuated::Punctuated,
21    spanned::Spanned,
22    token::Comma,
23    Item, Meta, Path, Result,
24};
25
26#[derive(Default)]
27pub struct PyModuleOptions {
28    krate: Option<CrateAttribute>,
29    name: Option<NameAttribute>,
30    module: Option<ModuleAttribute>,
31    submodule: Option<kw::submodule>,
32    gil_used: Option<GILUsedAttribute>,
33}
34
35impl Parse for PyModuleOptions {
36    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37        let mut options: PyModuleOptions = Default::default();
38
39        options.add_attributes(
40            Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
41        )?;
42
43        Ok(options)
44    }
45}
46
47impl PyModuleOptions {
48    fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
49        self.add_attributes(take_pyo3_options(attrs)?)
50    }
51
52    fn add_attributes(
53        &mut self,
54        attrs: impl IntoIterator<Item = PyModulePyO3Option>,
55    ) -> Result<()> {
56        macro_rules! set_option {
57            ($key:ident $(, $extra:literal)?) => {
58                {
59                    ensure_spanned!(
60                        self.$key.is_none(),
61                        $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
62                    );
63                    self.$key = Some($key);
64                }
65            };
66        }
67        for attr in attrs {
68            match attr {
69                PyModulePyO3Option::Crate(krate) => set_option!(krate),
70                PyModulePyO3Option::Name(name) => set_option!(name),
71                PyModulePyO3Option::Module(module) => set_option!(module),
72                PyModulePyO3Option::Submodule(submodule) => set_option!(
73                    submodule,
74                    " (it is implicitly always specified for nested modules)"
75                ),
76                PyModulePyO3Option::GILUsed(gil_used) => {
77                    set_option!(gil_used)
78                }
79            }
80        }
81        Ok(())
82    }
83}
84
85pub fn pymodule_module_impl(
86    module: &mut syn::ItemMod,
87    mut options: PyModuleOptions,
88) -> Result<TokenStream> {
89    let syn::ItemMod {
90        attrs,
91        vis,
92        unsafety: _,
93        ident,
94        mod_token,
95        content,
96        semi: _,
97    } = module;
98    let items = if let Some((_, items)) = content {
99        items
100    } else {
101        bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
102    };
103    options.take_pyo3_options(attrs)?;
104    let ctx = &Ctx::new(&options.krate, None);
105    let Ctx { pyo3_path, .. } = ctx;
106    let doc = get_doc(attrs, None, ctx);
107    let name = options
108        .name
109        .map_or_else(|| ident.unraw(), |name| name.value.0);
110    let full_name = if let Some(module) = &options.module {
111        format!("{}.{}", module.value.value(), name)
112    } else {
113        name.to_string()
114    };
115
116    let mut module_items = Vec::new();
117    let mut module_items_cfg_attrs = Vec::new();
118
119    fn extract_use_items(
120        source: &syn::UseTree,
121        cfg_attrs: &[syn::Attribute],
122        target_items: &mut Vec<syn::Ident>,
123        target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
124    ) -> Result<()> {
125        match source {
126            syn::UseTree::Name(name) => {
127                target_items.push(name.ident.clone());
128                target_cfg_attrs.push(cfg_attrs.to_vec());
129            }
130            syn::UseTree::Path(path) => {
131                extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
132            }
133            syn::UseTree::Group(group) => {
134                for tree in &group.items {
135                    extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
136                }
137            }
138            syn::UseTree::Glob(glob) => {
139                bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
140            }
141            syn::UseTree::Rename(rename) => {
142                target_items.push(rename.rename.clone());
143                target_cfg_attrs.push(cfg_attrs.to_vec());
144            }
145        }
146        Ok(())
147    }
148
149    let mut pymodule_init = None;
150
151    for item in &mut *items {
152        match item {
153            Item::Use(item_use) => {
154                let is_pymodule_export =
155                    find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
156                if is_pymodule_export {
157                    let cfg_attrs = get_cfg_attributes(&item_use.attrs);
158                    extract_use_items(
159                        &item_use.tree,
160                        &cfg_attrs,
161                        &mut module_items,
162                        &mut module_items_cfg_attrs,
163                    )?;
164                }
165            }
166            Item::Fn(item_fn) => {
167                ensure_spanned!(
168                    !has_attribute(&item_fn.attrs, "pymodule_export"),
169                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
170                );
171                let is_pymodule_init =
172                    find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
173                let ident = &item_fn.sig.ident;
174                if is_pymodule_init {
175                    ensure_spanned!(
176                        !has_attribute(&item_fn.attrs, "pyfunction"),
177                        item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
178                    );
179                    ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
180                    pymodule_init = Some(quote! { #ident(module)?; });
181                } else if has_attribute(&item_fn.attrs, "pyfunction")
182                    || has_attribute_with_namespace(
183                        &item_fn.attrs,
184                        Some(pyo3_path),
185                        &["pyfunction"],
186                    )
187                    || has_attribute_with_namespace(
188                        &item_fn.attrs,
189                        Some(pyo3_path),
190                        &["prelude", "pyfunction"],
191                    )
192                {
193                    module_items.push(ident.clone());
194                    module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
195                }
196            }
197            Item::Struct(item_struct) => {
198                ensure_spanned!(
199                    !has_attribute(&item_struct.attrs, "pymodule_export"),
200                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
201                );
202                if has_attribute(&item_struct.attrs, "pyclass")
203                    || has_attribute_with_namespace(
204                        &item_struct.attrs,
205                        Some(pyo3_path),
206                        &["pyclass"],
207                    )
208                    || has_attribute_with_namespace(
209                        &item_struct.attrs,
210                        Some(pyo3_path),
211                        &["prelude", "pyclass"],
212                    )
213                {
214                    module_items.push(item_struct.ident.clone());
215                    module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
216                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
217                        &item_struct.attrs,
218                        "pyclass",
219                        |option| matches!(option, PyClassPyO3Option::Module(_)),
220                    )? {
221                        set_module_attribute(&mut item_struct.attrs, &full_name);
222                    }
223                }
224            }
225            Item::Enum(item_enum) => {
226                ensure_spanned!(
227                    !has_attribute(&item_enum.attrs, "pymodule_export"),
228                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
229                );
230                if has_attribute(&item_enum.attrs, "pyclass")
231                    || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
232                    || has_attribute_with_namespace(
233                        &item_enum.attrs,
234                        Some(pyo3_path),
235                        &["prelude", "pyclass"],
236                    )
237                {
238                    module_items.push(item_enum.ident.clone());
239                    module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
240                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
241                        &item_enum.attrs,
242                        "pyclass",
243                        |option| matches!(option, PyClassPyO3Option::Module(_)),
244                    )? {
245                        set_module_attribute(&mut item_enum.attrs, &full_name);
246                    }
247                }
248            }
249            Item::Mod(item_mod) => {
250                ensure_spanned!(
251                    !has_attribute(&item_mod.attrs, "pymodule_export"),
252                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
253                );
254                if has_attribute(&item_mod.attrs, "pymodule")
255                    || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
256                    || has_attribute_with_namespace(
257                        &item_mod.attrs,
258                        Some(pyo3_path),
259                        &["prelude", "pymodule"],
260                    )
261                {
262                    module_items.push(item_mod.ident.clone());
263                    module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
264                    if !has_pyo3_module_declared::<PyModulePyO3Option>(
265                        &item_mod.attrs,
266                        "pymodule",
267                        |option| matches!(option, PyModulePyO3Option::Module(_)),
268                    )? {
269                        set_module_attribute(&mut item_mod.attrs, &full_name);
270                    }
271                    item_mod
272                        .attrs
273                        .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
274                }
275            }
276            Item::ForeignMod(item) => {
277                ensure_spanned!(
278                    !has_attribute(&item.attrs, "pymodule_export"),
279                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
280                );
281            }
282            Item::Trait(item) => {
283                ensure_spanned!(
284                    !has_attribute(&item.attrs, "pymodule_export"),
285                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
286                );
287            }
288            Item::Const(item) => {
289                ensure_spanned!(
290                    !has_attribute(&item.attrs, "pymodule_export"),
291                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
292                );
293            }
294            Item::Static(item) => {
295                ensure_spanned!(
296                    !has_attribute(&item.attrs, "pymodule_export"),
297                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
298                );
299            }
300            Item::Macro(item) => {
301                ensure_spanned!(
302                    !has_attribute(&item.attrs, "pymodule_export"),
303                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
304                );
305            }
306            Item::ExternCrate(item) => {
307                ensure_spanned!(
308                    !has_attribute(&item.attrs, "pymodule_export"),
309                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
310                );
311            }
312            Item::Impl(item) => {
313                ensure_spanned!(
314                    !has_attribute(&item.attrs, "pymodule_export"),
315                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
316                );
317            }
318            Item::TraitAlias(item) => {
319                ensure_spanned!(
320                    !has_attribute(&item.attrs, "pymodule_export"),
321                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
322                );
323            }
324            Item::Type(item) => {
325                ensure_spanned!(
326                    !has_attribute(&item.attrs, "pymodule_export"),
327                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
328                );
329            }
330            Item::Union(item) => {
331                ensure_spanned!(
332                    !has_attribute(&item.attrs, "pymodule_export"),
333                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
334                );
335            }
336            _ => (),
337        }
338    }
339
340    let module_def = quote! {{
341        use #pyo3_path::impl_::pymodule as impl_;
342        const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
343        unsafe {
344           impl_::ModuleDef::new(
345                __PYO3_NAME,
346                #doc,
347                INITIALIZER
348            )
349        }
350    }};
351    let initialization = module_initialization(
352        &name,
353        ctx,
354        module_def,
355        options.submodule.is_some(),
356        options.gil_used.map_or(true, |op| op.value.value),
357    );
358
359    Ok(quote!(
360        #(#attrs)*
361        #vis #mod_token #ident {
362            #(#items)*
363
364            #initialization
365
366            fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
367                use #pyo3_path::impl_::pymodule::PyAddToModule;
368                #(
369                    #(#module_items_cfg_attrs)*
370                    #module_items::_PYO3_DEF.add_to_module(module)?;
371                )*
372                #pymodule_init
373                ::std::result::Result::Ok(())
374            }
375        }
376    ))
377}
378
379/// Generates the function that is called by the python interpreter to initialize the native
380/// module
381pub fn pymodule_function_impl(
382    function: &mut syn::ItemFn,
383    mut options: PyModuleOptions,
384) -> Result<TokenStream> {
385    options.take_pyo3_options(&mut function.attrs)?;
386    process_functions_in_module(&options, function)?;
387    let ctx = &Ctx::new(&options.krate, None);
388    let Ctx { pyo3_path, .. } = ctx;
389    let ident = &function.sig.ident;
390    let name = options
391        .name
392        .map_or_else(|| ident.unraw(), |name| name.value.0);
393    let vis = &function.vis;
394    let doc = get_doc(&function.attrs, None, ctx);
395
396    let initialization = module_initialization(
397        &name,
398        ctx,
399        quote! { MakeDef::make_def() },
400        false,
401        options.gil_used.map_or(true, |op| op.value.value),
402    );
403
404    // Module function called with optional Python<'_> marker as first arg, followed by the module.
405    let mut module_args = Vec::new();
406    if function.sig.inputs.len() == 2 {
407        module_args.push(quote!(module.py()));
408    }
409    module_args
410        .push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
411
412    Ok(quote! {
413        #[doc(hidden)]
414        #vis mod #ident {
415            #initialization
416        }
417
418        // Generate the definition inside an anonymous function in the same scope as the original function -
419        // this avoids complications around the fact that the generated module has a different scope
420        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is
421        // inside a function body)
422        #[allow(unknown_lints, non_local_definitions)]
423        impl #ident::MakeDef {
424            const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
425                fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
426                    #ident(#(#module_args),*)
427                }
428
429                const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
430                unsafe {
431                    #pyo3_path::impl_::pymodule::ModuleDef::new(
432                        #ident::__PYO3_NAME,
433                        #doc,
434                        INITIALIZER
435                    )
436                }
437            }
438        }
439    })
440}
441
442fn module_initialization(
443    name: &syn::Ident,
444    ctx: &Ctx,
445    module_def: TokenStream,
446    is_submodule: bool,
447    gil_used: bool,
448) -> TokenStream {
449    let Ctx { pyo3_path, .. } = ctx;
450    let pyinit_symbol = format!("PyInit_{}", name);
451    let name = name.to_string();
452    let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
453
454    let mut result = quote! {
455        #[doc(hidden)]
456        pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
457
458        pub(super) struct MakeDef;
459        #[doc(hidden)]
460        pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
461        #[doc(hidden)]
462        // so wrapped submodules can see what gil_used is
463        pub static __PYO3_GIL_USED: bool = #gil_used;
464    };
465    if !is_submodule {
466        result.extend(quote! {
467            /// This autogenerated function is called by the python interpreter when importing
468            /// the module.
469            #[doc(hidden)]
470            #[export_name = #pyinit_symbol]
471            pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
472                unsafe { #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py, #gil_used)) }
473            }
474        });
475    }
476    result
477}
478
479/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
480fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
481    let ctx = &Ctx::new(&options.krate, None);
482    let Ctx { pyo3_path, .. } = ctx;
483    let mut stmts: Vec<syn::Stmt> = Vec::new();
484
485    for mut stmt in func.block.stmts.drain(..) {
486        if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
487            if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
488                let module_name = pyfn_args.modname;
489                let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
490                let name = &func.sig.ident;
491                let statements: Vec<syn::Stmt> = syn::parse_quote! {
492                    #wrapped_function
493                    {
494                        use #pyo3_path::types::PyModuleMethods;
495                        #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
496                    }
497                };
498                stmts.extend(statements);
499            }
500        };
501        stmts.push(stmt);
502    }
503
504    func.block.stmts = stmts;
505    Ok(())
506}
507
508pub struct PyFnArgs {
509    modname: Path,
510    options: PyFunctionOptions,
511}
512
513impl Parse for PyFnArgs {
514    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
515        let modname = input.parse().map_err(
516            |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
517        )?;
518
519        if input.is_empty() {
520            return Ok(Self {
521                modname,
522                options: Default::default(),
523            });
524        }
525
526        let _: Comma = input.parse()?;
527
528        Ok(Self {
529            modname,
530            options: input.parse()?,
531        })
532    }
533}
534
535/// Extracts the data from the #[pyfn(...)] attribute of a function
536fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
537    let mut pyfn_args: Option<PyFnArgs> = None;
538
539    take_attributes(attrs, |attr| {
540        if attr.path().is_ident("pyfn") {
541            ensure_spanned!(
542                pyfn_args.is_none(),
543                attr.span() => "`#[pyfn] may only be specified once"
544            );
545            pyfn_args = Some(attr.parse_args()?);
546            Ok(true)
547        } else {
548            Ok(false)
549        }
550    })?;
551
552    if let Some(pyfn_args) = &mut pyfn_args {
553        pyfn_args
554            .options
555            .add_attributes(take_pyo3_options(attrs)?)?;
556    }
557
558    Ok(pyfn_args)
559}
560
561fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
562    attrs
563        .iter()
564        .filter(|attr| attr.path().is_ident("cfg"))
565        .cloned()
566        .collect()
567}
568
569fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
570    let mut found = false;
571    attrs.retain(|attr| {
572        if attr.path().is_ident(ident) {
573            found = true;
574            false
575        } else {
576            true
577        }
578    });
579    found
580}
581
582impl PartialEq<syn::Ident> for IdentOrStr<'_> {
583    fn eq(&self, other: &syn::Ident) -> bool {
584        match self {
585            IdentOrStr::Str(s) => other == s,
586            IdentOrStr::Ident(i) => other == i,
587        }
588    }
589}
590
591fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
592    attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
593}
594
595fn has_pyo3_module_declared<T: Parse>(
596    attrs: &[syn::Attribute],
597    root_attribute_name: &str,
598    is_module_option: impl Fn(&T) -> bool + Copy,
599) -> Result<bool> {
600    for attr in attrs {
601        if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
602            && matches!(attr.meta, Meta::List(_))
603        {
604            for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
605                if is_module_option(option) {
606                    return Ok(true);
607                }
608            }
609        }
610    }
611    Ok(false)
612}
613
614enum PyModulePyO3Option {
615    Submodule(SubmoduleAttribute),
616    Crate(CrateAttribute),
617    Name(NameAttribute),
618    Module(ModuleAttribute),
619    GILUsed(GILUsedAttribute),
620}
621
622impl Parse for PyModulePyO3Option {
623    fn parse(input: ParseStream<'_>) -> Result<Self> {
624        let lookahead = input.lookahead1();
625        if lookahead.peek(attributes::kw::name) {
626            input.parse().map(PyModulePyO3Option::Name)
627        } else if lookahead.peek(syn::Token![crate]) {
628            input.parse().map(PyModulePyO3Option::Crate)
629        } else if lookahead.peek(attributes::kw::module) {
630            input.parse().map(PyModulePyO3Option::Module)
631        } else if lookahead.peek(attributes::kw::submodule) {
632            input.parse().map(PyModulePyO3Option::Submodule)
633        } else if lookahead.peek(attributes::kw::gil_used) {
634            input.parse().map(PyModulePyO3Option::GILUsed)
635        } else {
636            Err(lookahead.error())
637        }
638    }
639}