sqruff_lib/rules/structure/
st07.rs

1use ahash::AHashMap;
2use itertools::Itertools;
3use smol_str::{SmolStr, ToSmolStr};
4use sqruff_lib_core::dialects::init::DialectKind;
5use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder, Tables};
8use sqruff_lib_core::utils::analysis::select::get_select_statement_info;
9use sqruff_lib_core::utils::functional::segments::Segments;
10
11use crate::core::config::Value;
12use crate::core::rules::base::{CloneRule, ErasedRule, LintResult, Rule, RuleGroups};
13use crate::core::rules::context::RuleContext;
14use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
15use crate::utils::functional::context::FunctionalContext;
16
17#[derive(Clone, Debug, Default)]
18pub struct RuleST07;
19
20impl Rule for RuleST07 {
21    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
22        Ok(RuleST07.erased())
23    }
24
25    fn name(&self) -> &'static str {
26        "structure.using"
27    }
28
29    fn description(&self) -> &'static str {
30        "Prefer specifying join keys instead of using ``USING``."
31    }
32
33    fn long_description(&self) -> &'static str {
34        r"
35**Anti-pattern**
36
37```sql
38SELECT
39    table_a.field_1,
40    table_b.field_2
41FROM
42    table_a
43INNER JOIN table_b USING (id)
44```
45
46**Best practice**
47
48Specify the keys directly
49
50```sql
51SELECT
52    table_a.field_1,
53    table_b.field_2
54FROM
55    table_a
56INNER JOIN table_b
57    ON table_a.id = table_b.id
58```"
59    }
60
61    fn groups(&self) -> &'static [RuleGroups] {
62        &[RuleGroups::All, RuleGroups::Structure]
63    }
64
65    fn dialect_skip(&self) -> &'static [DialectKind] {
66        &[DialectKind::Clickhouse]
67    }
68
69    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
70        let functional_context = FunctionalContext::new(context);
71        let segment = functional_context.segment();
72        let parent_stack = functional_context.parent_stack();
73
74        let usings = segment.children(Some(|it: &ErasedSegment| it.is_keyword("using")));
75        let using_anchor = usings.first();
76
77        let Some(using_anchor) = using_anchor else {
78            return Vec::new();
79        };
80
81        let unfixable_result = LintResult::new(
82            using_anchor.clone().into(),
83            Vec::new(),
84            Some("Found USING statement. Expected only ON statements.".into()),
85            None,
86        );
87
88        let tables_in_join = parent_stack
89            .last()
90            .unwrap()
91            .segments()
92            .iter()
93            .filter(|it| {
94                matches!(
95                    it.get_type(),
96                    SyntaxKind::JoinClause | SyntaxKind::FromExpressionElement
97                )
98            })
99            .cloned()
100            .collect_vec();
101
102        if segment.get(0, None) != tables_in_join.get(1).cloned() {
103            return vec![unfixable_result];
104        }
105
106        let stmts = parent_stack.find_last(Some(|it: &ErasedSegment| {
107            it.is_type(SyntaxKind::SelectStatement)
108        }));
109        let parent_select = stmts.first();
110
111        let Some(parent_select) = parent_select else {
112            return vec![unfixable_result];
113        };
114
115        let select_info = get_select_statement_info(parent_select, context.dialect.into(), true);
116        let mut table_aliases =
117            select_info.map_or(Vec::new(), |select_info| select_info.table_aliases);
118        table_aliases.retain(|it| !it.ref_str.is_empty());
119
120        if table_aliases.len() < 2 {
121            return vec![unfixable_result];
122        }
123
124        let (to_delete, insert_after_anchor) = extract_deletion_sequence_and_anchor(&segment);
125
126        let [table_a, table_b, ..] = &table_aliases[..] else {
127            unreachable!()
128        };
129
130        let mut edit_segments = vec![
131            SegmentBuilder::keyword(context.tables.next_id(), "ON"),
132            SegmentBuilder::whitespace(context.tables.next_id(), " "),
133        ];
134
135        edit_segments.append(&mut generate_join_conditions(
136            context.tables,
137            context.dialect.name,
138            &table_a.ref_str,
139            &table_b.ref_str,
140            extract_cols_from_using(segment, using_anchor),
141        ));
142
143        let mut fixes = Vec::with_capacity(1 + to_delete.len());
144
145        fixes.push(LintFix::create_before(insert_after_anchor, edit_segments));
146        fixes.extend(to_delete.into_iter().map(LintFix::delete));
147
148        vec![LintResult::new(
149            using_anchor.clone().into(),
150            fixes,
151            None,
152            None,
153        )]
154    }
155
156    fn is_fix_compatible(&self) -> bool {
157        true
158    }
159
160    fn crawl_behaviour(&self) -> Crawler {
161        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::JoinClause]) }).into()
162    }
163}
164
165fn extract_cols_from_using(join_clause: Segments, using_segs: &ErasedSegment) -> Vec<SmolStr> {
166    join_clause
167        .children(None)
168        .select(
169            Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
170            None,
171            Some(using_segs),
172            None,
173        )
174        .find_first::<fn(&ErasedSegment) -> bool>(None)
175        .children(Some(|it: &ErasedSegment| {
176            it.is_type(SyntaxKind::Identifier) || it.is_type(SyntaxKind::NakedIdentifier)
177        }))
178        .into_iter()
179        .map(|it| it.raw().to_smolstr())
180        .collect()
181}
182
183fn generate_join_conditions(
184    tables: &Tables,
185    dialect: DialectKind,
186    table_a_ref: &str,
187    table_b_ref: &str,
188    columns: Vec<SmolStr>,
189) -> Vec<ErasedSegment> {
190    let mut edit_segments = Vec::new();
191
192    for col in columns {
193        edit_segments.extend_from_slice(&[
194            create_col_reference(tables, dialect, table_a_ref, &col),
195            SegmentBuilder::whitespace(tables.next_id(), " "),
196            SegmentBuilder::token(tables.next_id(), "=", SyntaxKind::Symbol).finish(),
197            SegmentBuilder::whitespace(tables.next_id(), " "),
198            create_col_reference(tables, dialect, table_b_ref, &col),
199            SegmentBuilder::whitespace(tables.next_id(), " "),
200            SegmentBuilder::keyword(tables.next_id(), "AND"),
201            SegmentBuilder::whitespace(tables.next_id(), " "),
202        ]);
203    }
204
205    edit_segments
206        .get(..edit_segments.len().saturating_sub(3))
207        .map_or(Vec::new(), ToOwned::to_owned)
208        .clone()
209}
210
211fn extract_deletion_sequence_and_anchor(
212    join_clause: &Segments,
213) -> (Vec<ErasedSegment>, ErasedSegment) {
214    let mut insert_anchor = None;
215    let mut to_delete = Vec::new();
216
217    for seg in join_clause.children(None) {
218        if seg.raw().eq_ignore_ascii_case("USING") {
219            to_delete.push(seg.clone());
220            continue;
221        }
222
223        if to_delete.is_empty() {
224            continue;
225        }
226
227        if to_delete.last().unwrap().is_type(SyntaxKind::Bracketed) {
228            insert_anchor = Some(seg);
229            break;
230        }
231
232        to_delete.push(seg);
233    }
234
235    (to_delete, insert_anchor.unwrap())
236}
237
238fn create_col_reference(
239    tables: &Tables,
240    dialect: DialectKind,
241    table_ref: &str,
242    column_name: &str,
243) -> ErasedSegment {
244    SegmentBuilder::node(
245        tables.next_id(),
246        SyntaxKind::ColumnReference,
247        dialect,
248        vec![
249            SegmentBuilder::token(tables.next_id(), table_ref, SyntaxKind::NakedIdentifier)
250                .finish(),
251            SegmentBuilder::symbol(tables.next_id(), "."),
252            SegmentBuilder::token(tables.next_id(), column_name, SyntaxKind::NakedIdentifier)
253                .finish(),
254        ],
255    )
256    .finish()
257}