sqruff_lib/rules/ambiguous/
am05.rs

1use std::str::FromStr;
2
3use ahash::AHashMap;
4use smol_str::StrExt;
5use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::base::SegmentBuilder;
8use strum_macros::{AsRefStr, EnumString};
9
10use crate::core::config::Value;
11use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
12use crate::core::rules::context::RuleContext;
13use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
14
15#[derive(Clone, Debug)]
16pub struct RuleAM05 {
17    fully_qualify_join_types: JoinType,
18}
19
20#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, AsRefStr, EnumString)]
21#[strum(serialize_all = "lowercase")]
22enum JoinType {
23    Inner,
24    Outer,
25    Both,
26}
27
28impl Default for RuleAM05 {
29    fn default() -> Self {
30        Self {
31            fully_qualify_join_types: JoinType::Inner,
32        }
33    }
34}
35
36impl Rule for RuleAM05 {
37    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
38        let fully_qualify_join_types = config["fully_qualify_join_types"].as_string();
39        // TODO We will need a more complete story for all the config parsing
40        match fully_qualify_join_types {
41            None => Err("Rule AM05 expects a `fully_qualify_join_types` array".to_string()),
42            Some(join_type) => {
43                let join_type = JoinType::from_str(join_type).map_err(|_| {
44                    format!(
45                        "Rule AM05 expects a `fully_qualify_join_types` array of valid join \
46                         types. Got: {}",
47                        join_type
48                    )
49                })?;
50                Ok(RuleAM05 {
51                    fully_qualify_join_types: join_type,
52                }
53                .erased())
54            }
55        }
56    }
57
58    fn name(&self) -> &'static str {
59        "ambiguous.join"
60    }
61
62    fn description(&self) -> &'static str {
63        "Join clauses should be fully qualified."
64    }
65
66    fn long_description(&self) -> &'static str {
67        r#"
68**Anti-pattern**
69
70In this example, `UNION DISTINCT` should be preferred over `UNION`, because explicit is better than implicit.
71
72
73```sql
74SELECT a, b FROM table_1
75UNION
76SELECT a, b FROM table_2
77```
78
79**Best practice**
80
81Specify `DISTINCT` or `ALL` after `UNION` (note that `DISTINCT` is the default behavior).
82
83```sql
84SELECT a, b FROM table_1
85UNION DISTINCT
86SELECT a, b FROM table_2
87```
88"#
89    }
90
91    fn groups(&self) -> &'static [RuleGroups] {
92        &[RuleGroups::All, RuleGroups::Ambiguous]
93    }
94
95    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
96        assert!(context.segment.is_type(SyntaxKind::JoinClause));
97
98        let join_clause_keywords = context
99            .segment
100            .segments()
101            .iter()
102            .filter(|segment| segment.is_type(SyntaxKind::Keyword))
103            .collect::<Vec<_>>();
104
105        // Identify LEFT/RIGHT/OUTER JOIN and if the next keyword is JOIN.
106        if (self.fully_qualify_join_types == JoinType::Outer
107            || self.fully_qualify_join_types == JoinType::Both)
108            && ["RIGHT", "LEFT", "FULL"].contains(
109                &join_clause_keywords[0]
110                    .raw()
111                    .to_uppercase_smolstr()
112                    .as_str(),
113            )
114            && join_clause_keywords[1].raw().eq_ignore_ascii_case("JOIN")
115        {
116            let outer_keyword = if join_clause_keywords[1].raw() == "JOIN" {
117                "OUTER"
118            } else {
119                "outer"
120            };
121            return vec![LintResult::new(
122                context.segment.segments()[0].clone().into(),
123                vec![LintFix::create_after(
124                    context.segment.segments()[0].clone(),
125                    vec![
126                        SegmentBuilder::whitespace(context.tables.next_id(), " "),
127                        SegmentBuilder::keyword(context.tables.next_id(), outer_keyword),
128                    ],
129                    None,
130                )],
131                None,
132                None,
133            )];
134        };
135
136        // Fully qualifying inner joins
137        if (self.fully_qualify_join_types == JoinType::Inner
138            || self.fully_qualify_join_types == JoinType::Both)
139            && join_clause_keywords[0].raw().eq_ignore_ascii_case("JOIN")
140        {
141            let inner_keyword = if join_clause_keywords[0].raw() == "JOIN" {
142                "INNER"
143            } else {
144                "inner"
145            };
146            return vec![LintResult::new(
147                context.segment.segments()[0].clone().into(),
148                vec![LintFix::create_before(
149                    context.segment.segments()[0].clone(),
150                    vec![
151                        SegmentBuilder::keyword(context.tables.next_id(), inner_keyword),
152                        SegmentBuilder::whitespace(context.tables.next_id(), " "),
153                    ],
154                )],
155                None,
156                None,
157            )];
158        }
159        vec![]
160    }
161
162    fn is_fix_compatible(&self) -> bool {
163        true
164    }
165
166    fn crawl_behaviour(&self) -> Crawler {
167        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::JoinClause]) }).into()
168    }
169}