1use proc_macro2::{Ident, Span, TokenStream};
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::token::{Colon, PathSep};
5use syn::{
6 ConstParam, FnArg, GenericParam, ItemFn, LifetimeParam, Pat, PatIdent, PatType, Path,
7 PathSegment, Type, TypeParam, TypePath,
8};
9
10#[proc_macro_attribute]
11pub fn with_simd(
12 attr: proc_macro::TokenStream,
13 item: proc_macro::TokenStream,
14) -> proc_macro::TokenStream {
15 let attr: TokenStream = attr.into();
16 let item: TokenStream = item.into();
17 let Ok(syn::Meta::NameValue(attr)) = syn::parse2::<syn::Meta>(attr.clone()) else {
18 return quote! {
19 ::core::compile_error!("pulp::with_simd expected function name and arch expression");
20 #item
21 }
22 .into();
23 };
24 let Some(name) = attr.path.get_ident() else {
25 return quote! {
26 ::core::compile_error!("pulp::with_simd expected function name and arch expression");
27 #item
28 }
29 .into();
30 };
31 let Ok(item) = syn::parse2::<syn::ItemFn>(item.clone()) else {
32 return quote! {
33 ::core::compile_error!("pulp::with_simd expected function");
34 #item
35 }
36 .into();
37 };
38
39 let ItemFn {
40 attrs,
41 vis,
42 sig,
43 block,
44 } = item.clone();
45
46 let mut struct_generics = Vec::new();
47 let mut struct_field_names = Vec::new();
48 let mut struct_field_types = Vec::new();
49
50 let mut first_non_lifetime = usize::MAX;
51 for (idx, param) in sig.generics.params.clone().into_pairs().enumerate() {
52 let (param, _) = param.into_tuple();
53 match ¶m {
54 syn::GenericParam::Lifetime(_) => {}
55 _ => {
56 if first_non_lifetime == usize::MAX {
57 first_non_lifetime = idx;
58 continue;
59 }
60 }
61 }
62 }
63 let mut new_fn_sig = sig.clone();
64 new_fn_sig.generics.params = new_fn_sig
65 .generics
66 .params
67 .into_iter()
68 .enumerate()
69 .filter(|(idx, _)| *idx != first_non_lifetime)
70 .map(|(_, arg)| arg)
71 .collect();
72 new_fn_sig.inputs = new_fn_sig
73 .inputs
74 .into_iter()
75 .skip(1)
76 .enumerate()
77 .map(|(idx, arg)| {
78 FnArg::Typed(PatType {
79 attrs: Vec::new(),
80 pat: Box::new(Pat::Ident(PatIdent {
81 attrs: Vec::new(),
82 by_ref: None,
83 mutability: None,
84 ident: Ident::new(&format!("__{idx}"), Span::call_site()),
85 subpat: None,
86 })),
87 colon_token: Colon {
88 spans: [Span::call_site()],
89 },
90 ty: match arg {
91 FnArg::Typed(ty) => ty.ty,
92 FnArg::Receiver(_) => panic!(),
93 },
94 })
95 })
96 .collect();
97 new_fn_sig.ident = name.clone();
98 let mut param_ty = Vec::new();
99
100 for (idx, param) in new_fn_sig.inputs.clone().into_pairs().enumerate() {
101 let (param, _) = param.into_tuple();
102 let FnArg::Typed(param) = param.clone() else {
103 panic!();
104 };
105 let name = *param.pat;
106 let syn::Pat::Ident(name) = name else {
107 panic!();
108 };
109
110 let anon_ty = Ident::new(&format!("__T{idx}"), Span::call_site());
111
112 struct_field_names.push(name.ident.clone());
113 let mut ty = Punctuated::<_, PathSep>::new();
114 ty.push_value(PathSegment {
115 ident: anon_ty.clone(),
116 arguments: syn::PathArguments::None,
117 });
118 struct_field_types.push(Type::Path(TypePath {
119 qself: None,
120 path: Path {
121 leading_colon: None,
122 segments: ty,
123 },
124 }));
125 struct_generics.push(anon_ty);
126 param_ty.push(*param.ty);
127 }
128
129 let output_ty = match sig.output.clone() {
130 syn::ReturnType::Default => quote! { () },
131 syn::ReturnType::Type(_, ty) => quote! { #ty },
132 };
133
134 let fn_name = sig.ident.clone();
135
136 let arch = attr.value;
137 let new_fn_generics = new_fn_sig.generics.clone();
138 let params = new_fn_generics.params.clone();
139 let generics = params.into_iter().collect::<Vec<_>>();
140 let non_lt_generics_names = generics
141 .iter()
142 .map(|p| match p {
143 GenericParam::Type(TypeParam { ident, .. })
144 | GenericParam::Const(ConstParam { ident, .. }) => {
145 quote! { #ident, }
146 }
147 _ => quote! {},
148 })
149 .collect::<Vec<_>>();
150 let generics_decl = generics
151 .iter()
152 .map(|p| match p {
153 GenericParam::Lifetime(LifetimeParam {
154 lifetime,
155 colon_token,
156 bounds,
157 ..
158 }) => {
159 quote! { #lifetime #colon_token #bounds }
160 }
161 GenericParam::Type(TypeParam {
162 ident,
163 colon_token,
164 bounds,
165 ..
166 }) => {
167 quote! { #ident #colon_token #bounds }
168 }
169 GenericParam::Const(ConstParam {
170 const_token,
171 ident,
172 colon_token,
173 ty,
174 ..
175 }) => {
176 quote! { #const_token #ident #colon_token #ty }
177 }
178 })
179 .collect::<Vec<_>>();
180 let generics_where_clause = new_fn_generics.where_clause;
181
182 let code = quote! {
183 #(#attrs)*
184 #vis #new_fn_sig {
185 #[allow(non_camel_case_types)]
186 struct #name<#(#struct_generics,)*> (#(#struct_field_types,)*);
187
188 impl<#(#generics_decl,)*> ::pulp::WithSimd for #name<
189 #(#param_ty,)*
190 > #generics_where_clause {
191 type Output = #output_ty;
192
193 #[inline(always)]
194 fn with_simd<__S: ::pulp::Simd>(self, __simd: __S) -> <Self as ::pulp::WithSimd>::Output {
195 let Self ( #(#struct_field_names,)* ) = self;
196 #[allow(unused_unsafe)]
197 unsafe {
198 #fn_name::<__S,
199 #(#non_lt_generics_names)*
200 >(__simd, #(#struct_field_names,)*)
201 }
202 }
203 }
204
205 (#arch).dispatch( #name ( #(#struct_field_names,)* ) )
206 }
207
208 #(#attrs)*
209 #vis #sig #block
210 };
211 code.into()
212}