sqruff_lib/rules/structure/
st02.rs

1use ahash::{AHashMap, AHashSet};
2use itertools::{Itertools, chain};
3use smol_str::StrExt;
4use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder};
7use sqruff_lib_core::utils::functional::segments::Segments;
8
9use crate::core::config::Value;
10use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::core::rules::context::RuleContext;
12use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Default, Debug, Clone)]
16pub struct RuleST02;
17
18impl Rule for RuleST02 {
19    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
20        Ok(RuleST02.erased())
21    }
22
23    fn name(&self) -> &'static str {
24        "structure.simple_case"
25    }
26
27    fn description(&self) -> &'static str {
28        "Unnecessary 'CASE' statement."
29    }
30
31    fn long_description(&self) -> &'static str {
32        r#"
33**Anti-pattern**
34
35CASE statement returns booleans.
36
37```sql
38select
39    case
40        when fab > 0 then true
41        else false
42    end as is_fab
43from fancy_table
44
45-- This rule can also simplify CASE statements
46-- that aim to fill NULL values.
47
48select
49    case
50        when fab is null then 0
51        else fab
52    end as fab_clean
53from fancy_table
54
55-- This also covers where the case statement
56-- replaces NULL values with NULL values.
57
58select
59    case
60        when fab is null then null
61        else fab
62    end as fab_clean
63from fancy_table
64```
65
66**Best practice**
67
68Reduce to WHEN condition within COALESCE function.
69
70```sql
71select
72    coalesce(fab > 0, false) as is_fab
73from fancy_table
74
75-- To fill NULL values.
76
77select
78    coalesce(fab, 0) as fab_clean
79from fancy_table
80
81-- NULL filling NULL.
82
83select fab as fab_clean
84from fancy_table
85```
86"#
87    }
88
89    fn groups(&self) -> &'static [RuleGroups] {
90        &[RuleGroups::All, RuleGroups::Structure]
91    }
92
93    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
94        if context.segment.segments()[0]
95            .raw()
96            .eq_ignore_ascii_case("CASE")
97        {
98            let children = FunctionalContext::new(context).segment().children(None);
99
100            let when_clauses = children.select(
101                Some(|it: &ErasedSegment| it.is_type(SyntaxKind::WhenClause)),
102                None,
103                None,
104                None,
105            );
106            let else_clauses = children.select(
107                Some(|it: &ErasedSegment| it.is_type(SyntaxKind::ElseClause)),
108                None,
109                None,
110                None,
111            );
112
113            if when_clauses.len() > 1 {
114                return Vec::new();
115            }
116
117            let condition_expression =
118                when_clauses.children(Some(|it| it.is_type(SyntaxKind::Expression)))[0].clone();
119            let then_expression =
120                when_clauses.children(Some(|it| it.is_type(SyntaxKind::Expression)))[1].clone();
121
122            if !else_clauses.is_empty() {
123                if let Some(else_expression) = else_clauses
124                    .children(Some(|it| it.is_type(SyntaxKind::Expression)))
125                    .first()
126                {
127                    let upper_bools = ["TRUE", "FALSE"];
128
129                    let then_expression_upper = then_expression.raw().to_uppercase_smolstr();
130                    let else_expression_upper = else_expression.raw().to_uppercase_smolstr();
131
132                    if upper_bools.contains(&then_expression_upper.as_str())
133                        && upper_bools.contains(&else_expression_upper.as_str())
134                        && then_expression_upper != else_expression_upper
135                    {
136                        let coalesce_arg_1 = condition_expression.clone();
137                        let coalesce_arg_2 =
138                            SegmentBuilder::keyword(context.tables.next_id(), "false");
139                        let preceding_not = then_expression_upper == "FALSE";
140
141                        let fixes = Self::coalesce_fix_list(
142                            context,
143                            coalesce_arg_1,
144                            coalesce_arg_2,
145                            preceding_not,
146                        );
147
148                        return vec![LintResult::new(
149                            condition_expression.into(),
150                            fixes,
151                            "Unnecessary CASE statement. Use COALESCE function instead."
152                                .to_owned()
153                                .into(),
154                            None,
155                        )];
156                    }
157                }
158            }
159
160            let condition_expression_segments_raw: AHashSet<_> = AHashSet::from_iter(
161                condition_expression
162                    .segments()
163                    .iter()
164                    .map(|segment| segment.raw().to_uppercase_smolstr()),
165            );
166
167            if condition_expression_segments_raw.contains("IS")
168                && condition_expression_segments_raw.contains("NULL")
169                && condition_expression_segments_raw
170                    .intersection(&AHashSet::from_iter(["AND".into(), "OR".into()]))
171                    .next()
172                    .is_none()
173            {
174                let is_not_prefix = condition_expression_segments_raw.contains("NOT");
175
176                let tmp = Segments::new(condition_expression.clone(), None)
177                    .children(Some(|it| it.is_type(SyntaxKind::ColumnReference)));
178
179                let Some(column_reference_segment) = tmp.first() else {
180                    return Vec::new();
181                };
182
183                let array_accessor_segment = Segments::new(condition_expression.clone(), None)
184                    .children(Some(|it: &ErasedSegment| {
185                        it.is_type(SyntaxKind::ArrayAccessor)
186                    }))
187                    .first()
188                    .cloned();
189
190                let column_reference_segment_raw_upper = match array_accessor_segment {
191                    Some(array_accessor_segment) => {
192                        column_reference_segment.raw().to_lowercase()
193                            + &array_accessor_segment.raw().to_uppercase()
194                    }
195                    None => column_reference_segment.raw().to_uppercase(),
196                };
197
198                if !else_clauses.is_empty() {
199                    let else_expression = else_clauses
200                        .children(Some(|it| it.is_type(SyntaxKind::Expression)))[0]
201                        .clone();
202
203                    let (coalesce_arg_1, coalesce_arg_2) = if !is_not_prefix
204                        && column_reference_segment_raw_upper
205                            == else_expression.raw().to_uppercase_smolstr()
206                    {
207                        (else_expression, then_expression)
208                    } else if is_not_prefix
209                        && column_reference_segment_raw_upper
210                            == then_expression.raw().to_uppercase_smolstr()
211                    {
212                        (then_expression, else_expression)
213                    } else {
214                        return Vec::new();
215                    };
216
217                    if coalesce_arg_2.raw().eq_ignore_ascii_case("NULL") {
218                        let fixes =
219                            Self::column_only_fix_list(context, column_reference_segment.clone());
220                        return vec![LintResult::new(
221                            condition_expression.into(),
222                            fixes,
223                            Some(String::new()),
224                            None,
225                        )];
226                    }
227
228                    let fixes =
229                        Self::coalesce_fix_list(context, coalesce_arg_1, coalesce_arg_2, false);
230
231                    return vec![LintResult::new(
232                        condition_expression.into(),
233                        fixes,
234                        "Unnecessary CASE statement. Use COALESCE function instead."
235                            .to_owned()
236                            .into(),
237                        None,
238                    )];
239                } else if column_reference_segment
240                    .raw()
241                    .eq_ignore_ascii_case(then_expression.raw())
242                {
243                    let fixes =
244                        Self::column_only_fix_list(context, column_reference_segment.clone());
245
246                    return vec![LintResult::new(
247                        condition_expression.into(),
248                        fixes,
249                        format!(
250                            "Unnecessary CASE statement. Just use column '{}'.",
251                            column_reference_segment.raw()
252                        )
253                        .into(),
254                        None,
255                    )];
256                }
257            }
258
259            Vec::new()
260        } else {
261            Vec::new()
262        }
263    }
264
265    fn crawl_behaviour(&self) -> Crawler {
266        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::CaseExpression]) }).into()
267    }
268}
269
270impl RuleST02 {
271    fn coalesce_fix_list(
272        context: &RuleContext,
273        coalesce_arg_1: ErasedSegment,
274        coalesce_arg_2: ErasedSegment,
275        preceding_not: bool,
276    ) -> Vec<LintFix> {
277        let mut edits = vec![
278            SegmentBuilder::token(
279                context.tables.next_id(),
280                "coalesce",
281                SyntaxKind::FunctionNameIdentifier,
282            )
283            .finish(),
284            SegmentBuilder::symbol(context.tables.next_id(), "("),
285            coalesce_arg_1,
286            SegmentBuilder::symbol(context.tables.next_id(), ","),
287            SegmentBuilder::whitespace(context.tables.next_id(), " "),
288            coalesce_arg_2,
289            SegmentBuilder::symbol(context.tables.next_id(), ")"),
290        ];
291
292        if preceding_not {
293            edits = chain(
294                [
295                    SegmentBuilder::keyword(context.tables.next_id(), "not"),
296                    SegmentBuilder::whitespace(context.tables.next_id(), " "),
297                ],
298                edits,
299            )
300            .collect_vec();
301        }
302
303        vec![LintFix::replace(context.segment.clone(), edits, None)]
304    }
305
306    fn column_only_fix_list(
307        context: &RuleContext,
308        column_reference_segment: ErasedSegment,
309    ) -> Vec<LintFix> {
310        vec![LintFix::replace(
311            context.segment.clone(),
312            vec![column_reference_segment],
313            None,
314        )]
315    }
316}