1use crate::{
4 attributes::{
5 self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
6 ModuleAttribute, NameAttribute, SubmoduleAttribute,
7 },
8 get_doc,
9 pyclass::PyClassPyO3Option,
10 pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
11 utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
12};
13use proc_macro2::{Span, TokenStream};
14use quote::quote;
15use std::ffi::CString;
16use syn::{
17 ext::IdentExt,
18 parse::{Parse, ParseStream},
19 parse_quote, parse_quote_spanned,
20 punctuated::Punctuated,
21 spanned::Spanned,
22 token::Comma,
23 Item, Meta, Path, Result,
24};
25
26#[derive(Default)]
27pub struct PyModuleOptions {
28 krate: Option<CrateAttribute>,
29 name: Option<NameAttribute>,
30 module: Option<ModuleAttribute>,
31 submodule: Option<kw::submodule>,
32 gil_used: Option<GILUsedAttribute>,
33}
34
35impl Parse for PyModuleOptions {
36 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37 let mut options: PyModuleOptions = Default::default();
38
39 options.add_attributes(
40 Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
41 )?;
42
43 Ok(options)
44 }
45}
46
47impl PyModuleOptions {
48 fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
49 self.add_attributes(take_pyo3_options(attrs)?)
50 }
51
52 fn add_attributes(
53 &mut self,
54 attrs: impl IntoIterator<Item = PyModulePyO3Option>,
55 ) -> Result<()> {
56 macro_rules! set_option {
57 ($key:ident $(, $extra:literal)?) => {
58 {
59 ensure_spanned!(
60 self.$key.is_none(),
61 $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
62 );
63 self.$key = Some($key);
64 }
65 };
66 }
67 for attr in attrs {
68 match attr {
69 PyModulePyO3Option::Crate(krate) => set_option!(krate),
70 PyModulePyO3Option::Name(name) => set_option!(name),
71 PyModulePyO3Option::Module(module) => set_option!(module),
72 PyModulePyO3Option::Submodule(submodule) => set_option!(
73 submodule,
74 " (it is implicitly always specified for nested modules)"
75 ),
76 PyModulePyO3Option::GILUsed(gil_used) => {
77 set_option!(gil_used)
78 }
79 }
80 }
81 Ok(())
82 }
83}
84
85pub fn pymodule_module_impl(
86 module: &mut syn::ItemMod,
87 mut options: PyModuleOptions,
88) -> Result<TokenStream> {
89 let syn::ItemMod {
90 attrs,
91 vis,
92 unsafety: _,
93 ident,
94 mod_token,
95 content,
96 semi: _,
97 } = module;
98 let items = if let Some((_, items)) = content {
99 items
100 } else {
101 bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
102 };
103 options.take_pyo3_options(attrs)?;
104 let ctx = &Ctx::new(&options.krate, None);
105 let Ctx { pyo3_path, .. } = ctx;
106 let doc = get_doc(attrs, None, ctx);
107 let name = options
108 .name
109 .map_or_else(|| ident.unraw(), |name| name.value.0);
110 let full_name = if let Some(module) = &options.module {
111 format!("{}.{}", module.value.value(), name)
112 } else {
113 name.to_string()
114 };
115
116 let mut module_items = Vec::new();
117 let mut module_items_cfg_attrs = Vec::new();
118
119 fn extract_use_items(
120 source: &syn::UseTree,
121 cfg_attrs: &[syn::Attribute],
122 target_items: &mut Vec<syn::Ident>,
123 target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
124 ) -> Result<()> {
125 match source {
126 syn::UseTree::Name(name) => {
127 target_items.push(name.ident.clone());
128 target_cfg_attrs.push(cfg_attrs.to_vec());
129 }
130 syn::UseTree::Path(path) => {
131 extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
132 }
133 syn::UseTree::Group(group) => {
134 for tree in &group.items {
135 extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
136 }
137 }
138 syn::UseTree::Glob(glob) => {
139 bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
140 }
141 syn::UseTree::Rename(rename) => {
142 target_items.push(rename.rename.clone());
143 target_cfg_attrs.push(cfg_attrs.to_vec());
144 }
145 }
146 Ok(())
147 }
148
149 let mut pymodule_init = None;
150
151 for item in &mut *items {
152 match item {
153 Item::Use(item_use) => {
154 let is_pymodule_export =
155 find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
156 if is_pymodule_export {
157 let cfg_attrs = get_cfg_attributes(&item_use.attrs);
158 extract_use_items(
159 &item_use.tree,
160 &cfg_attrs,
161 &mut module_items,
162 &mut module_items_cfg_attrs,
163 )?;
164 }
165 }
166 Item::Fn(item_fn) => {
167 ensure_spanned!(
168 !has_attribute(&item_fn.attrs, "pymodule_export"),
169 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
170 );
171 let is_pymodule_init =
172 find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
173 let ident = &item_fn.sig.ident;
174 if is_pymodule_init {
175 ensure_spanned!(
176 !has_attribute(&item_fn.attrs, "pyfunction"),
177 item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
178 );
179 ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
180 pymodule_init = Some(quote! { #ident(module)?; });
181 } else if has_attribute(&item_fn.attrs, "pyfunction")
182 || has_attribute_with_namespace(
183 &item_fn.attrs,
184 Some(pyo3_path),
185 &["pyfunction"],
186 )
187 || has_attribute_with_namespace(
188 &item_fn.attrs,
189 Some(pyo3_path),
190 &["prelude", "pyfunction"],
191 )
192 {
193 module_items.push(ident.clone());
194 module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
195 }
196 }
197 Item::Struct(item_struct) => {
198 ensure_spanned!(
199 !has_attribute(&item_struct.attrs, "pymodule_export"),
200 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
201 );
202 if has_attribute(&item_struct.attrs, "pyclass")
203 || has_attribute_with_namespace(
204 &item_struct.attrs,
205 Some(pyo3_path),
206 &["pyclass"],
207 )
208 || has_attribute_with_namespace(
209 &item_struct.attrs,
210 Some(pyo3_path),
211 &["prelude", "pyclass"],
212 )
213 {
214 module_items.push(item_struct.ident.clone());
215 module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
216 if !has_pyo3_module_declared::<PyClassPyO3Option>(
217 &item_struct.attrs,
218 "pyclass",
219 |option| matches!(option, PyClassPyO3Option::Module(_)),
220 )? {
221 set_module_attribute(&mut item_struct.attrs, &full_name);
222 }
223 }
224 }
225 Item::Enum(item_enum) => {
226 ensure_spanned!(
227 !has_attribute(&item_enum.attrs, "pymodule_export"),
228 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
229 );
230 if has_attribute(&item_enum.attrs, "pyclass")
231 || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
232 || has_attribute_with_namespace(
233 &item_enum.attrs,
234 Some(pyo3_path),
235 &["prelude", "pyclass"],
236 )
237 {
238 module_items.push(item_enum.ident.clone());
239 module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
240 if !has_pyo3_module_declared::<PyClassPyO3Option>(
241 &item_enum.attrs,
242 "pyclass",
243 |option| matches!(option, PyClassPyO3Option::Module(_)),
244 )? {
245 set_module_attribute(&mut item_enum.attrs, &full_name);
246 }
247 }
248 }
249 Item::Mod(item_mod) => {
250 ensure_spanned!(
251 !has_attribute(&item_mod.attrs, "pymodule_export"),
252 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
253 );
254 if has_attribute(&item_mod.attrs, "pymodule")
255 || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
256 || has_attribute_with_namespace(
257 &item_mod.attrs,
258 Some(pyo3_path),
259 &["prelude", "pymodule"],
260 )
261 {
262 module_items.push(item_mod.ident.clone());
263 module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
264 if !has_pyo3_module_declared::<PyModulePyO3Option>(
265 &item_mod.attrs,
266 "pymodule",
267 |option| matches!(option, PyModulePyO3Option::Module(_)),
268 )? {
269 set_module_attribute(&mut item_mod.attrs, &full_name);
270 }
271 item_mod
272 .attrs
273 .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
274 }
275 }
276 Item::ForeignMod(item) => {
277 ensure_spanned!(
278 !has_attribute(&item.attrs, "pymodule_export"),
279 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
280 );
281 }
282 Item::Trait(item) => {
283 ensure_spanned!(
284 !has_attribute(&item.attrs, "pymodule_export"),
285 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
286 );
287 }
288 Item::Const(item) => {
289 ensure_spanned!(
290 !has_attribute(&item.attrs, "pymodule_export"),
291 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
292 );
293 }
294 Item::Static(item) => {
295 ensure_spanned!(
296 !has_attribute(&item.attrs, "pymodule_export"),
297 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
298 );
299 }
300 Item::Macro(item) => {
301 ensure_spanned!(
302 !has_attribute(&item.attrs, "pymodule_export"),
303 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
304 );
305 }
306 Item::ExternCrate(item) => {
307 ensure_spanned!(
308 !has_attribute(&item.attrs, "pymodule_export"),
309 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
310 );
311 }
312 Item::Impl(item) => {
313 ensure_spanned!(
314 !has_attribute(&item.attrs, "pymodule_export"),
315 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
316 );
317 }
318 Item::TraitAlias(item) => {
319 ensure_spanned!(
320 !has_attribute(&item.attrs, "pymodule_export"),
321 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
322 );
323 }
324 Item::Type(item) => {
325 ensure_spanned!(
326 !has_attribute(&item.attrs, "pymodule_export"),
327 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
328 );
329 }
330 Item::Union(item) => {
331 ensure_spanned!(
332 !has_attribute(&item.attrs, "pymodule_export"),
333 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
334 );
335 }
336 _ => (),
337 }
338 }
339
340 let module_def = quote! {{
341 use #pyo3_path::impl_::pymodule as impl_;
342 const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
343 unsafe {
344 impl_::ModuleDef::new(
345 __PYO3_NAME,
346 #doc,
347 INITIALIZER
348 )
349 }
350 }};
351 let initialization = module_initialization(
352 &name,
353 ctx,
354 module_def,
355 options.submodule.is_some(),
356 options.gil_used.map_or(true, |op| op.value.value),
357 );
358
359 Ok(quote!(
360 #(#attrs)*
361 #vis #mod_token #ident {
362 #(#items)*
363
364 #initialization
365
366 fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
367 use #pyo3_path::impl_::pymodule::PyAddToModule;
368 #(
369 #(#module_items_cfg_attrs)*
370 #module_items::_PYO3_DEF.add_to_module(module)?;
371 )*
372 #pymodule_init
373 ::std::result::Result::Ok(())
374 }
375 }
376 ))
377}
378
379pub fn pymodule_function_impl(
382 function: &mut syn::ItemFn,
383 mut options: PyModuleOptions,
384) -> Result<TokenStream> {
385 options.take_pyo3_options(&mut function.attrs)?;
386 process_functions_in_module(&options, function)?;
387 let ctx = &Ctx::new(&options.krate, None);
388 let Ctx { pyo3_path, .. } = ctx;
389 let ident = &function.sig.ident;
390 let name = options
391 .name
392 .map_or_else(|| ident.unraw(), |name| name.value.0);
393 let vis = &function.vis;
394 let doc = get_doc(&function.attrs, None, ctx);
395
396 let initialization = module_initialization(
397 &name,
398 ctx,
399 quote! { MakeDef::make_def() },
400 false,
401 options.gil_used.map_or(true, |op| op.value.value),
402 );
403
404 let mut module_args = Vec::new();
406 if function.sig.inputs.len() == 2 {
407 module_args.push(quote!(module.py()));
408 }
409 module_args
410 .push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
411
412 Ok(quote! {
413 #[doc(hidden)]
414 #vis mod #ident {
415 #initialization
416 }
417
418 #[allow(unknown_lints, non_local_definitions)]
423 impl #ident::MakeDef {
424 const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
425 fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
426 #ident(#(#module_args),*)
427 }
428
429 const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
430 unsafe {
431 #pyo3_path::impl_::pymodule::ModuleDef::new(
432 #ident::__PYO3_NAME,
433 #doc,
434 INITIALIZER
435 )
436 }
437 }
438 }
439 })
440}
441
442fn module_initialization(
443 name: &syn::Ident,
444 ctx: &Ctx,
445 module_def: TokenStream,
446 is_submodule: bool,
447 gil_used: bool,
448) -> TokenStream {
449 let Ctx { pyo3_path, .. } = ctx;
450 let pyinit_symbol = format!("PyInit_{}", name);
451 let name = name.to_string();
452 let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
453
454 let mut result = quote! {
455 #[doc(hidden)]
456 pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
457
458 pub(super) struct MakeDef;
459 #[doc(hidden)]
460 pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
461 #[doc(hidden)]
462 pub static __PYO3_GIL_USED: bool = #gil_used;
464 };
465 if !is_submodule {
466 result.extend(quote! {
467 #[doc(hidden)]
470 #[export_name = #pyinit_symbol]
471 pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
472 unsafe { #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py, #gil_used)) }
473 }
474 });
475 }
476 result
477}
478
479fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
481 let ctx = &Ctx::new(&options.krate, None);
482 let Ctx { pyo3_path, .. } = ctx;
483 let mut stmts: Vec<syn::Stmt> = Vec::new();
484
485 for mut stmt in func.block.stmts.drain(..) {
486 if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
487 if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
488 let module_name = pyfn_args.modname;
489 let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
490 let name = &func.sig.ident;
491 let statements: Vec<syn::Stmt> = syn::parse_quote! {
492 #wrapped_function
493 {
494 use #pyo3_path::types::PyModuleMethods;
495 #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
496 }
497 };
498 stmts.extend(statements);
499 }
500 };
501 stmts.push(stmt);
502 }
503
504 func.block.stmts = stmts;
505 Ok(())
506}
507
508pub struct PyFnArgs {
509 modname: Path,
510 options: PyFunctionOptions,
511}
512
513impl Parse for PyFnArgs {
514 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
515 let modname = input.parse().map_err(
516 |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
517 )?;
518
519 if input.is_empty() {
520 return Ok(Self {
521 modname,
522 options: Default::default(),
523 });
524 }
525
526 let _: Comma = input.parse()?;
527
528 Ok(Self {
529 modname,
530 options: input.parse()?,
531 })
532 }
533}
534
535fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
537 let mut pyfn_args: Option<PyFnArgs> = None;
538
539 take_attributes(attrs, |attr| {
540 if attr.path().is_ident("pyfn") {
541 ensure_spanned!(
542 pyfn_args.is_none(),
543 attr.span() => "`#[pyfn] may only be specified once"
544 );
545 pyfn_args = Some(attr.parse_args()?);
546 Ok(true)
547 } else {
548 Ok(false)
549 }
550 })?;
551
552 if let Some(pyfn_args) = &mut pyfn_args {
553 pyfn_args
554 .options
555 .add_attributes(take_pyo3_options(attrs)?)?;
556 }
557
558 Ok(pyfn_args)
559}
560
561fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
562 attrs
563 .iter()
564 .filter(|attr| attr.path().is_ident("cfg"))
565 .cloned()
566 .collect()
567}
568
569fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
570 let mut found = false;
571 attrs.retain(|attr| {
572 if attr.path().is_ident(ident) {
573 found = true;
574 false
575 } else {
576 true
577 }
578 });
579 found
580}
581
582impl PartialEq<syn::Ident> for IdentOrStr<'_> {
583 fn eq(&self, other: &syn::Ident) -> bool {
584 match self {
585 IdentOrStr::Str(s) => other == s,
586 IdentOrStr::Ident(i) => other == i,
587 }
588 }
589}
590
591fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
592 attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
593}
594
595fn has_pyo3_module_declared<T: Parse>(
596 attrs: &[syn::Attribute],
597 root_attribute_name: &str,
598 is_module_option: impl Fn(&T) -> bool + Copy,
599) -> Result<bool> {
600 for attr in attrs {
601 if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
602 && matches!(attr.meta, Meta::List(_))
603 {
604 for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
605 if is_module_option(option) {
606 return Ok(true);
607 }
608 }
609 }
610 }
611 Ok(false)
612}
613
614enum PyModulePyO3Option {
615 Submodule(SubmoduleAttribute),
616 Crate(CrateAttribute),
617 Name(NameAttribute),
618 Module(ModuleAttribute),
619 GILUsed(GILUsedAttribute),
620}
621
622impl Parse for PyModulePyO3Option {
623 fn parse(input: ParseStream<'_>) -> Result<Self> {
624 let lookahead = input.lookahead1();
625 if lookahead.peek(attributes::kw::name) {
626 input.parse().map(PyModulePyO3Option::Name)
627 } else if lookahead.peek(syn::Token![crate]) {
628 input.parse().map(PyModulePyO3Option::Crate)
629 } else if lookahead.peek(attributes::kw::module) {
630 input.parse().map(PyModulePyO3Option::Module)
631 } else if lookahead.peek(attributes::kw::submodule) {
632 input.parse().map(PyModulePyO3Option::Submodule)
633 } else if lookahead.peek(attributes::kw::gil_used) {
634 input.parse().map(PyModulePyO3Option::GILUsed)
635 } else {
636 Err(lookahead.error())
637 }
638 }
639}