1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#![doc(html_root_url = "http://docs.rs/const-default-derive/0.2.0")]

extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2};
use proc_macro_crate::{crate_name, FoundCrate};
use quote::{quote, quote_spanned};
use syn::{spanned::Spanned, Error};

/// Derives an implementation for the [`ConstDefault`] trait.
///
/// # Note
///
/// Currently only works with `struct` inputs.
///
/// # Example
///
/// ## Struct
///
/// ```
/// # use const_default::ConstDefault;
/// #[derive(ConstDefault)]
/// # #[derive(Debug, PartialEq)]
/// pub struct Color {
///     r: u8,
///     g: u8,
///     b: u8,
/// }
///
/// assert_eq!(
///     <Color as ConstDefault>::DEFAULT,
///     Color { r: 0, g: 0, b: 0 },
/// )
/// ```
///
/// ## Tuple Struct
///
/// ```
/// # use const_default::ConstDefault;
/// #[derive(ConstDefault)]
/// # #[derive(Debug, PartialEq)]
/// pub struct Vec3(f32, f32, f32);
///
/// assert_eq!(
///     <Vec3 as ConstDefault>::DEFAULT,
///     Vec3(0.0, 0.0, 0.0),
/// )
/// ```
#[proc_macro_derive(ConstDefault, attributes(const_default))]
pub fn derive(input: TokenStream) -> TokenStream {
	match derive_default(input.into()) {
		Ok(output) => output.into(),
		Err(error) => error.to_compile_error().into(),
	}
}

/// Implements the derive of `#[derive(ConstDefault)]` for struct types.
fn derive_default(input: TokenStream2) -> Result<TokenStream2, syn::Error> {
	let crate_ident = query_crate_ident()?;
	let input = syn::parse2::<syn::DeriveInput>(input)?;
	let ident = input.ident;
	let data_struct = match input.data {
		syn::Data::Struct(data_struct) => data_struct,
		_ => {
			return Err(Error::new(
				Span::call_site(),
				"ConstDefault derive only works on struct types",
			))
		}
	};
	let default_impl =
		generate_default_impl_struct(&crate_ident, &data_struct)?;
	let mut generics = input.generics;
	generate_default_impl_where_bounds(
		&crate_ident,
		&data_struct,
		&mut generics,
	)?;
	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
	Ok(quote! {
		impl #impl_generics #crate_ident::ConstDefault for #ident #ty_generics #where_clause {
			const DEFAULT: Self = #default_impl;
		}
	})
}

/// Queries the dependencies for the derive root crate name and returns the identifier.
///
/// # Note
///
/// This allows to use crate aliases in `Cargo.toml` files of dependencies.
fn query_crate_ident() -> Result<TokenStream2, syn::Error> {
	let query = crate_name("const-default").map_err(|error| {
		Error::new(
			Span::call_site(),
			format!(
				"could not find root crate for ConstDefault derive: {}",
				error
			),
		)
	})?;
	match query {
		FoundCrate::Itself => Ok(quote! { crate }),
		FoundCrate::Name(name) => {
			let ident = Ident::new(&name, Span::call_site());
			Ok(quote! { ::#ident })
		}
	}
}

/// Generates the `ConstDefault` implementation for `struct` input types.
///
/// # Note
///
/// The generated code abuses the fact that in Rust struct types can always
/// be represented with braces and either using identifiers for fields or
/// raw number literals in case of tuple-structs.
///
/// For example `struct Foo(u32)` can be represented as `Foo { 0: 42 }`.
fn generate_default_impl_struct(
	crate_ident: &TokenStream2,
	data_struct: &syn::DataStruct,
) -> Result<TokenStream2, syn::Error> {
	let fields_impl =
		data_struct.fields.iter().enumerate().map(|(n, field)| {
			let field_span = field.span();
			let field_type = &field.ty;
			let field_pos = Literal::usize_unsuffixed(n);
			let field_ident = field
				.ident
				.as_ref()
				.map(|ident| quote_spanned!(field_span=> #ident))
				.unwrap_or_else(|| quote_spanned!(field_span=> #field_pos));
			quote_spanned!(field_span=>
				#field_ident: <#field_type as #crate_ident::ConstDefault>::DEFAULT
			)
		});
	Ok(quote! {
		Self {
			#( #fields_impl ),*
		}
	})
}

/// Generates `ConstDefault` where bounds for all fields of the input.
fn generate_default_impl_where_bounds(
	crate_ident: &TokenStream2,
	data_struct: &syn::DataStruct,
	generics: &mut syn::Generics,
) -> Result<(), syn::Error> {
	let where_clause = generics.make_where_clause();
	for field in &data_struct.fields {
		let field_type = &field.ty;
		where_clause.predicates.push(syn::parse_quote!(
			#field_type: #crate_ident::ConstDefault
		))
	}
	Ok(())
}