pyo3_macros_backend/
pyfunction.rs

1use crate::utils::Ctx;
2use crate::{
3    attributes::{
4        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
5        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
6    },
7    method::{self, CallingConvention, FnArg},
8    pymethod::check_generic,
9};
10use proc_macro2::TokenStream;
11use quote::{format_ident, quote};
12use syn::{ext::IdentExt, spanned::Spanned, Result};
13use syn::{
14    parse::{Parse, ParseStream},
15    token::Comma,
16};
17
18mod signature;
19
20pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
21
22#[derive(Clone, Debug)]
23pub struct PyFunctionArgPyO3Attributes {
24    pub from_py_with: Option<FromPyWithAttribute>,
25    pub cancel_handle: Option<attributes::kw::cancel_handle>,
26}
27
28enum PyFunctionArgPyO3Attribute {
29    FromPyWith(FromPyWithAttribute),
30    CancelHandle(attributes::kw::cancel_handle),
31}
32
33impl Parse for PyFunctionArgPyO3Attribute {
34    fn parse(input: ParseStream<'_>) -> Result<Self> {
35        let lookahead = input.lookahead1();
36        if lookahead.peek(attributes::kw::cancel_handle) {
37            input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
38        } else if lookahead.peek(attributes::kw::from_py_with) {
39            input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
40        } else {
41            Err(lookahead.error())
42        }
43    }
44}
45
46impl PyFunctionArgPyO3Attributes {
47    /// Parses #[pyo3(from_python_with = "func")]
48    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
49        let mut attributes = PyFunctionArgPyO3Attributes {
50            from_py_with: None,
51            cancel_handle: None,
52        };
53        take_attributes(attrs, |attr| {
54            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
55                for attr in pyo3_attrs {
56                    match attr {
57                        PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
58                            ensure_spanned!(
59                                attributes.from_py_with.is_none(),
60                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
61                            );
62                            attributes.from_py_with = Some(from_py_with);
63                        }
64                        PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
65                            ensure_spanned!(
66                                attributes.cancel_handle.is_none(),
67                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
68                            );
69                            attributes.cancel_handle = Some(cancel_handle);
70                        }
71                    }
72                    ensure_spanned!(
73                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
74                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
75                    );
76                }
77                Ok(true)
78            } else {
79                Ok(false)
80            }
81        })?;
82        Ok(attributes)
83    }
84}
85
86#[derive(Default)]
87pub struct PyFunctionOptions {
88    pub pass_module: Option<attributes::kw::pass_module>,
89    pub name: Option<NameAttribute>,
90    pub signature: Option<SignatureAttribute>,
91    pub text_signature: Option<TextSignatureAttribute>,
92    pub krate: Option<CrateAttribute>,
93}
94
95impl Parse for PyFunctionOptions {
96    fn parse(input: ParseStream<'_>) -> Result<Self> {
97        let mut options = PyFunctionOptions::default();
98
99        while !input.is_empty() {
100            let lookahead = input.lookahead1();
101            if lookahead.peek(attributes::kw::name)
102                || lookahead.peek(attributes::kw::pass_module)
103                || lookahead.peek(attributes::kw::signature)
104                || lookahead.peek(attributes::kw::text_signature)
105            {
106                options.add_attributes(std::iter::once(input.parse()?))?;
107                if !input.is_empty() {
108                    let _: Comma = input.parse()?;
109                }
110            } else if lookahead.peek(syn::Token![crate]) {
111                // TODO needs duplicate check?
112                options.krate = Some(input.parse()?);
113            } else {
114                return Err(lookahead.error());
115            }
116        }
117
118        Ok(options)
119    }
120}
121
122pub enum PyFunctionOption {
123    Name(NameAttribute),
124    PassModule(attributes::kw::pass_module),
125    Signature(SignatureAttribute),
126    TextSignature(TextSignatureAttribute),
127    Crate(CrateAttribute),
128}
129
130impl Parse for PyFunctionOption {
131    fn parse(input: ParseStream<'_>) -> Result<Self> {
132        let lookahead = input.lookahead1();
133        if lookahead.peek(attributes::kw::name) {
134            input.parse().map(PyFunctionOption::Name)
135        } else if lookahead.peek(attributes::kw::pass_module) {
136            input.parse().map(PyFunctionOption::PassModule)
137        } else if lookahead.peek(attributes::kw::signature) {
138            input.parse().map(PyFunctionOption::Signature)
139        } else if lookahead.peek(attributes::kw::text_signature) {
140            input.parse().map(PyFunctionOption::TextSignature)
141        } else if lookahead.peek(syn::Token![crate]) {
142            input.parse().map(PyFunctionOption::Crate)
143        } else {
144            Err(lookahead.error())
145        }
146    }
147}
148
149impl PyFunctionOptions {
150    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
151        let mut options = PyFunctionOptions::default();
152        options.add_attributes(take_pyo3_options(attrs)?)?;
153        Ok(options)
154    }
155
156    pub fn add_attributes(
157        &mut self,
158        attrs: impl IntoIterator<Item = PyFunctionOption>,
159    ) -> Result<()> {
160        macro_rules! set_option {
161            ($key:ident) => {
162                {
163                    ensure_spanned!(
164                        self.$key.is_none(),
165                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
166                    );
167                    self.$key = Some($key);
168                }
169            };
170        }
171        for attr in attrs {
172            match attr {
173                PyFunctionOption::Name(name) => set_option!(name),
174                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
175                PyFunctionOption::Signature(signature) => set_option!(signature),
176                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
177                PyFunctionOption::Crate(krate) => set_option!(krate),
178            }
179        }
180        Ok(())
181    }
182}
183
184pub fn build_py_function(
185    ast: &mut syn::ItemFn,
186    mut options: PyFunctionOptions,
187) -> syn::Result<TokenStream> {
188    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
189    impl_wrap_pyfunction(ast, options)
190}
191
192/// Generates python wrapper over a function that allows adding it to a python module as a python
193/// function
194pub fn impl_wrap_pyfunction(
195    func: &mut syn::ItemFn,
196    options: PyFunctionOptions,
197) -> syn::Result<TokenStream> {
198    check_generic(&func.sig)?;
199    let PyFunctionOptions {
200        pass_module,
201        name,
202        signature,
203        text_signature,
204        krate,
205    } = options;
206
207    let ctx = &Ctx::new(&krate, Some(&func.sig));
208    let Ctx { pyo3_path, .. } = &ctx;
209
210    let python_name = name
211        .as_ref()
212        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
213        .unraw();
214
215    let tp = if pass_module.is_some() {
216        let span = match func.sig.inputs.first() {
217            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
218            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
219                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
220            ),
221        };
222        method::FnType::FnModule(span)
223    } else {
224        method::FnType::FnStatic
225    };
226
227    let arguments = func
228        .sig
229        .inputs
230        .iter_mut()
231        .skip(if tp.skip_first_rust_argument_in_python_signature() {
232            1
233        } else {
234            0
235        })
236        .map(FnArg::parse)
237        .collect::<syn::Result<Vec<_>>>()?;
238
239    let signature = if let Some(signature) = signature {
240        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
241    } else {
242        FunctionSignature::from_arguments(arguments)?
243    };
244
245    let spec = method::FnSpec {
246        tp,
247        name: &func.sig.ident,
248        convention: CallingConvention::from_signature(&signature),
249        python_name,
250        signature,
251        text_signature,
252        asyncness: func.sig.asyncness,
253        unsafety: func.sig.unsafety,
254    };
255
256    let vis = &func.vis;
257    let name = &func.sig.ident;
258
259    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
260    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
261    let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
262
263    let wrapped_pyfunction = quote! {
264
265        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
266        // will actually bring both the module and the function into scope.
267        #[doc(hidden)]
268        #vis mod #name {
269            pub(crate) struct MakeDef;
270            pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
271        }
272
273        // Generate the definition inside an anonymous function in the same scope as the original function -
274        // this avoids complications around the fact that the generated module has a different scope
275        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
276        // inside a function body)
277        #[allow(unknown_lints, non_local_definitions)]
278        impl #name::MakeDef {
279            const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
280        }
281
282        #[allow(non_snake_case)]
283        #wrapper
284    };
285    Ok(wrapped_pyfunction)
286}