pulp_macro/
lib.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::token::{Colon, PathSep};
5use syn::{
6    ConstParam, FnArg, GenericParam, ItemFn, LifetimeParam, Pat, PatIdent, PatType, Path,
7    PathSegment, Type, TypeParam, TypePath,
8};
9
10#[proc_macro_attribute]
11pub fn with_simd(
12    attr: proc_macro::TokenStream,
13    item: proc_macro::TokenStream,
14) -> proc_macro::TokenStream {
15    let attr: TokenStream = attr.into();
16    let item: TokenStream = item.into();
17    let Ok(syn::Meta::NameValue(attr)) = syn::parse2::<syn::Meta>(attr.clone()) else {
18        return quote! {
19            ::core::compile_error!("pulp::with_simd expected function name and arch expression");
20            #item
21        }
22        .into();
23    };
24    let Some(name) = attr.path.get_ident() else {
25        return quote! {
26            ::core::compile_error!("pulp::with_simd expected function name and arch expression");
27            #item
28        }
29        .into();
30    };
31    let Ok(item) = syn::parse2::<syn::ItemFn>(item.clone()) else {
32        return quote! {
33            ::core::compile_error!("pulp::with_simd expected function");
34            #item
35        }
36        .into();
37    };
38
39    let ItemFn {
40        attrs,
41        vis,
42        sig,
43        block,
44    } = item.clone();
45
46    let mut struct_generics = Vec::new();
47    let mut struct_field_names = Vec::new();
48    let mut struct_field_types = Vec::new();
49
50    let mut first_non_lifetime = usize::MAX;
51    for (idx, param) in sig.generics.params.clone().into_pairs().enumerate() {
52        let (param, _) = param.into_tuple();
53        match &param {
54            syn::GenericParam::Lifetime(_) => {}
55            _ => {
56                if first_non_lifetime == usize::MAX {
57                    first_non_lifetime = idx;
58                    continue;
59                }
60            }
61        }
62    }
63    let mut new_fn_sig = sig.clone();
64    new_fn_sig.generics.params = new_fn_sig
65        .generics
66        .params
67        .into_iter()
68        .enumerate()
69        .filter(|(idx, _)| *idx != first_non_lifetime)
70        .map(|(_, arg)| arg)
71        .collect();
72    new_fn_sig.inputs = new_fn_sig
73        .inputs
74        .into_iter()
75        .skip(1)
76        .enumerate()
77        .map(|(idx, arg)| {
78            FnArg::Typed(PatType {
79                attrs: Vec::new(),
80                pat: Box::new(Pat::Ident(PatIdent {
81                    attrs: Vec::new(),
82                    by_ref: None,
83                    mutability: None,
84                    ident: Ident::new(&format!("__{idx}"), Span::call_site()),
85                    subpat: None,
86                })),
87                colon_token: Colon {
88                    spans: [Span::call_site()],
89                },
90                ty: match arg {
91                    FnArg::Typed(ty) => ty.ty,
92                    FnArg::Receiver(_) => panic!(),
93                },
94            })
95        })
96        .collect();
97    new_fn_sig.ident = name.clone();
98    let mut param_ty = Vec::new();
99
100    for (idx, param) in new_fn_sig.inputs.clone().into_pairs().enumerate() {
101        let (param, _) = param.into_tuple();
102        let FnArg::Typed(param) = param.clone() else {
103            panic!();
104        };
105        let name = *param.pat;
106        let syn::Pat::Ident(name) = name else {
107            panic!();
108        };
109
110        let anon_ty = Ident::new(&format!("__T{idx}"), Span::call_site());
111
112        struct_field_names.push(name.ident.clone());
113        let mut ty = Punctuated::<_, PathSep>::new();
114        ty.push_value(PathSegment {
115            ident: anon_ty.clone(),
116            arguments: syn::PathArguments::None,
117        });
118        struct_field_types.push(Type::Path(TypePath {
119            qself: None,
120            path: Path {
121                leading_colon: None,
122                segments: ty,
123            },
124        }));
125        struct_generics.push(anon_ty);
126        param_ty.push(*param.ty);
127    }
128
129    let output_ty = match sig.output.clone() {
130        syn::ReturnType::Default => quote! { () },
131        syn::ReturnType::Type(_, ty) => quote! { #ty },
132    };
133
134    let fn_name = sig.ident.clone();
135
136    let arch = attr.value;
137    let new_fn_generics = new_fn_sig.generics.clone();
138    let params = new_fn_generics.params.clone();
139    let generics = params.into_iter().collect::<Vec<_>>();
140    let non_lt_generics_names = generics
141        .iter()
142        .map(|p| match p {
143            GenericParam::Type(TypeParam { ident, .. })
144            | GenericParam::Const(ConstParam { ident, .. }) => {
145                quote! { #ident, }
146            }
147            _ => quote! {},
148        })
149        .collect::<Vec<_>>();
150    let generics_decl = generics
151        .iter()
152        .map(|p| match p {
153            GenericParam::Lifetime(LifetimeParam {
154                lifetime,
155                colon_token,
156                bounds,
157                ..
158            }) => {
159                quote! { #lifetime #colon_token #bounds }
160            }
161            GenericParam::Type(TypeParam {
162                ident,
163                colon_token,
164                bounds,
165                ..
166            }) => {
167                quote! { #ident #colon_token #bounds }
168            }
169            GenericParam::Const(ConstParam {
170                const_token,
171                ident,
172                colon_token,
173                ty,
174                ..
175            }) => {
176                quote! { #const_token #ident #colon_token #ty }
177            }
178        })
179        .collect::<Vec<_>>();
180    let generics_where_clause = new_fn_generics.where_clause;
181
182    let code = quote! {
183        #(#attrs)*
184        #vis #new_fn_sig {
185            #[allow(non_camel_case_types)]
186            struct #name<#(#struct_generics,)*> (#(#struct_field_types,)*);
187
188            impl<#(#generics_decl,)*> ::pulp::WithSimd for #name<
189                #(#param_ty,)*
190            > #generics_where_clause {
191                type Output = #output_ty;
192
193                #[inline(always)]
194                fn with_simd<__S: ::pulp::Simd>(self, __simd: __S) -> <Self as ::pulp::WithSimd>::Output {
195                    let Self ( #(#struct_field_names,)* ) = self;
196                    #[allow(unused_unsafe)]
197                    unsafe {
198                        #fn_name::<__S,
199                        #(#non_lt_generics_names)*
200                        >(__simd, #(#struct_field_names,)*)
201                    }
202                }
203            }
204
205            (#arch).dispatch( #name ( #(#struct_field_names,)* ) )
206        }
207
208        #(#attrs)*
209        #vis #sig #block
210    };
211    code.into()
212}