1use proc_macro2::{Span, TokenStream};
12use quote::quote;
13use syn::{Ident, Index, TraitBound};
14use synstructure::{decl_derive, Structure, BindStyle, AddBounds};
15use unicode_xid::UnicodeXID;
16
17fn sanitize_ident(s: &str) -> Ident {
19 let mut res = String::with_capacity(s.len());
20 for mut c in s.chars() {
21 if !UnicodeXID::is_xid_continue(c) {
22 c = '_'
23 }
24 if res.ends_with('_') && c == '_' {
26 continue;
27 }
28 res.push(c);
29 }
30 Ident::new(&res, Span::call_site())
31}
32
33fn get_discriminant_size_type(len: usize) -> TokenStream {
35 if len <= <u8>::max_value() as usize {
36 quote! { u8 }
37 } else if len <= <u16>::max_value() as usize {
38 quote! { u16 }
39 } else {
40 quote! { u32 }
41 }
42}
43
44fn is_struct(s: &Structure) -> bool {
45 match &s.variants()[..] {
47 [v] if v.prefix.is_none() => true,
48 _ => false,
49 }
50}
51
52fn derive_max_size(s: &Structure) -> TokenStream {
53 let max_size = s.variants().iter().fold(quote!(0), |acc, vi| {
54 let variant_size = vi.bindings().iter().fold(quote!(0), |acc, bi| {
55 let ty = &bi.ast().ty;
57 quote!(#acc + <#ty>::max_size())
58 });
59
60 quote! {
62 max(#acc, #variant_size)
63 }
64 });
65
66 let body = if is_struct(s) {
67 max_size
68 } else {
69 let discriminant_size_type = get_discriminant_size_type(s.variants().len());
70 quote! {
71 #discriminant_size_type ::max_size() + #max_size
72 }
73 };
74
75 quote! {
76 #[inline(always)]
77 fn max_size() -> usize {
78 use std::cmp::max;
79 #body
80 }
81 }
82}
83
84fn derive_peek_from_for_enum(s: &mut Structure) -> TokenStream {
85 assert!(!is_struct(s));
86 s.bind_with(|_| BindStyle::Move);
87
88 let num_variants = s.variants().len();
89 let discriminant_size_type = get_discriminant_size_type(num_variants);
90 let body = s
91 .variants()
92 .iter()
93 .enumerate()
94 .fold(quote!(), |acc, (i, vi)| {
95 let bindings = vi
96 .bindings()
97 .iter()
98 .map(|bi| quote!(#bi))
99 .collect::<Vec<_>>();
100
101 let variant_pat = Index::from(i);
102 let poke_exprs = bindings.iter().fold(quote!(), |acc, bi| {
103 quote! {
104 #acc
105 let (#bi, bytes) = peek_poke::peek_from_default(bytes);
106 }
107 });
108 let construct = vi.construct(|_, i| {
109 let bi = &bindings[i];
110 quote!(#bi)
111 });
112
113 quote! {
114 #acc
115 #variant_pat => {
116 #poke_exprs
117 *output = #construct;
118 bytes
119 }
120 }
121 });
122
123 let type_name = s.ast().ident.to_string();
124 let max_tag_value = num_variants - 1;
125
126 quote! {
127 #[inline(always)]
128 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
129 let (variant, bytes) = peek_poke::peek_from_default::<#discriminant_size_type>(bytes);
130 match variant {
131 #body
132 out_of_range_tag => {
133 panic!("WRDL: memory corruption detected while parsing {} - enum tag should be <= {}, but was {}",
134 #type_name, #max_tag_value, out_of_range_tag);
135 }
136 }
137 }
138 }
139}
140
141fn derive_peek_from_for_struct(s: &mut Structure) -> TokenStream {
142 assert!(is_struct(&s));
143
144 s.variants_mut()[0].bind_with(|_| BindStyle::RefMut);
145 let pat = s.variants()[0].pat();
146 let peek_exprs = s.variants()[0].bindings().iter().fold(quote!(), |acc, bi| {
147 let ty = &bi.ast().ty;
148 quote! {
149 #acc
150 let bytes = <#ty>::peek_from(bytes, #bi);
151 }
152 });
153
154 let body = quote! {
155 #pat => {
156 #peek_exprs
157 bytes
158 }
159 };
160
161 quote! {
162 #[inline(always)]
163 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
164 match &mut (*output) {
165 #body
166 }
167 }
168 }
169}
170
171fn derive_poke_into(s: &Structure) -> TokenStream {
172 let is_struct = is_struct(&s);
173 let discriminant_size_type = get_discriminant_size_type(s.variants().len());
174 let body = s
175 .variants()
176 .iter()
177 .enumerate()
178 .fold(quote!(), |acc, (i, vi)| {
179 let init = if !is_struct {
180 let index = Index::from(i);
181 quote! {
182 let bytes = #discriminant_size_type::poke_into(&#index, bytes);
183 }
184 } else {
185 quote!()
186 };
187 let variant_pat = vi.pat();
188 let poke_exprs = vi.bindings().iter().fold(init, |acc, bi| {
189 quote! {
190 #acc
191 let bytes = #bi.poke_into(bytes);
192 }
193 });
194
195 quote! {
196 #acc
197 #variant_pat => {
198 #poke_exprs
199 bytes
200 }
201 }
202 });
203
204 quote! {
205 #[inline(always)]
206 unsafe fn poke_into(&self, bytes: *mut u8) -> *mut u8 {
207 match &*self {
208 #body
209 }
210 }
211 }
212}
213
214fn peek_poke_derive(mut s: Structure) -> TokenStream {
215 s.binding_name(|_, i| Ident::new(&format!("__self_{}", i), Span::call_site()));
216
217 let max_size_fn = derive_max_size(&s);
218 let poke_into_fn = derive_poke_into(&s);
219 let peek_from_fn = if is_struct(&s) {
220 derive_peek_from_for_struct(&mut s)
221 } else {
222 derive_peek_from_for_enum(&mut s)
223 };
224
225 let poke_impl = s.gen_impl(quote! {
226 extern crate peek_poke;
227
228 gen unsafe impl peek_poke::Poke for @Self {
229 #max_size_fn
230 #poke_into_fn
231 }
232 });
233
234 let default_trait = syn::parse_str::<TraitBound>("::std::default::Default").unwrap();
238 let peek_trait = syn::parse_str::<TraitBound>("peek_poke::Peek").unwrap();
239
240 let ast = s.ast();
241 let name = &ast.ident;
242 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
243 let mut where_clause = where_clause.cloned();
244 s.add_trait_bounds(&default_trait, &mut where_clause, AddBounds::Generics);
245 s.add_trait_bounds(&peek_trait, &mut where_clause, AddBounds::Generics);
246
247 let dummy_const: Ident = sanitize_ident(&format!("_DERIVE_peek_poke_Peek_FOR_{}", name));
248
249 let peek_impl = quote! {
250 #[allow(non_upper_case_globals)]
251 const #dummy_const: () = {
252 extern crate peek_poke;
253
254 impl #impl_generics peek_poke::Peek for #name #ty_generics #where_clause {
255 #peek_from_fn
256 }
257 };
258 };
259
260 quote! {
261 #poke_impl
262 #peek_impl
263 }
264}
265
266decl_derive!([PeekPoke] => peek_poke_derive);