sqruff_lib/core/rules/
crawlers.rs1use enum_dispatch::enum_dispatch;
2use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
3use sqruff_lib_core::parser::segments::base::ErasedSegment;
4
5use crate::core::rules::context::RuleContext;
6
7#[enum_dispatch]
8pub trait BaseCrawler {
9 fn works_on_unparsable(&self) -> bool {
10 false
11 }
12
13 fn passes_filter(&self, segment: &ErasedSegment) -> bool {
14 self.works_on_unparsable() || !segment.is_type(SyntaxKind::Unparsable)
15 }
16
17 fn crawl<'a>(&self, context: &mut RuleContext<'a>, f: &mut impl FnMut(&RuleContext<'a>));
18}
19
20#[enum_dispatch(BaseCrawler)]
21pub enum Crawler {
22 RootOnlyCrawler,
23 SegmentSeekerCrawler,
24 TokenSeekerCrawler,
25}
26
27#[derive(Debug, Default, Clone)]
32pub struct RootOnlyCrawler;
33
34impl BaseCrawler for RootOnlyCrawler {
35 fn crawl<'a>(&self, context: &mut RuleContext<'a>, f: &mut impl FnMut(&RuleContext<'a>)) {
36 if self.passes_filter(&context.segment) {
37 f(context);
38 }
39 }
40}
41
42pub struct SegmentSeekerCrawler {
43 types: SyntaxSet,
44 provide_raw_stack: bool,
45 allow_recurse: bool,
46}
47
48impl SegmentSeekerCrawler {
49 pub fn new(types: SyntaxSet) -> Self {
50 Self {
51 types,
52 provide_raw_stack: false,
53 allow_recurse: true,
54 }
55 }
56
57 pub fn disallow_recurse(mut self) -> Self {
58 self.allow_recurse = false;
59 self
60 }
61
62 pub fn provide_raw_stack(mut self) -> Self {
63 self.provide_raw_stack = true;
64 self
65 }
66
67 fn is_self_match(&self, segment: &ErasedSegment) -> bool {
68 self.types.contains(segment.get_type())
69 }
70}
71
72impl BaseCrawler for SegmentSeekerCrawler {
73 fn crawl<'a>(&self, context: &mut RuleContext<'a>, f: &mut impl FnMut(&RuleContext<'a>)) {
74 let mut self_match = false;
75
76 if self.is_self_match(&context.segment) {
77 self_match = true;
78 f(context);
79 }
80
81 if context.segment.segments().is_empty() || (self_match && !self.allow_recurse) {
82 return;
83 }
84
85 if !self.types.intersects(context.segment.descendant_type_set()) {
86 if self.provide_raw_stack {
87 let raw_segments = context.segment.get_raw_segments();
88 context.raw_stack.extend(raw_segments);
89 }
90
91 return;
92 }
93
94 let segment = context.segment.clone();
95 context.parent_stack.push(segment.clone());
96 for (idx, child) in segment.segments().iter().enumerate() {
97 context.segment = child.clone();
98 context.segment_idx = idx;
99 let checkpoint = context.checkpoint();
100 self.crawl(context, f);
101 context.restore(checkpoint);
102 }
103 }
104}
105
106pub struct TokenSeekerCrawler;
107
108impl BaseCrawler for TokenSeekerCrawler {
109 fn crawl<'a>(&self, context: &mut RuleContext<'a>, f: &mut impl FnMut(&RuleContext<'a>)) {
110 if context.segment.segments().is_empty() {
111 f(context);
112 }
113
114 let segment = context.segment.clone();
115 context.parent_stack.push(segment.clone());
116 for (idx, child) in segment.segments().iter().enumerate() {
117 context.segment = child.clone();
118 context.segment_idx = idx;
119
120 let checkpoint = context.checkpoint();
121 self.crawl(context, f);
122 context.restore(checkpoint);
123 }
124 }
125}