pyo3_macros_backend/
pyfunction.rs

1use crate::utils::Ctx;
2use crate::{
3    attributes::{
4        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
5        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
6    },
7    method::{self, CallingConvention, FnArg},
8    pymethod::check_generic,
9};
10use proc_macro2::TokenStream;
11use quote::{format_ident, quote};
12use syn::parse::{Parse, ParseStream};
13use syn::punctuated::Punctuated;
14use syn::{ext::IdentExt, spanned::Spanned, Result};
15
16mod signature;
17
18pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
19
20#[derive(Clone, Debug)]
21pub struct PyFunctionArgPyO3Attributes {
22    pub from_py_with: Option<FromPyWithAttribute>,
23    pub cancel_handle: Option<attributes::kw::cancel_handle>,
24}
25
26enum PyFunctionArgPyO3Attribute {
27    FromPyWith(FromPyWithAttribute),
28    CancelHandle(attributes::kw::cancel_handle),
29}
30
31impl Parse for PyFunctionArgPyO3Attribute {
32    fn parse(input: ParseStream<'_>) -> Result<Self> {
33        let lookahead = input.lookahead1();
34        if lookahead.peek(attributes::kw::cancel_handle) {
35            input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
36        } else if lookahead.peek(attributes::kw::from_py_with) {
37            input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
38        } else {
39            Err(lookahead.error())
40        }
41    }
42}
43
44impl PyFunctionArgPyO3Attributes {
45    /// Parses #[pyo3(from_python_with = "func")]
46    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
47        let mut attributes = PyFunctionArgPyO3Attributes {
48            from_py_with: None,
49            cancel_handle: None,
50        };
51        take_attributes(attrs, |attr| {
52            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
53                for attr in pyo3_attrs {
54                    match attr {
55                        PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
56                            ensure_spanned!(
57                                attributes.from_py_with.is_none(),
58                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
59                            );
60                            attributes.from_py_with = Some(from_py_with);
61                        }
62                        PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
63                            ensure_spanned!(
64                                attributes.cancel_handle.is_none(),
65                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
66                            );
67                            attributes.cancel_handle = Some(cancel_handle);
68                        }
69                    }
70                    ensure_spanned!(
71                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
72                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
73                    );
74                }
75                Ok(true)
76            } else {
77                Ok(false)
78            }
79        })?;
80        Ok(attributes)
81    }
82}
83
84#[derive(Default)]
85pub struct PyFunctionOptions {
86    pub pass_module: Option<attributes::kw::pass_module>,
87    pub name: Option<NameAttribute>,
88    pub signature: Option<SignatureAttribute>,
89    pub text_signature: Option<TextSignatureAttribute>,
90    pub krate: Option<CrateAttribute>,
91}
92
93impl Parse for PyFunctionOptions {
94    fn parse(input: ParseStream<'_>) -> Result<Self> {
95        let mut options = PyFunctionOptions::default();
96
97        let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
98        options.add_attributes(attrs)?;
99
100        Ok(options)
101    }
102}
103
104pub enum PyFunctionOption {
105    Name(NameAttribute),
106    PassModule(attributes::kw::pass_module),
107    Signature(SignatureAttribute),
108    TextSignature(TextSignatureAttribute),
109    Crate(CrateAttribute),
110}
111
112impl Parse for PyFunctionOption {
113    fn parse(input: ParseStream<'_>) -> Result<Self> {
114        let lookahead = input.lookahead1();
115        if lookahead.peek(attributes::kw::name) {
116            input.parse().map(PyFunctionOption::Name)
117        } else if lookahead.peek(attributes::kw::pass_module) {
118            input.parse().map(PyFunctionOption::PassModule)
119        } else if lookahead.peek(attributes::kw::signature) {
120            input.parse().map(PyFunctionOption::Signature)
121        } else if lookahead.peek(attributes::kw::text_signature) {
122            input.parse().map(PyFunctionOption::TextSignature)
123        } else if lookahead.peek(syn::Token![crate]) {
124            input.parse().map(PyFunctionOption::Crate)
125        } else {
126            Err(lookahead.error())
127        }
128    }
129}
130
131impl PyFunctionOptions {
132    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
133        let mut options = PyFunctionOptions::default();
134        options.add_attributes(take_pyo3_options(attrs)?)?;
135        Ok(options)
136    }
137
138    pub fn add_attributes(
139        &mut self,
140        attrs: impl IntoIterator<Item = PyFunctionOption>,
141    ) -> Result<()> {
142        macro_rules! set_option {
143            ($key:ident) => {
144                {
145                    ensure_spanned!(
146                        self.$key.is_none(),
147                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
148                    );
149                    self.$key = Some($key);
150                }
151            };
152        }
153        for attr in attrs {
154            match attr {
155                PyFunctionOption::Name(name) => set_option!(name),
156                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
157                PyFunctionOption::Signature(signature) => set_option!(signature),
158                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
159                PyFunctionOption::Crate(krate) => set_option!(krate),
160            }
161        }
162        Ok(())
163    }
164}
165
166pub fn build_py_function(
167    ast: &mut syn::ItemFn,
168    mut options: PyFunctionOptions,
169) -> syn::Result<TokenStream> {
170    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
171    impl_wrap_pyfunction(ast, options)
172}
173
174/// Generates python wrapper over a function that allows adding it to a python module as a python
175/// function
176pub fn impl_wrap_pyfunction(
177    func: &mut syn::ItemFn,
178    options: PyFunctionOptions,
179) -> syn::Result<TokenStream> {
180    check_generic(&func.sig)?;
181    let PyFunctionOptions {
182        pass_module,
183        name,
184        signature,
185        text_signature,
186        krate,
187    } = options;
188
189    let ctx = &Ctx::new(&krate, Some(&func.sig));
190    let Ctx { pyo3_path, .. } = &ctx;
191
192    let python_name = name
193        .as_ref()
194        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
195        .unraw();
196
197    let tp = if pass_module.is_some() {
198        let span = match func.sig.inputs.first() {
199            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
200            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
201                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
202            ),
203        };
204        method::FnType::FnModule(span)
205    } else {
206        method::FnType::FnStatic
207    };
208
209    let arguments = func
210        .sig
211        .inputs
212        .iter_mut()
213        .skip(if tp.skip_first_rust_argument_in_python_signature() {
214            1
215        } else {
216            0
217        })
218        .map(FnArg::parse)
219        .collect::<syn::Result<Vec<_>>>()?;
220
221    let signature = if let Some(signature) = signature {
222        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
223    } else {
224        FunctionSignature::from_arguments(arguments)
225    };
226
227    let spec = method::FnSpec {
228        tp,
229        name: &func.sig.ident,
230        convention: CallingConvention::from_signature(&signature),
231        python_name,
232        signature,
233        text_signature,
234        asyncness: func.sig.asyncness,
235        unsafety: func.sig.unsafety,
236    };
237
238    let vis = &func.vis;
239    let name = &func.sig.ident;
240
241    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
242    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
243    let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
244
245    let wrapped_pyfunction = quote! {
246
247        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
248        // will actually bring both the module and the function into scope.
249        #[doc(hidden)]
250        #vis mod #name {
251            pub(crate) struct MakeDef;
252            pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
253        }
254
255        // Generate the definition inside an anonymous function in the same scope as the original function -
256        // this avoids complications around the fact that the generated module has a different scope
257        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
258        // inside a function body)
259        #[allow(unknown_lints, non_local_definitions)]
260        impl #name::MakeDef {
261            const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
262        }
263
264        #[allow(non_snake_case)]
265        #wrapper
266    };
267    Ok(wrapped_pyfunction)
268}