sqruff_lib/rules/structure/
st09.rs

1use ahash::AHashMap;
2use itertools::Itertools;
3use smol_str::{SmolStr, StrExt};
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder};
7use sqruff_lib_core::parser::segments::from::FromExpressionElementSegment;
8use sqruff_lib_core::parser::segments::join::JoinClauseSegment;
9use sqruff_lib_core::utils::functional::segments::Segments;
10
11use crate::core::config::Value;
12use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::core::rules::context::RuleContext;
14use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
15use crate::utils::functional::context::FunctionalContext;
16
17#[derive(Default, Debug, Clone)]
18pub struct RuleST09 {
19    preferred_first_table_in_join_clause: String,
20}
21
22impl Rule for RuleST09 {
23    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
24        Ok(RuleST09 {
25            preferred_first_table_in_join_clause: config["preferred_first_table_in_join_clause"]
26                .as_string()
27                .unwrap()
28                .to_owned(),
29        }
30        .erased())
31    }
32
33    fn name(&self) -> &'static str {
34        "structure.join_condition_order"
35    }
36
37    fn description(&self) -> &'static str {
38        "Joins should list the table referenced earlier/later first."
39    }
40
41    fn long_description(&self) -> &'static str {
42        r#"
43**Anti-pattern**
44
45In this example, the tables that were referenced later are listed first
46and the `preferred_first_table_in_join_clause` configuration
47is set to `earlier`.
48
49```sql
50select
51    foo.a,
52    foo.b,
53    bar.c
54from foo
55left join bar
56    -- This subcondition does not list
57    -- the table referenced earlier first:
58    on bar.a = foo.a
59    -- Neither does this subcondition:
60    and bar.b = foo.b
61```
62
63**Best practice**
64
65List the tables that were referenced earlier first.
66
67```sql
68select
69    foo.a,
70    foo.b,
71    bar.c
72from foo
73left join bar
74    on foo.a = bar.a
75    and foo.b = bar.b
76```
77"#
78    }
79
80    fn groups(&self) -> &'static [RuleGroups] {
81        &[RuleGroups::All, RuleGroups::Structure]
82    }
83
84    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
85        let mut table_aliases = Vec::new();
86        let children = FunctionalContext::new(context).segment().children(None);
87        let join_clauses =
88            children.recursive_crawl(const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) }, true);
89        let join_on_conditions = join_clauses.children(None).recursive_crawl(
90            const { &SyntaxSet::new(&[SyntaxKind::JoinOnCondition]) },
91            true,
92        );
93
94        if join_on_conditions.is_empty() {
95            return Vec::new();
96        }
97
98        let from_expression_alias = FromExpressionElementSegment(
99            children.recursive_crawl(
100                const { &SyntaxSet::new(&[SyntaxKind::FromExpressionElement]) },
101                true,
102            )[0]
103            .clone(),
104        )
105        .eventual_alias()
106        .ref_str
107        .clone();
108
109        table_aliases.push(from_expression_alias);
110
111        let mut join_clause_aliases = join_clauses
112            .into_iter()
113            .map(|join_clause| {
114                JoinClauseSegment(join_clause)
115                    .eventual_aliases()
116                    .first()
117                    .unwrap()
118                    .1
119                    .ref_str
120                    .clone()
121            })
122            .collect_vec();
123
124        table_aliases.append(&mut join_clause_aliases);
125
126        let table_aliases = table_aliases
127            .iter()
128            .map(|it| it.to_uppercase_smolstr())
129            .collect_vec();
130        let mut conditions = Vec::new();
131
132        let join_on_condition_expressions = join_on_conditions
133            .children(None)
134            .recursive_crawl(const { &SyntaxSet::new(&[SyntaxKind::Expression]) }, true);
135
136        for expression in join_on_condition_expressions {
137            let mut expression_group = Vec::new();
138            for element in Segments::new(expression, None).children(None) {
139                if !matches!(
140                    element.get_type(),
141                    SyntaxKind::Whitespace | SyntaxKind::Newline
142                ) {
143                    expression_group.push(element);
144                }
145            }
146            conditions.push(expression_group);
147        }
148
149        let mut subconditions = Vec::new();
150
151        for expression_group in conditions {
152            subconditions.append(&mut split_list_by_segment_type(
153                expression_group,
154                SyntaxKind::BinaryOperator,
155                vec!["and".into(), "or".into()],
156            ));
157        }
158
159        let column_operator_column_subconditions = subconditions
160            .into_iter()
161            .filter(|it| is_qualified_column_operator_qualified_column_sequence(it))
162            .collect_vec();
163
164        let mut fixes = Vec::new();
165
166        for subcondition in column_operator_column_subconditions {
167            let comparison_operator = subcondition[1].clone();
168            let first_column_reference = subcondition[0].clone();
169            let second_column_reference = subcondition[2].clone();
170            let raw_comparison_operators: Vec<_> = comparison_operator
171                .children(const { &SyntaxSet::new(&[SyntaxKind::RawComparisonOperator]) })
172                .collect();
173            let first_table_seg = first_column_reference
174                .child(
175                    const {
176                        &SyntaxSet::new(&[
177                            SyntaxKind::NakedIdentifier,
178                            SyntaxKind::QuotedIdentifier,
179                        ])
180                    },
181                )
182                .unwrap();
183            let second_table_seg = second_column_reference
184                .child(
185                    const {
186                        &SyntaxSet::new(&[
187                            SyntaxKind::NakedIdentifier,
188                            SyntaxKind::QuotedIdentifier,
189                        ])
190                    },
191                )
192                .unwrap();
193
194            let first_table = first_table_seg.raw().to_uppercase_smolstr();
195            let second_table = second_table_seg.raw().to_uppercase_smolstr();
196
197            let raw_comparison_operator_opposites = |op| match op {
198                "<" => ">",
199                ">" => "<",
200                _ => unimplemented!(),
201            };
202
203            if !table_aliases.contains(&first_table) || !table_aliases.contains(&second_table) {
204                continue;
205            }
206
207            if (table_aliases
208                .iter()
209                .position(|x| x == &first_table)
210                .unwrap()
211                > table_aliases
212                    .iter()
213                    .position(|x| x == &second_table)
214                    .unwrap()
215                && self.preferred_first_table_in_join_clause == "earlier")
216                || (table_aliases
217                    .iter()
218                    .position(|x| x == &first_table)
219                    .unwrap()
220                    < table_aliases
221                        .iter()
222                        .position(|x| x == &second_table)
223                        .unwrap()
224                    && self.preferred_first_table_in_join_clause == "later")
225            {
226                fixes.push(LintFix::replace(
227                    first_column_reference.clone(),
228                    vec![second_column_reference.clone()],
229                    None,
230                ));
231                fixes.push(LintFix::replace(
232                    second_column_reference.clone(),
233                    vec![first_column_reference.clone()],
234                    None,
235                ));
236
237                if matches!(raw_comparison_operators[0].raw().as_ref(), "<" | ">")
238                    && raw_comparison_operators
239                        .iter()
240                        .map(|it| it.raw())
241                        .ne(["<", ">"])
242                {
243                    fixes.push(LintFix::replace(
244                        raw_comparison_operators[0].clone(),
245                        vec![
246                            SegmentBuilder::token(
247                                context.tables.next_id(),
248                                raw_comparison_operator_opposites(
249                                    raw_comparison_operators[0].raw().as_ref(),
250                                ),
251                                SyntaxKind::RawComparisonOperator,
252                            )
253                            .finish(),
254                        ],
255                        None,
256                    ));
257                }
258            }
259        }
260
261        if fixes.is_empty() {
262            return Vec::new();
263        }
264
265        vec![LintResult::new(
266            context.segment.clone().into(),
267            fixes,
268            format!(
269                "Joins should list the table referenced {}",
270                self.preferred_first_table_in_join_clause
271            )
272            .into(),
273            None,
274        )]
275    }
276
277    fn crawl_behaviour(&self) -> Crawler {
278        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::FromExpression]) }).into()
279    }
280}
281
282fn split_list_by_segment_type(
283    segment_list: Vec<ErasedSegment>,
284    delimiter_type: SyntaxKind,
285    delimiters: Vec<SmolStr>,
286) -> Vec<Vec<ErasedSegment>> {
287    let delimiters = delimiters
288        .into_iter()
289        .map(|it| it.to_uppercase_smolstr())
290        .collect_vec();
291    let mut new_list = Vec::new();
292    let mut sub_list = Vec::new();
293
294    for i in 0..segment_list.len() {
295        if i == segment_list.len() - 1 {
296            sub_list.push(segment_list[i].clone());
297            new_list.push(sub_list.clone());
298        } else if segment_list[i].get_type() == delimiter_type
299            && delimiters.contains(&segment_list[i].raw().to_uppercase_smolstr())
300        {
301            new_list.push(sub_list.clone());
302            sub_list.clear();
303        } else {
304            sub_list.push(segment_list[i].clone());
305        }
306    }
307
308    new_list
309}
310
311fn is_qualified_column_operator_qualified_column_sequence(segment_list: &[ErasedSegment]) -> bool {
312    if segment_list.len() != 3 {
313        return false;
314    }
315
316    if segment_list[0].get_type() == SyntaxKind::ColumnReference
317        && segment_list[0]
318            .direct_descendant_type_set()
319            .contains(SyntaxKind::Dot)
320        && segment_list[1].get_type() == SyntaxKind::ComparisonOperator
321        && segment_list[2].get_type() == SyntaxKind::ColumnReference
322        && segment_list[2]
323            .direct_descendant_type_set()
324            .contains(SyntaxKind::Dot)
325    {
326        return true;
327    }
328
329    false
330}