sqruff_lib/rules/capitalisation/
cp01.rs

1use ahash::{AHashMap, AHashSet};
2use itertools::Itertools;
3use regex::Regex;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::helpers::capitalize;
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::base::ErasedSegment;
8
9use crate::core::config::Value;
10use crate::core::rules::base::{Erased, ErasedRule, LintPhase, LintResult, Rule, RuleGroups};
11use crate::core::rules::context::RuleContext;
12use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
13
14fn is_capitalizable(character: char) -> bool {
15    character.to_lowercase().ne(character.to_uppercase())
16}
17
18#[derive(Debug, Clone)]
19pub struct RuleCP01 {
20    pub(crate) capitalisation_policy: String,
21    pub(crate) ignore_words: Vec<String>,
22    pub(crate) ignore_words_regex: Vec<Regex>,
23    pub(crate) cap_policy_name: String,
24    pub(crate) skip_literals: bool,
25    pub(crate) exclude_parent_types: &'static [SyntaxKind],
26    pub(crate) description_elem: &'static str,
27}
28
29impl Default for RuleCP01 {
30    fn default() -> Self {
31        Self {
32            capitalisation_policy: "consistent".into(),
33            cap_policy_name: "capitalisation_policy".into(),
34            skip_literals: true,
35            exclude_parent_types: &[
36                SyntaxKind::DataType,
37                SyntaxKind::DatetimeTypeIdentifier,
38                SyntaxKind::PrimitiveType,
39                SyntaxKind::NakedIdentifier,
40            ],
41            description_elem: "Keywords",
42            ignore_words: Vec::new(),
43            ignore_words_regex: Vec::new(),
44        }
45    }
46}
47
48impl Rule for RuleCP01 {
49    fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
50        Ok(RuleCP01 {
51            capitalisation_policy: config["capitalisation_policy"].as_string().unwrap().into(),
52            ignore_words: config["ignore_words"]
53                .map(|it| {
54                    it.as_array()
55                        .unwrap()
56                        .iter()
57                        .map(|it| it.as_string().unwrap().to_lowercase())
58                        .collect()
59                })
60                .unwrap_or_default(),
61            ignore_words_regex: config["ignore_words_regex"]
62                .map(|it| {
63                    it.as_array()
64                        .unwrap()
65                        .iter()
66                        .map(|it| Regex::new(it.as_string().unwrap()).unwrap())
67                        .collect()
68                })
69                .unwrap_or_default(),
70            ..Default::default()
71        }
72        .erased())
73    }
74
75    fn lint_phase(&self) -> LintPhase {
76        LintPhase::Post
77    }
78
79    fn name(&self) -> &'static str {
80        "capitalisation.keywords"
81    }
82
83    fn description(&self) -> &'static str {
84        "Inconsistent capitalisation of keywords."
85    }
86
87    fn long_description(&self) -> &'static str {
88        r#"
89**Anti-pattern**
90
91In this example, select is in lower-case whereas `FROM` is in upper-case.
92
93```sql
94select
95    a
96FROM foo
97```
98
99**Best practice**
100
101Make all keywords either in upper-case or in lower-case.
102
103```sql
104SELECT
105    a
106FROM foo
107
108-- Also good
109
110select
111    a
112from foo
113```
114"#
115    }
116
117    fn groups(&self) -> &'static [RuleGroups] {
118        &[
119            RuleGroups::All,
120            RuleGroups::Core,
121            RuleGroups::Capitalisation,
122        ]
123    }
124
125    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
126        let parent = context.parent_stack.last().unwrap();
127
128        if self
129            .ignore_words
130            .contains(&context.segment.raw().to_lowercase())
131        {
132            return Vec::new();
133        }
134
135        if self
136            .ignore_words_regex
137            .iter()
138            .any(|regex| regex.is_match(context.segment.raw().as_ref()))
139        {
140            return Vec::new();
141        }
142
143        if (self.skip_literals && context.segment.is_type(SyntaxKind::Literal))
144            || !self.exclude_parent_types.is_empty()
145                && self
146                    .exclude_parent_types
147                    .iter()
148                    .any(|&it| parent.is_type(it))
149        {
150            return vec![LintResult::new(None, Vec::new(), None, None)];
151        }
152
153        if parent.get_type() == SyntaxKind::FunctionName && parent.segments().len() != 1 {
154            return vec![LintResult::new(None, Vec::new(), None, None)];
155        }
156
157        vec![handle_segment(
158            self.description_elem,
159            &self.capitalisation_policy,
160            &self.cap_policy_name,
161            context.segment.clone(),
162            context,
163        )]
164    }
165
166    fn is_fix_compatible(&self) -> bool {
167        true
168    }
169
170    fn crawl_behaviour(&self) -> Crawler {
171        SegmentSeekerCrawler::new(
172            const {
173                SyntaxSet::new(&[
174                    SyntaxKind::Keyword,
175                    SyntaxKind::BinaryOperator,
176                    SyntaxKind::DatePart,
177                ])
178            },
179        )
180        .into()
181    }
182}
183
184#[derive(Clone, Default)]
185struct RefutedCases(AHashSet<&'static str>);
186
187#[derive(Clone)]
188struct LatestPossibleCase(String);
189
190pub fn handle_segment(
191    description_elem: &str,
192    extended_capitalisation_policy: &str,
193    cap_policy_name: &str,
194    seg: ErasedSegment,
195    context: &RuleContext,
196) -> LintResult {
197    if seg.raw().is_empty() || seg.is_templated() {
198        return LintResult::new(None, Vec::new(), None, None);
199    }
200
201    let mut refuted_cases = context.try_get::<RefutedCases>().unwrap_or_default().0;
202
203    let mut first_letter_is_lowercase = false;
204    for ch in seg.raw().chars() {
205        if is_capitalizable(ch) {
206            first_letter_is_lowercase = Some(ch).into_iter().ne(ch.to_uppercase());
207            break;
208        }
209        first_letter_is_lowercase = false;
210    }
211
212    if first_letter_is_lowercase {
213        refuted_cases.extend(["upper", "capitalise", "pascal"]);
214        if seg.raw().as_str() != seg.raw().to_lowercase() {
215            refuted_cases.insert("lower");
216        }
217    } else {
218        refuted_cases.insert("lower");
219
220        let segment_raw = seg.raw();
221        if segment_raw.as_str() != segment_raw.to_uppercase() {
222            refuted_cases.insert("upper");
223        }
224        if segment_raw.as_str()
225            != segment_raw
226                .to_uppercase()
227                .chars()
228                .next()
229                .unwrap()
230                .to_string()
231                + segment_raw[1..].to_lowercase().as_str()
232        {
233            refuted_cases.insert("capitalise");
234        }
235        if !segment_raw.chars().all(|c| c.is_alphanumeric()) {
236            refuted_cases.insert("pascal");
237        }
238    }
239
240    context.set(RefutedCases(refuted_cases.clone()));
241
242    let concrete_policy = if extended_capitalisation_policy == "consistent" {
243        let cap_policy_opts = match cap_policy_name {
244            "capitalisation_policy" => ["upper", "lower", "capitalise"].as_slice(),
245            "extended_capitalisation_policy" => {
246                ["upper", "lower", "pascal", "capitalise"].as_slice()
247            }
248            _ => unimplemented!("Unknown capitalisation policy name: {cap_policy_name}"),
249        };
250
251        let possible_cases = cap_policy_opts
252            .iter()
253            .filter(|&it| !refuted_cases.contains(it))
254            .collect_vec();
255
256        if !possible_cases.is_empty() {
257            context.set(LatestPossibleCase(possible_cases[0].to_string()));
258            return LintResult::new(None, Vec::new(), None, None);
259        } else {
260            context
261                .try_get::<LatestPossibleCase>()
262                .unwrap_or_else(|| LatestPossibleCase("upper".into()))
263                .0
264        }
265    } else {
266        extended_capitalisation_policy.to_string()
267    };
268
269    let concrete_policy = concrete_policy.as_str();
270
271    let mut fixed_raw = seg.raw().to_string();
272    fixed_raw = match concrete_policy {
273        "upper" => fixed_raw.to_uppercase(),
274        "lower" => fixed_raw.to_lowercase(),
275        "capitalise" => capitalize(&fixed_raw),
276        "pascal" => {
277            let re = lazy_regex::regex!(r"([^a-zA-Z0-9]+|^)([a-zA-Z0-9])([a-zA-Z0-9]*)");
278            re.replace_all(&fixed_raw, |caps: &regex::Captures| {
279                let mut replacement_string = String::from(&caps[1]);
280                let capitalized = caps[2].to_uppercase();
281                replacement_string.push_str(&capitalized);
282                replacement_string.push_str(&caps[3]);
283                replacement_string
284            })
285            .into()
286        }
287        _ => fixed_raw,
288    };
289
290    if fixed_raw == seg.raw().as_str() {
291        LintResult::new(None, Vec::new(), None, None)
292    } else {
293        let consistency = if extended_capitalisation_policy == "consistent" {
294            "consistently "
295        } else {
296            ""
297        };
298        let policy = match concrete_policy {
299            concrete_policy @ ("upper" | "lower") => format!("{} case.", concrete_policy),
300            "capitalise" => "capitalised.".to_string(),
301            "pascal" => "pascal case.".to_string(),
302            _ => "".to_string(),
303        };
304
305        LintResult::new(
306            seg.clone().into(),
307            vec![LintFix::replace(
308                seg.clone(),
309                vec![seg.edit(context.tables.next_id(), fixed_raw.to_string().into(), None)],
310                None,
311            )],
312            format!("{description_elem} must be {consistency}{policy}").into(),
313            None,
314        )
315    }
316}