sqruff_lib/core/rules/
base.rs

1use std::fmt::{self, Debug};
2use std::ops::Deref;
3
4use std::sync::Arc;
5
6use ahash::{AHashMap, AHashSet};
7use itertools::chain;
8use sqruff_lib_core::dialects::base::Dialect;
9use sqruff_lib_core::dialects::init::DialectKind;
10use sqruff_lib_core::errors::{ErrorStructRule, SQLLintError};
11use sqruff_lib_core::helpers::{Config, IndexMap};
12use sqruff_lib_core::lint_fix::LintFix;
13use sqruff_lib_core::parser::segments::base::{ErasedSegment, Tables};
14use sqruff_lib_core::templaters::base::TemplatedFile;
15use strum_macros::AsRefStr;
16
17use super::context::RuleContext;
18use super::crawlers::{BaseCrawler, Crawler};
19use crate::core::config::{FluffConfig, Value};
20
21pub struct LintResult {
22    pub anchor: Option<ErasedSegment>,
23    pub fixes: Vec<LintFix>,
24    description: Option<String>,
25    source: String,
26}
27
28#[derive(Debug, Clone, PartialEq, Copy, Hash, Eq, AsRefStr)]
29#[strum(serialize_all = "lowercase")]
30pub enum RuleGroups {
31    All,
32    Core,
33    Aliasing,
34    Ambiguous,
35    Capitalisation,
36    Convention,
37    Layout,
38    References,
39    Structure,
40}
41
42impl LintResult {
43    pub fn new(
44        anchor: Option<ErasedSegment>,
45        fixes: Vec<LintFix>,
46        description: Option<String>,
47        source: Option<String>,
48    ) -> Self {
49        // let fixes = fixes.into_iter().filter(|f| !f.is_trivial()).collect();
50
51        LintResult {
52            anchor,
53            fixes,
54            description,
55            source: source.unwrap_or_default(),
56        }
57    }
58
59    pub fn to_linting_error(&self, rule: ErasedRule, fixes: Vec<LintFix>) -> Option<SQLLintError> {
60        let anchor = self.anchor.clone()?;
61
62        let description = self
63            .description
64            .clone()
65            .unwrap_or_else(|| rule.description().to_string());
66
67        let is_fixable = rule.is_fix_compatible();
68
69        SQLLintError::new(description.as_str(), anchor, is_fixable, fixes)
70            .config(|this| {
71                this.rule = Some(ErrorStructRule {
72                    name: rule.name(),
73                    code: rule.code(),
74                })
75            })
76            .into()
77    }
78}
79
80impl Debug for LintResult {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match &self.anchor {
83            None => write!(f, "LintResult(<empty>)"),
84            Some(anchor) => {
85                let fix_coda = if !self.fixes.is_empty() {
86                    format!("+{}F", self.fixes.len())
87                } else {
88                    "".to_string()
89                };
90
91                match &self.description {
92                    Some(desc) => {
93                        if !self.source.is_empty() {
94                            write!(
95                                f,
96                                "LintResult({} [{}]: {:?}{})",
97                                desc, self.source, anchor, fix_coda
98                            )
99                        } else {
100                            write!(f, "LintResult({}: {:?}{})", desc, anchor, fix_coda)
101                        }
102                    }
103                    None => write!(f, "LintResult({:?}{})", anchor, fix_coda),
104                }
105            }
106        }
107    }
108}
109
110pub trait CloneRule {
111    fn erased(&self) -> ErasedRule;
112}
113
114impl<T: Rule> CloneRule for T {
115    fn erased(&self) -> ErasedRule {
116        dyn_clone::clone(self).erased()
117    }
118}
119
120#[derive(Debug, Clone, PartialEq)]
121pub enum LintPhase {
122    Main,
123    Post,
124}
125
126pub trait Rule: CloneRule + dyn_clone::DynClone + Debug + 'static + Send + Sync {
127    fn load_from_config(&self, _config: &AHashMap<String, Value>) -> Result<ErasedRule, String>;
128
129    fn lint_phase(&self) -> LintPhase {
130        LintPhase::Main
131    }
132
133    fn name(&self) -> &'static str;
134
135    fn config_ref(&self) -> &'static str {
136        self.name()
137    }
138
139    fn description(&self) -> &'static str;
140
141    fn long_description(&self) -> &'static str;
142
143    /// All the groups this rule belongs to, including 'all' because that is a
144    /// given. There should be no duplicates and 'all' should be the first
145    /// element.
146    fn groups(&self) -> &'static [RuleGroups];
147
148    fn force_enable(&self) -> bool {
149        false
150    }
151
152    /// Returns the set of dialects for which a particular rule should be
153    /// skipped.
154    fn dialect_skip(&self) -> &'static [DialectKind] {
155        &[]
156    }
157
158    fn code(&self) -> &'static str {
159        let name = std::any::type_name::<Self>();
160        name.split("::")
161            .last()
162            .unwrap()
163            .strip_prefix("Rule")
164            .unwrap_or(name)
165    }
166
167    fn eval(&self, rule_cx: &RuleContext) -> Vec<LintResult>;
168
169    fn is_fix_compatible(&self) -> bool {
170        false
171    }
172
173    fn crawl_behaviour(&self) -> Crawler;
174
175    fn crawl(
176        &self,
177        tables: &Tables,
178        dialect: &Dialect,
179        templated_file: &TemplatedFile,
180        tree: ErasedSegment,
181        config: &FluffConfig,
182    ) -> Vec<SQLLintError> {
183        let mut root_context = RuleContext::new(tables, dialect, config, tree.clone());
184        let mut vs = Vec::new();
185
186        // TODO Will to return a note that rules were skipped
187        if self.dialect_skip().contains(&dialect.name) && !self.force_enable() {
188            return Vec::new();
189        }
190
191        self.crawl_behaviour().crawl(&mut root_context, &mut |context| {
192            let resp =
193                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| self.eval(context)));
194
195            let resp = match resp {
196                Ok(t) => t,
197                Err(_) => {
198                    vs.push(SQLLintError::new("Unexpected exception. Could you open an issue at https://github.com/quarylabs/sqruff", tree.clone(), false, vec![]));
199                    Vec::new()
200                }
201            };
202
203            let mut new_lerrs = Vec::new();
204
205            if resp.is_empty() {
206                // Assume this means no problems (also means no memory)
207            } else {
208                for elem in resp {
209                    self.process_lint_result(elem, templated_file, &mut new_lerrs);
210                }
211            }
212
213            // Consume the new results
214            vs.extend(new_lerrs);
215        });
216
217        vs
218    }
219
220    fn process_lint_result(
221        &self,
222        res: LintResult,
223        templated_file: &TemplatedFile,
224        new_lerrs: &mut Vec<SQLLintError>,
225    ) {
226        if res
227            .fixes
228            .iter()
229            .any(|it| it.has_template_conflicts(templated_file))
230        {
231            return;
232        }
233
234        if let Some(lerr) = res.to_linting_error(self.erased(), res.fixes.clone()) {
235            new_lerrs.push(lerr);
236        }
237    }
238}
239
240dyn_clone::clone_trait_object!(Rule);
241
242#[derive(Debug, Clone)]
243pub struct ErasedRule {
244    erased: Arc<dyn Rule>,
245}
246
247impl PartialEq for ErasedRule {
248    fn eq(&self, _other: &Self) -> bool {
249        unimplemented!()
250    }
251}
252
253impl Deref for ErasedRule {
254    type Target = dyn Rule;
255
256    fn deref(&self) -> &Self::Target {
257        self.erased.as_ref()
258    }
259}
260
261pub trait Erased {
262    type Erased;
263
264    fn erased(self) -> Self::Erased;
265}
266
267impl<T: Rule> Erased for T {
268    type Erased = ErasedRule;
269
270    fn erased(self) -> Self::Erased {
271        ErasedRule {
272            erased: Arc::new(self),
273        }
274    }
275}
276
277pub struct RuleManifest {
278    pub code: &'static str,
279    pub name: &'static str,
280    pub description: &'static str,
281    pub groups: &'static [RuleGroups],
282    pub rule_class: ErasedRule,
283}
284
285#[derive(Clone)]
286pub struct RulePack {
287    pub(crate) rules: Vec<ErasedRule>,
288    _reference_map: AHashMap<&'static str, AHashSet<&'static str>>,
289}
290
291impl RulePack {
292    pub fn rules(&self) -> Vec<ErasedRule> {
293        self.rules.clone()
294    }
295}
296
297pub struct RuleSet {
298    pub(crate) register: IndexMap<&'static str, RuleManifest>,
299}
300
301impl RuleSet {
302    fn rule_reference_map(&self) -> AHashMap<&'static str, AHashSet<&'static str>> {
303        let valid_codes: AHashSet<_> = self.register.keys().copied().collect();
304
305        let reference_map: AHashMap<_, AHashSet<_>> = valid_codes
306            .iter()
307            .map(|&code| (code, AHashSet::from([code])))
308            .collect();
309
310        let name_map = {
311            let mut name_map = AHashMap::new();
312            for manifest in self.register.values() {
313                name_map
314                    .entry(manifest.name)
315                    .or_insert_with(AHashSet::new)
316                    .insert(manifest.code);
317            }
318            name_map
319        };
320
321        let name_collisions: AHashSet<_> = {
322            let name_keys: AHashSet<_> = name_map.keys().copied().collect();
323            name_keys.intersection(&valid_codes).copied().collect()
324        };
325
326        if !name_collisions.is_empty() {
327            tracing::warn!(
328                "The following defined rule names were found which collide with codes. Those \
329                 names will not be available for selection: {name_collisions:?}",
330            );
331        }
332
333        let reference_map: AHashMap<_, _> = chain(name_map, reference_map).collect();
334
335        let mut group_map: AHashMap<_, AHashSet<&'static str>> = AHashMap::new();
336        for manifest in self.register.values() {
337            for group in manifest.groups {
338                let group = group.as_ref();
339                if let Some(codes) = reference_map.get(group) {
340                    tracing::warn!(
341                        "Rule {} defines group '{}' which is already defined as a name or code of \
342                         {:?}. This group will not be available for use as a result of this \
343                         collision.",
344                        manifest.code,
345                        group,
346                        codes
347                    );
348                } else {
349                    group_map
350                        .entry(group)
351                        .or_insert_with(AHashSet::new)
352                        .insert(manifest.code);
353                }
354            }
355        }
356
357        chain(group_map, reference_map).collect()
358    }
359
360    fn expand_rule_refs(
361        &self,
362        glob_list: Vec<String>,
363        reference_map: &AHashMap<&'static str, AHashSet<&'static str>>,
364    ) -> AHashSet<&'static str> {
365        let mut expanded_rule_set = AHashSet::new();
366
367        for r in glob_list {
368            if reference_map.contains_key(r.as_str()) {
369                expanded_rule_set.extend(reference_map[r.as_str()].clone());
370            } else {
371                panic!("Rule {r} not found in rule reference map");
372            }
373        }
374
375        expanded_rule_set
376    }
377
378    pub(crate) fn get_rulepack(&self, config: &FluffConfig) -> RulePack {
379        let reference_map = self.rule_reference_map();
380        let rules = config.get_section("rules");
381        let keylist = self.register.keys();
382        let mut instantiated_rules = Vec::with_capacity(keylist.len());
383
384        let allowlist: Vec<String> = match config.get("rule_allowlist", "core").as_array() {
385            Some(array) => array
386                .iter()
387                .map(|it| it.as_string().unwrap().to_owned())
388                .collect(),
389            None => self.register.keys().map(|it| it.to_string()).collect(),
390        };
391
392        let denylist: Vec<String> = match config.get("rule_denylist", "core").as_array() {
393            Some(array) => array
394                .iter()
395                .map(|it| it.as_string().unwrap().to_owned())
396                .collect(),
397            None => Vec::new(),
398        };
399
400        let expanded_allowlist = self.expand_rule_refs(allowlist, &reference_map);
401        let expanded_denylist = self.expand_rule_refs(denylist, &reference_map);
402
403        let keylist: Vec<_> = keylist
404            .into_iter()
405            .filter(|&&r| expanded_allowlist.contains(r) && !expanded_denylist.contains(r))
406            .collect();
407
408        for code in keylist {
409            let rule = self.register[code].rule_class.clone();
410            let rule_config_ref = rule.config_ref();
411
412            let tmp = AHashMap::new();
413
414            let specific_rule_config = rules
415                .get(rule_config_ref)
416                .and_then(|section| section.as_map())
417                .unwrap_or(&tmp);
418
419            // TODO fail the rulepack if any need unwrapping
420            instantiated_rules.push(rule.load_from_config(specific_rule_config).unwrap());
421        }
422
423        RulePack {
424            rules: instantiated_rules,
425            _reference_map: reference_map,
426        }
427    }
428}