pyo3_macros_backend/
pyfunction.rs1use crate::utils::Ctx;
2use crate::{
3 attributes::{
4 self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
5 FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
6 },
7 method::{self, CallingConvention, FnArg},
8 pymethod::check_generic,
9};
10use proc_macro2::TokenStream;
11use quote::{format_ident, quote};
12use syn::{ext::IdentExt, spanned::Spanned, Result};
13use syn::{
14 parse::{Parse, ParseStream},
15 token::Comma,
16};
17
18mod signature;
19
20pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
21
22#[derive(Clone, Debug)]
23pub struct PyFunctionArgPyO3Attributes {
24 pub from_py_with: Option<FromPyWithAttribute>,
25 pub cancel_handle: Option<attributes::kw::cancel_handle>,
26}
27
28enum PyFunctionArgPyO3Attribute {
29 FromPyWith(FromPyWithAttribute),
30 CancelHandle(attributes::kw::cancel_handle),
31}
32
33impl Parse for PyFunctionArgPyO3Attribute {
34 fn parse(input: ParseStream<'_>) -> Result<Self> {
35 let lookahead = input.lookahead1();
36 if lookahead.peek(attributes::kw::cancel_handle) {
37 input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
38 } else if lookahead.peek(attributes::kw::from_py_with) {
39 input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
40 } else {
41 Err(lookahead.error())
42 }
43 }
44}
45
46impl PyFunctionArgPyO3Attributes {
47 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
49 let mut attributes = PyFunctionArgPyO3Attributes {
50 from_py_with: None,
51 cancel_handle: None,
52 };
53 take_attributes(attrs, |attr| {
54 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
55 for attr in pyo3_attrs {
56 match attr {
57 PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
58 ensure_spanned!(
59 attributes.from_py_with.is_none(),
60 from_py_with.span() => "`from_py_with` may only be specified once per argument"
61 );
62 attributes.from_py_with = Some(from_py_with);
63 }
64 PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
65 ensure_spanned!(
66 attributes.cancel_handle.is_none(),
67 cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
68 );
69 attributes.cancel_handle = Some(cancel_handle);
70 }
71 }
72 ensure_spanned!(
73 attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
74 attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
75 );
76 }
77 Ok(true)
78 } else {
79 Ok(false)
80 }
81 })?;
82 Ok(attributes)
83 }
84}
85
86#[derive(Default)]
87pub struct PyFunctionOptions {
88 pub pass_module: Option<attributes::kw::pass_module>,
89 pub name: Option<NameAttribute>,
90 pub signature: Option<SignatureAttribute>,
91 pub text_signature: Option<TextSignatureAttribute>,
92 pub krate: Option<CrateAttribute>,
93}
94
95impl Parse for PyFunctionOptions {
96 fn parse(input: ParseStream<'_>) -> Result<Self> {
97 let mut options = PyFunctionOptions::default();
98
99 while !input.is_empty() {
100 let lookahead = input.lookahead1();
101 if lookahead.peek(attributes::kw::name)
102 || lookahead.peek(attributes::kw::pass_module)
103 || lookahead.peek(attributes::kw::signature)
104 || lookahead.peek(attributes::kw::text_signature)
105 {
106 options.add_attributes(std::iter::once(input.parse()?))?;
107 if !input.is_empty() {
108 let _: Comma = input.parse()?;
109 }
110 } else if lookahead.peek(syn::Token![crate]) {
111 options.krate = Some(input.parse()?);
113 } else {
114 return Err(lookahead.error());
115 }
116 }
117
118 Ok(options)
119 }
120}
121
122pub enum PyFunctionOption {
123 Name(NameAttribute),
124 PassModule(attributes::kw::pass_module),
125 Signature(SignatureAttribute),
126 TextSignature(TextSignatureAttribute),
127 Crate(CrateAttribute),
128}
129
130impl Parse for PyFunctionOption {
131 fn parse(input: ParseStream<'_>) -> Result<Self> {
132 let lookahead = input.lookahead1();
133 if lookahead.peek(attributes::kw::name) {
134 input.parse().map(PyFunctionOption::Name)
135 } else if lookahead.peek(attributes::kw::pass_module) {
136 input.parse().map(PyFunctionOption::PassModule)
137 } else if lookahead.peek(attributes::kw::signature) {
138 input.parse().map(PyFunctionOption::Signature)
139 } else if lookahead.peek(attributes::kw::text_signature) {
140 input.parse().map(PyFunctionOption::TextSignature)
141 } else if lookahead.peek(syn::Token![crate]) {
142 input.parse().map(PyFunctionOption::Crate)
143 } else {
144 Err(lookahead.error())
145 }
146 }
147}
148
149impl PyFunctionOptions {
150 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
151 let mut options = PyFunctionOptions::default();
152 options.add_attributes(take_pyo3_options(attrs)?)?;
153 Ok(options)
154 }
155
156 pub fn add_attributes(
157 &mut self,
158 attrs: impl IntoIterator<Item = PyFunctionOption>,
159 ) -> Result<()> {
160 macro_rules! set_option {
161 ($key:ident) => {
162 {
163 ensure_spanned!(
164 self.$key.is_none(),
165 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
166 );
167 self.$key = Some($key);
168 }
169 };
170 }
171 for attr in attrs {
172 match attr {
173 PyFunctionOption::Name(name) => set_option!(name),
174 PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
175 PyFunctionOption::Signature(signature) => set_option!(signature),
176 PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
177 PyFunctionOption::Crate(krate) => set_option!(krate),
178 }
179 }
180 Ok(())
181 }
182}
183
184pub fn build_py_function(
185 ast: &mut syn::ItemFn,
186 mut options: PyFunctionOptions,
187) -> syn::Result<TokenStream> {
188 options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
189 impl_wrap_pyfunction(ast, options)
190}
191
192pub fn impl_wrap_pyfunction(
195 func: &mut syn::ItemFn,
196 options: PyFunctionOptions,
197) -> syn::Result<TokenStream> {
198 check_generic(&func.sig)?;
199 let PyFunctionOptions {
200 pass_module,
201 name,
202 signature,
203 text_signature,
204 krate,
205 } = options;
206
207 let ctx = &Ctx::new(&krate, Some(&func.sig));
208 let Ctx { pyo3_path, .. } = &ctx;
209
210 let python_name = name
211 .as_ref()
212 .map_or_else(|| &func.sig.ident, |name| &name.value.0)
213 .unraw();
214
215 let tp = if pass_module.is_some() {
216 let span = match func.sig.inputs.first() {
217 Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
218 Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
219 func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
220 ),
221 };
222 method::FnType::FnModule(span)
223 } else {
224 method::FnType::FnStatic
225 };
226
227 let arguments = func
228 .sig
229 .inputs
230 .iter_mut()
231 .skip(if tp.skip_first_rust_argument_in_python_signature() {
232 1
233 } else {
234 0
235 })
236 .map(FnArg::parse)
237 .collect::<syn::Result<Vec<_>>>()?;
238
239 let signature = if let Some(signature) = signature {
240 FunctionSignature::from_arguments_and_attribute(arguments, signature)?
241 } else {
242 FunctionSignature::from_arguments(arguments)?
243 };
244
245 let spec = method::FnSpec {
246 tp,
247 name: &func.sig.ident,
248 convention: CallingConvention::from_signature(&signature),
249 python_name,
250 signature,
251 text_signature,
252 asyncness: func.sig.asyncness,
253 unsafety: func.sig.unsafety,
254 };
255
256 let vis = &func.vis;
257 let name = &func.sig.ident;
258
259 let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
260 let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
261 let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
262
263 let wrapped_pyfunction = quote! {
264
265 #[doc(hidden)]
268 #vis mod #name {
269 pub(crate) struct MakeDef;
270 pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
271 }
272
273 #[allow(unknown_lints, non_local_definitions)]
278 impl #name::MakeDef {
279 const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
280 }
281
282 #[allow(non_snake_case)]
283 #wrapper
284 };
285 Ok(wrapped_pyfunction)
286}