peek_poke_derive/
lib.rs

1// Copyright 2019 The Servo Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11use proc_macro2::{Span, TokenStream};
12use quote::quote;
13use syn::{Ident, Index, TraitBound};
14use synstructure::{decl_derive, Structure, BindStyle, AddBounds};
15use unicode_xid::UnicodeXID;
16
17// Internal method for sanitizing an identifier for hygiene purposes.
18fn sanitize_ident(s: &str) -> Ident {
19    let mut res = String::with_capacity(s.len());
20    for mut c in s.chars() {
21        if !UnicodeXID::is_xid_continue(c) {
22            c = '_'
23        }
24        // Deduplicate consecutive _ characters.
25        if res.ends_with('_') && c == '_' {
26            continue;
27        }
28        res.push(c);
29    }
30    Ident::new(&res, Span::call_site())
31}
32
33/// Calculates size type for number of variants (used for enums)
34fn get_discriminant_size_type(len: usize) -> TokenStream {
35    if len <= <u8>::max_value() as usize {
36        quote! { u8 }
37    } else if len <= <u16>::max_value() as usize {
38        quote! { u16 }
39    } else {
40        quote! { u32 }
41    }
42}
43
44fn is_struct(s: &Structure) -> bool {
45    // a single variant with no prefix is 'struct'
46    match &s.variants()[..] {
47        [v] if v.prefix.is_none() => true,
48        _ => false,
49    }
50}
51
52fn derive_max_size(s: &Structure) -> TokenStream {
53    let max_size = s.variants().iter().fold(quote!(0), |acc, vi| {
54        let variant_size = vi.bindings().iter().fold(quote!(0), |acc, bi| {
55            // compute size of each variant by summing the sizes of its bindings
56            let ty = &bi.ast().ty;
57            quote!(#acc + <#ty>::max_size())
58        });
59
60        // find the maximum of each variant
61        quote! {
62            max(#acc, #variant_size)
63        }
64    });
65
66    let body = if is_struct(s) {
67        max_size
68    } else {
69        let discriminant_size_type = get_discriminant_size_type(s.variants().len());
70        quote! {
71            #discriminant_size_type ::max_size() + #max_size
72        }
73    };
74
75    quote! {
76        #[inline(always)]
77        fn max_size() -> usize {
78            use std::cmp::max;
79            #body
80        }
81    }
82}
83
84fn derive_peek_from_for_enum(s: &mut Structure) -> TokenStream {
85    assert!(!is_struct(s));
86    s.bind_with(|_| BindStyle::Move);
87
88    let num_variants = s.variants().len();
89    let discriminant_size_type = get_discriminant_size_type(num_variants);
90    let body = s
91        .variants()
92        .iter()
93        .enumerate()
94        .fold(quote!(), |acc, (i, vi)| {
95            let bindings = vi
96                .bindings()
97                .iter()
98                .map(|bi| quote!(#bi))
99                .collect::<Vec<_>>();
100
101            let variant_pat = Index::from(i);
102            let poke_exprs = bindings.iter().fold(quote!(), |acc, bi| {
103                quote! {
104                    #acc
105                    let (#bi, bytes) = peek_poke::peek_from_default(bytes);
106                }
107            });
108            let construct = vi.construct(|_, i| {
109                let bi = &bindings[i];
110                quote!(#bi)
111            });
112
113            quote! {
114                #acc
115                #variant_pat => {
116                    #poke_exprs
117                    *output = #construct;
118                    bytes
119                }
120            }
121        });
122
123    let type_name = s.ast().ident.to_string();
124    let max_tag_value = num_variants - 1;
125
126    quote! {
127        #[inline(always)]
128        unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
129            let (variant, bytes) = peek_poke::peek_from_default::<#discriminant_size_type>(bytes);
130            match variant {
131                #body
132                out_of_range_tag => {
133                    panic!("WRDL: memory corruption detected while parsing {} - enum tag should be <= {}, but was {}",
134                        #type_name, #max_tag_value, out_of_range_tag);
135                }
136            }
137        }
138    }
139}
140
141fn derive_peek_from_for_struct(s: &mut Structure) -> TokenStream {
142    assert!(is_struct(&s));
143
144    s.variants_mut()[0].bind_with(|_| BindStyle::RefMut);
145    let pat = s.variants()[0].pat();
146    let peek_exprs = s.variants()[0].bindings().iter().fold(quote!(), |acc, bi| {
147        let ty = &bi.ast().ty;
148        quote! {
149            #acc
150            let bytes = <#ty>::peek_from(bytes, #bi);
151        }
152    });
153
154    let body = quote! {
155        #pat => {
156            #peek_exprs
157            bytes
158        }
159    };
160
161    quote! {
162        #[inline(always)]
163        unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 {
164            match &mut (*output) {
165                #body
166            }
167        }
168    }
169}
170
171fn derive_poke_into(s: &Structure) -> TokenStream {
172    let is_struct = is_struct(&s);
173    let discriminant_size_type = get_discriminant_size_type(s.variants().len());
174    let body = s
175        .variants()
176        .iter()
177        .enumerate()
178        .fold(quote!(), |acc, (i, vi)| {
179            let init = if !is_struct {
180                let index = Index::from(i);
181                quote! {
182                    let bytes = #discriminant_size_type::poke_into(&#index, bytes);
183                }
184            } else {
185                quote!()
186            };
187            let variant_pat = vi.pat();
188            let poke_exprs = vi.bindings().iter().fold(init, |acc, bi| {
189                quote! {
190                    #acc
191                    let bytes = #bi.poke_into(bytes);
192                }
193            });
194
195            quote! {
196                #acc
197                #variant_pat => {
198                    #poke_exprs
199                    bytes
200                }
201            }
202        });
203
204    quote! {
205        #[inline(always)]
206        unsafe fn poke_into(&self, bytes: *mut u8) -> *mut u8 {
207            match &*self {
208                #body
209            }
210        }
211    }
212}
213
214fn peek_poke_derive(mut s: Structure) -> TokenStream {
215    s.binding_name(|_, i| Ident::new(&format!("__self_{}", i), Span::call_site()));
216
217    let max_size_fn = derive_max_size(&s);
218    let poke_into_fn = derive_poke_into(&s);
219    let peek_from_fn = if is_struct(&s) {
220        derive_peek_from_for_struct(&mut s)
221    } else {
222        derive_peek_from_for_enum(&mut s)
223    };
224
225    let poke_impl = s.gen_impl(quote! {
226        extern crate peek_poke;
227
228        gen unsafe impl peek_poke::Poke for @Self {
229            #max_size_fn
230            #poke_into_fn
231        }
232    });
233
234    // To implement `fn peek_from` we require that types implement `Default`
235    // trait to create temporary values. This code does the addition all
236    // manually until https://github.com/mystor/synstructure/issues/24 is fixed.
237    let default_trait = syn::parse_str::<TraitBound>("::std::default::Default").unwrap();
238    let peek_trait = syn::parse_str::<TraitBound>("peek_poke::Peek").unwrap();
239
240    let ast = s.ast();
241    let name = &ast.ident;
242    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
243    let mut where_clause = where_clause.cloned();
244    s.add_trait_bounds(&default_trait, &mut where_clause, AddBounds::Generics);
245    s.add_trait_bounds(&peek_trait, &mut where_clause, AddBounds::Generics);
246
247    let dummy_const: Ident = sanitize_ident(&format!("_DERIVE_peek_poke_Peek_FOR_{}", name));
248
249    let peek_impl = quote! {
250        #[allow(non_upper_case_globals)]
251        const #dummy_const: () = {
252            extern crate peek_poke;
253
254            impl #impl_generics peek_poke::Peek for #name #ty_generics #where_clause {
255                #peek_from_fn
256            }
257        };
258    };
259
260    quote! {
261        #poke_impl
262        #peek_impl
263    }
264}
265
266decl_derive!([PeekPoke] => peek_poke_derive);