sqruff_lib/rules/ambiguous/
am07.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
use ahash::{AHashMap, HashSet, HashSetExt};
use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
use sqruff_lib_core::utils::analysis::query::{Query, Selectable, Source, WildcardInfo};

use crate::core::config::Value;
use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
use crate::core::rules::context::RuleContext;
use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};

#[derive(Debug, Clone)]
pub struct RuleAM07;

impl Rule for RuleAM07 {
    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
        Ok(RuleAM07.erased())
    }

    fn name(&self) -> &'static str {
        "ambiguous.set_columns"
    }

    fn description(&self) -> &'static str {
        "All queries in set expression should return the same number of columns."
    }

    fn long_description(&self) -> &'static str {
        r#"
**Anti-pattern**

When writing set expressions, all queries must return the same number of columns.

```sql
WITH cte AS (
    SELECT
        a,
        b
    FROM foo
)
SELECT * FROM cte
UNION
SELECT
    c,
    d,
    e
 FROM t
```

**Best practice**

Always specify columns when writing set queries and ensure that they all seleect same number of columns.

```sql
WITH cte AS (
    SELECT a, b FROM foo
)
SELECT
    a,
    b
FROM cte
UNION
SELECT
    c,
    d
FROM t
```
"#
    }

    fn groups(&self) -> &'static [RuleGroups] {
        &[RuleGroups::All, RuleGroups::Ambiguous]
    }

    fn eval(&self, context: RuleContext) -> Vec<LintResult> {
        debug_assert!(context.segment.is_type(SyntaxKind::SetExpression));

        let mut root = &context.segment;

        // Is the parent of the set expression a WITH expression?
        // NOTE: Backward slice to work outward.
        for parent in context.parent_stack.iter().rev() {
            if parent.is_type(SyntaxKind::WithCompoundStatement) {
                root = parent;
                break;
            }
        }

        let query: Query<()> = Query::from_segment(root, context.dialect, None);
        let (set_segment_select_sizes, resolve_wildcard) = self.get_select_target_counts(query);

        // if queries had different select target counts and all wildcards had been
        // resolved; fail
        if set_segment_select_sizes.len() > 1 && resolve_wildcard {
            vec![LintResult::new(
                Some(context.segment.clone()),
                vec![],
                None,
                None,
            )]
        } else {
            vec![]
        }
    }

    fn crawl_behaviour(&self) -> Crawler {
        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::SetExpression]) })
            .provide_raw_stack()
            .into()
    }
}

impl RuleAM07 {
    /// Given a set expression, get the number of select targets in each query.
    ///
    /// We keep track of the number of columns in each selectable using a
    /// ``set``. Ideally at the end there is only one item in the set,
    /// showing that all selectables have the same size. Importantly we
    /// can't guarantee that we can always resolve any wildcards (*), so
    /// we also return a flag to indicate whether any present have been
    /// fully resolved.
    fn get_select_target_counts(&self, query: Query<()>) -> (HashSet<usize>, bool) {
        let mut select_target_counts = HashSet::new();
        let mut resolved_wildcard = true;

        let selectables = query.inner.borrow().selectables.clone();
        for selectable in selectables {
            let (cnt, res) = self.resolve_selectable(selectable.clone(), query.clone());
            if !res {
                resolved_wildcard = false;
            }
            select_target_counts.insert(cnt);
        }

        (select_target_counts, resolved_wildcard)
    }

    /// Resolve the number of columns in a single Selectable.
    ///
    /// The selectable may opr may not have (*) wildcard expressions. If it
    /// does, we attempt to resolve them.
    fn resolve_selectable(&self, selectable: Selectable, root_query: Query<()>) -> (usize, bool) {
        debug_assert!(selectable.select_info().is_some());

        let wildcard_info = selectable.wildcard_info();

        // Start with the number of non-wildcard columns.
        let mut num_cols =
            selectable.select_info().unwrap().select_targets.len() - wildcard_info.len();

        // If there are no wildcards, we're done.
        if wildcard_info.is_empty() {
            return (num_cols, true);
        }

        let mut resolved = true;
        // If the set query contains one or more wildcards, attempt to resolve it to a
        // list of select targets that can be counted.
        for wildcard in wildcard_info {
            let (_cols, _resolved) =
                self.resolve_selectable_wildcard(wildcard, selectable.clone(), root_query.clone());
            resolved = resolved && _resolved;
            // Add on the number of columns which the wildcard resolves to.
            num_cols += _cols;
        }

        (num_cols, resolved)
    }

    /// Attempt to resolve a full query which may contain wildcards.
    ///
    /// NOTE: This requires a ``Query`` as input rather than just a
    /// ``Selectable`` and will delegate to ``__resolve_selectable``
    /// once any Selectables have been identified.
    ///
    /// This method is *not* called on the initial set expression as
    /// that is evaluated as a series of Selectables. This method is
    /// only called on any subqueries (which may themselves be SELECT,
    /// WITH or set expressions) found during the resolution of any
    /// wildcards.
    fn resolve_wild_query(&self, query: Query<()>) -> (usize, bool) {
        // if one of the source queries for a query within the set is a
        // set expression, just use the first query. If that first query isn't
        // reflective of the others, that will be caught when that segment
        // is processed. We'll know if we're in a set based on whether there
        // is more than one selectable. i.e. Just take the first selectable.
        let selectable = query.inner.borrow().selectables[0].clone();
        self.resolve_selectable(selectable, query.clone())
    }

    /// Attempt to resolve a single wildcard (*) within a Selectable.
    ///
    /// Note: This means resolving the number of columns implied by
    /// a single *. This method would be run multiple times if there
    /// are multiple wildcards in a single selectable.
    fn resolve_selectable_wildcard(
        &self,
        wildcard: WildcardInfo,
        selectable: Selectable,
        root_query: Query<()>,
    ) -> (usize, bool) {
        let mut resolved = true;

        // If there is no table specified, it is likely a subquery so handle that first.
        if wildcard.tables.is_empty() {
            // Crawl the query looking for the subquery, problem in the FROM.
            for source in root_query.crawl_sources(selectable.selectable, false, true) {
                if let Source::Query(query) = source {
                    return self.resolve_wild_query(query);
                }
            }
            return (0, false);
        }

        // There might be multiple tables references in some wildcard cases.
        let mut num_columns = 0;
        for wildcard_table in wildcard.tables {
            let mut cte_name = wildcard_table.clone();

            // Get the AliasInfo for the table referenced in the wildcard expression.
            let alias_info = selectable.find_alias(&wildcard_table);
            if let Some(alias_info) = alias_info {
                let select_info_target = root_query
                    .crawl_sources(alias_info.from_expression_element, false, true)
                    .into_iter()
                    .next()
                    .unwrap();

                match select_info_target {
                    Source::TableReference(name) => {
                        cte_name = name;
                    }
                    Source::Query(query) => {
                        let (_cols, _resolved) = self.resolve_wild_query(query);
                        num_columns += _cols;
                        resolved = resolved && _resolved;
                        continue;
                    }
                }
            }

            let cte = root_query.lookup_cte(&cte_name, true);
            if let Some(cte) = cte {
                let (cols, _resolved) = self.resolve_wild_query(cte);
                num_columns += cols;
                resolved = resolved && _resolved;
            } else {
                resolved = false;
            }
        }
        (num_columns, resolved)
    }
}