sqruff_lib/rules/structure/
st06.rs

1use std::iter::zip;
2
3use ahash::AHashMap;
4use itertools::{Itertools, enumerate};
5use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::base::ErasedSegment;
8
9use crate::core::config::Value;
10use crate::core::rules::base::{CloneRule, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::core::rules::context::RuleContext;
12use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
13
14#[derive(Clone, Debug)]
15pub struct RuleST06;
16
17impl Rule for RuleST06 {
18    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
19        Ok(RuleST06.erased())
20    }
21
22    fn name(&self) -> &'static str {
23        "structure.column_order"
24    }
25
26    fn description(&self) -> &'static str {
27        "Select wildcards then simple targets before calculations and aggregates."
28    }
29
30    fn long_description(&self) -> &'static str {
31        r"
32**Anti-pattern**
33
34```sql
35select
36    a,
37    *,
38    row_number() over (partition by id order by date) as y,
39    b
40from x
41```
42
43**Best practice**
44
45Order `select` targets in ascending complexity
46
47```sql
48select
49    *,
50    a,
51    b,
52    row_number() over (partition by id order by date) as y
53from x
54```"
55    }
56
57    fn groups(&self) -> &'static [RuleGroups] {
58        &[RuleGroups::All, RuleGroups::Structure]
59    }
60
61    fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
62        let mut violation_exists = false;
63
64        static SELECT_ELEMENT_ORDER_PREFERENCE: &[&[Validate]] = &[
65            &[Validate::Types(
66                const { SyntaxSet::new(&[SyntaxKind::WildcardExpression]) },
67            )],
68            &[
69                Validate::Types(
70                    const { SyntaxSet::new(&[SyntaxKind::ObjectReference, SyntaxKind::ColumnReference]) },
71                ),
72                Validate::Types(const { SyntaxSet::new(&[SyntaxKind::Literal]) }),
73                Validate::Types(const { SyntaxSet::new(&[SyntaxKind::CastExpression]) }),
74                Validate::Function { name: "cast" },
75                Validate::Expression {
76                    child_typ: SyntaxKind::CastExpression,
77                },
78            ],
79        ];
80
81        if context.parent_stack.len() >= 2
82            && matches!(
83                context.parent_stack[context.parent_stack.len() - 2].get_type(),
84                SyntaxKind::InsertStatement | SyntaxKind::SetExpression
85            )
86        {
87            return Vec::new();
88        }
89
90        if context.parent_stack.len() >= 3
91            && matches!(
92                context.parent_stack[context.parent_stack.len() - 3].get_type(),
93                SyntaxKind::InsertStatement | SyntaxKind::SetExpression
94            )
95            && context.parent_stack[context.parent_stack.len() - 2].get_type()
96                == SyntaxKind::WithCompoundStatement
97        {
98            return Vec::new();
99        }
100
101        if context.parent_stack.len() >= 3
102            && matches!(
103                context.parent_stack[context.parent_stack.len() - 3].get_type(),
104                SyntaxKind::CreateTableStatement | SyntaxKind::MergeStatement
105            )
106        {
107            return Vec::new();
108        }
109
110        if context.parent_stack.len() >= 4
111            && matches!(
112                context.parent_stack[context.parent_stack.len() - 4].get_type(),
113                SyntaxKind::CreateTableStatement | SyntaxKind::MergeStatement
114            )
115            && context.parent_stack[context.parent_stack.len() - 2].get_type()
116                == SyntaxKind::WithCompoundStatement
117        {
118            return Vec::new();
119        }
120
121        let select_clause_segment = context.segment.clone();
122        let select_target_elements: Vec<_> = select_clause_segment
123            .children(const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) })
124            .collect();
125
126        if select_target_elements.is_empty() {
127            return Vec::new();
128        }
129
130        let mut seen_band_elements: Vec<Vec<ErasedSegment>> = SELECT_ELEMENT_ORDER_PREFERENCE
131            .iter()
132            .map(|_| Vec::new())
133            .collect();
134        seen_band_elements.push(Vec::new());
135
136        for &segment in &select_target_elements {
137            let mut current_element_band: Option<usize> = None;
138
139            for (i, band) in enumerate(SELECT_ELEMENT_ORDER_PREFERENCE) {
140                for e in *band {
141                    match e {
142                        Validate::Types(types) => {
143                            if segment.child(types).is_some() {
144                                validate(
145                                    i,
146                                    segment.clone(),
147                                    &mut current_element_band,
148                                    &mut violation_exists,
149                                    &mut seen_band_elements,
150                                );
151                            }
152                        }
153                        Validate::Function { name } => {
154                            (|| {
155                                let function = segment
156                                    .child(const { &SyntaxSet::new(&[SyntaxKind::Function]) })?;
157                                let function_name = function.child(
158                                    const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) },
159                                )?;
160                                if function_name.raw() == *name {
161                                    validate(
162                                        i,
163                                        segment.clone(),
164                                        &mut current_element_band,
165                                        &mut violation_exists,
166                                        &mut seen_band_elements,
167                                    );
168                                }
169
170                                Some(())
171                            })();
172                        }
173                        Validate::Expression { child_typ } => {
174                            (|| {
175                                let expression = segment
176                                    .child(const { &SyntaxSet::new(&[SyntaxKind::Expression]) })?;
177                                if expression.child(&SyntaxSet::new(&[*child_typ])).is_some()
178                                    && matches!(
179                                        expression.segments()[0].get_type(),
180                                        SyntaxKind::ColumnReference
181                                            | SyntaxKind::ObjectReference
182                                            | SyntaxKind::Literal
183                                            | SyntaxKind::CastExpression
184                                    )
185                                    && expression.segments().len() == 2
186                                    || expression.segments().len() == 1
187                                {
188                                    validate(
189                                        i,
190                                        segment.clone(),
191                                        &mut current_element_band,
192                                        &mut violation_exists,
193                                        &mut seen_band_elements,
194                                    );
195                                }
196
197                                Some(())
198                            })();
199                        }
200                    }
201                }
202            }
203
204            if current_element_band.is_none() {
205                seen_band_elements.last_mut().unwrap().push(segment.clone());
206            }
207        }
208
209        if violation_exists {
210            if context
211                .parent_stack
212                .last()
213                .is_some_and(implicit_column_references)
214            {
215                return vec![LintResult::new(
216                    select_clause_segment.into(),
217                    Vec::new(),
218                    None,
219                    None,
220                )];
221            }
222
223            let ordered_select_target_elements =
224                seen_band_elements.into_iter().flatten().collect_vec();
225
226            let fixes = zip(select_target_elements, ordered_select_target_elements)
227                .filter_map(
228                    |(initial_select_target_element, replace_select_target_element)| {
229                        (initial_select_target_element != &replace_select_target_element).then(
230                            || {
231                                LintFix::replace(
232                                    initial_select_target_element.clone(),
233                                    vec![replace_select_target_element],
234                                    None,
235                                )
236                            },
237                        )
238                    },
239                )
240                .collect_vec();
241
242            return vec![LintResult::new(
243                select_clause_segment.into(),
244                fixes,
245                None,
246                None,
247            )];
248        }
249
250        Vec::new()
251    }
252
253    fn is_fix_compatible(&self) -> bool {
254        true
255    }
256
257    fn crawl_behaviour(&self) -> Crawler {
258        SegmentSeekerCrawler::new(const { SyntaxSet::new(&[SyntaxKind::SelectClause]) }).into()
259    }
260}
261
262enum Validate {
263    Types(SyntaxSet),
264    Function { name: &'static str },
265    Expression { child_typ: SyntaxKind },
266}
267
268fn validate(
269    i: usize,
270    segment: ErasedSegment,
271    current_element_band: &mut Option<usize>,
272    violation_exists: &mut bool,
273    seen_band_elements: &mut [Vec<ErasedSegment>],
274) {
275    if seen_band_elements[i + 1..] != vec![Vec::new(); seen_band_elements[i + 1..].len()] {
276        *violation_exists = true;
277    }
278
279    *current_element_band = Some(1);
280    seen_band_elements[i].push(segment);
281}
282
283fn implicit_column_references(segment: &ErasedSegment) -> bool {
284    if !matches!(
285        segment.get_type(),
286        SyntaxKind::WithingroupClause | SyntaxKind::WindowSpecification
287    ) {
288        if matches!(
289            segment.get_type(),
290            SyntaxKind::GroupbyClause | SyntaxKind::OrderbyClause
291        ) {
292            for seg in segment.segments() {
293                if seg.is_type(SyntaxKind::NumericLiteral) {
294                    return true;
295                }
296            }
297        } else {
298            for seg in segment.segments() {
299                if implicit_column_references(seg) {
300                    return true;
301                }
302            }
303        }
304    }
305
306    false
307}