sqruff_lib/rules/convention/
cv04.rs

1use ahash::AHashMap;
2use itertools::Itertools;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder};
6
7use crate::core::config::Value;
8use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
9use crate::core::rules::context::RuleContext;
10use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
11use crate::utils::functional::context::FunctionalContext;
12
13#[derive(Debug, Default, Clone)]
14pub struct RuleCV04 {
15    pub prefer_count_1: bool,
16    pub prefer_count_0: bool,
17}
18
19impl Rule for RuleCV04 {
20    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
21        Ok(RuleCV04 {
22            prefer_count_1: _config
23                .get("prefer_count_1")
24                .unwrap_or(&Value::Bool(false))
25                .as_bool()
26                .unwrap(),
27            prefer_count_0: _config
28                .get("prefer_count_0")
29                .unwrap_or(&Value::Bool(false))
30                .as_bool()
31                .unwrap(),
32        }
33        .erased())
34    }
35
36    fn name(&self) -> &'static str {
37        "convention.count_rows"
38    }
39
40    fn description(&self) -> &'static str {
41        "Use consistent syntax to express \"count number of rows\"."
42    }
43
44    fn long_description(&self) -> &'static str {
45        r#"
46**Anti-pattern**
47
48In this example, `count(1)` is used to count the number of rows in a table.
49
50```sql
51select
52    count(1)
53from table_a
54```
55
56**Best practice**
57
58Use count(*) unless specified otherwise by config prefer_count_1, or prefer_count_0 as preferred.
59
60```sql
61select
62    count(*)
63from table_a
64```
65"#
66    }
67
68    fn groups(&self) -> &'static [RuleGroups] {
69        &[RuleGroups::All, RuleGroups::Core, RuleGroups::Convention]
70    }
71
72    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
73        let Some(function_name) = context
74            .segment
75            .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
76        else {
77            return Vec::new();
78        };
79
80        if function_name.raw().eq_ignore_ascii_case("COUNT") {
81            let f_content = FunctionalContext::new(context)
82                .segment()
83                .children(Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)))
84                .children(Some(|it: &ErasedSegment| {
85                    !it.is_meta()
86                        && !matches!(
87                            it.get_type(),
88                            SyntaxKind::StartBracket
89                                | SyntaxKind::EndBracket
90                                | SyntaxKind::Whitespace
91                                | SyntaxKind::Newline
92                        )
93                }));
94
95            if f_content.len() != 1 {
96                return Vec::new();
97            }
98
99            let preferred = if self.prefer_count_1 {
100                "1"
101            } else if self.prefer_count_0 {
102                "0"
103            } else {
104                "*"
105            };
106
107            if f_content[0].is_type(SyntaxKind::Star)
108                && (self.prefer_count_0 || self.prefer_count_1)
109            {
110                let new_segment =
111                    SegmentBuilder::token(context.tables.next_id(), preferred, SyntaxKind::Literal)
112                        .finish();
113                return vec![LintResult::new(
114                    context.segment.clone().into(),
115                    vec![LintFix::replace(
116                        f_content[0].clone(),
117                        vec![new_segment],
118                        None,
119                    )],
120                    None,
121                    None,
122                )];
123            }
124
125            if f_content[0].is_type(SyntaxKind::Expression) {
126                let expression_content = f_content[0]
127                    .segments()
128                    .iter()
129                    .filter(|it| !it.is_meta())
130                    .collect_vec();
131
132                let raw = expression_content[0].raw();
133                if expression_content.len() == 1
134                    && matches!(
135                        expression_content[0].get_type(),
136                        SyntaxKind::NumericLiteral | SyntaxKind::Literal
137                    )
138                    && (raw == "0" || raw == "1")
139                    && raw != preferred
140                {
141                    let first_expression = expression_content[0].clone();
142                    let first_expression_raw = first_expression.raw();
143
144                    return vec![LintResult::new(
145                        context.segment.clone().into(),
146                        vec![LintFix::replace(
147                            first_expression.clone(),
148                            vec![
149                                first_expression.edit(
150                                    context.tables.next_id(),
151                                    first_expression
152                                        .raw()
153                                        .replace(first_expression_raw.as_str(), preferred)
154                                        .into(),
155                                    None,
156                                ),
157                            ],
158                            None,
159                        )],
160                        None,
161                        None,
162                    )];
163                }
164            }
165        }
166
167        Vec::new()
168    }
169
170    fn crawl_behaviour(&self) -> Crawler {
171        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::Function]) }).into()
172    }
173}