libp2p_core_derive/
lib.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![recursion_limit = "256"]
22
23
24
25use quote::quote;
26use proc_macro::TokenStream;
27use syn::{parse_macro_input, DeriveInput, Data, DataStruct, Ident};
28
29/// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See
30/// the trait documentation for better description.
31#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
32pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
33    let ast = parse_macro_input!(input as DeriveInput);
34    build(&ast)
35}
36
37/// The actual implementation.
38fn build(ast: &DeriveInput) -> TokenStream {
39    match ast.data {
40        Data::Struct(ref s) => build_struct(ast, s),
41        Data::Enum(_) => unimplemented!("Deriving NetworkBehaviour is not implemented for enums"),
42        Data::Union(_) => unimplemented!("Deriving NetworkBehaviour is not implemented for unions"),
43    }
44}
45
46/// The version for structs
47fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream {
48    let name = &ast.ident;
49    let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
50    let multiaddr = quote!{::libp2p::core::Multiaddr};
51    let trait_to_impl = quote!{::libp2p::swarm::NetworkBehaviour};
52    let net_behv_event_proc = quote!{::libp2p::swarm::NetworkBehaviourEventProcess};
53    let either_ident = quote!{::libp2p::core::either::EitherOutput};
54    let network_behaviour_action = quote!{::libp2p::swarm::NetworkBehaviourAction};
55    let into_protocols_handler = quote!{::libp2p::swarm::IntoProtocolsHandler};
56    let protocols_handler = quote!{::libp2p::swarm::ProtocolsHandler};
57    let into_proto_select_ident = quote!{::libp2p::swarm::IntoProtocolsHandlerSelect};
58    let peer_id = quote!{::libp2p::core::PeerId};
59    let connection_id = quote!{::libp2p::core::connection::ConnectionId};
60    let connected_point = quote!{::libp2p::core::ConnectedPoint};
61    let listener_id = quote!{::libp2p::core::connection::ListenerId};
62
63    let poll_parameters = quote!{::libp2p::swarm::PollParameters};
64
65    // Build the generics.
66    let impl_generics = {
67        let tp = ast.generics.type_params();
68        let lf = ast.generics.lifetimes();
69        let cst = ast.generics.const_params();
70        quote!{<#(#lf,)* #(#tp,)* #(#cst,)*>}
71    };
72
73    // Whether or not we require the `NetworkBehaviourEventProcess` trait to be implemented.
74    let event_process = {
75        let mut event_process = true; // Default to true for backwards compatibility
76
77        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
78            for meta_item in meta_items {
79                match meta_item {
80                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("event_process") => {
81                        if let syn::Lit::Bool(ref b) = m.lit {
82                            event_process = b.value
83                        }
84                    }
85                    _ => ()
86                }
87            }
88        }
89
90        event_process
91    };
92
93    // The final out event.
94    // If we find a `#[behaviour(out_event = "Foo")]` attribute on the struct, we set `Foo` as
95    // the out event. Otherwise we use `()`.
96    let out_event = {
97        let mut out = quote!{()};
98        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
99            for meta_item in meta_items {
100                match meta_item {
101                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("out_event") => {
102                        if let syn::Lit::Str(ref s) = m.lit {
103                            let ident: syn::Type = syn::parse_str(&s.value()).unwrap();
104                            out = quote!{#ident};
105                        }
106                    }
107                    _ => ()
108                }
109            }
110        }
111        out
112    };
113
114    // Build the `where ...` clause of the trait implementation.
115    let where_clause = {
116        let additional = data_struct.fields.iter()
117            .filter(|x| !is_ignored(x))
118            .flat_map(|field| {
119                let ty = &field.ty;
120                vec![
121                    quote!{#ty: #trait_to_impl},
122                    if event_process {
123                        quote!{Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>}
124                    } else {
125                        quote!{#out_event: From< <#ty as #trait_to_impl>::OutEvent >}
126                    }
127                ]
128            })
129            .collect::<Vec<_>>();
130
131        if let Some(where_clause) = where_clause {
132            if where_clause.predicates.trailing_punct() {
133                Some(quote!{#where_clause #(#additional),*})
134            } else {
135                Some(quote!{#where_clause, #(#additional),*})
136            }
137        } else {
138            Some(quote!{where #(#additional),*})
139        }
140    };
141
142    // Build the list of statements to put in the body of `addresses_of_peer()`.
143    let addresses_of_peer_stmts = {
144        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
145            if is_ignored(&field) {
146                return None;
147            }
148
149            Some(match field.ident {
150                Some(ref i) => quote!{ out.extend(self.#i.addresses_of_peer(peer_id)); },
151                None => quote!{ out.extend(self.#field_n.addresses_of_peer(peer_id)); },
152            })
153        })
154    };
155
156    // Build the list of statements to put in the body of `inject_connected()`.
157    let inject_connected_stmts = {
158        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
159            if is_ignored(&field) {
160                return None;
161            }
162            Some(match field.ident {
163                Some(ref i) => quote!{ self.#i.inject_connected(peer_id); },
164                None => quote!{ self.#field_n.inject_connected(peer_id); },
165            })
166        })
167    };
168
169    // Build the list of statements to put in the body of `inject_disconnected()`.
170    let inject_disconnected_stmts = {
171        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
172            if is_ignored(&field) {
173                return None;
174            }
175            Some(match field.ident {
176                Some(ref i) => quote!{ self.#i.inject_disconnected(peer_id); },
177                None => quote!{ self.#field_n.inject_disconnected(peer_id); },
178            })
179        })
180    };
181
182    // Build the list of statements to put in the body of `inject_connection_established()`.
183    let inject_connection_established_stmts = {
184        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
185            if is_ignored(&field) {
186                return None;
187            }
188            Some(match field.ident {
189                Some(ref i) => quote!{ self.#i.inject_connection_established(peer_id, connection_id, endpoint); },
190                None => quote!{ self.#field_n.inject_connection_established(peer_id, connection_id, endpoint); },
191            })
192        })
193    };
194
195    // Build the list of statements to put in the body of `inject_address_change()`.
196    let inject_address_change_stmts = {
197        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
198            if is_ignored(&field) {
199                return None;
200            }
201            Some(match field.ident {
202                Some(ref i) => quote!{ self.#i.inject_address_change(peer_id, connection_id, old, new); },
203                None => quote!{ self.#field_n.inject_address_change(peer_id, connection_id, old, new); },
204            })
205        })
206    };
207
208    // Build the list of statements to put in the body of `inject_connection_closed()`.
209    let inject_connection_closed_stmts = {
210        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
211            if is_ignored(&field) {
212                return None;
213            }
214            Some(match field.ident {
215                Some(ref i) => quote!{ self.#i.inject_connection_closed(peer_id, connection_id, endpoint); },
216                None => quote!{ self.#field_n.inject_connection_closed(peer_id, connection_id, endpoint); },
217            })
218        })
219    };
220
221    // Build the list of statements to put in the body of `inject_addr_reach_failure()`.
222    let inject_addr_reach_failure_stmts = {
223        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
224            if is_ignored(&field) {
225                return None;
226            }
227
228            Some(match field.ident {
229                Some(ref i) => quote!{ self.#i.inject_addr_reach_failure(peer_id, addr, error); },
230                None => quote!{ self.#field_n.inject_addr_reach_failure(peer_id, addr, error); },
231            })
232        })
233    };
234
235    // Build the list of statements to put in the body of `inject_dial_failure()`.
236    let inject_dial_failure_stmts = {
237        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
238            if is_ignored(&field) {
239                return None;
240            }
241
242            Some(match field.ident {
243                Some(ref i) => quote!{ self.#i.inject_dial_failure(peer_id); },
244                None => quote!{ self.#field_n.inject_dial_failure(peer_id); },
245            })
246        })
247    };
248
249    // Build the list of statements to put in the body of `inject_new_listen_addr()`.
250    let inject_new_listen_addr_stmts = {
251        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
252            if is_ignored(&field) {
253                return None;
254            }
255
256            Some(match field.ident {
257                Some(ref i) => quote!{ self.#i.inject_new_listen_addr(addr); },
258                None => quote!{ self.#field_n.inject_new_listen_addr(addr); },
259            })
260        })
261    };
262
263    // Build the list of statements to put in the body of `inject_expired_listen_addr()`.
264    let inject_expired_listen_addr_stmts = {
265        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
266            if is_ignored(&field) {
267                return None;
268            }
269
270            Some(match field.ident {
271                Some(ref i) => quote!{ self.#i.inject_expired_listen_addr(addr); },
272                None => quote!{ self.#field_n.inject_expired_listen_addr(addr); },
273            })
274        })
275    };
276
277    // Build the list of statements to put in the body of `inject_new_external_addr()`.
278    let inject_new_external_addr_stmts = {
279        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
280            if is_ignored(&field) {
281                return None;
282            }
283
284            Some(match field.ident {
285                Some(ref i) => quote!{ self.#i.inject_new_external_addr(addr); },
286                None => quote!{ self.#field_n.inject_new_external_addr(addr); },
287            })
288        })
289    };
290
291    // Build the list of statements to put in the body of `inject_listener_error()`.
292    let inject_listener_error_stmts = {
293        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
294            if is_ignored(&field) {
295                return None
296            }
297            Some(match field.ident {
298                Some(ref i) => quote!(self.#i.inject_listener_error(id, err);),
299                None => quote!(self.#field_n.inject_listener_error(id, err);)
300            })
301        })
302    };
303
304    // Build the list of statements to put in the body of `inject_listener_closed()`.
305    let inject_listener_closed_stmts = {
306        data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| {
307            if is_ignored(&field) {
308                return None
309            }
310            Some(match field.ident {
311                Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);),
312                None => quote!(self.#field_n.inject_listener_closed(id, reason);)
313            })
314        })
315    };
316
317    // Build the list of variants to put in the body of `inject_event()`.
318    //
319    // The event type is a construction of nested `#either_ident`s of the events of the children.
320    // We call `inject_event` on the corresponding child.
321    let inject_node_event_stmts = data_struct.fields.iter().enumerate().filter(|f| !is_ignored(&f.1)).enumerate().map(|(enum_n, (field_n, field))| {
322        let mut elem = if enum_n != 0 {
323            quote!{ #either_ident::Second(ev) }
324        } else {
325            quote!{ ev }
326        };
327
328        for _ in 0 .. data_struct.fields.iter().filter(|f| !is_ignored(f)).count() - 1 - enum_n {
329            elem = quote!{ #either_ident::First(#elem) };
330        }
331
332        Some(match field.ident {
333            Some(ref i) => quote!{ #elem => #trait_to_impl::inject_event(&mut self.#i, peer_id, connection_id, ev) },
334            None => quote!{ #elem => #trait_to_impl::inject_event(&mut self.#field_n, peer_id, connection_id, ev) },
335        })
336    });
337
338    // The `ProtocolsHandler` associated type.
339    let protocols_handler_ty = {
340        let mut ph_ty = None;
341        for field in data_struct.fields.iter() {
342            if is_ignored(&field) {
343                continue;
344            }
345            let ty = &field.ty;
346            let field_info = quote!{ <#ty as #trait_to_impl>::ProtocolsHandler };
347            match ph_ty {
348                Some(ev) => ph_ty = Some(quote!{ #into_proto_select_ident<#ev, #field_info> }),
349                ref mut ev @ None => *ev = Some(field_info),
350            }
351        }
352        ph_ty.unwrap_or(quote!{()})     // TODO: `!` instead
353    };
354
355    // The content of `new_handler()`.
356    // Example output: `self.field1.select(self.field2.select(self.field3))`.
357    let new_handler = {
358        let mut out_handler = None;
359
360        for (field_n, field) in data_struct.fields.iter().enumerate() {
361            if is_ignored(&field) {
362                continue;
363            }
364
365            let field_name = match field.ident {
366                Some(ref i) => quote!{ self.#i },
367                None => quote!{ self.#field_n },
368            };
369
370            let builder = quote! {
371                #field_name.new_handler()
372            };
373
374            match out_handler {
375                Some(h) => out_handler = Some(quote!{ #into_protocols_handler::select(#h, #builder) }),
376                ref mut h @ None => *h = Some(builder),
377            }
378        }
379
380        out_handler.unwrap_or(quote!{()})     // TODO: incorrect
381    };
382
383    // The method to use to poll.
384    // If we find a `#[behaviour(poll_method = "poll")]` attribute on the struct, we call
385    // `self.poll()` at the end of the polling.
386    let poll_method = {
387        let mut poll_method = quote!{std::task::Poll::Pending};
388        for meta_items in ast.attrs.iter().filter_map(get_meta_items) {
389            for meta_item in meta_items {
390                match meta_item {
391                    syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("poll_method") => {
392                        if let syn::Lit::Str(ref s) = m.lit {
393                            let ident: Ident = syn::parse_str(&s.value()).unwrap();
394                            poll_method = quote!{#name::#ident(self, cx, poll_params)};
395                        }
396                    }
397                    _ => ()
398                }
399            }
400        }
401        poll_method
402    };
403
404    // List of statements to put in `poll()`.
405    //
406    // We poll each child one by one and wrap around the output.
407    let poll_stmts = data_struct.fields.iter().enumerate().filter(|f| !is_ignored(&f.1)).enumerate().map(|(enum_n, (field_n, field))| {
408        let field_name = match field.ident {
409            Some(ref i) => quote!{ self.#i },
410            None => quote!{ self.#field_n },
411        };
412
413        let mut wrapped_event = if enum_n != 0 {
414            quote!{ #either_ident::Second(event) }
415        } else {
416            quote!{ event }
417        };
418        for _ in 0 .. data_struct.fields.iter().filter(|f| !is_ignored(f)).count() - 1 - enum_n {
419            wrapped_event = quote!{ #either_ident::First(#wrapped_event) };
420        }
421
422        let generate_event_match_arm = if event_process {
423            quote! {
424                std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event)) => {
425                    #net_behv_event_proc::inject_event(self, event)
426                }
427            }
428        } else {
429            quote! {
430                std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event)) => {
431                    return std::task::Poll::Ready(#network_behaviour_action::GenerateEvent(event.into()))
432                }
433            }
434        };
435
436        Some(quote!{
437            loop {
438                match #trait_to_impl::poll(&mut #field_name, cx, poll_params) {
439                    #generate_event_match_arm
440                    std::task::Poll::Ready(#network_behaviour_action::DialAddress { address }) => {
441                        return std::task::Poll::Ready(#network_behaviour_action::DialAddress { address });
442                    }
443                    std::task::Poll::Ready(#network_behaviour_action::DialPeer { peer_id, condition }) => {
444                        return std::task::Poll::Ready(#network_behaviour_action::DialPeer { peer_id, condition });
445                    }
446                    std::task::Poll::Ready(#network_behaviour_action::NotifyHandler { peer_id, handler, event }) => {
447                        return std::task::Poll::Ready(#network_behaviour_action::NotifyHandler {
448                            peer_id,
449                            handler,
450                            event: #wrapped_event,
451                        });
452                    }
453                    std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score }) => {
454                        return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address, score });
455                    }
456                    std::task::Poll::Pending => break,
457                }
458            }
459        })
460    });
461
462    // Now the magic happens.
463    let final_quote = quote!{
464        impl #impl_generics #trait_to_impl for #name #ty_generics
465        #where_clause
466        {
467            type ProtocolsHandler = #protocols_handler_ty;
468            type OutEvent = #out_event;
469
470            fn new_handler(&mut self) -> Self::ProtocolsHandler {
471                use #into_protocols_handler;
472                #new_handler
473            }
474
475            fn addresses_of_peer(&mut self, peer_id: &#peer_id) -> Vec<#multiaddr> {
476                let mut out = Vec::new();
477                #(#addresses_of_peer_stmts);*
478                out
479            }
480
481            fn inject_connected(&mut self, peer_id: &#peer_id) {
482                #(#inject_connected_stmts);*
483            }
484
485            fn inject_disconnected(&mut self, peer_id: &#peer_id) {
486                #(#inject_disconnected_stmts);*
487            }
488
489            fn inject_connection_established(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, endpoint: &#connected_point) {
490                #(#inject_connection_established_stmts);*
491            }
492
493            fn inject_address_change(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, old: &#connected_point, new: &#connected_point) {
494                #(#inject_address_change_stmts);*
495            }
496
497            fn inject_connection_closed(&mut self, peer_id: &#peer_id, connection_id: &#connection_id, endpoint: &#connected_point) {
498                #(#inject_connection_closed_stmts);*
499            }
500
501            fn inject_addr_reach_failure(&mut self, peer_id: Option<&#peer_id>, addr: &#multiaddr, error: &dyn std::error::Error) {
502                #(#inject_addr_reach_failure_stmts);*
503            }
504
505            fn inject_dial_failure(&mut self, peer_id: &#peer_id) {
506                #(#inject_dial_failure_stmts);*
507            }
508
509            fn inject_new_listen_addr(&mut self, addr: &#multiaddr) {
510                #(#inject_new_listen_addr_stmts);*
511            }
512
513            fn inject_expired_listen_addr(&mut self, addr: &#multiaddr) {
514                #(#inject_expired_listen_addr_stmts);*
515            }
516
517            fn inject_new_external_addr(&mut self, addr: &#multiaddr) {
518                #(#inject_new_external_addr_stmts);*
519            }
520
521            fn inject_listener_error(&mut self, id: #listener_id, err: &(dyn std::error::Error + 'static)) {
522                #(#inject_listener_error_stmts);*
523            }
524
525            fn inject_listener_closed(&mut self, id: #listener_id, reason: std::result::Result<(), &std::io::Error>) {
526                #(#inject_listener_closed_stmts);*
527            }
528
529            fn inject_event(
530                &mut self,
531                peer_id: #peer_id,
532                connection_id: #connection_id,
533                event: <<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::OutEvent
534            ) {
535                match event {
536                    #(#inject_node_event_stmts),*
537                }
538            }
539
540            fn poll(&mut self, cx: &mut std::task::Context, poll_params: &mut impl #poll_parameters) -> std::task::Poll<#network_behaviour_action<<<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::InEvent, Self::OutEvent>> {
541                use libp2p::futures::prelude::*;
542                #(#poll_stmts)*
543                let f: std::task::Poll<#network_behaviour_action<<<Self::ProtocolsHandler as #into_protocols_handler>::Handler as #protocols_handler>::InEvent, Self::OutEvent>> = #poll_method;
544                f
545            }
546        }
547    };
548
549    final_quote.into()
550}
551
552fn get_meta_items(attr: &syn::Attribute) -> Option<Vec<syn::NestedMeta>> {
553    if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "behaviour" {
554        match attr.parse_meta() {
555            Ok(syn::Meta::List(ref meta)) => Some(meta.nested.iter().cloned().collect()),
556            Ok(_) => None,
557            Err(e) => {
558                eprintln!("error parsing attribute metadata: {}", e);
559                None
560            }
561        }
562    } else {
563        None
564    }
565}
566
567/// Returns true if a field is marked as ignored by the user.
568fn is_ignored(field: &syn::Field) -> bool {
569    for meta_items in field.attrs.iter().filter_map(get_meta_items) {
570        for meta_item in meta_items {
571            match meta_item {
572                syn::NestedMeta::Meta(syn::Meta::Path(ref m)) if m.is_ident("ignore") => {
573                    return true;
574                }
575                _ => ()
576            }
577        }
578    }
579
580    false
581}