sqruff_lib/utils/reflow/
depth_map.rs

1use std::iter::zip;
2
3use ahash::{AHashMap, AHashSet};
4use nohash_hasher::{IntMap, IntSet};
5use sqruff_lib_core::dialects::syntax::SyntaxSet;
6use sqruff_lib_core::parser::segments::base::{ErasedSegment, PathStep};
7
8/// An element of the stack_positions property of DepthInfo.
9#[derive(Debug, PartialEq, Eq, Clone)]
10pub struct StackPosition {
11    pub idx: usize,
12    pub len: usize,
13    pub type_: Option<StackPositionType>,
14}
15
16#[derive(Debug, PartialEq, Eq, Clone)]
17pub enum StackPositionType {
18    Solo,
19    Start,
20    End,
21}
22
23impl StackPosition {
24    /// Interpret a path step for stack_positions.
25    fn stack_pos_interpreter(path_step: &PathStep) -> Option<StackPositionType> {
26        if path_step.code_idxs.is_empty() {
27            None
28        } else if path_step.code_idxs.len() == 1 {
29            Some(StackPositionType::Solo)
30        } else if path_step.idx == *path_step.code_idxs.first().unwrap() {
31            Some(StackPositionType::Start)
32        } else if path_step.idx == *path_step.code_idxs.last().unwrap() {
33            Some(StackPositionType::End)
34        } else {
35            None
36        }
37    }
38
39    /// Interpret a PathStep to construct a StackPosition
40    fn from_path_step(path_step: &PathStep) -> StackPosition {
41        StackPosition {
42            idx: path_step.idx,
43            len: path_step.len,
44            type_: StackPosition::stack_pos_interpreter(path_step),
45        }
46    }
47}
48
49pub struct DepthMap {
50    depth_info: AHashMap<u32, DepthInfo>,
51}
52
53impl DepthMap {
54    fn new<'a>(raws_with_stack: impl Iterator<Item = &'a (ErasedSegment, Vec<PathStep>)>) -> Self {
55        let depth_info = raws_with_stack
56            .into_iter()
57            .map(|(raw, stack)| (raw.id(), DepthInfo::from_stack(stack)))
58            .collect();
59        Self { depth_info }
60    }
61
62    pub fn get_depth_info(&self, seg: &ErasedSegment) -> DepthInfo {
63        self.depth_info[&seg.id()].clone()
64    }
65
66    pub fn copy_depth_info(
67        &mut self,
68        anchor: &ErasedSegment,
69        new_segment: &ErasedSegment,
70        trim: u32,
71    ) {
72        self.depth_info.insert(
73            new_segment.id(),
74            self.get_depth_info(anchor).trim(trim.try_into().unwrap()),
75        );
76    }
77
78    pub fn from_parent(parent: &ErasedSegment) -> Self {
79        Self::new(parent.raw_segments_with_ancestors().iter())
80    }
81
82    pub fn from_raws_and_root(
83        raw_segments: impl Iterator<Item = ErasedSegment>,
84        root_segment: &ErasedSegment,
85    ) -> DepthMap {
86        let depth_info = raw_segments
87            .into_iter()
88            .map(|raw| {
89                let stack = root_segment.path_to(&raw);
90                (raw.id(), DepthInfo::from_stack(&stack))
91            })
92            .collect();
93
94        DepthMap { depth_info }
95    }
96}
97
98/// An object to hold the depth information for a specific raw segment.
99#[derive(Debug, PartialEq, Eq, Clone)]
100pub struct DepthInfo {
101    pub stack_depth: usize,
102    pub stack_hashes: Vec<u64>,
103    /// This is a convenience cache to speed up operations.
104    pub stack_hash_set: IntSet<u64>,
105    pub stack_class_types: Vec<SyntaxSet>,
106    pub stack_positions: IntMap<u64, StackPosition>,
107}
108
109impl DepthInfo {
110    fn from_stack(stack: &[PathStep]) -> DepthInfo {
111        let stack_hashes: Vec<u64> = stack.iter().map(|ps| ps.segment.hash_value()).collect();
112        let stack_hash_set: IntSet<u64> = IntSet::from_iter(stack_hashes.clone());
113
114        let stack_class_types = stack
115            .iter()
116            .map(|ps| ps.segment.class_types().clone())
117            .collect();
118
119        let stack_positions: IntMap<u64, StackPosition> = zip(stack_hashes.iter(), stack.iter())
120            .map(|(&hash, path)| (hash, StackPosition::from_path_step(path)))
121            .collect();
122
123        DepthInfo {
124            stack_depth: stack_hashes.len(),
125            stack_hashes,
126            stack_hash_set,
127            stack_class_types,
128            stack_positions,
129        }
130    }
131
132    pub fn trim(self, amount: usize) -> DepthInfo {
133        // Return a DepthInfo object with some amount trimmed.
134        if amount == 0 {
135            // The trivial case.
136            return self;
137        }
138
139        let slice_set: IntSet<_> = IntSet::from_iter(
140            self.stack_hashes[self.stack_hashes.len() - amount..]
141                .iter()
142                .copied(),
143        );
144
145        let new_hash_set: IntSet<_> = self
146            .stack_hash_set
147            .difference(&slice_set)
148            .copied()
149            .collect();
150
151        let stack_positions = self
152            .stack_positions
153            .into_iter()
154            .filter(|(hash, _)| new_hash_set.contains(hash))
155            .collect();
156
157        DepthInfo {
158            stack_depth: self.stack_depth - amount,
159            stack_hashes: self.stack_hashes[..self.stack_hashes.len() - amount].to_vec(),
160            stack_hash_set: new_hash_set,
161            stack_class_types: self.stack_class_types[..self.stack_class_types.len() - amount]
162                .to_vec(),
163            stack_positions,
164        }
165    }
166
167    pub fn common_with(&self, other: &DepthInfo) -> Vec<u64> {
168        // Get the common depth and hashes with the other.
169        // We use AHashSet intersection because it's efficient and hashes should be
170        // unique.
171
172        let common_hashes: AHashSet<_> = self
173            .stack_hash_set
174            .intersection(&other.stack_hash_set)
175            .copied()
176            .collect();
177
178        // We should expect there to be _at least_ one common ancestor, because
179        // they should share the same file segment. If that's not the case we
180        // should error because it's likely a bug or programming error.
181        assert!(
182            !common_hashes.is_empty(),
183            "DepthInfo comparison shares no common ancestor!"
184        );
185
186        let common_depth = common_hashes.len();
187        self.stack_hashes
188            .iter()
189            .take(common_depth)
190            .copied()
191            .collect()
192    }
193}