sqruff_lib/rules/structure/
st09.rs1use ahash::AHashMap;
2use itertools::Itertools;
3use smol_str::{SmolStr, StrExt};
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder};
7use sqruff_lib_core::parser::segments::from::FromExpressionElementSegment;
8use sqruff_lib_core::parser::segments::join::JoinClauseSegment;
9use sqruff_lib_core::utils::functional::segments::Segments;
10
11use crate::core::config::Value;
12use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::core::rules::context::RuleContext;
14use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
15use crate::utils::functional::context::FunctionalContext;
16
17#[derive(Default, Debug, Clone)]
18pub struct RuleST09 {
19 preferred_first_table_in_join_clause: String,
20}
21
22impl Rule for RuleST09 {
23 fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
24 Ok(RuleST09 {
25 preferred_first_table_in_join_clause: config["preferred_first_table_in_join_clause"]
26 .as_string()
27 .unwrap()
28 .to_owned(),
29 }
30 .erased())
31 }
32
33 fn name(&self) -> &'static str {
34 "structure.join_condition_order"
35 }
36
37 fn description(&self) -> &'static str {
38 "Joins should list the table referenced earlier/later first."
39 }
40
41 fn long_description(&self) -> &'static str {
42 r#"
43**Anti-pattern**
44
45In this example, the tables that were referenced later are listed first
46and the `preferred_first_table_in_join_clause` configuration
47is set to `earlier`.
48
49```sql
50select
51 foo.a,
52 foo.b,
53 bar.c
54from foo
55left join bar
56 -- This subcondition does not list
57 -- the table referenced earlier first:
58 on bar.a = foo.a
59 -- Neither does this subcondition:
60 and bar.b = foo.b
61```
62
63**Best practice**
64
65List the tables that were referenced earlier first.
66
67```sql
68select
69 foo.a,
70 foo.b,
71 bar.c
72from foo
73left join bar
74 on foo.a = bar.a
75 and foo.b = bar.b
76```
77"#
78 }
79
80 fn groups(&self) -> &'static [RuleGroups] {
81 &[RuleGroups::All, RuleGroups::Structure]
82 }
83
84 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
85 let mut table_aliases = Vec::new();
86 let children = FunctionalContext::new(context).segment().children(None);
87 let join_clauses =
88 children.recursive_crawl(const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) }, true);
89 let join_on_conditions = join_clauses.children(None).recursive_crawl(
90 const { &SyntaxSet::new(&[SyntaxKind::JoinOnCondition]) },
91 true,
92 );
93
94 if join_on_conditions.is_empty() {
95 return Vec::new();
96 }
97
98 let from_expression_alias = FromExpressionElementSegment(
99 children.recursive_crawl(
100 const { &SyntaxSet::new(&[SyntaxKind::FromExpressionElement]) },
101 true,
102 )[0]
103 .clone(),
104 )
105 .eventual_alias()
106 .ref_str
107 .clone();
108
109 table_aliases.push(from_expression_alias);
110
111 let mut join_clause_aliases = join_clauses
112 .into_iter()
113 .map(|join_clause| {
114 JoinClauseSegment(join_clause)
115 .eventual_aliases()
116 .first()
117 .unwrap()
118 .1
119 .ref_str
120 .clone()
121 })
122 .collect_vec();
123
124 table_aliases.append(&mut join_clause_aliases);
125
126 let table_aliases = table_aliases
127 .iter()
128 .map(|it| it.to_uppercase_smolstr())
129 .collect_vec();
130 let mut conditions = Vec::new();
131
132 let join_on_condition_expressions = join_on_conditions
133 .children(None)
134 .recursive_crawl(const { &SyntaxSet::new(&[SyntaxKind::Expression]) }, true);
135
136 for expression in join_on_condition_expressions {
137 let mut expression_group = Vec::new();
138 for element in Segments::new(expression, None).children(None) {
139 if !matches!(
140 element.get_type(),
141 SyntaxKind::Whitespace | SyntaxKind::Newline
142 ) {
143 expression_group.push(element);
144 }
145 }
146 conditions.push(expression_group);
147 }
148
149 let mut subconditions = Vec::new();
150
151 for expression_group in conditions {
152 subconditions.append(&mut split_list_by_segment_type(
153 expression_group,
154 SyntaxKind::BinaryOperator,
155 vec!["and".into(), "or".into()],
156 ));
157 }
158
159 let column_operator_column_subconditions = subconditions
160 .into_iter()
161 .filter(|it| is_qualified_column_operator_qualified_column_sequence(it))
162 .collect_vec();
163
164 let mut fixes = Vec::new();
165
166 for subcondition in column_operator_column_subconditions {
167 let comparison_operator = subcondition[1].clone();
168 let first_column_reference = subcondition[0].clone();
169 let second_column_reference = subcondition[2].clone();
170 let raw_comparison_operators: Vec<_> = comparison_operator
171 .children(const { &SyntaxSet::new(&[SyntaxKind::RawComparisonOperator]) })
172 .collect();
173 let first_table_seg = first_column_reference
174 .child(
175 const {
176 &SyntaxSet::new(&[
177 SyntaxKind::NakedIdentifier,
178 SyntaxKind::QuotedIdentifier,
179 ])
180 },
181 )
182 .unwrap();
183 let second_table_seg = second_column_reference
184 .child(
185 const {
186 &SyntaxSet::new(&[
187 SyntaxKind::NakedIdentifier,
188 SyntaxKind::QuotedIdentifier,
189 ])
190 },
191 )
192 .unwrap();
193
194 let first_table = first_table_seg.raw().to_uppercase_smolstr();
195 let second_table = second_table_seg.raw().to_uppercase_smolstr();
196
197 let raw_comparison_operator_opposites = |op| match op {
198 "<" => ">",
199 ">" => "<",
200 _ => unimplemented!(),
201 };
202
203 if !table_aliases.contains(&first_table) || !table_aliases.contains(&second_table) {
204 continue;
205 }
206
207 if (table_aliases
208 .iter()
209 .position(|x| x == &first_table)
210 .unwrap()
211 > table_aliases
212 .iter()
213 .position(|x| x == &second_table)
214 .unwrap()
215 && self.preferred_first_table_in_join_clause == "earlier")
216 || (table_aliases
217 .iter()
218 .position(|x| x == &first_table)
219 .unwrap()
220 < table_aliases
221 .iter()
222 .position(|x| x == &second_table)
223 .unwrap()
224 && self.preferred_first_table_in_join_clause == "later")
225 {
226 fixes.push(LintFix::replace(
227 first_column_reference.clone(),
228 vec![second_column_reference.clone()],
229 None,
230 ));
231 fixes.push(LintFix::replace(
232 second_column_reference.clone(),
233 vec![first_column_reference.clone()],
234 None,
235 ));
236
237 if matches!(raw_comparison_operators[0].raw().as_ref(), "<" | ">")
238 && raw_comparison_operators
239 .iter()
240 .map(|it| it.raw())
241 .ne(["<", ">"])
242 {
243 fixes.push(LintFix::replace(
244 raw_comparison_operators[0].clone(),
245 vec![
246 SegmentBuilder::token(
247 context.tables.next_id(),
248 raw_comparison_operator_opposites(
249 raw_comparison_operators[0].raw().as_ref(),
250 ),
251 SyntaxKind::RawComparisonOperator,
252 )
253 .finish(),
254 ],
255 None,
256 ));
257 }
258 }
259 }
260
261 if fixes.is_empty() {
262 return Vec::new();
263 }
264
265 vec![LintResult::new(
266 context.segment.clone().into(),
267 fixes,
268 format!(
269 "Joins should list the table referenced {}",
270 self.preferred_first_table_in_join_clause
271 )
272 .into(),
273 None,
274 )]
275 }
276
277 fn crawl_behaviour(&self) -> Crawler {
278 SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::FromExpression]) }).into()
279 }
280}
281
282fn split_list_by_segment_type(
283 segment_list: Vec<ErasedSegment>,
284 delimiter_type: SyntaxKind,
285 delimiters: Vec<SmolStr>,
286) -> Vec<Vec<ErasedSegment>> {
287 let delimiters = delimiters
288 .into_iter()
289 .map(|it| it.to_uppercase_smolstr())
290 .collect_vec();
291 let mut new_list = Vec::new();
292 let mut sub_list = Vec::new();
293
294 for i in 0..segment_list.len() {
295 if i == segment_list.len() - 1 {
296 sub_list.push(segment_list[i].clone());
297 new_list.push(sub_list.clone());
298 } else if segment_list[i].get_type() == delimiter_type
299 && delimiters.contains(&segment_list[i].raw().to_uppercase_smolstr())
300 {
301 new_list.push(sub_list.clone());
302 sub_list.clear();
303 } else {
304 sub_list.push(segment_list[i].clone());
305 }
306 }
307
308 new_list
309}
310
311fn is_qualified_column_operator_qualified_column_sequence(segment_list: &[ErasedSegment]) -> bool {
312 if segment_list.len() != 3 {
313 return false;
314 }
315
316 if segment_list[0].get_type() == SyntaxKind::ColumnReference
317 && segment_list[0]
318 .direct_descendant_type_set()
319 .contains(SyntaxKind::Dot)
320 && segment_list[1].get_type() == SyntaxKind::ComparisonOperator
321 && segment_list[2].get_type() == SyntaxKind::ColumnReference
322 && segment_list[2]
323 .direct_descendant_type_set()
324 .contains(SyntaxKind::Dot)
325 {
326 return true;
327 }
328
329 false
330}