sqruff_lib/rules/ambiguous/
am07.rs1use ahash::{AHashMap, HashSet, HashSetExt};
2use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
3use sqruff_lib_core::utils::analysis::query::{Query, Selectable, Source, WildcardInfo};
4
5use crate::core::config::Value;
6use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
7use crate::core::rules::context::RuleContext;
8use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
9
10#[derive(Debug, Clone)]
11pub struct RuleAM07;
12
13impl Rule for RuleAM07 {
14 fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
15 Ok(RuleAM07.erased())
16 }
17
18 fn name(&self) -> &'static str {
19 "ambiguous.set_columns"
20 }
21
22 fn description(&self) -> &'static str {
23 "All queries in set expression should return the same number of columns."
24 }
25
26 fn long_description(&self) -> &'static str {
27 r#"
28**Anti-pattern**
29
30When writing set expressions, all queries must return the same number of columns.
31
32```sql
33WITH cte AS (
34 SELECT
35 a,
36 b
37 FROM foo
38)
39SELECT * FROM cte
40UNION
41SELECT
42 c,
43 d,
44 e
45 FROM t
46```
47
48**Best practice**
49
50Always specify columns when writing set queries and ensure that they all seleect same number of columns.
51
52```sql
53WITH cte AS (
54 SELECT a, b FROM foo
55)
56SELECT
57 a,
58 b
59FROM cte
60UNION
61SELECT
62 c,
63 d
64FROM t
65```
66"#
67 }
68
69 fn groups(&self) -> &'static [RuleGroups] {
70 &[RuleGroups::All, RuleGroups::Ambiguous]
71 }
72
73 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
74 debug_assert!(context.segment.is_type(SyntaxKind::SetExpression));
75
76 let mut root = &context.segment;
77
78 for parent in context.parent_stack.iter().rev() {
81 if parent.is_type(SyntaxKind::WithCompoundStatement) {
82 root = parent;
83 break;
84 }
85 }
86
87 let query: Query<()> = Query::from_segment(root, context.dialect, None);
88 let (set_segment_select_sizes, resolve_wildcard) = self.get_select_target_counts(query);
89
90 if set_segment_select_sizes.len() > 1 && resolve_wildcard {
93 vec![LintResult::new(
94 Some(context.segment.clone()),
95 vec![],
96 None,
97 None,
98 )]
99 } else {
100 vec![]
101 }
102 }
103
104 fn crawl_behaviour(&self) -> Crawler {
105 SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::SetExpression]) })
106 .provide_raw_stack()
107 .into()
108 }
109}
110
111impl RuleAM07 {
112 fn get_select_target_counts(&self, query: Query<()>) -> (HashSet<usize>, bool) {
121 let mut select_target_counts = HashSet::new();
122 let mut resolved_wildcard = true;
123
124 let selectables = query.inner.borrow().selectables.clone();
125 for selectable in selectables {
126 let (cnt, res) = self.resolve_selectable(selectable.clone(), query.clone());
127 if !res {
128 resolved_wildcard = false;
129 }
130 select_target_counts.insert(cnt);
131 }
132
133 (select_target_counts, resolved_wildcard)
134 }
135
136 fn resolve_selectable(&self, selectable: Selectable, root_query: Query<()>) -> (usize, bool) {
141 debug_assert!(selectable.select_info().is_some());
142
143 let wildcard_info = selectable.wildcard_info();
144
145 let mut num_cols =
147 selectable.select_info().unwrap().select_targets.len() - wildcard_info.len();
148
149 if wildcard_info.is_empty() {
151 return (num_cols, true);
152 }
153
154 let mut resolved = true;
155 for wildcard in wildcard_info {
158 let (_cols, _resolved) =
159 self.resolve_selectable_wildcard(wildcard, selectable.clone(), root_query.clone());
160 resolved = resolved && _resolved;
161 num_cols += _cols;
163 }
164
165 (num_cols, resolved)
166 }
167
168 fn resolve_wild_query(&self, query: Query<()>) -> (usize, bool) {
180 let selectable = query.inner.borrow().selectables[0].clone();
186 self.resolve_selectable(selectable, query.clone())
187 }
188
189 fn resolve_selectable_wildcard(
195 &self,
196 wildcard: WildcardInfo,
197 selectable: Selectable,
198 root_query: Query<()>,
199 ) -> (usize, bool) {
200 let mut resolved = true;
201
202 if wildcard.tables.is_empty() {
204 for source in root_query.crawl_sources(selectable.selectable, false, true) {
206 if let Source::Query(query) = source {
207 return self.resolve_wild_query(query);
208 }
209 }
210 return (0, false);
211 }
212
213 let mut num_columns = 0;
215 for wildcard_table in wildcard.tables {
216 let mut cte_name = wildcard_table.clone();
217
218 let alias_info = selectable.find_alias(&wildcard_table);
220 if let Some(alias_info) = alias_info {
221 let select_info_target = root_query
222 .crawl_sources(alias_info.from_expression_element, false, true)
223 .into_iter()
224 .next()
225 .unwrap();
226
227 match select_info_target {
228 Source::TableReference(name) => {
229 cte_name = name;
230 }
231 Source::Query(query) => {
232 let (_cols, _resolved) = self.resolve_wild_query(query);
233 num_columns += _cols;
234 resolved = resolved && _resolved;
235 continue;
236 }
237 }
238 }
239
240 let cte = root_query.lookup_cte(&cte_name, true);
241 if let Some(cte) = cte {
242 let (cols, _resolved) = self.resolve_wild_query(cte);
243 num_columns += cols;
244 resolved = resolved && _resolved;
245 } else {
246 resolved = false;
247 }
248 }
249 (num_columns, resolved)
250 }
251}