seq_macro/
lib.rs

1//! [![github]](https://github.com/dtolnay/seq-macro) [![crates-io]](https://crates.io/crates/seq-macro) [![docs-rs]](https://docs.rs/seq-macro)
2//!
3//! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github
4//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
5//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
6//!
7//! <br>
8//!
9//! # Imagine for-loops in a macro
10//!
11//! This crate provides a `seq!` macro to repeat a fragment of source code and
12//! substitute into each repetition a sequential numeric counter.
13//!
14//! ```
15//! use seq_macro::seq;
16//!
17//! fn main() {
18//!     let tuple = (1000, 100, 10);
19//!     let mut sum = 0;
20//!
21//!     // Expands to:
22//!     //
23//!     //     sum += tuple.0;
24//!     //     sum += tuple.1;
25//!     //     sum += tuple.2;
26//!     //
27//!     // This cannot be written using an ordinary for-loop because elements of
28//!     // a tuple can only be accessed by their integer literal index, not by a
29//!     // variable.
30//!     seq!(N in 0..=2 {
31//!         sum += tuple.N;
32//!     });
33//!
34//!     assert_eq!(sum, 1110);
35//! }
36//! ```
37//!
38//! - If the input tokens contain a section surrounded by `#(` ... `)*` then
39//!   only that part is repeated.
40//!
41//! - The numeric counter can be pasted onto the end of some prefix to form
42//!   sequential identifiers.
43//!
44//! ```
45//! use seq_macro::seq;
46//!
47//! seq!(N in 64..=127 {
48//!     #[derive(Debug)]
49//!     enum Demo {
50//!         // Expands to Variant64, Variant65, ...
51//!         ##(
52//!             Variant~N,
53//!         )*
54//!     }
55//! });
56//!
57//! fn main() {
58//!     assert_eq!("Variant99", format!("{:?}", Demo::Variant99));
59//! }
60//! ```
61//!
62//! - Byte and character ranges are supported: `b'a'..=b'z'`, `'a'..='z'`.
63//!
64//! - If the range bounds are written in binary, octal, hex, or with zero
65//!   padding, those features are preserved in any generated tokens.
66//!
67//! ```
68//! use seq_macro::seq;
69//!
70//! seq!(P in 0x000..=0x00F {
71//!     // expands to structs Pin000, ..., Pin009, Pin00A, ..., Pin00F
72//!     struct Pin~P;
73//! });
74//! ```
75
76#![doc(html_root_url = "https://docs.rs/seq-macro/0.3.6")]
77#![allow(
78    clippy::cast_lossless,
79    clippy::cast_possible_truncation,
80    clippy::derive_partial_eq_without_eq,
81    clippy::into_iter_without_iter,
82    clippy::let_underscore_untyped,
83    clippy::needless_doctest_main,
84    clippy::single_match_else,
85    clippy::wildcard_imports
86)]
87
88mod parse;
89
90use crate::parse::*;
91use proc_macro::{Delimiter, Group, Ident, Literal, Span, TokenStream, TokenTree};
92use std::char;
93use std::iter::{self, FromIterator};
94
95#[proc_macro]
96pub fn seq(input: TokenStream) -> TokenStream {
97    match seq_impl(input) {
98        Ok(expanded) => expanded,
99        Err(error) => error.into_compile_error(),
100    }
101}
102
103struct Range {
104    begin: u64,
105    end: u64,
106    inclusive: bool,
107    kind: Kind,
108    suffix: String,
109    width: usize,
110    radix: Radix,
111}
112
113struct Value {
114    int: u64,
115    kind: Kind,
116    suffix: String,
117    width: usize,
118    radix: Radix,
119    span: Span,
120}
121
122struct Splice<'a> {
123    int: u64,
124    kind: Kind,
125    suffix: &'a str,
126    width: usize,
127    radix: Radix,
128}
129
130#[derive(Copy, Clone, PartialEq)]
131enum Kind {
132    Int,
133    Byte,
134    Char,
135}
136
137#[derive(Copy, Clone, PartialEq)]
138enum Radix {
139    Binary,
140    Octal,
141    Decimal,
142    LowerHex,
143    UpperHex,
144}
145
146impl<'a> IntoIterator for &'a Range {
147    type Item = Splice<'a>;
148    type IntoIter = Box<dyn Iterator<Item = Splice<'a>> + 'a>;
149
150    fn into_iter(self) -> Self::IntoIter {
151        let splice = move |int| Splice {
152            int,
153            kind: self.kind,
154            suffix: &self.suffix,
155            width: self.width,
156            radix: self.radix,
157        };
158        match self.kind {
159            Kind::Int | Kind::Byte => {
160                if self.inclusive {
161                    Box::new((self.begin..=self.end).map(splice))
162                } else {
163                    Box::new((self.begin..self.end).map(splice))
164                }
165            }
166            Kind::Char => {
167                let begin = char::from_u32(self.begin as u32).unwrap();
168                let end = char::from_u32(self.end as u32).unwrap();
169                let int = |ch| u64::from(u32::from(ch));
170                if self.inclusive {
171                    Box::new((begin..=end).map(int).map(splice))
172                } else {
173                    Box::new((begin..end).map(int).map(splice))
174                }
175            }
176        }
177    }
178}
179
180fn seq_impl(input: TokenStream) -> Result<TokenStream, SyntaxError> {
181    let mut iter = input.into_iter();
182    let var = require_ident(&mut iter)?;
183    require_keyword(&mut iter, "in")?;
184    let begin = require_value(&mut iter)?;
185    require_punct(&mut iter, '.')?;
186    require_punct(&mut iter, '.')?;
187    let inclusive = require_if_punct(&mut iter, '=')?;
188    let end = require_value(&mut iter)?;
189    let body = require_braces(&mut iter)?;
190    require_end(&mut iter)?;
191
192    let range = validate_range(begin, end, inclusive)?;
193
194    let mut found_repetition = false;
195    let expanded = expand_repetitions(&var, &range, body.clone(), &mut found_repetition);
196    if found_repetition {
197        Ok(expanded)
198    } else {
199        // If no `#(...)*`, repeat the entire body.
200        Ok(repeat(&var, &range, &body))
201    }
202}
203
204fn repeat(var: &Ident, range: &Range, body: &TokenStream) -> TokenStream {
205    let mut repeated = TokenStream::new();
206    for value in range {
207        repeated.extend(substitute_value(var, &value, body.clone()));
208    }
209    repeated
210}
211
212fn substitute_value(var: &Ident, splice: &Splice, body: TokenStream) -> TokenStream {
213    let mut tokens = Vec::from_iter(body);
214
215    let mut i = 0;
216    while i < tokens.len() {
217        // Substitute our variable by itself, e.g. `N`.
218        let replace = match &tokens[i] {
219            TokenTree::Ident(ident) => ident.to_string() == var.to_string(),
220            _ => false,
221        };
222        if replace {
223            let original_span = tokens[i].span();
224            let mut literal = splice.literal();
225            literal.set_span(original_span);
226            tokens[i] = TokenTree::Literal(literal);
227            i += 1;
228            continue;
229        }
230
231        // Substitute our variable concatenated onto some prefix, `Prefix~N`.
232        if i + 3 <= tokens.len() {
233            let prefix = match &tokens[i..i + 3] {
234                [first, TokenTree::Punct(tilde), TokenTree::Ident(ident)]
235                    if tilde.as_char() == '~' && ident.to_string() == var.to_string() =>
236                {
237                    match first {
238                        TokenTree::Ident(ident) => Some(ident.clone()),
239                        TokenTree::Group(group) => {
240                            let mut iter = group.stream().into_iter().fuse();
241                            match (iter.next(), iter.next()) {
242                                (Some(TokenTree::Ident(ident)), None) => Some(ident),
243                                _ => None,
244                            }
245                        }
246                        _ => None,
247                    }
248                }
249                _ => None,
250            };
251            if let Some(prefix) = prefix {
252                let number = match splice.kind {
253                    Kind::Int => match splice.radix {
254                        Radix::Binary => format!("{0:01$b}", splice.int, splice.width),
255                        Radix::Octal => format!("{0:01$o}", splice.int, splice.width),
256                        Radix::Decimal => format!("{0:01$}", splice.int, splice.width),
257                        Radix::LowerHex => format!("{0:01$x}", splice.int, splice.width),
258                        Radix::UpperHex => format!("{0:01$X}", splice.int, splice.width),
259                    },
260                    Kind::Byte | Kind::Char => {
261                        char::from_u32(splice.int as u32).unwrap().to_string()
262                    }
263                };
264                let concat = format!("{}{}", prefix, number);
265                let ident = Ident::new(&concat, prefix.span());
266                tokens.splice(i..i + 3, iter::once(TokenTree::Ident(ident)));
267                i += 1;
268                continue;
269            }
270        }
271
272        // Recursively substitute content nested in a group.
273        if let TokenTree::Group(group) = &mut tokens[i] {
274            let original_span = group.span();
275            let content = substitute_value(var, splice, group.stream());
276            *group = Group::new(group.delimiter(), content);
277            group.set_span(original_span);
278        }
279
280        i += 1;
281    }
282
283    TokenStream::from_iter(tokens)
284}
285
286fn enter_repetition(tokens: &[TokenTree]) -> Option<TokenStream> {
287    assert!(tokens.len() == 3);
288    match &tokens[0] {
289        TokenTree::Punct(punct) if punct.as_char() == '#' => {}
290        _ => return None,
291    }
292    match &tokens[2] {
293        TokenTree::Punct(punct) if punct.as_char() == '*' => {}
294        _ => return None,
295    }
296    match &tokens[1] {
297        TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
298            Some(group.stream())
299        }
300        _ => None,
301    }
302}
303
304fn expand_repetitions(
305    var: &Ident,
306    range: &Range,
307    body: TokenStream,
308    found_repetition: &mut bool,
309) -> TokenStream {
310    let mut tokens = Vec::from_iter(body);
311
312    // Look for `#(...)*`.
313    let mut i = 0;
314    while i < tokens.len() {
315        if let TokenTree::Group(group) = &mut tokens[i] {
316            let content = expand_repetitions(var, range, group.stream(), found_repetition);
317            let original_span = group.span();
318            *group = Group::new(group.delimiter(), content);
319            group.set_span(original_span);
320            i += 1;
321            continue;
322        }
323        if i + 3 > tokens.len() {
324            i += 1;
325            continue;
326        }
327        let template = match enter_repetition(&tokens[i..i + 3]) {
328            Some(template) => template,
329            None => {
330                i += 1;
331                continue;
332            }
333        };
334        *found_repetition = true;
335        let mut repeated = Vec::new();
336        for value in range {
337            repeated.extend(substitute_value(var, &value, template.clone()));
338        }
339        let repeated_len = repeated.len();
340        tokens.splice(i..i + 3, repeated);
341        i += repeated_len;
342    }
343
344    TokenStream::from_iter(tokens)
345}
346
347impl Splice<'_> {
348    fn literal(&self) -> Literal {
349        match self.kind {
350            Kind::Int | Kind::Byte => {
351                let repr = match self.radix {
352                    Radix::Binary => format!("0b{0:02$b}{1}", self.int, self.suffix, self.width),
353                    Radix::Octal => format!("0o{0:02$o}{1}", self.int, self.suffix, self.width),
354                    Radix::Decimal => format!("{0:02$}{1}", self.int, self.suffix, self.width),
355                    Radix::LowerHex => format!("0x{0:02$x}{1}", self.int, self.suffix, self.width),
356                    Radix::UpperHex => format!("0x{0:02$X}{1}", self.int, self.suffix, self.width),
357                };
358                let tokens = repr.parse::<TokenStream>().unwrap();
359                let mut iter = tokens.into_iter();
360                let literal = match iter.next() {
361                    Some(TokenTree::Literal(literal)) => literal,
362                    _ => unreachable!(),
363                };
364                assert!(iter.next().is_none());
365                literal
366            }
367            Kind::Char => {
368                let ch = char::from_u32(self.int as u32).unwrap();
369                Literal::character(ch)
370            }
371        }
372    }
373}