deref_derive/
lib.rs

1//! A tiny crate that provides `#[derive(Deref)]` and `#[derive(DerefMut)]`.
2//!
3//! While this in unidiomatic to implement [`Deref`](std::ops::Deref) for wrapper types.
4//! It can be useful and sees widespread use in the community. Therefore, this crate
5//! provides a macro to derive [`Deref`](std::ops::Deref) and [`DerefMut`](std::ops::DerefMut)
6//! for you to help reduce boilerplate.
7
8/// Used to derive [`Deref`](std::ops::Deref) for a struct.
9///
10/// # Example
11/// If have a struct with only one field, you can derive `Deref` for it.
12/// ```rust
13/// # use deref_derive::Deref;
14/// #[derive(Default, Deref)]
15/// struct Foo {
16///     field: String,
17/// }
18///
19/// assert_eq!(Foo::default().len(), 0);
20/// ```
21/// If you have a struct with multiple fields, you will have to use the `deref` attribute.
22/// ```rust
23/// # use deref_derive::Deref;
24/// #[derive(Default, Deref)]
25/// struct Foo {
26///    #[deref]
27///    field: u32,
28///    other_field: String,
29/// }
30///
31/// assert_eq!(*Foo::default(), 0);
32/// ```
33/// Tuple structs are also supported.
34/// ```rust
35/// # use deref_derive::{Deref, DerefMut};
36/// #[derive(Default, Deref, DerefMut)]
37/// struct Foo(u32, #[deref] String);
38///
39/// let mut foo = Foo::default();
40/// *foo = "bar".to_string();
41/// foo.push('!');
42///
43/// assert_eq!(*foo, "bar!");
44#[proc_macro_derive(Deref, attributes(deref))]
45pub fn derive_deref(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
46    let input = syn::parse_macro_input!(input as syn::DeriveInput);
47    let ident = input.ident;
48
49    let target = DerefTarget::get(&input.data);
50    let target_ty = target.ty;
51    let target_field = target.field;
52
53    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
54
55    let expanded = quote::quote! {
56        #[automatically_derived]
57        impl #impl_generics ::std::ops::Deref for #ident #ty_generics #where_clause {
58            type Target = #target_ty;
59
60            #[inline(always)]
61            fn deref(&self) -> &Self::Target {
62                &self.#target_field
63            }
64        }
65    };
66
67    proc_macro::TokenStream::from(expanded)
68}
69
70/// Used to derive [`DerefMut`](std::ops::DerefMut) for a struct.
71///
72/// For examples, see [`Deref`].
73#[proc_macro_derive(DerefMut, attributes(deref))]
74pub fn derive_deref_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
75    let input = syn::parse_macro_input!(input as syn::DeriveInput);
76    let ident = input.ident;
77
78    let target = DerefTarget::get(&input.data);
79    let target_field = target.field;
80
81    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
82
83    let expanded = quote::quote! {
84        #[automatically_derived]
85        impl #impl_generics ::std::ops::DerefMut for #ident #ty_generics #where_clause {
86            #[inline(always)]
87            fn deref_mut(&mut self) -> &mut Self::Target {
88                &mut self.#target_field
89            }
90        }
91    };
92
93    proc_macro::TokenStream::from(expanded)
94}
95
96struct DerefTarget {
97    ty: syn::Type,
98    field: proc_macro2::TokenStream,
99    has_attr: bool,
100}
101
102impl DerefTarget {
103    const ATTR_NAME: &'static str = "deref";
104
105    fn has_attr(attrs: &[syn::Attribute]) -> bool {
106        attrs.iter().any(|attr| attr.path.is_ident(Self::ATTR_NAME))
107    }
108
109    fn get_target(mut targets: impl ExactSizeIterator<Item = Self>) -> Self {
110        if targets.len() == 1 {
111            targets.next().unwrap()
112        } else {
113            let targets = targets.filter(|target| target.has_attr).collect::<Vec<_>>();
114
115            if targets.len() == 1 {
116                targets.into_iter().next().unwrap()
117            } else {
118                panic!("expected exactly one field with #[deref] attribute");
119            }
120        }
121    }
122
123    fn get(data: &syn::Data) -> Self {
124        match data {
125            syn::Data::Struct(data) => match data.fields {
126                syn::Fields::Named(ref fields) => {
127                    let fields = fields.named.iter().map(|f| {
128                        let ty = f.ty.clone();
129                        let field = f.ident.clone().unwrap();
130                        let has_attr = Self::has_attr(&f.attrs);
131
132                        Self {
133                            ty,
134                            field: quote::quote!(#field),
135                            has_attr,
136                        }
137                    });
138
139                    Self::get_target(fields)
140                }
141                syn::Fields::Unnamed(ref fields) => {
142                    let fields = fields.unnamed.iter().enumerate().map(|(i, f)| {
143                        let ty = f.ty.clone();
144                        let field = syn::Index::from(i);
145                        let has_attr = Self::has_attr(&f.attrs);
146
147                        Self {
148                            ty,
149                            field: quote::quote!(#field),
150                            has_attr,
151                        }
152                    });
153
154                    Self::get_target(fields)
155                }
156                syn::Fields::Unit => {
157                    panic!("cannot be derived for unit structs")
158                }
159            },
160            _ => unimplemented!("can only be derived for structs"),
161        }
162    }
163}