sqruff_lib/rules/ambiguous/
am07.rs

1use ahash::{AHashMap, HashSet, HashSetExt};
2use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
3use sqruff_lib_core::utils::analysis::query::{Query, Selectable, Source, WildcardInfo};
4
5use crate::core::config::Value;
6use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
7use crate::core::rules::context::RuleContext;
8use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
9
10#[derive(Debug, Clone)]
11pub struct RuleAM07;
12
13impl Rule for RuleAM07 {
14    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
15        Ok(RuleAM07.erased())
16    }
17
18    fn name(&self) -> &'static str {
19        "ambiguous.set_columns"
20    }
21
22    fn description(&self) -> &'static str {
23        "All queries in set expression should return the same number of columns."
24    }
25
26    fn long_description(&self) -> &'static str {
27        r#"
28**Anti-pattern**
29
30When writing set expressions, all queries must return the same number of columns.
31
32```sql
33WITH cte AS (
34    SELECT
35        a,
36        b
37    FROM foo
38)
39SELECT * FROM cte
40UNION
41SELECT
42    c,
43    d,
44    e
45 FROM t
46```
47
48**Best practice**
49
50Always specify columns when writing set queries and ensure that they all seleect same number of columns.
51
52```sql
53WITH cte AS (
54    SELECT a, b FROM foo
55)
56SELECT
57    a,
58    b
59FROM cte
60UNION
61SELECT
62    c,
63    d
64FROM t
65```
66"#
67    }
68
69    fn groups(&self) -> &'static [RuleGroups] {
70        &[RuleGroups::All, RuleGroups::Ambiguous]
71    }
72
73    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
74        debug_assert!(context.segment.is_type(SyntaxKind::SetExpression));
75
76        let mut root = &context.segment;
77
78        // Is the parent of the set expression a WITH expression?
79        // NOTE: Backward slice to work outward.
80        for parent in context.parent_stack.iter().rev() {
81            if parent.is_type(SyntaxKind::WithCompoundStatement) {
82                root = parent;
83                break;
84            }
85        }
86
87        let query: Query<()> = Query::from_segment(root, context.dialect, None);
88        let (set_segment_select_sizes, resolve_wildcard) = self.get_select_target_counts(query);
89
90        // if queries had different select target counts and all wildcards had been
91        // resolved; fail
92        if set_segment_select_sizes.len() > 1 && resolve_wildcard {
93            vec![LintResult::new(
94                Some(context.segment.clone()),
95                vec![],
96                None,
97                None,
98            )]
99        } else {
100            vec![]
101        }
102    }
103
104    fn crawl_behaviour(&self) -> Crawler {
105        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::SetExpression]) })
106            .provide_raw_stack()
107            .into()
108    }
109}
110
111impl RuleAM07 {
112    /// Given a set expression, get the number of select targets in each query.
113    ///
114    /// We keep track of the number of columns in each selectable using a
115    /// ``set``. Ideally at the end there is only one item in the set,
116    /// showing that all selectables have the same size. Importantly we
117    /// can't guarantee that we can always resolve any wildcards (*), so
118    /// we also return a flag to indicate whether any present have been
119    /// fully resolved.
120    fn get_select_target_counts(&self, query: Query<()>) -> (HashSet<usize>, bool) {
121        let mut select_target_counts = HashSet::new();
122        let mut resolved_wildcard = true;
123
124        let selectables = query.inner.borrow().selectables.clone();
125        for selectable in selectables {
126            let (cnt, res) = self.resolve_selectable(selectable.clone(), query.clone());
127            if !res {
128                resolved_wildcard = false;
129            }
130            select_target_counts.insert(cnt);
131        }
132
133        (select_target_counts, resolved_wildcard)
134    }
135
136    /// Resolve the number of columns in a single Selectable.
137    ///
138    /// The selectable may opr may not have (*) wildcard expressions. If it
139    /// does, we attempt to resolve them.
140    fn resolve_selectable(&self, selectable: Selectable, root_query: Query<()>) -> (usize, bool) {
141        debug_assert!(selectable.select_info().is_some());
142
143        let wildcard_info = selectable.wildcard_info();
144
145        // Start with the number of non-wildcard columns.
146        let mut num_cols =
147            selectable.select_info().unwrap().select_targets.len() - wildcard_info.len();
148
149        // If there are no wildcards, we're done.
150        if wildcard_info.is_empty() {
151            return (num_cols, true);
152        }
153
154        let mut resolved = true;
155        // If the set query contains one or more wildcards, attempt to resolve it to a
156        // list of select targets that can be counted.
157        for wildcard in wildcard_info {
158            let (_cols, _resolved) =
159                self.resolve_selectable_wildcard(wildcard, selectable.clone(), root_query.clone());
160            resolved = resolved && _resolved;
161            // Add on the number of columns which the wildcard resolves to.
162            num_cols += _cols;
163        }
164
165        (num_cols, resolved)
166    }
167
168    /// Attempt to resolve a full query which may contain wildcards.
169    ///
170    /// NOTE: This requires a ``Query`` as input rather than just a
171    /// ``Selectable`` and will delegate to ``__resolve_selectable``
172    /// once any Selectables have been identified.
173    ///
174    /// This method is *not* called on the initial set expression as
175    /// that is evaluated as a series of Selectables. This method is
176    /// only called on any subqueries (which may themselves be SELECT,
177    /// WITH or set expressions) found during the resolution of any
178    /// wildcards.
179    fn resolve_wild_query(&self, query: Query<()>) -> (usize, bool) {
180        // if one of the source queries for a query within the set is a
181        // set expression, just use the first query. If that first query isn't
182        // reflective of the others, that will be caught when that segment
183        // is processed. We'll know if we're in a set based on whether there
184        // is more than one selectable. i.e. Just take the first selectable.
185        let selectable = query.inner.borrow().selectables[0].clone();
186        self.resolve_selectable(selectable, query.clone())
187    }
188
189    /// Attempt to resolve a single wildcard (*) within a Selectable.
190    ///
191    /// Note: This means resolving the number of columns implied by
192    /// a single *. This method would be run multiple times if there
193    /// are multiple wildcards in a single selectable.
194    fn resolve_selectable_wildcard(
195        &self,
196        wildcard: WildcardInfo,
197        selectable: Selectable,
198        root_query: Query<()>,
199    ) -> (usize, bool) {
200        let mut resolved = true;
201
202        // If there is no table specified, it is likely a subquery so handle that first.
203        if wildcard.tables.is_empty() {
204            // Crawl the query looking for the subquery, problem in the FROM.
205            for source in root_query.crawl_sources(selectable.selectable, false, true) {
206                if let Source::Query(query) = source {
207                    return self.resolve_wild_query(query);
208                }
209            }
210            return (0, false);
211        }
212
213        // There might be multiple tables references in some wildcard cases.
214        let mut num_columns = 0;
215        for wildcard_table in wildcard.tables {
216            let mut cte_name = wildcard_table.clone();
217
218            // Get the AliasInfo for the table referenced in the wildcard expression.
219            let alias_info = selectable.find_alias(&wildcard_table);
220            if let Some(alias_info) = alias_info {
221                let select_info_target = root_query
222                    .crawl_sources(alias_info.from_expression_element, false, true)
223                    .into_iter()
224                    .next()
225                    .unwrap();
226
227                match select_info_target {
228                    Source::TableReference(name) => {
229                        cte_name = name;
230                    }
231                    Source::Query(query) => {
232                        let (_cols, _resolved) = self.resolve_wild_query(query);
233                        num_columns += _cols;
234                        resolved = resolved && _resolved;
235                        continue;
236                    }
237                }
238            }
239
240            let cte = root_query.lookup_cte(&cte_name, true);
241            if let Some(cte) = cte {
242                let (cols, _resolved) = self.resolve_wild_query(cte);
243                num_columns += cols;
244                resolved = resolved && _resolved;
245            } else {
246                resolved = false;
247            }
248        }
249        (num_columns, resolved)
250    }
251}