tree_sitter_cli/
query_testing.rs

1use std::{fs, path::Path, sync::LazyLock};
2
3use anyhow::{anyhow, Result};
4use bstr::{BStr, ByteSlice};
5use regex::Regex;
6use tree_sitter::{Language, Parser, Point};
7
8static CAPTURE_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("[\\w_\\-.]+").unwrap());
9
10#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
11pub struct Utf8Point {
12    pub row: usize,
13    pub column: usize,
14}
15
16impl std::fmt::Display for Utf8Point {
17    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18        write!(f, "({}, {})", self.row, self.column)
19    }
20}
21
22impl Utf8Point {
23    #[must_use]
24    pub const fn new(row: usize, column: usize) -> Self {
25        Self { row, column }
26    }
27}
28
29#[must_use]
30pub fn to_utf8_point(point: Point, source: &[u8]) -> Utf8Point {
31    if point.column == 0 {
32        return Utf8Point::new(point.row, 0);
33    }
34
35    let bstr = BStr::new(source);
36    let line = bstr.lines_with_terminator().nth(point.row).unwrap();
37    let mut utf8_column = 0;
38
39    for (_, grapheme_end, _) in line.grapheme_indices() {
40        utf8_column += 1;
41        if grapheme_end >= point.column {
42            break;
43        }
44    }
45
46    Utf8Point {
47        row: point.row,
48        column: utf8_column,
49    }
50}
51
52#[derive(Debug, Eq, PartialEq)]
53pub struct CaptureInfo {
54    pub name: String,
55    pub start: Utf8Point,
56    pub end: Utf8Point,
57}
58
59#[derive(Debug, PartialEq, Eq)]
60pub struct Assertion {
61    pub position: Utf8Point,
62    pub length: usize,
63    pub negative: bool,
64    pub expected_capture_name: String,
65}
66
67impl Assertion {
68    #[must_use]
69    pub const fn new(
70        row: usize,
71        col: usize,
72        length: usize,
73        negative: bool,
74        expected_capture_name: String,
75    ) -> Self {
76        Self {
77            position: Utf8Point::new(row, col),
78            length,
79            negative,
80            expected_capture_name,
81        }
82    }
83}
84
85/// Parse the given source code, finding all of the comments that contain
86/// highlighting assertions. Return a vector of (position, expected highlight name)
87/// pairs.
88pub fn parse_position_comments(
89    parser: &mut Parser,
90    language: &Language,
91    source: &[u8],
92) -> Result<Vec<Assertion>> {
93    let mut result = Vec::new();
94    let mut assertion_ranges = Vec::new();
95
96    // Parse the code.
97    parser.set_included_ranges(&[]).unwrap();
98    parser.set_language(language).unwrap();
99    let tree = parser.parse(source, None).unwrap();
100
101    // Walk the tree, finding comment nodes that contain assertions.
102    let mut ascending = false;
103    let mut cursor = tree.root_node().walk();
104    loop {
105        if ascending {
106            let node = cursor.node();
107
108            // Find every comment node.
109            if node.kind().to_lowercase().contains("comment") {
110                if let Ok(text) = node.utf8_text(source) {
111                    let mut position = node.start_position();
112                    if position.row > 0 {
113                        // Find the arrow character ("^" or "<-") in the comment. A left arrow
114                        // refers to the column where the comment node starts. An up arrow refers
115                        // to its own column.
116                        let mut has_left_caret = false;
117                        let mut has_arrow = false;
118                        let mut negative = false;
119                        let mut arrow_end = 0;
120                        let mut arrow_count = 1;
121                        for (i, c) in text.char_indices() {
122                            arrow_end = i + 1;
123                            if c == '-' && has_left_caret {
124                                has_arrow = true;
125                                break;
126                            }
127                            if c == '^' {
128                                has_arrow = true;
129                                position.column += i;
130                                // Continue counting remaining arrows and update their end column
131                                for (_, c) in text[arrow_end..].char_indices() {
132                                    if c != '^' {
133                                        arrow_end += arrow_count - 1;
134                                        break;
135                                    }
136                                    arrow_count += 1;
137                                }
138                                break;
139                            }
140                            has_left_caret = c == '<';
141                        }
142
143                        // find any ! after arrows but before capture name
144                        if has_arrow {
145                            for (i, c) in text[arrow_end..].char_indices() {
146                                if c == '!' {
147                                    negative = true;
148                                    arrow_end += i + 1;
149                                    break;
150                                } else if !c.is_whitespace() {
151                                    break;
152                                }
153                            }
154                        }
155
156                        // If the comment node contains an arrow and a highlight name, record the
157                        // highlight name and the position.
158                        if let (true, Some(mat)) =
159                            (has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..]))
160                        {
161                            assertion_ranges.push((node.start_position(), node.end_position()));
162                            result.push(Assertion {
163                                position: to_utf8_point(position, source),
164                                length: arrow_count,
165                                negative,
166                                expected_capture_name: mat.as_str().to_string(),
167                            });
168                        }
169                    }
170                }
171            }
172
173            // Continue walking the tree.
174            if cursor.goto_next_sibling() {
175                ascending = false;
176            } else if !cursor.goto_parent() {
177                break;
178            }
179        } else if !cursor.goto_first_child() {
180            ascending = true;
181        }
182    }
183
184    // Adjust the row number in each assertion's position to refer to the line of
185    // code *above* the assertion. There can be multiple lines of assertion comments and empty
186    // lines, so the positions may have to be decremented by more than one row.
187    let mut i = 0;
188    let lines = source.lines_with_terminator().collect::<Vec<_>>();
189    for assertion in &mut result {
190        let original_position = assertion.position;
191        loop {
192            let on_assertion_line = assertion_ranges[i..]
193                .iter()
194                .any(|(start, _)| start.row == assertion.position.row);
195            let on_empty_line = lines[assertion.position.row].len() <= assertion.position.column;
196            if on_assertion_line || on_empty_line {
197                if assertion.position.row > 0 {
198                    assertion.position.row -= 1;
199                } else {
200                    return Err(anyhow!(
201                        "Error: could not find a line that corresponds to the assertion `{}` located at {original_position}",
202                        assertion.expected_capture_name
203                    ));
204                }
205            } else {
206                while i < assertion_ranges.len()
207                    && assertion_ranges[i].0.row < assertion.position.row
208                {
209                    i += 1;
210                }
211                break;
212            }
213        }
214    }
215
216    // The assertions can end up out of order due to the line adjustments.
217    result.sort_unstable_by_key(|a| a.position);
218
219    Ok(result)
220}
221
222pub fn assert_expected_captures(
223    infos: &[CaptureInfo],
224    path: &Path,
225    parser: &mut Parser,
226    language: &Language,
227) -> Result<usize> {
228    let contents = fs::read_to_string(path)?;
229    let pairs = parse_position_comments(parser, language, contents.as_bytes())?;
230    for assertion in &pairs {
231        if let Some(found) = &infos.iter().find(|p| {
232            assertion.position >= p.start
233                && (assertion.position.row < p.end.row
234                    || assertion.position.column + assertion.length - 1 < p.end.column)
235        }) {
236            if assertion.expected_capture_name != found.name && found.name != "name" {
237                return Err(anyhow!(
238                    "Assertion failed: at {}, found {}, expected {}",
239                    found.start,
240                    assertion.expected_capture_name,
241                    found.name
242                ));
243            }
244        } else {
245            return Err(anyhow!(
246                "Assertion failed: could not match {} at row {}, column {}",
247                assertion.expected_capture_name,
248                assertion.position.row,
249                assertion.position.column + assertion.length - 1,
250            ));
251        }
252    }
253    Ok(pairs.len())
254}