use std::cmp::Ordering;
use std::ops::RangeInclusive;
use proc_macro2::{Ident, Literal, Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens, TokenStreamExt};
use syn::parse::{self, Parse, ParseStream};
use syn::{braced, parse_macro_input, token::Brace, Token};
use syn::{Attribute, Error, Expr, Visibility};
use syn::{BinOp, ExprBinary, ExprRange, ExprUnary, RangeLimits, UnOp};
use syn::{ExprGroup, ExprParen};
use syn::{ExprLit, Lit};
#[proc_macro]
pub fn bounded_integer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let bounded_integer = parse_macro_input!(input as BoundedInteger);
let mut result = TokenStream::new();
bounded_integer.generate_item(&mut result);
bounded_integer.generate_impl(&mut result);
result.into()
}
#[allow(dead_code)]
enum BoundedInteger {
Struct {
attrs: Vec<Attribute>,
repr: Ident,
vis: Visibility,
struct_token: Token![struct],
ident: Ident,
brace_token: Brace,
range: Box<(Option<Expr>, Option<Expr>)>,
},
Enum {
attrs: Vec<Attribute>,
repr: Ident,
vis: Visibility,
enum_token: Token![enum],
ident: Ident,
brace_token: Brace,
range: RangeInclusive<isize>,
semi_token: Option<Token![;]>,
},
}
impl BoundedInteger {
fn generate_item(&self, tokens: &mut TokenStream) {
for attr in self.attrs() {
attr.to_tokens(tokens);
}
tokens.extend(quote! {
#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
});
match self {
Self::Struct {
repr,
vis,
struct_token,
ident,
brace_token,
..
} => {
vis.to_tokens(tokens);
struct_token.to_tokens(tokens);
ident.to_tokens(tokens);
tokens.extend(quote_spanned!(brace_token.span=> (#repr)));
Token).to_tokens(tokens);
}
Self::Enum {
repr,
vis,
enum_token,
ident,
brace_token,
range,
semi_token,
..
} => {
tokens.extend(quote!(#[repr(#repr)]));
vis.to_tokens(tokens);
enum_token.to_tokens(tokens);
ident.to_tokens(tokens);
let mut inner_tokens = TokenStream::new();
let mut variants = range.clone().map(enum_variant);
if let Some(first_variant) = variants.next() {
first_variant.to_tokens(&mut inner_tokens);
Token).to_tokens(&mut inner_tokens);
inner_tokens.append(Literal::isize_unsuffixed(*range.start()));
}
for variant in variants {
Token).to_tokens(&mut inner_tokens);
variant.to_tokens(&mut inner_tokens);
}
tokens.extend(quote_spanned!(brace_token.span=> { #inner_tokens }));
semi_token.to_tokens(tokens);
}
}
}
fn generate_consts(&self, tokens: &mut TokenStream) {
let vis = self.vis();
let repr = self.repr();
let (min_value, min, max_value, max);
match self {
Self::Struct { range, .. } => {
min_value = match &range.0 {
Some(from) => from.into_token_stream(),
None => quote!(::core::primitive::#repr::MIN),
};
min = quote!(Self(Self::MIN_VALUE));
max_value = match &range.1 {
Some(to) => to.into_token_stream(),
None => quote!(::core::primitive::#repr::MAX),
};
max = quote!(Self(Self::MAX_VALUE));
}
Self::Enum { range, .. } => {
min_value = Literal::isize_unsuffixed(*range.start()).into_token_stream();
max_value = Literal::isize_unsuffixed(*range.end()).into_token_stream();
let min_variant = enum_variant(*range.start());
let max_variant = enum_variant(*range.end());
min = quote!(Self::#min_variant);
max = quote!(Self::#max_variant);
}
}
tokens.extend(quote! {
#vis const MIN_VALUE: #repr = #min_value;
#vis const MAX_VALUE: #repr = #max_value;
#vis const MIN: Self = #min;
#vis const MAX: Self = #max;
#vis const RANGE: #repr = Self::MAX_VALUE - Self::MIN_VALUE + 1;
});
}
fn generate_base(&self, tokens: &mut TokenStream) {
let vis = self.vis();
let repr = self.repr();
let (get_body, new_body, low_bounded, high_bounded) = match self {
Self::Struct { range, .. } => (
quote!(self.0),
quote!(Self(n)),
range.0.is_some(),
range.1.is_some(),
),
Self::Enum { .. } => (
quote!(self as #repr),
quote!(::core::mem::transmute::<#repr, Self>(n)),
true,
true,
),
};
let low_check = if low_bounded {
quote!(n >= Self::MIN_VALUE)
} else {
quote!(true)
};
let high_check = if high_bounded {
quote!(n <= Self::MAX_VALUE)
} else {
quote!(true)
};
tokens.extend(quote! {
#[must_use]
#vis unsafe fn new_unchecked(n: #repr) -> Self {
#new_body
}
#[must_use]
#vis fn in_range(n: #repr) -> bool {
#low_check && #high_check
}
#[must_use]
#vis fn new(n: #repr) -> Option<Self> {
if Self::in_range(n) {
Some(unsafe { Self::new_unchecked(n) })
} else {
None
}
}
#[must_use]
#vis fn new_saturating(n: #repr) -> Self {
if !(#low_check) {
Self::MIN
} else if !(#high_check) {
Self::MAX
} else {
unsafe { Self::new_unchecked(n) }
}
}
#[must_use]
#vis fn new_wrapping(n: #repr) -> Self {
unsafe {
Self::new_unchecked(
(n + (Self::RANGE - (Self::MIN_VALUE.rem_euclid(Self::RANGE)))).rem_euclid(Self::RANGE)
+ Self::MIN_VALUE
)
}
}
#[must_use]
#vis fn get(self) -> #repr {
#get_body
}
});
}
fn generate_operators(&self, tokens: &mut TokenStream) {
let vis = self.vis();
let repr = self.repr();
tokens.extend(quote! {
#[must_use]
#vis fn abs(self) -> Self {
Self::new(self.get().abs()).expect("Absolute value out of range")
}
#[must_use]
#vis fn pow(self, exp: u32) -> Self {
Self::new(self.get().pow(exp)).expect("Value raised to power out of range")
}
#[must_use]
#vis fn div_euclid(self, rhs: #repr) -> Self {
Self::new(self.get().div_euclid(rhs)).expect("Attempted to divide out of range")
}
#[must_use]
#vis fn rem_euclid(self, rhs: #repr) -> Self {
Self::new(self.get().rem_euclid(rhs))
.expect("Attempted to divide with remainder out of range")
}
});
}
fn generate_ops_traits(&self, tokens: &mut TokenStream) {
let ident = self.ident();
let repr = self.repr();
for op in OPERATORS {
let description = op.description;
if op.bin {
binop_trait_variations(
op.trait_name,
op.method,
ident,
repr,
|trait_name, method| {
quote! {
Self::new(<#repr as ::core::ops::#trait_name>::#method(self.get(), rhs))
.expect(concat!("Attempted to ", #description, " out of range"))
}
},
tokens,
);
binop_trait_variations(
op.trait_name,
op.method,
ident,
ident,
|trait_name, method| {
quote! {
<Self as ::core::ops::#trait_name<#repr>>::#method(self, rhs.get())
}
},
tokens,
);
} else {
let trait_name = Ident::new(op.trait_name, Span::call_site());
let method = Ident::new(op.method, Span::call_site());
unop_trait_variations(
&trait_name,
&method,
ident,
"e! {
Self::new(<#repr as ::core::ops::#trait_name>::#method(self.get()))
.expect(concat!("Attempted to ", #description, " out of range"))
},
tokens,
);
}
}
}
fn generate_checked_operators(&self, tokens: &mut TokenStream) {
let vis = self.vis();
for op in CHECKED_OPERATORS {
let mut rhs_ident_storage = None;
let rhs = op.rhs.map(|name| {
if name == "Self" {
self.repr()
} else {
rhs_ident_storage.get_or_insert_with(|| Ident::new(name, Span::call_site()))
}
});
let rhs_type = rhs.map(|ty| quote!(rhs: #ty,));
let rhs_value = rhs.map(|_| quote!(rhs,));
let checked_name = Ident::new(&format!("checked_{}", op.name), Span::call_site());
let checked_comment = format!("Checked {}.", op.description);
tokens.extend(quote! {
#[doc = #checked_comment]
#[must_use]
#vis fn #checked_name(self, #rhs_type) -> Option<Self> {
self.get().#checked_name(#rhs_value).and_then(Self::new)
}
});
if op.saturating {
let saturating_name =
Ident::new(&format!("saturating_{}", op.name), Span::call_site());
let saturating_comment = format!("Saturing {}.", op.description);
tokens.extend(quote! {
#[doc = #saturating_comment]
#[must_use]
#vis fn #saturating_name(self, #rhs_type) -> Self {
Self::new_saturating(self.get().#saturating_name(#rhs_value))
}
});
}
}
}
fn generate_fmt_traits(&self, tokens: &mut TokenStream) {
let ident = self.ident();
let repr = self.repr();
for &fmt_trait in &[
"Binary", "Display", "LowerExp", "LowerHex", "Octal", "UpperExp", "UpperHex",
] {
let fmt_trait = Ident::new(fmt_trait, Span::call_site());
tokens.extend(quote! {
impl ::core::fmt::#fmt_trait for #ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
<#repr as ::core::fmt::#fmt_trait>::fmt(&self.get(), f)
}
}
});
}
}
fn generate_impl(&self, tokens: &mut TokenStream) {
let mut inner_tokens = TokenStream::new();
self.generate_consts(&mut inner_tokens);
self.generate_base(&mut inner_tokens);
self.generate_operators(&mut inner_tokens);
self.generate_checked_operators(&mut inner_tokens);
let ident = self.ident();
tokens.extend(quote!(impl #ident { #inner_tokens }));
self.generate_ops_traits(tokens);
self.generate_fmt_traits(tokens);
}
fn attrs(&self) -> &Vec<Attribute> {
match self {
Self::Struct { attrs, .. } => attrs,
Self::Enum { attrs, .. } => attrs,
}
}
fn repr(&self) -> &Ident {
match self {
Self::Struct { repr, .. } => repr,
Self::Enum { repr, .. } => repr,
}
}
fn vis(&self) -> &Visibility {
match self {
Self::Struct { vis, .. } => vis,
Self::Enum { vis, .. } => vis,
}
}
fn ident(&self) -> &Ident {
match self {
Self::Struct { ident, .. } => ident,
Self::Enum { ident, .. } => ident,
}
}
}
impl Parse for BoundedInteger {
fn parse(input: ParseStream) -> parse::Result<Self> {
let mut attrs = input.call(Attribute::parse_outer)?;
let repr_pos = attrs
.iter()
.position(|attr| attr.path.is_ident("repr"))
.ok_or_else(|| input.error("no repr attribute on bounded integer"))?;
let repr = attrs.remove(repr_pos).parse_args()?;
let vis: Visibility = input.parse()?;
Ok(if input.peek(Token![struct]) {
let struct_token: Token![struct] = input.parse()?;
let repr: Ident = repr;
let range;
#[allow(clippy::eval_order_dependence)]
let this = Self::Struct {
attrs,
repr,
vis,
struct_token,
ident: input.parse()?,
brace_token: braced!(range in input),
range: {
let range: ExprRange = range.parse()?;
let limits = range.limits;
Box::new((
range.from.map(|from| *from),
range.to.map(|to| match limits {
RangeLimits::HalfOpen(_) => Expr::Verbatim(quote!(#to - 1)),
RangeLimits::Closed(_) => *to,
}),
))
},
};
input.parse::<Option<Token![;]>>()?;
this
} else {
let range_tokens;
#[allow(clippy::eval_order_dependence)]
Self::Enum {
attrs,
repr,
vis,
enum_token: input.parse()?,
ident: input.parse()?,
brace_token: braced!(range_tokens in input),
range: {
let range: ExprRange = range_tokens.parse()?;
let (from, to) = range.from.as_deref()
.zip(range.to.as_deref())
.ok_or_else(|| Error::new_spanned(&range, "the bounds of an enum range must be closed"))?;
let (from, to) = (eval_expr(from)?, eval_expr(to)?);
from..=if let RangeLimits::HalfOpen(_) = range.limits {
to - 1
} else {
to
}
},
semi_token: input.parse()?,
}
})
}
}
fn eval_expr(expr: &Expr) -> syn::Result<isize> {
Ok(match expr {
Expr::Lit(ExprLit { lit, .. }) => match lit {
Lit::Int(int) => int.base10_parse()?,
_ => {
return Err(Error::new_spanned(lit, "literal must be integer"));
}
},
Expr::Unary(ExprUnary { op, expr, .. }) => {
let expr = eval_expr(&expr)?;
match op {
UnOp::Not(_) => !expr,
UnOp::Neg(_) => -expr,
_ => {
return Err(Error::new_spanned(op, "unary operator must be ! or -"));
}
}
}
Expr::Binary(ExprBinary {
left, op, right, ..
}) => {
let left = eval_expr(&left)?;
let right = eval_expr(&right)?;
match op {
BinOp::Add(_) => left + right,
BinOp::Sub(_) => left - right,
BinOp::Mul(_) => left * right,
BinOp::Div(_) => left / right,
BinOp::Rem(_) => left % right,
BinOp::BitXor(_) => left ^ right,
BinOp::BitAnd(_) => left & right,
BinOp::BitOr(_) => left | right,
_ => {
return Err(Error::new_spanned(op, "operator not supported in this context"));
}
}
}
Expr::Group(ExprGroup { expr, .. }) | Expr::Paren(ExprParen { expr, .. }) => {
eval_expr(expr)?
}
_ => return Err(Error::new_spanned(expr, "expected simple expression")),
})
}
fn enum_variant(i: isize) -> Ident {
Ident::new(
&*match i.cmp(&0) {
Ordering::Less => format!("N{}", i.abs()),
Ordering::Equal => "Z0".to_owned(),
Ordering::Greater => format!("P{}", i),
},
Span::call_site(),
)
}
#[rustfmt::skip]
const CHECKED_OPERATORS: &[CheckedOperator] = &[
CheckedOperator::new("add" , "integer addition" , Some("Self"), true ),
CheckedOperator::new("sub" , "integer subtraction" , Some("Self"), true ),
CheckedOperator::new("mul" , "integer multiplication", Some("Self"), true ),
CheckedOperator::new("div" , "integer division" , Some("Self"), false),
CheckedOperator::new("div_euclid", "Euclidean division" , Some("Self"), false),
CheckedOperator::new("rem" , "integer remainder" , Some("Self"), false),
CheckedOperator::new("rem_euclid", "Euclidean remainder" , Some("Self"), false),
CheckedOperator::new("neg" , "negation" , None , true ),
CheckedOperator::new("abs" , "absolute value" , None , true ),
CheckedOperator::new("pow" , "exponentiation" , Some("u32") , true ),
];
struct CheckedOperator {
name: &'static str,
description: &'static str,
rhs: Option<&'static str>,
saturating: bool,
}
impl CheckedOperator {
const fn new(
name: &'static str,
description: &'static str,
rhs: Option<&'static str>,
saturating: bool,
) -> Self {
Self {
name,
description,
rhs,
saturating,
}
}
}
#[rustfmt::skip]
const OPERATORS: &[Operator] = &[
Operator { trait_name: "Add", method: "add", description: "add" , bin: true },
Operator { trait_name: "Sub", method: "sub", description: "subtract" , bin: true },
Operator { trait_name: "Mul", method: "mul", description: "multiply" , bin: true },
Operator { trait_name: "Div", method: "div", description: "divide" , bin: true },
Operator { trait_name: "Rem", method: "rem", description: "take remainder", bin: true },
Operator { trait_name: "Neg", method: "neg", description: "negate" , bin: false },
];
struct Operator {
trait_name: &'static str,
method: &'static str,
description: &'static str,
bin: bool,
}
fn binop_trait_variations<B: ToTokens>(
trait_name_root: &str,
method_root: &str,
lhs: &impl ToTokens,
rhs: &impl ToTokens,
body: impl FnOnce(&Ident, &Ident) -> B,
tokens: &mut TokenStream,
) {
let trait_name = Ident::new(trait_name_root, Span::call_site());
let trait_name_assign = Ident::new(&format!("{}Assign", trait_name_root), Span::call_site());
let method = Ident::new(method_root, Span::call_site());
let method_assign = Ident::new(&format!("{}_assign", method_root), Span::call_site());
let body = body(&trait_name, &method);
tokens.extend(quote! {
impl ::core::ops::#trait_name<#rhs> for #lhs {
type Output = #lhs;
fn #method(self, rhs: #rhs) -> Self::Output {
#body
}
}
impl<'a> ::core::ops::#trait_name<#rhs> for &'a #lhs {
type Output = #lhs;
fn #method(self, rhs: #rhs) -> Self::Output {
<#lhs as ::core::ops::#trait_name<#rhs>>::#method(*self, rhs)
}
}
impl<'b> ::core::ops::#trait_name<&'b #rhs> for #lhs {
type Output = #lhs;
fn #method(self, rhs: &'b #rhs) -> Self::Output {
<#lhs as ::core::ops::#trait_name<#rhs>>::#method(self, *rhs)
}
}
impl<'b, 'a> ::core::ops::#trait_name<&'b #rhs> for &'a #lhs {
type Output = #lhs;
fn #method(self, rhs: &'b #rhs) -> Self::Output {
<#lhs as ::core::ops::#trait_name<#rhs>>::#method(*self, *rhs)
}
}
impl ::core::ops::#trait_name_assign<#rhs> for #lhs {
fn #method_assign(&mut self, rhs: #rhs) {
*self = <Self as ::core::ops::#trait_name<#rhs>>::#method(*self, rhs);
}
}
impl<'a> ::core::ops::#trait_name_assign<&'a #rhs> for #lhs {
fn #method_assign(&mut self, rhs: &'a #rhs) {
*self = <Self as ::core::ops::#trait_name<#rhs>>::#method(*self, *rhs);
}
}
});
}
fn unop_trait_variations(
trait_name: &impl ToTokens,
method: &impl ToTokens,
lhs: &impl ToTokens,
body: &impl ToTokens,
tokens: &mut TokenStream,
) {
tokens.extend(quote! {
impl ::core::ops::#trait_name for #lhs {
type Output = #lhs;
fn #method(self) -> Self::Output {
#body
}
}
impl<'a> ::core::ops::#trait_name for &'a #lhs {
type Output = #lhs;
fn #method(self) -> Self::Output {
<#lhs as ::core::ops::#trait_name>::#method(*self)
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse2;
fn assert_result(
f: impl FnOnce(&BoundedInteger, &mut TokenStream),
input: TokenStream,
expected: TokenStream,
) {
let mut result = TokenStream::new();
f(&parse2::<BoundedInteger>(input).unwrap(), &mut result);
assert_eq!(result.to_string(), expected.to_string());
}
#[cfg(test)]
#[test]
fn test_tokens() {
assert_result(
BoundedInteger::generate_item,
quote! {
#[repr(isize)]
pub(crate) enum Nibble { -8..6+2 }
},
quote! {
#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(isize)]
pub(crate) enum Nibble {
N8 = -8, N7, N6, N5, N4, N3, N2, N1, Z0, P1, P2, P3, P4, P5, P6, P7
}
},
);
assert_result(
BoundedInteger::generate_item,
quote! {
#[repr(u16)]
enum Nibble { 3..=7 };
},
quote! {
#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u16)]
enum Nibble {
P3 = 3, P4, P5, P6, P7
};
},
);
assert_result(
BoundedInteger::generate_item,
quote! {
#[repr(i8)]
pub struct S { -3..2 }
},
quote! {
#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct S(i8);
},
);
}
}