shred_derive/
lib.rs

1#![recursion_limit = "256"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5#[macro_use]
6extern crate quote;
7#[macro_use]
8extern crate syn;
9
10use proc_macro::TokenStream;
11use syn::{
12    punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Field, Fields,
13    FieldsNamed, FieldsUnnamed, Ident, Lifetime, Type, WhereClause, WherePredicate,
14};
15
16/// Used to `#[derive]` the trait `SystemData`.
17///
18/// You need to have the following items included in the current scope:
19///
20/// * `SystemData`
21/// * `World`
22/// * `ResourceId`
23///
24/// This macro can either be used directly via `shred-derive`, or by enabling
25/// the `shred-derive` feature for another crate (e.g. `shred` or `specs`, which
26/// both reexport the macro).
27#[proc_macro_derive(SystemData)]
28pub fn system_data(input: TokenStream) -> TokenStream {
29    let ast = syn::parse(input).unwrap();
30
31    let gen = impl_system_data(&ast);
32
33    gen.into()
34}
35
36fn impl_system_data(ast: &DeriveInput) -> proc_macro2::TokenStream {
37    let name = &ast.ident;
38    let mut generics = ast.generics.clone();
39
40    let (fetch_return, tys) = gen_from_body(&ast.data, name);
41    let tys = &tys;
42    // Assumes that the first lifetime is the fetch lt
43    let def_fetch_lt = ast
44        .generics
45        .lifetimes()
46        .next()
47        .expect("There has to be at least one lifetime");
48    let impl_fetch_lt = &def_fetch_lt.lifetime;
49
50    {
51        let where_clause = generics.make_where_clause();
52        constrain_system_data_types(where_clause, impl_fetch_lt, tys);
53    }
54    // Reads and writes are taken from the same types,
55    // but need to be cloned before.
56
57    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59    quote! {
60        impl #impl_generics
61            shred::SystemData< #impl_fetch_lt >
62            for #name #ty_generics #where_clause
63        {
64            fn setup(world: &mut shred::World) {
65                #(
66                    <#tys as shred::SystemData> :: setup(world);
67                )*
68            }
69
70            fn fetch(world: & #impl_fetch_lt shred::World) -> Self {
71                #fetch_return
72            }
73
74            fn reads() -> Vec<shred::ResourceId> {
75                let mut r = Vec::new();
76
77                #( {
78                        let mut reads = <#tys as shred::SystemData> :: reads();
79                        r.append(&mut reads);
80                    } )*
81
82                r
83            }
84
85            fn writes() -> Vec<shred::ResourceId> {
86                let mut r = Vec::new();
87
88                #( {
89                        let mut writes = <#tys as shred::SystemData> :: writes();
90                        r.append(&mut writes);
91                    } )*
92
93                r
94            }
95        }
96    }
97}
98
99fn collect_field_types(fields: &Punctuated<Field, Comma>) -> Vec<Type> {
100    fields.iter().map(|x| x.ty.clone()).collect()
101}
102
103fn gen_identifiers(fields: &Punctuated<Field, Comma>) -> Vec<Ident> {
104    fields.iter().map(|x| x.ident.clone().unwrap()).collect()
105}
106
107/// Adds a `SystemData<'lt>` bound on each of the system data types.
108fn constrain_system_data_types(clause: &mut WhereClause, fetch_lt: &Lifetime, tys: &[Type]) {
109    for ty in tys.iter() {
110        let where_predicate: WherePredicate = parse_quote!(#ty : shred::SystemData< #fetch_lt >);
111        clause.predicates.push(where_predicate);
112    }
113}
114
115fn gen_from_body(ast: &Data, name: &Ident) -> (proc_macro2::TokenStream, Vec<Type>) {
116    enum DataType {
117        Struct,
118        Tuple,
119    }
120
121    let (body, fields) = match *ast {
122        Data::Struct(DataStruct {
123            fields: Fields::Named(FieldsNamed { named: ref x, .. }),
124            ..
125        }) => (DataType::Struct, x),
126        Data::Struct(DataStruct {
127            fields: Fields::Unnamed(FieldsUnnamed { unnamed: ref x, .. }),
128            ..
129        }) => (DataType::Tuple, x),
130        _ => panic!("Enums are not supported"),
131    };
132
133    let tys = collect_field_types(fields);
134
135    let fetch_return = match body {
136        DataType::Struct => {
137            let identifiers = gen_identifiers(fields);
138
139            quote! {
140                #name {
141                    #( #identifiers: shred::SystemData::fetch(world) ),*
142                }
143            }
144        }
145        DataType::Tuple => {
146            let count = tys.len();
147            let fetch = vec![quote! { shred::SystemData::fetch(world) }; count];
148
149            quote! {
150                #name ( #( #fetch ),* )
151            }
152        }
153    };
154
155    (fetch_return, tys)
156}