syn_solidity/item/
function.rs

1use crate::{
2    kw, Block, FunctionAttribute, FunctionAttributes, Mutability, ParameterList, Parameters,
3    SolIdent, Spanned, Stmt, Type, VariableDeclaration, VariableDefinition, Visibility,
4};
5use proc_macro2::Span;
6use std::{
7    fmt,
8    hash::{Hash, Hasher},
9    num::NonZeroU16,
10};
11use syn::{
12    parenthesized,
13    parse::{Parse, ParseStream},
14    token::{Brace, Paren},
15    Attribute, Error, Result, Token,
16};
17
18/// A function, constructor, fallback, receive, or modifier definition:
19/// `function helloWorld() external pure returns(string memory);`.
20///
21/// Solidity reference:
22/// <https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityParser.functionDefinition>
23#[derive(Clone)]
24pub struct ItemFunction {
25    /// The `syn` attributes of the function.
26    pub attrs: Vec<Attribute>,
27    pub kind: FunctionKind,
28    pub name: Option<SolIdent>,
29    /// Parens are optional for modifiers:
30    /// <https://docs.soliditylang.org/en/latest/grammar.html#a4.SolidityParser.modifierDefinition>
31    pub paren_token: Option<Paren>,
32    pub parameters: ParameterList,
33    /// The Solidity attributes of the function.
34    pub attributes: FunctionAttributes,
35    /// The optional return types of the function.
36    pub returns: Option<Returns>,
37    pub body: FunctionBody,
38}
39
40impl fmt::Display for ItemFunction {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_str(self.kind.as_str())?;
43        if let Some(name) = &self.name {
44            f.write_str(" ")?;
45            name.fmt(f)?;
46        }
47        write!(f, "({})", self.parameters)?;
48
49        if !self.attributes.is_empty() {
50            write!(f, " {}", self.attributes)?;
51        }
52
53        if let Some(returns) = &self.returns {
54            write!(f, " {returns}")?;
55        }
56
57        if !self.body.is_empty() {
58            f.write_str(" ")?;
59        }
60        f.write_str(self.body.as_str())
61    }
62}
63
64impl fmt::Debug for ItemFunction {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("ItemFunction")
67            .field("attrs", &self.attrs)
68            .field("kind", &self.kind)
69            .field("name", &self.name)
70            .field("arguments", &self.parameters)
71            .field("attributes", &self.attributes)
72            .field("returns", &self.returns)
73            .field("body", &self.body)
74            .finish()
75    }
76}
77
78impl Parse for ItemFunction {
79    fn parse(input: ParseStream<'_>) -> Result<Self> {
80        let attrs = input.call(Attribute::parse_outer)?;
81        let kind: FunctionKind = input.parse()?;
82        let name = input.call(SolIdent::parse_opt)?;
83
84        let (paren_token, parameters) = if kind.is_modifier() && !input.peek(Paren) {
85            (None, ParameterList::new())
86        } else {
87            let content;
88            (Some(parenthesized!(content in input)), content.parse()?)
89        };
90
91        let attributes = input.parse()?;
92        let returns = input.call(Returns::parse_opt)?;
93        let body = input.parse()?;
94
95        Ok(Self { attrs, kind, name, paren_token, parameters, attributes, returns, body })
96    }
97}
98
99impl Spanned for ItemFunction {
100    fn span(&self) -> Span {
101        if let Some(name) = &self.name {
102            name.span()
103        } else {
104            self.kind.span()
105        }
106    }
107
108    fn set_span(&mut self, span: Span) {
109        self.kind.set_span(span);
110        if let Some(name) = &mut self.name {
111            name.set_span(span);
112        }
113    }
114}
115
116impl ItemFunction {
117    /// Create a new function of the given kind.
118    pub fn new(kind: FunctionKind, name: Option<SolIdent>) -> Self {
119        let span = name.as_ref().map_or_else(|| kind.span(), |name| name.span());
120        Self {
121            attrs: Vec::new(),
122            kind,
123            name,
124            paren_token: Some(Paren(span)),
125            parameters: Parameters::new(),
126            attributes: FunctionAttributes::new(),
127            returns: None,
128            body: FunctionBody::Empty(Token![;](span)),
129        }
130    }
131
132    /// Create a new function with the given name and arguments.
133    ///
134    /// Note that:
135    /// - the type is not validated
136    /// - structs/array of structs in return position are not expanded
137    /// - the body is not set
138    ///
139    /// The attributes are set to `public view`.
140    ///
141    /// See [the Solidity documentation][ref] for more details on how getters
142    /// are generated.
143    ///
144    /// [ref]: https://docs.soliditylang.org/en/latest/contracts.html#getter-functions
145    pub fn new_getter(name: SolIdent, ty: Type) -> Self {
146        let span = name.span();
147        let kind = FunctionKind::new_function(span);
148        let mut function = Self::new(kind, Some(name.clone()));
149
150        // `public view`
151        function.attributes.0 = vec![
152            FunctionAttribute::Visibility(Visibility::new_public(span)),
153            FunctionAttribute::Mutability(Mutability::new_view(span)),
154        ];
155
156        // Recurse into mappings and arrays to generate arguments and the return type.
157        // If the return type is simple, the return value name is set to the variable name.
158        let mut ty = ty;
159        let mut return_name = None;
160        let mut first = true;
161        loop {
162            match ty {
163                // mapping(k => v) -> arguments += k, ty = v
164                Type::Mapping(map) => {
165                    let key = VariableDeclaration::new_with(*map.key, None, map.key_name);
166                    function.parameters.push(key);
167                    return_name = map.value_name;
168                    ty = *map.value;
169                }
170                // inner[] -> arguments += uint256, ty = inner
171                Type::Array(array) => {
172                    let uint256 = Type::Uint(span, NonZeroU16::new(256));
173                    function.parameters.push(VariableDeclaration::new(uint256));
174                    ty = *array.ty;
175                }
176                _ => {
177                    if first {
178                        return_name = Some(name);
179                    }
180                    break;
181                }
182            }
183            first = false;
184        }
185        let mut returns = ParameterList::new();
186        returns.push(VariableDeclaration::new_with(ty, None, return_name));
187        function.returns = Some(Returns::new(span, returns));
188
189        function
190    }
191
192    /// Creates a new function from a variable definition.
193    ///
194    /// The function will have the same name and the variable type's will be the
195    /// return type. The variable attributes are ignored, and instead will
196    /// always generate `public returns`.
197    ///
198    /// See [`new_getter`](Self::new_getter) for more details.
199    pub fn from_variable_definition(var: VariableDefinition) -> Self {
200        let mut function = Self::new_getter(var.name, var.ty);
201        function.attrs = var.attrs;
202        function
203    }
204
205    /// Returns the name of the function.
206    ///
207    /// # Panics
208    ///
209    /// Panics if the function has no name. This is the case when `kind` is not
210    /// `Function`.
211    #[track_caller]
212    pub fn name(&self) -> &SolIdent {
213        match &self.name {
214            Some(name) => name,
215            None => panic!("function has no name: {self:?}"),
216        }
217    }
218
219    /// Returns true if the function returns nothing.
220    pub fn is_void(&self) -> bool {
221        match &self.returns {
222            None => true,
223            Some(returns) => returns.returns.is_empty(),
224        }
225    }
226
227    /// Returns true if the function has a body.
228    pub fn has_implementation(&self) -> bool {
229        matches!(self.body, FunctionBody::Block(_))
230    }
231
232    /// Returns the function's arguments tuple type.
233    pub fn call_type(&self) -> Type {
234        Type::Tuple(self.parameters.types().cloned().collect())
235    }
236
237    /// Returns the function's return tuple type.
238    pub fn return_type(&self) -> Option<Type> {
239        self.returns.as_ref().map(|returns| Type::Tuple(returns.returns.types().cloned().collect()))
240    }
241
242    /// Returns a reference to the function's body, if any.
243    pub fn body(&self) -> Option<&[Stmt]> {
244        match &self.body {
245            FunctionBody::Block(block) => Some(&block.stmts),
246            _ => None,
247        }
248    }
249
250    /// Returns a mutable reference to the function's body, if any.
251    pub fn body_mut(&mut self) -> Option<&mut Vec<Stmt>> {
252        match &mut self.body {
253            FunctionBody::Block(block) => Some(&mut block.stmts),
254            _ => None,
255        }
256    }
257
258    #[allow(clippy::result_large_err)]
259    pub fn into_body(self) -> std::result::Result<Vec<Stmt>, Self> {
260        match self.body {
261            FunctionBody::Block(block) => Ok(block.stmts),
262            _ => Err(self),
263        }
264    }
265}
266
267kw_enum! {
268    /// The kind of function.
269    pub enum FunctionKind {
270        Constructor(kw::constructor),
271        Function(kw::function),
272        Fallback(kw::fallback),
273        Receive(kw::receive),
274        Modifier(kw::modifier),
275    }
276}
277
278/// The `returns` attribute of a function.
279#[derive(Clone)]
280pub struct Returns {
281    pub returns_token: kw::returns,
282    pub paren_token: Paren,
283    /// The returns of the function. This cannot be parsed empty.
284    pub returns: ParameterList,
285}
286
287impl PartialEq for Returns {
288    fn eq(&self, other: &Self) -> bool {
289        self.returns == other.returns
290    }
291}
292
293impl Eq for Returns {}
294
295impl Hash for Returns {
296    fn hash<H: Hasher>(&self, state: &mut H) {
297        self.returns.hash(state);
298    }
299}
300
301impl fmt::Display for Returns {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        f.write_str("returns (")?;
304        self.returns.fmt(f)?;
305        f.write_str(")")
306    }
307}
308
309impl fmt::Debug for Returns {
310    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311        f.debug_tuple("Returns").field(&self.returns).finish()
312    }
313}
314
315impl Parse for Returns {
316    fn parse(input: ParseStream<'_>) -> Result<Self> {
317        let content;
318        let this = Self {
319            returns_token: input.parse()?,
320            paren_token: parenthesized!(content in input),
321            returns: content.parse()?,
322        };
323        if this.returns.is_empty() {
324            Err(Error::new(this.paren_token.span.join(), "expected at least one return type"))
325        } else {
326            Ok(this)
327        }
328    }
329}
330
331impl Spanned for Returns {
332    fn span(&self) -> Span {
333        let span = self.returns_token.span;
334        span.join(self.paren_token.span.join()).unwrap_or(span)
335    }
336
337    fn set_span(&mut self, span: Span) {
338        self.returns_token.span = span;
339        self.paren_token = Paren(span);
340    }
341}
342
343impl Returns {
344    pub fn new(span: Span, returns: ParameterList) -> Self {
345        Self { returns_token: kw::returns(span), paren_token: Paren(span), returns }
346    }
347
348    pub fn parse_opt(input: ParseStream<'_>) -> Result<Option<Self>> {
349        if input.peek(kw::returns) {
350            input.parse().map(Some)
351        } else {
352            Ok(None)
353        }
354    }
355}
356
357/// The body of a function.
358#[derive(Clone)]
359pub enum FunctionBody {
360    /// A function without implementation.
361    Empty(Token![;]),
362    /// A function body delimited by curly braces.
363    Block(Block),
364}
365
366impl fmt::Display for FunctionBody {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        f.write_str(self.as_str())
369    }
370}
371
372impl fmt::Debug for FunctionBody {
373    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
374        f.write_str("FunctionBody::")?;
375        match self {
376            Self::Empty(_) => f.write_str("Empty"),
377            Self::Block(block) => block.fmt(f),
378        }
379    }
380}
381
382impl Parse for FunctionBody {
383    fn parse(input: ParseStream<'_>) -> Result<Self> {
384        let lookahead = input.lookahead1();
385        if lookahead.peek(Brace) {
386            input.parse().map(Self::Block)
387        } else if lookahead.peek(Token![;]) {
388            input.parse().map(Self::Empty)
389        } else {
390            Err(lookahead.error())
391        }
392    }
393}
394
395impl FunctionBody {
396    /// Returns `true` if the function body is empty.
397    #[inline]
398    pub fn is_empty(&self) -> bool {
399        matches!(self, Self::Empty(_))
400    }
401
402    /// Returns a string representation of the function body.
403    #[inline]
404    pub fn as_str(&self) -> &'static str {
405        match self {
406            Self::Empty(_) => ";",
407            // TODO: fmt::Display for Stmt
408            Self::Block(_) => "{ <stmts> }",
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use pretty_assertions::assert_eq;
417    use std::{
418        error::Error,
419        io::Write,
420        process::{Command, Stdio},
421    };
422    use syn::parse_quote;
423
424    #[test]
425    fn modifiers() {
426        let none: ItemFunction = parse_quote! {
427            modifier noParens {
428                _;
429            }
430        };
431        let some: ItemFunction = parse_quote! {
432            modifier withParens() {
433                _;
434            }
435        };
436        assert_eq!(none.kind, FunctionKind::new_modifier(Span::call_site()));
437        assert_eq!(none.kind, some.kind);
438        assert_eq!(none.paren_token, None);
439        assert_eq!(some.paren_token, Some(Default::default()));
440    }
441
442    #[test]
443    #[cfg_attr(miri, ignore = "takes too long")]
444    fn getters() {
445        let run_solc = run_solc();
446
447        macro_rules! test_getters {
448            ($($var:literal => $f:literal),* $(,)?) => {
449                let vars: &[&str] = &[$($var),*];
450                let fns: &[&str] = &[$($f),*];
451                for (var, f) in std::iter::zip(vars, fns) {
452                    test_getter(var, f, run_solc);
453                }
454            };
455        }
456
457        test_getters! {
458            "bool public simple;"
459                => "function simple() public view returns (bool simple);",
460            "bool public constant simpleConstant = false;"
461                => "function simpleConstant() public view returns (bool simpleConstant);",
462
463            "mapping(address => bool) public map;"
464                => "function map(address) public view returns (bool);",
465            "mapping(address a => bool b) public mapWithNames;"
466                => "function mapWithNames(address a) public view returns (bool b);",
467            "mapping(uint256 k1 => mapping(uint256 k2 => bool v) ignored) public nested2;"
468                => "function nested2(uint256 k1, uint256 k2) public view returns (bool v);",
469            "mapping(uint256 k1 => mapping(uint256 k2 => mapping(uint256 k3 => bool v) ignored1) ignored2) public nested3;"
470                => "function nested3(uint256 k1, uint256 k2, uint256 k3) public view returns (bool v);",
471
472            "bool[] public boolArray;"
473                => "function boolArray(uint256) public view returns(bool);",
474            "mapping(bool => bytes2)[] public mapArray;"
475                => "function mapArray(uint256, bool) public view returns(bytes2);",
476            "mapping(bool => mapping(address => int[])[])[][] public nestedMapArray;"
477                => "function nestedMapArray(uint256, uint256, bool, uint256, address, uint256) public view returns(int);",
478        }
479    }
480
481    fn test_getter(var_s: &str, fn_s: &str, run_solc: bool) {
482        let var = syn::parse_str::<VariableDefinition>(var_s).unwrap();
483        let getter = ItemFunction::from_variable_definition(var);
484        let f = syn::parse_str::<ItemFunction>(fn_s).unwrap();
485        assert_eq!(format!("{getter:#?}"), format!("{f:#?}"), "{var_s}");
486
487        // Test that the ABIs are the same.
488        // Skip `simple` getters since the return type will have a different ABI because Solc
489        // doesn't populate the field.
490        if run_solc && !var_s.contains("simple") {
491            match (wrap_and_compile(var_s, true), wrap_and_compile(fn_s, false)) {
492                (Ok(a), Ok(b)) => {
493                    assert_eq!(a.trim(), b.trim(), "\nleft:  {var_s:?}\nright: {fn_s:?}")
494                }
495                (Err(e), _) | (_, Err(e)) => panic!("{e}"),
496            }
497        }
498    }
499
500    fn run_solc() -> bool {
501        let Some(v) = get_solc_version() else { return false };
502        // Named keys in mappings: https://soliditylang.org/blog/2023/02/01/solidity-0.8.18-release-announcement/
503        v >= (0, 8, 18)
504    }
505
506    fn get_solc_version() -> Option<(u16, u16, u16)> {
507        let output = Command::new("solc").arg("--version").output().ok()?;
508        if !output.status.success() {
509            return None;
510        }
511        let stdout = String::from_utf8(output.stdout).ok()?;
512
513        let start = stdout.find(": 0.")?;
514        let version = &stdout[start + 2..];
515        let end = version.find('+')?;
516        let version = &version[..end];
517
518        let mut iter = version.split('.').map(|s| s.parse::<u16>().expect("bad solc version"));
519        let major = iter.next().unwrap();
520        let minor = iter.next().unwrap();
521        let patch = iter.next().unwrap();
522        Some((major, minor, patch))
523    }
524
525    fn wrap_and_compile(s: &str, var: bool) -> std::result::Result<String, Box<dyn Error>> {
526        let contract = if var {
527            format!("contract C {{ {s} }}")
528        } else {
529            format!("abstract contract C {{ {} }}", s.replace("returns", "virtual returns"))
530        };
531        let mut cmd = Command::new("solc")
532            .args(["--abi", "--pretty-json", "-"])
533            .stdin(Stdio::piped())
534            .stdout(Stdio::piped())
535            .stderr(Stdio::piped())
536            .spawn()?;
537        cmd.stdin.as_mut().unwrap().write_all(contract.as_bytes())?;
538        let output = cmd.wait_with_output()?;
539        if output.status.success() {
540            String::from_utf8(output.stdout).map_err(Into::into)
541        } else {
542            Err(String::from_utf8(output.stderr)?.into())
543        }
544    }
545}