1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
//! `fold_expr_with_context` function, for rewriting exprs with knowledge of their contexts (rvalue
//! / lvalue / mut lvalue).
use rustc_target::spec::abi::Abi;
use smallvec::SmallVec;
use std::rc::Rc;
use syntax::ast::*;
use syntax::mut_visit::{self, MutVisitor};
use syntax::parse::token::{DelimToken, Nonterminal, Token};
use syntax::ptr::P;
use syntax::source_map::{Span, Spanned};
use syntax::tokenstream::{DelimSpan, TokenStream, TokenTree};
use syntax::ThinVec;
use syntax_pos::hygiene::SyntaxContext;

use crate::ast_manip::MutVisit;

// TODO: Check for autoborrow adjustments.  Some method receivers are actually Lvalue / LvalueMut
// contexts, but currently they're all treated as Rvalues.

// TODO: Handle match inputs properly.  The target expression of a match could be any context,
// depending on whether `ref` / `ref mut` appears in any of the patterns.

/// Trait implemented by all AST types, allowing folding over exprs while tracking the context.
trait LRExpr {
    fn fold_rvalue<LR: LRRewrites>(&mut self, lr: &mut LR);
    fn fold_lvalue<LR: LRRewrites>(&mut self, lr: &mut LR);
    fn fold_lvalue_mut<LR: LRRewrites>(&mut self, lr: &mut LR);
}

/// A set of expr rewrites, one for each kind of context where an expr may appear.
trait LRRewrites {
    fn fold_rvalue(&mut self, e: &mut P<Expr>);
    fn fold_lvalue(&mut self, e: &mut P<Expr>);
    fn fold_lvalue_mut(&mut self, e: &mut P<Expr>);
}

// Helper macro for generating LRExpr instances.
macro_rules! lr_expr_fn {
    (($slf:ident, $next:ident($T:ty)) => $e:expr) => {
        #[allow(unused_mut)]
        fn fold_rvalue<LR: LRRewrites>(&mut $slf, lr: &mut LR) {
            let mut $next = |x: &mut $T| x.fold_rvalue(lr);
            $e
        }

        #[allow(unused_mut)]
        fn fold_lvalue<LR: LRRewrites>(&mut $slf, lr: &mut LR) {
            let mut $next = |x: &mut $T| x.fold_lvalue(lr);
            $e
        }

        #[allow(unused_mut)]
        fn fold_lvalue_mut<LR: LRRewrites>(&mut $slf, lr: &mut LR) {
            let mut $next = |x: &mut $T| x.fold_lvalue_mut(lr);
            $e
        }
    };
}

impl<T: LRExpr> LRExpr for Vec<T> {
    lr_expr_fn!((self, next(T)) => {
        mut_visit::visit_vec(self, next)
    });
}

impl<T: LRExpr> LRExpr for ThinVec<T> {
    lr_expr_fn!((self, next(T)) => {
        for x in self.iter_mut() {
            next(x);
        }
    });
}

impl<T: LRExpr + 'static> LRExpr for P<T> {
    lr_expr_fn!((self, next(T)) => {
        next(self);
    });
}

impl<T: LRExpr + Clone> LRExpr for Rc<T> {
    lr_expr_fn!((self, next(T)) => {
        next(Rc::make_mut(self));
    });
}

impl<T: LRExpr> LRExpr for Spanned<T> {
    lr_expr_fn!((self, next(T)) => {
        next(&mut self.node)
    });
}

impl<T: LRExpr> LRExpr for Option<T> {
    lr_expr_fn!((self, next(T)) => {
        mut_visit::visit_opt(self, next)
    });
}

impl<A: LRExpr, B: LRExpr> LRExpr for (A, B) {
    fn fold_rvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_rvalue(lr);
        self.1.fold_rvalue(lr);
    }

    fn fold_lvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_lvalue(lr);
        self.1.fold_lvalue(lr);
    }

    fn fold_lvalue_mut<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_lvalue_mut(lr);
        self.1.fold_lvalue_mut(lr);
    }
}

impl<A: LRExpr, B: LRExpr, C: LRExpr> LRExpr for (A, B, C) {
    fn fold_rvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_rvalue(lr);
        self.1.fold_rvalue(lr);
        self.2.fold_rvalue(lr);
    }

    fn fold_lvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_lvalue(lr);
        self.1.fold_lvalue(lr);
        self.2.fold_lvalue(lr);
    }

    fn fold_lvalue_mut<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.0.fold_lvalue_mut(lr);
        self.1.fold_lvalue_mut(lr);
        self.2.fold_lvalue_mut(lr);
    }
}

impl LRExpr for P<Expr> {
    fn fold_rvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.node.fold_rvalue(lr);
        lr.fold_rvalue(self)
    }
    fn fold_lvalue<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.node.fold_lvalue(lr);
        lr.fold_lvalue(self)
    }
    fn fold_lvalue_mut<LR: LRRewrites>(&mut self, lr: &mut LR) {
        self.node.fold_lvalue_mut(lr);
        lr.fold_lvalue_mut(self)
    }
}

include!(concat!(env!("OUT_DIR"), "/lr_expr_gen.inc.rs"));

/// Kinds of contexts where exprs can appear.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Context {
    Rvalue,
    Lvalue,
    LvalueMut,
}

struct Rewrites<F: FnMut(&mut P<Expr>, Context)> {
    callback: F,
}

impl<F> LRRewrites for Rewrites<F>
where
    F: FnMut(&mut P<Expr>, Context),
{
    fn fold_rvalue(&mut self, e: &mut P<Expr>) {
        (self.callback)(e, Context::Rvalue)
    }

    fn fold_lvalue(&mut self, e: &mut P<Expr>) {
        (self.callback)(e, Context::Lvalue)
    }

    fn fold_lvalue_mut(&mut self, e: &mut P<Expr>) {
        (self.callback)(e, Context::LvalueMut)
    }
}

/// Perform a bottom-up rewrite of an `Expr`, indicating at each step whether the expr is in an
/// rvalue, (immutable) lvalue, or mutable lvalue context.
///
/// `start` is the context of the outermost expression `e`.
pub fn fold_expr_with_context<F>(e: &mut P<Expr>, start: Context, callback: F)
where
    F: FnMut(&mut P<Expr>, Context),
{
    let mut lr = Rewrites { callback: callback };
    match start {
        Context::Rvalue => e.fold_rvalue(&mut lr),
        Context::Lvalue => e.fold_lvalue(&mut lr),
        Context::LvalueMut => e.fold_lvalue_mut(&mut lr),
    }
}

// MutVisitor for rewriting exprs that aren't children of other exprs.
struct TopExprFolder<F> {
    callback: F,
    in_expr: bool,
}

impl<F> TopExprFolder<F> {
    fn in_expr<G: FnOnce(&mut Self) -> R, R>(&mut self, in_expr: bool, callback: G) -> R {
        let old_in_expr = self.in_expr;
        self.in_expr = in_expr;
        let r = callback(self);
        self.in_expr = old_in_expr;
        r
    }
}

impl<F: FnMut(&mut P<Expr>)> MutVisitor for TopExprFolder<F> {
    fn visit_expr(&mut self, e: &mut P<Expr>) {
        self.in_expr(true, |this| mut_visit::noop_visit_expr(e, this));
        if !self.in_expr {
            (self.callback)(e);
        }
    }

    // Clear the `in_expr` flag upon entry to a non-expr node that may contain exprs.
    fn visit_ty(&mut self, ty: &mut P<Ty>) {
        self.in_expr(false, |this| mut_visit::noop_visit_ty(ty, this))
    }

    fn visit_pat(&mut self, p: &mut P<Pat>) {
        self.in_expr(false, |this| mut_visit::noop_visit_pat(p, this))
    }

    fn flat_map_stmt(&mut self, s: Stmt) -> SmallVec<[Stmt; 1]> {
        self.in_expr(false, |this| mut_visit::noop_flat_map_stmt(s, this))
    }
}

fn fold_top_exprs<T, F>(x: &mut T, callback: F)
where
    T: MutVisit,
    F: FnMut(&mut P<Expr>),
{
    let mut f = TopExprFolder {
        callback: callback,
        in_expr: false,
    };
    x.visit(&mut f)
}

pub fn fold_exprs_with_context<T, F>(x: &mut T, mut callback: F)
where
    T: MutVisit,
    F: FnMut(&mut P<Expr>, Context),
{
    fold_top_exprs(x, |e| {
        fold_expr_with_context(e, Context::Rvalue, |e, ctx| callback(e, ctx))
    })
}