sqruff_lib/rules/convention/
cv05.rs

1use std::borrow::Cow;
2
3use ahash::AHashMap;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder};
6use sqruff_lib_core::utils::functional::segments::Segments;
7
8use crate::core::config::Value;
9use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
10use crate::core::rules::context::RuleContext;
11use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
12use crate::utils::reflow::sequence::{Filter, ReflowSequence, TargetSide};
13
14#[derive(Debug)]
15enum CorrectionListItem {
16    WhitespaceSegment,
17    KeywordSegment(String),
18}
19
20type CorrectionList = Vec<CorrectionListItem>;
21
22#[derive(Default, Clone, Debug)]
23pub struct RuleCV05;
24
25fn create_base_is_null_sequence(is_upper: bool, operator_raw: Cow<str>) -> CorrectionList {
26    let is_seg = CorrectionListItem::KeywordSegment(if is_upper { "IS" } else { "is" }.to_string());
27    let not_seg =
28        CorrectionListItem::KeywordSegment(if is_upper { "NOT" } else { "not" }.to_string());
29
30    if operator_raw == "=" {
31        vec![is_seg]
32    } else {
33        vec![is_seg, CorrectionListItem::WhitespaceSegment, not_seg]
34    }
35}
36
37impl Rule for RuleCV05 {
38    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
39        Ok(RuleCV05.erased())
40    }
41
42    fn name(&self) -> &'static str {
43        "convention.is_null"
44    }
45
46    fn description(&self) -> &'static str {
47        "Relational operators should not be used to check for NULL values."
48    }
49
50    fn long_description(&self) -> &'static str {
51        r#"
52**Anti-pattern**
53
54In this example, the `=` operator is used to check for `NULL` values.
55
56```sql
57SELECT
58    a
59FROM foo
60WHERE a = NULL
61```
62
63**Best practice**
64
65Use `IS` or `IS NOT` to check for `NULL` values.
66
67```sql
68SELECT
69    a
70FROM foo
71WHERE a IS NULL
72```
73"#
74    }
75
76    fn groups(&self) -> &'static [RuleGroups] {
77        &[RuleGroups::All, RuleGroups::Core, RuleGroups::Convention]
78    }
79
80    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
81        if context.parent_stack.len() >= 2 {
82            for type_str in [
83                SyntaxKind::SetClauseList,
84                SyntaxKind::ExecuteScriptStatement,
85                SyntaxKind::OptionsSegment,
86            ] {
87                if context.parent_stack[context.parent_stack.len() - 2].is_type(type_str) {
88                    return Vec::new();
89                }
90            }
91        }
92
93        if !context.parent_stack.is_empty() {
94            for type_str in [
95                SyntaxKind::SetClauseList,
96                SyntaxKind::ExecuteScriptStatement,
97                SyntaxKind::AssignmentOperator,
98            ] {
99                if context.parent_stack[context.parent_stack.len() - 1].is_type(type_str) {
100                    return Vec::new();
101                }
102            }
103        }
104
105        if !context.parent_stack.is_empty()
106            && context.parent_stack[context.parent_stack.len() - 1]
107                .is_type(SyntaxKind::ExclusionConstraintElement)
108        {
109            return Vec::new();
110        }
111
112        let raw_consist = context.segment.raw();
113        if !["=", "!=", "<>"].contains(&raw_consist.as_str()) {
114            return Vec::new();
115        }
116
117        let segment = context.parent_stack.last().unwrap().segments().to_vec();
118
119        let siblings = Segments::from_vec(segment, None);
120        let after_op_list =
121            siblings.select::<fn(&ErasedSegment) -> bool>(None, None, Some(&context.segment), None);
122
123        let next_code = after_op_list.find_first(Some(|sp: &ErasedSegment| sp.is_code()));
124
125        if !next_code.all(Some(|it| it.is_type(SyntaxKind::NullLiteral))) {
126            return Vec::new();
127        }
128
129        let sub_seg = next_code.get(0, None);
130        let edit = create_base_is_null_sequence(
131            sub_seg.as_ref().unwrap().raw().starts_with('N'),
132            context.segment.raw().as_str().into(),
133        );
134
135        let mut seg = Vec::with_capacity(edit.len());
136
137        for item in edit {
138            match item {
139                CorrectionListItem::KeywordSegment(keyword) => {
140                    seg.push(SegmentBuilder::keyword(context.tables.next_id(), &keyword));
141                }
142                CorrectionListItem::WhitespaceSegment => {
143                    seg.push(SegmentBuilder::whitespace(context.tables.next_id(), " "));
144                }
145            };
146        }
147
148        let fixes = ReflowSequence::from_around_target(
149            &context.segment,
150            context.parent_stack[0].clone(),
151            TargetSide::Both,
152            context.config,
153        )
154        .replace(context.segment.clone(), &seg)
155        .respace(context.tables, false, Filter::All)
156        .fixes();
157
158        vec![LintResult::new(
159            Some(context.segment.clone()),
160            fixes,
161            None,
162            None,
163        )]
164    }
165
166    fn crawl_behaviour(&self) -> Crawler {
167        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::ComparisonOperator]) })
168            .into()
169    }
170}