1use crate::{
2 kw, Block, FunctionAttribute, FunctionAttributes, Mutability, ParameterList, Parameters,
3 SolIdent, Spanned, Stmt, Type, VariableDeclaration, VariableDefinition, Visibility,
4};
5use proc_macro2::Span;
6use std::{
7 fmt,
8 hash::{Hash, Hasher},
9 num::NonZeroU16,
10};
11use syn::{
12 parenthesized,
13 parse::{Parse, ParseStream},
14 token::{Brace, Paren},
15 Attribute, Error, Result, Token,
16};
17
18#[derive(Clone)]
24pub struct ItemFunction {
25 pub attrs: Vec<Attribute>,
27 pub kind: FunctionKind,
28 pub name: Option<SolIdent>,
29 pub paren_token: Option<Paren>,
32 pub parameters: ParameterList,
33 pub attributes: FunctionAttributes,
35 pub returns: Option<Returns>,
37 pub body: FunctionBody,
38}
39
40impl fmt::Display for ItemFunction {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.write_str(self.kind.as_str())?;
43 if let Some(name) = &self.name {
44 f.write_str(" ")?;
45 name.fmt(f)?;
46 }
47 write!(f, "({})", self.parameters)?;
48
49 if !self.attributes.is_empty() {
50 write!(f, " {}", self.attributes)?;
51 }
52
53 if let Some(returns) = &self.returns {
54 write!(f, " {returns}")?;
55 }
56
57 if !self.body.is_empty() {
58 f.write_str(" ")?;
59 }
60 f.write_str(self.body.as_str())
61 }
62}
63
64impl fmt::Debug for ItemFunction {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("ItemFunction")
67 .field("attrs", &self.attrs)
68 .field("kind", &self.kind)
69 .field("name", &self.name)
70 .field("arguments", &self.parameters)
71 .field("attributes", &self.attributes)
72 .field("returns", &self.returns)
73 .field("body", &self.body)
74 .finish()
75 }
76}
77
78impl Parse for ItemFunction {
79 fn parse(input: ParseStream<'_>) -> Result<Self> {
80 let attrs = input.call(Attribute::parse_outer)?;
81 let kind: FunctionKind = input.parse()?;
82 let name = input.call(SolIdent::parse_opt)?;
83
84 let (paren_token, parameters) = if kind.is_modifier() && !input.peek(Paren) {
85 (None, ParameterList::new())
86 } else {
87 let content;
88 (Some(parenthesized!(content in input)), content.parse()?)
89 };
90
91 let attributes = input.parse()?;
92 let returns = input.call(Returns::parse_opt)?;
93 let body = input.parse()?;
94
95 Ok(Self { attrs, kind, name, paren_token, parameters, attributes, returns, body })
96 }
97}
98
99impl Spanned for ItemFunction {
100 fn span(&self) -> Span {
101 if let Some(name) = &self.name {
102 name.span()
103 } else {
104 self.kind.span()
105 }
106 }
107
108 fn set_span(&mut self, span: Span) {
109 self.kind.set_span(span);
110 if let Some(name) = &mut self.name {
111 name.set_span(span);
112 }
113 }
114}
115
116impl ItemFunction {
117 pub fn new(kind: FunctionKind, name: Option<SolIdent>) -> Self {
119 let span = name.as_ref().map_or_else(|| kind.span(), |name| name.span());
120 Self {
121 attrs: Vec::new(),
122 kind,
123 name,
124 paren_token: Some(Paren(span)),
125 parameters: Parameters::new(),
126 attributes: FunctionAttributes::new(),
127 returns: None,
128 body: FunctionBody::Empty(Token),
129 }
130 }
131
132 pub fn new_getter(name: SolIdent, ty: Type) -> Self {
146 let span = name.span();
147 let kind = FunctionKind::new_function(span);
148 let mut function = Self::new(kind, Some(name.clone()));
149
150 function.attributes.0 = vec![
152 FunctionAttribute::Visibility(Visibility::new_public(span)),
153 FunctionAttribute::Mutability(Mutability::new_view(span)),
154 ];
155
156 let mut ty = ty;
159 let mut return_name = None;
160 let mut first = true;
161 loop {
162 match ty {
163 Type::Mapping(map) => {
165 let key = VariableDeclaration::new_with(*map.key, None, map.key_name);
166 function.parameters.push(key);
167 return_name = map.value_name;
168 ty = *map.value;
169 }
170 Type::Array(array) => {
172 let uint256 = Type::Uint(span, NonZeroU16::new(256));
173 function.parameters.push(VariableDeclaration::new(uint256));
174 ty = *array.ty;
175 }
176 _ => {
177 if first {
178 return_name = Some(name);
179 }
180 break;
181 }
182 }
183 first = false;
184 }
185 let mut returns = ParameterList::new();
186 returns.push(VariableDeclaration::new_with(ty, None, return_name));
187 function.returns = Some(Returns::new(span, returns));
188
189 function
190 }
191
192 pub fn from_variable_definition(var: VariableDefinition) -> Self {
200 let mut function = Self::new_getter(var.name, var.ty);
201 function.attrs = var.attrs;
202 function
203 }
204
205 #[track_caller]
212 pub fn name(&self) -> &SolIdent {
213 match &self.name {
214 Some(name) => name,
215 None => panic!("function has no name: {self:?}"),
216 }
217 }
218
219 pub fn is_void(&self) -> bool {
221 match &self.returns {
222 None => true,
223 Some(returns) => returns.returns.is_empty(),
224 }
225 }
226
227 pub fn has_implementation(&self) -> bool {
229 matches!(self.body, FunctionBody::Block(_))
230 }
231
232 pub fn call_type(&self) -> Type {
234 Type::Tuple(self.parameters.types().cloned().collect())
235 }
236
237 pub fn return_type(&self) -> Option<Type> {
239 self.returns.as_ref().map(|returns| Type::Tuple(returns.returns.types().cloned().collect()))
240 }
241
242 pub fn body(&self) -> Option<&[Stmt]> {
244 match &self.body {
245 FunctionBody::Block(block) => Some(&block.stmts),
246 _ => None,
247 }
248 }
249
250 pub fn body_mut(&mut self) -> Option<&mut Vec<Stmt>> {
252 match &mut self.body {
253 FunctionBody::Block(block) => Some(&mut block.stmts),
254 _ => None,
255 }
256 }
257
258 #[allow(clippy::result_large_err)]
259 pub fn into_body(self) -> std::result::Result<Vec<Stmt>, Self> {
260 match self.body {
261 FunctionBody::Block(block) => Ok(block.stmts),
262 _ => Err(self),
263 }
264 }
265}
266
267kw_enum! {
268 pub enum FunctionKind {
270 Constructor(kw::constructor),
271 Function(kw::function),
272 Fallback(kw::fallback),
273 Receive(kw::receive),
274 Modifier(kw::modifier),
275 }
276}
277
278#[derive(Clone)]
280pub struct Returns {
281 pub returns_token: kw::returns,
282 pub paren_token: Paren,
283 pub returns: ParameterList,
285}
286
287impl PartialEq for Returns {
288 fn eq(&self, other: &Self) -> bool {
289 self.returns == other.returns
290 }
291}
292
293impl Eq for Returns {}
294
295impl Hash for Returns {
296 fn hash<H: Hasher>(&self, state: &mut H) {
297 self.returns.hash(state);
298 }
299}
300
301impl fmt::Display for Returns {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 f.write_str("returns (")?;
304 self.returns.fmt(f)?;
305 f.write_str(")")
306 }
307}
308
309impl fmt::Debug for Returns {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 f.debug_tuple("Returns").field(&self.returns).finish()
312 }
313}
314
315impl Parse for Returns {
316 fn parse(input: ParseStream<'_>) -> Result<Self> {
317 let content;
318 let this = Self {
319 returns_token: input.parse()?,
320 paren_token: parenthesized!(content in input),
321 returns: content.parse()?,
322 };
323 if this.returns.is_empty() {
324 Err(Error::new(this.paren_token.span.join(), "expected at least one return type"))
325 } else {
326 Ok(this)
327 }
328 }
329}
330
331impl Spanned for Returns {
332 fn span(&self) -> Span {
333 let span = self.returns_token.span;
334 span.join(self.paren_token.span.join()).unwrap_or(span)
335 }
336
337 fn set_span(&mut self, span: Span) {
338 self.returns_token.span = span;
339 self.paren_token = Paren(span);
340 }
341}
342
343impl Returns {
344 pub fn new(span: Span, returns: ParameterList) -> Self {
345 Self { returns_token: kw::returns(span), paren_token: Paren(span), returns }
346 }
347
348 pub fn parse_opt(input: ParseStream<'_>) -> Result<Option<Self>> {
349 if input.peek(kw::returns) {
350 input.parse().map(Some)
351 } else {
352 Ok(None)
353 }
354 }
355}
356
357#[derive(Clone)]
359pub enum FunctionBody {
360 Empty(Token![;]),
362 Block(Block),
364}
365
366impl fmt::Display for FunctionBody {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.write_str(self.as_str())
369 }
370}
371
372impl fmt::Debug for FunctionBody {
373 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
374 f.write_str("FunctionBody::")?;
375 match self {
376 Self::Empty(_) => f.write_str("Empty"),
377 Self::Block(block) => block.fmt(f),
378 }
379 }
380}
381
382impl Parse for FunctionBody {
383 fn parse(input: ParseStream<'_>) -> Result<Self> {
384 let lookahead = input.lookahead1();
385 if lookahead.peek(Brace) {
386 input.parse().map(Self::Block)
387 } else if lookahead.peek(Token![;]) {
388 input.parse().map(Self::Empty)
389 } else {
390 Err(lookahead.error())
391 }
392 }
393}
394
395impl FunctionBody {
396 #[inline]
398 pub fn is_empty(&self) -> bool {
399 matches!(self, Self::Empty(_))
400 }
401
402 #[inline]
404 pub fn as_str(&self) -> &'static str {
405 match self {
406 Self::Empty(_) => ";",
407 Self::Block(_) => "{ <stmts> }",
409 }
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use pretty_assertions::assert_eq;
417 use std::{
418 error::Error,
419 io::Write,
420 process::{Command, Stdio},
421 };
422 use syn::parse_quote;
423
424 #[test]
425 fn modifiers() {
426 let none: ItemFunction = parse_quote! {
427 modifier noParens {
428 _;
429 }
430 };
431 let some: ItemFunction = parse_quote! {
432 modifier withParens() {
433 _;
434 }
435 };
436 assert_eq!(none.kind, FunctionKind::new_modifier(Span::call_site()));
437 assert_eq!(none.kind, some.kind);
438 assert_eq!(none.paren_token, None);
439 assert_eq!(some.paren_token, Some(Default::default()));
440 }
441
442 #[test]
443 #[cfg_attr(miri, ignore = "takes too long")]
444 fn getters() {
445 let run_solc = run_solc();
446
447 macro_rules! test_getters {
448 ($($var:literal => $f:literal),* $(,)?) => {
449 let vars: &[&str] = &[$($var),*];
450 let fns: &[&str] = &[$($f),*];
451 for (var, f) in std::iter::zip(vars, fns) {
452 test_getter(var, f, run_solc);
453 }
454 };
455 }
456
457 test_getters! {
458 "bool public simple;"
459 => "function simple() public view returns (bool simple);",
460 "bool public constant simpleConstant = false;"
461 => "function simpleConstant() public view returns (bool simpleConstant);",
462
463 "mapping(address => bool) public map;"
464 => "function map(address) public view returns (bool);",
465 "mapping(address a => bool b) public mapWithNames;"
466 => "function mapWithNames(address a) public view returns (bool b);",
467 "mapping(uint256 k1 => mapping(uint256 k2 => bool v) ignored) public nested2;"
468 => "function nested2(uint256 k1, uint256 k2) public view returns (bool v);",
469 "mapping(uint256 k1 => mapping(uint256 k2 => mapping(uint256 k3 => bool v) ignored1) ignored2) public nested3;"
470 => "function nested3(uint256 k1, uint256 k2, uint256 k3) public view returns (bool v);",
471
472 "bool[] public boolArray;"
473 => "function boolArray(uint256) public view returns(bool);",
474 "mapping(bool => bytes2)[] public mapArray;"
475 => "function mapArray(uint256, bool) public view returns(bytes2);",
476 "mapping(bool => mapping(address => int[])[])[][] public nestedMapArray;"
477 => "function nestedMapArray(uint256, uint256, bool, uint256, address, uint256) public view returns(int);",
478 }
479 }
480
481 fn test_getter(var_s: &str, fn_s: &str, run_solc: bool) {
482 let var = syn::parse_str::<VariableDefinition>(var_s).unwrap();
483 let getter = ItemFunction::from_variable_definition(var);
484 let f = syn::parse_str::<ItemFunction>(fn_s).unwrap();
485 assert_eq!(format!("{getter:#?}"), format!("{f:#?}"), "{var_s}");
486
487 if run_solc && !var_s.contains("simple") {
491 match (wrap_and_compile(var_s, true), wrap_and_compile(fn_s, false)) {
492 (Ok(a), Ok(b)) => {
493 assert_eq!(a.trim(), b.trim(), "\nleft: {var_s:?}\nright: {fn_s:?}")
494 }
495 (Err(e), _) | (_, Err(e)) => panic!("{e}"),
496 }
497 }
498 }
499
500 fn run_solc() -> bool {
501 let Some(v) = get_solc_version() else { return false };
502 v >= (0, 8, 18)
504 }
505
506 fn get_solc_version() -> Option<(u16, u16, u16)> {
507 let output = Command::new("solc").arg("--version").output().ok()?;
508 if !output.status.success() {
509 return None;
510 }
511 let stdout = String::from_utf8(output.stdout).ok()?;
512
513 let start = stdout.find(": 0.")?;
514 let version = &stdout[start + 2..];
515 let end = version.find('+')?;
516 let version = &version[..end];
517
518 let mut iter = version.split('.').map(|s| s.parse::<u16>().expect("bad solc version"));
519 let major = iter.next().unwrap();
520 let minor = iter.next().unwrap();
521 let patch = iter.next().unwrap();
522 Some((major, minor, patch))
523 }
524
525 fn wrap_and_compile(s: &str, var: bool) -> std::result::Result<String, Box<dyn Error>> {
526 let contract = if var {
527 format!("contract C {{ {s} }}")
528 } else {
529 format!("abstract contract C {{ {} }}", s.replace("returns", "virtual returns"))
530 };
531 let mut cmd = Command::new("solc")
532 .args(["--abi", "--pretty-json", "-"])
533 .stdin(Stdio::piped())
534 .stdout(Stdio::piped())
535 .stderr(Stdio::piped())
536 .spawn()?;
537 cmd.stdin.as_mut().unwrap().write_all(contract.as_bytes())?;
538 let output = cmd.wait_with_output()?;
539 if output.status.success() {
540 String::from_utf8(output.stdout).map_err(Into::into)
541 } else {
542 Err(String::from_utf8(output.stderr)?.into())
543 }
544 }
545}