tree_sitter_cli/
query_testing.rs1use std::{fs, path::Path, sync::LazyLock};
2
3use anyhow::{anyhow, Result};
4use bstr::{BStr, ByteSlice};
5use regex::Regex;
6use tree_sitter::{Language, Parser, Point};
7
8static CAPTURE_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("[\\w_\\-.]+").unwrap());
9
10#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
11pub struct Utf8Point {
12 pub row: usize,
13 pub column: usize,
14}
15
16impl std::fmt::Display for Utf8Point {
17 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18 write!(f, "({}, {})", self.row, self.column)
19 }
20}
21
22impl Utf8Point {
23 #[must_use]
24 pub const fn new(row: usize, column: usize) -> Self {
25 Self { row, column }
26 }
27}
28
29#[must_use]
30pub fn to_utf8_point(point: Point, source: &[u8]) -> Utf8Point {
31 if point.column == 0 {
32 return Utf8Point::new(point.row, 0);
33 }
34
35 let bstr = BStr::new(source);
36 let line = bstr.lines_with_terminator().nth(point.row).unwrap();
37 let mut utf8_column = 0;
38
39 for (_, grapheme_end, _) in line.grapheme_indices() {
40 utf8_column += 1;
41 if grapheme_end >= point.column {
42 break;
43 }
44 }
45
46 Utf8Point {
47 row: point.row,
48 column: utf8_column,
49 }
50}
51
52#[derive(Debug, Eq, PartialEq)]
53pub struct CaptureInfo {
54 pub name: String,
55 pub start: Utf8Point,
56 pub end: Utf8Point,
57}
58
59#[derive(Debug, PartialEq, Eq)]
60pub struct Assertion {
61 pub position: Utf8Point,
62 pub length: usize,
63 pub negative: bool,
64 pub expected_capture_name: String,
65}
66
67impl Assertion {
68 #[must_use]
69 pub const fn new(
70 row: usize,
71 col: usize,
72 length: usize,
73 negative: bool,
74 expected_capture_name: String,
75 ) -> Self {
76 Self {
77 position: Utf8Point::new(row, col),
78 length,
79 negative,
80 expected_capture_name,
81 }
82 }
83}
84
85pub fn parse_position_comments(
89 parser: &mut Parser,
90 language: &Language,
91 source: &[u8],
92) -> Result<Vec<Assertion>> {
93 let mut result = Vec::new();
94 let mut assertion_ranges = Vec::new();
95
96 parser.set_included_ranges(&[]).unwrap();
98 parser.set_language(language).unwrap();
99 let tree = parser.parse(source, None).unwrap();
100
101 let mut ascending = false;
103 let mut cursor = tree.root_node().walk();
104 loop {
105 if ascending {
106 let node = cursor.node();
107
108 if node.kind().to_lowercase().contains("comment") {
110 if let Ok(text) = node.utf8_text(source) {
111 let mut position = node.start_position();
112 if position.row > 0 {
113 let mut has_left_caret = false;
117 let mut has_arrow = false;
118 let mut negative = false;
119 let mut arrow_end = 0;
120 let mut arrow_count = 1;
121 for (i, c) in text.char_indices() {
122 arrow_end = i + 1;
123 if c == '-' && has_left_caret {
124 has_arrow = true;
125 break;
126 }
127 if c == '^' {
128 has_arrow = true;
129 position.column += i;
130 for (_, c) in text[arrow_end..].char_indices() {
132 if c != '^' {
133 arrow_end += arrow_count - 1;
134 break;
135 }
136 arrow_count += 1;
137 }
138 break;
139 }
140 has_left_caret = c == '<';
141 }
142
143 if has_arrow {
145 for (i, c) in text[arrow_end..].char_indices() {
146 if c == '!' {
147 negative = true;
148 arrow_end += i + 1;
149 break;
150 } else if !c.is_whitespace() {
151 break;
152 }
153 }
154 }
155
156 if let (true, Some(mat)) =
159 (has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..]))
160 {
161 assertion_ranges.push((node.start_position(), node.end_position()));
162 result.push(Assertion {
163 position: to_utf8_point(position, source),
164 length: arrow_count,
165 negative,
166 expected_capture_name: mat.as_str().to_string(),
167 });
168 }
169 }
170 }
171 }
172
173 if cursor.goto_next_sibling() {
175 ascending = false;
176 } else if !cursor.goto_parent() {
177 break;
178 }
179 } else if !cursor.goto_first_child() {
180 ascending = true;
181 }
182 }
183
184 let mut i = 0;
188 let lines = source.lines_with_terminator().collect::<Vec<_>>();
189 for assertion in &mut result {
190 let original_position = assertion.position;
191 loop {
192 let on_assertion_line = assertion_ranges[i..]
193 .iter()
194 .any(|(start, _)| start.row == assertion.position.row);
195 let on_empty_line = lines[assertion.position.row].len() <= assertion.position.column;
196 if on_assertion_line || on_empty_line {
197 if assertion.position.row > 0 {
198 assertion.position.row -= 1;
199 } else {
200 return Err(anyhow!(
201 "Error: could not find a line that corresponds to the assertion `{}` located at {original_position}",
202 assertion.expected_capture_name
203 ));
204 }
205 } else {
206 while i < assertion_ranges.len()
207 && assertion_ranges[i].0.row < assertion.position.row
208 {
209 i += 1;
210 }
211 break;
212 }
213 }
214 }
215
216 result.sort_unstable_by_key(|a| a.position);
218
219 Ok(result)
220}
221
222pub fn assert_expected_captures(
223 infos: &[CaptureInfo],
224 path: &Path,
225 parser: &mut Parser,
226 language: &Language,
227) -> Result<usize> {
228 let contents = fs::read_to_string(path)?;
229 let pairs = parse_position_comments(parser, language, contents.as_bytes())?;
230 for assertion in &pairs {
231 if let Some(found) = &infos.iter().find(|p| {
232 assertion.position >= p.start
233 && (assertion.position.row < p.end.row
234 || assertion.position.column + assertion.length - 1 < p.end.column)
235 }) {
236 if assertion.expected_capture_name != found.name && found.name != "name" {
237 return Err(anyhow!(
238 "Assertion failed: at {}, found {}, expected {}",
239 found.start,
240 assertion.expected_capture_name,
241 found.name
242 ));
243 }
244 } else {
245 return Err(anyhow!(
246 "Assertion failed: could not match {} at row {}, column {}",
247 assertion.expected_capture_name,
248 assertion.position.row,
249 assertion.position.column + assertion.length - 1,
250 ));
251 }
252 }
253 Ok(pairs.len())
254}