use crate::ast::{ContainerKind, Field};
use crate::attr::{Display, Trait};
use crate::scan_expr::scan_expr;
use crate::unraw::{IdentUnraw, MemberUnraw};
use proc_macro2::{Delimiter, TokenStream, TokenTree};
use quote::{format_ident, quote, quote_spanned, ToTokens as _};
use std::collections::{BTreeSet, HashMap};
use std::iter;
use syn::ext::IdentExt;
use syn::parse::discouraged::Speculative;
use syn::parse::{Error, ParseStream, Parser, Result};
use syn::{Expr, Ident, Index, LitStr, Token};
impl Display<'_> {
pub fn expand_shorthand(&mut self, fields: &[Field], container: ContainerKind) -> Result<()> {
let raw_args = self.args.clone();
let FmtArguments {
named: user_named_args,
first_unnamed,
} = explicit_named_args.parse2(raw_args).unwrap();
let mut member_index = HashMap::new();
let mut extra_positional_arguments_allowed = true;
for (i, field) in fields.iter().enumerate() {
member_index.insert(&field.member, i);
extra_positional_arguments_allowed &= matches!(&field.member, MemberUnraw::Named(_));
}
let span = self.fmt.span();
let fmt = self.fmt.value();
let mut read = fmt.as_str();
let mut out = String::new();
let mut has_bonus_display = false;
let mut infinite_recursive = false;
let mut implied_bounds = BTreeSet::new();
let mut bindings = Vec::new();
let mut macro_named_args = BTreeSet::new();
self.requires_fmt_machinery = self.requires_fmt_machinery || fmt.contains('}');
while let Some(brace) = read.find('{') {
self.requires_fmt_machinery = true;
out += &read[..brace + 1];
read = &read[brace + 1..];
if read.starts_with('{') {
out.push('{');
read = &read[1..];
continue;
}
let next = match read.chars().next() {
Some(next) => next,
None => return Ok(()),
};
let member = match next {
'0'..='9' => {
let int = take_int(&mut read);
if !extra_positional_arguments_allowed {
if let Some(first_unnamed) = &first_unnamed {
let msg = format!("ambiguous reference to positional arguments by number in a {container}; change this to a named argument");
return Err(Error::new_spanned(first_unnamed, msg));
}
}
match int.parse::<u32>() {
Ok(index) => MemberUnraw::Unnamed(Index { index, span }),
Err(_) => return Ok(()),
}
}
'a'..='z' | 'A'..='Z' | '_' => {
if read.starts_with("r#") {
continue;
}
let repr = take_ident(&mut read);
if repr == "_" {
out += repr;
continue;
}
let ident = IdentUnraw::new(Ident::new(repr, span));
if user_named_args.contains(&ident) {
out += repr;
continue;
}
MemberUnraw::Named(ident)
}
_ => continue,
};
let end_spec = match read.find('}') {
Some(end_spec) => end_spec,
None => return Ok(()),
};
let mut bonus_display = false;
let bound = match read[..end_spec].chars().next_back() {
Some('?') => Trait::Debug,
Some('o') => Trait::Octal,
Some('x') => Trait::LowerHex,
Some('X') => Trait::UpperHex,
Some('p') => Trait::Pointer,
Some('b') => Trait::Binary,
Some('e') => Trait::LowerExp,
Some('E') => Trait::UpperExp,
Some(_) => Trait::Display,
None => {
bonus_display = true;
has_bonus_display = true;
Trait::Display
}
};
infinite_recursive |= member == *"self" && bound == Trait::Display;
if let Some(&field) = member_index.get(&member) {
implied_bounds.insert((field, bound));
} else {
out += &member.to_string();
continue;
}
let formatvar_prefix = if bonus_display {
"__display"
} else if bound == Trait::Pointer {
"__pointer"
} else {
"__field"
};
let mut formatvar = IdentUnraw::new(match &member {
MemberUnraw::Unnamed(index) => format_ident!("{}{}", formatvar_prefix, index),
MemberUnraw::Named(ident) => {
format_ident!("{}_{}", formatvar_prefix, ident.to_string())
}
});
while user_named_args.contains(&formatvar) {
formatvar = IdentUnraw::new(format_ident!("_{}", formatvar.to_string()));
}
out += &formatvar.to_string();
if !macro_named_args.insert(formatvar.clone()) {
continue;
}
let binding_value = match &member {
MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
MemberUnraw::Named(ident) => ident.to_local(),
};
let wrapped_binding_value = if bonus_display {
quote_spanned!(span=> #binding_value.as_display())
} else if bound == Trait::Pointer {
quote!(::thiserror::__private::Var(#binding_value))
} else {
binding_value.into_token_stream()
};
bindings.push((formatvar.to_local(), wrapped_binding_value));
}
out += read;
self.fmt = LitStr::new(&out, self.fmt.span());
self.has_bonus_display = has_bonus_display;
self.infinite_recursive = infinite_recursive;
self.implied_bounds = implied_bounds;
self.bindings = bindings;
Ok(())
}
}
struct FmtArguments {
named: BTreeSet<IdentUnraw>,
first_unnamed: Option<TokenStream>,
}
#[allow(clippy::unnecessary_wraps)]
fn explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
let ahead = input.fork();
if let Ok(set) = try_explicit_named_args(&ahead) {
input.advance_to(&ahead);
return Ok(set);
}
let ahead = input.fork();
if let Ok(set) = fallback_explicit_named_args(&ahead) {
input.advance_to(&ahead);
return Ok(set);
}
input.parse::<TokenStream>().unwrap();
Ok(FmtArguments {
named: BTreeSet::new(),
first_unnamed: None,
})
}
fn try_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
let mut syn_full = None;
let mut args = FmtArguments {
named: BTreeSet::new(),
first_unnamed: None,
};
while !input.is_empty() {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
let mut begin_unnamed = None;
if input.peek(Ident::peek_any) && input.peek2(Token![=]) && !input.peek2(Token![==]) {
let ident: IdentUnraw = input.parse()?;
input.parse::<Token![=]>()?;
args.named.insert(ident);
} else {
begin_unnamed = Some(input.fork());
}
let ahead = input.fork();
if *syn_full.get_or_insert_with(is_syn_full) && ahead.parse::<Expr>().is_ok() {
input.advance_to(&ahead);
} else {
scan_expr(input)?;
}
if let Some(begin_unnamed) = begin_unnamed {
if args.first_unnamed.is_none() {
args.first_unnamed = Some(between(&begin_unnamed, input));
}
}
}
Ok(args)
}
fn fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
let mut args = FmtArguments {
named: BTreeSet::new(),
first_unnamed: None,
};
while !input.is_empty() {
if input.peek(Token![,])
&& input.peek2(Ident::peek_any)
&& input.peek3(Token![=])
&& !input.peek3(Token![==])
{
input.parse::<Token![,]>()?;
let ident: IdentUnraw = input.parse()?;
input.parse::<Token![=]>()?;
args.named.insert(ident);
} else {
input.parse::<TokenTree>()?;
}
}
Ok(args)
}
fn is_syn_full() -> bool {
let test = quote!({
trait Trait {}
});
match syn::parse2(test) {
Ok(Expr::Verbatim(_)) | Err(_) => false,
Ok(Expr::Block(_)) => true,
Ok(_) => unreachable!(),
}
}
fn take_int<'a>(read: &mut &'a str) -> &'a str {
let mut int_len = 0;
for ch in read.chars() {
match ch {
'0'..='9' => int_len += 1,
_ => break,
}
}
let (int, rest) = read.split_at(int_len);
*read = rest;
int
}
fn take_ident<'a>(read: &mut &'a str) -> &'a str {
let mut ident_len = 0;
for ch in read.chars() {
match ch {
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => ident_len += 1,
_ => break,
}
}
let (ident, rest) = read.split_at(ident_len);
*read = rest;
ident
}
fn between<'a>(begin: ParseStream<'a>, end: ParseStream<'a>) -> TokenStream {
let end = end.cursor();
let mut cursor = begin.cursor();
let mut tokens = TokenStream::new();
while cursor < end {
let (tt, next) = cursor.token_tree().unwrap();
if end < next {
if let Some((inside, _span, _after)) = cursor.group(Delimiter::None) {
cursor = inside;
continue;
}
if tokens.is_empty() {
tokens.extend(iter::once(tt));
}
break;
}
tokens.extend(iter::once(tt));
cursor = next;
}
tokens
}