tree_sitter_cli/
highlight.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::Write,
4    fs,
5    io::{self, Write as _},
6    path::{self, Path, PathBuf},
7    str,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11
12use ansi_colours::{ansi256_from_rgb, rgb_from_ansi256};
13use anstyle::{Ansi256Color, AnsiColor, Color, Effects, RgbColor};
14use anyhow::Result;
15use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
16use serde_json::{json, Value};
17use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter, HtmlRenderer};
18use tree_sitter_loader::Loader;
19
20pub const HTML_HEAD_HEADER: &str = "
21<!doctype HTML>
22<head>
23  <title>Tree-sitter Highlighting</title>
24  <style>
25    body {
26      font-family: monospace
27    }
28    .line-number {
29      user-select: none;
30      text-align: right;
31      color: rgba(27,31,35,.3);
32      padding: 0 10px;
33    }
34    .line {
35      white-space: pre;
36    }
37  </style>";
38
39pub const HTML_BODY_HEADER: &str = "
40</head>
41<body>
42";
43
44pub const HTML_FOOTER: &str = "
45</body>
46";
47
48#[derive(Debug, Default)]
49pub struct Style {
50    pub ansi: anstyle::Style,
51    pub css: Option<String>,
52}
53
54#[derive(Debug)]
55pub struct Theme {
56    pub styles: Vec<Style>,
57    pub highlight_names: Vec<String>,
58}
59
60#[derive(Default, Deserialize, Serialize)]
61pub struct ThemeConfig {
62    #[serde(default)]
63    pub theme: Theme,
64}
65
66impl Theme {
67    pub fn load(path: &path::Path) -> io::Result<Self> {
68        let json = fs::read_to_string(path)?;
69        Ok(serde_json::from_str(&json).unwrap_or_default())
70    }
71
72    #[must_use]
73    pub fn default_style(&self) -> Style {
74        Style::default()
75    }
76}
77
78impl<'de> Deserialize<'de> for Theme {
79    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
80    where
81        D: Deserializer<'de>,
82    {
83        let mut styles = Vec::new();
84        let mut highlight_names = Vec::new();
85        if let Ok(colors) = HashMap::<String, Value>::deserialize(deserializer) {
86            highlight_names.reserve(colors.len());
87            styles.reserve(colors.len());
88            for (name, style_value) in colors {
89                let mut style = Style::default();
90                parse_style(&mut style, style_value);
91                highlight_names.push(name);
92                styles.push(style);
93            }
94        }
95        Ok(Self {
96            styles,
97            highlight_names,
98        })
99    }
100}
101
102impl Serialize for Theme {
103    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
104    where
105        S: Serializer,
106    {
107        let mut map = serializer.serialize_map(Some(self.styles.len()))?;
108        for (name, style) in self.highlight_names.iter().zip(&self.styles) {
109            let style = &style.ansi;
110            let color = style.get_fg_color().map(|color| match color {
111                Color::Ansi(color) => match color {
112                    AnsiColor::Black => json!("black"),
113                    AnsiColor::Blue => json!("blue"),
114                    AnsiColor::Cyan => json!("cyan"),
115                    AnsiColor::Green => json!("green"),
116                    AnsiColor::Magenta => json!("purple"),
117                    AnsiColor::Red => json!("red"),
118                    AnsiColor::White => json!("white"),
119                    AnsiColor::Yellow => json!("yellow"),
120                    _ => unreachable!(),
121                },
122                Color::Ansi256(Ansi256Color(n)) => json!(n),
123                Color::Rgb(RgbColor(r, g, b)) => json!(format!("#{r:x?}{g:x?}{b:x?}")),
124            });
125            let effects = style.get_effects();
126            if effects.contains(Effects::BOLD)
127                || effects.contains(Effects::ITALIC)
128                || effects.contains(Effects::UNDERLINE)
129            {
130                let mut style_json = HashMap::new();
131                if let Some(color) = color {
132                    style_json.insert("color", color);
133                }
134                if effects.contains(Effects::BOLD) {
135                    style_json.insert("bold", Value::Bool(true));
136                }
137                if effects.contains(Effects::ITALIC) {
138                    style_json.insert("italic", Value::Bool(true));
139                }
140                if effects.contains(Effects::UNDERLINE) {
141                    style_json.insert("underline", Value::Bool(true));
142                }
143                map.serialize_entry(&name, &style_json)?;
144            } else if let Some(color) = color {
145                map.serialize_entry(&name, &color)?;
146            } else {
147                map.serialize_entry(&name, &Value::Null)?;
148            }
149        }
150        map.end()
151    }
152}
153
154impl Default for Theme {
155    fn default() -> Self {
156        serde_json::from_value(json!({
157            "attribute": {"color": 124, "italic": true},
158            "comment": {"color": 245, "italic": true},
159            "constant": 94,
160            "constant.builtin": {"color": 94, "bold": true},
161            "constructor": 136,
162            "embedded": null,
163            "function": 26,
164            "function.builtin": {"color": 26, "bold": true},
165            "keyword": 56,
166            "module": 136,
167            "number": {"color": 94, "bold": true},
168            "operator": {"color": 239, "bold": true},
169            "property": 124,
170            "property.builtin": {"color": 124, "bold": true},
171            "punctuation": 239,
172            "punctuation.bracket": 239,
173            "punctuation.delimiter": 239,
174            "punctuation.special": 239,
175            "string": 28,
176            "string.special": 30,
177            "tag": 18,
178            "type": 23,
179            "type.builtin": {"color": 23, "bold": true},
180            "variable": 252,
181            "variable.builtin": {"color": 252, "bold": true},
182            "variable.parameter": {"color": 252, "underline": true}
183        }))
184        .unwrap()
185    }
186}
187
188fn parse_style(style: &mut Style, json: Value) {
189    if let Value::Object(entries) = json {
190        for (property_name, value) in entries {
191            match property_name.as_str() {
192                "bold" => {
193                    if value == Value::Bool(true) {
194                        style.ansi = style.ansi.bold();
195                    }
196                }
197                "italic" => {
198                    if value == Value::Bool(true) {
199                        style.ansi = style.ansi.italic();
200                    }
201                }
202                "underline" => {
203                    if value == Value::Bool(true) {
204                        style.ansi = style.ansi.underline();
205                    }
206                }
207                "color" => {
208                    if let Some(color) = parse_color(value) {
209                        style.ansi = style.ansi.fg_color(Some(color));
210                    }
211                }
212                _ => {}
213            }
214        }
215        style.css = Some(style_to_css(style.ansi));
216    } else if let Some(color) = parse_color(json) {
217        style.ansi = style.ansi.fg_color(Some(color));
218        style.css = Some(style_to_css(style.ansi));
219    } else {
220        style.css = None;
221    }
222
223    if let Some(Color::Rgb(RgbColor(red, green, blue))) = style.ansi.get_fg_color() {
224        if !terminal_supports_truecolor() {
225            let ansi256 = Color::Ansi256(Ansi256Color(ansi256_from_rgb((red, green, blue))));
226            style.ansi = style.ansi.fg_color(Some(ansi256));
227        }
228    }
229}
230
231fn parse_color(json: Value) -> Option<Color> {
232    match json {
233        Value::Number(n) => n.as_u64().map(|n| Color::Ansi256(Ansi256Color(n as u8))),
234        Value::String(s) => match s.to_lowercase().as_str() {
235            "black" => Some(Color::Ansi(AnsiColor::Black)),
236            "blue" => Some(Color::Ansi(AnsiColor::Blue)),
237            "cyan" => Some(Color::Ansi(AnsiColor::Cyan)),
238            "green" => Some(Color::Ansi(AnsiColor::Green)),
239            "purple" => Some(Color::Ansi(AnsiColor::Magenta)),
240            "red" => Some(Color::Ansi(AnsiColor::Red)),
241            "white" => Some(Color::Ansi(AnsiColor::White)),
242            "yellow" => Some(Color::Ansi(AnsiColor::Yellow)),
243            s => {
244                if let Some((red, green, blue)) = hex_string_to_rgb(s) {
245                    Some(Color::Rgb(RgbColor(red, green, blue)))
246                } else {
247                    None
248                }
249            }
250        },
251        _ => None,
252    }
253}
254
255fn hex_string_to_rgb(s: &str) -> Option<(u8, u8, u8)> {
256    if s.starts_with('#') && s.len() >= 7 {
257        if let (Ok(red), Ok(green), Ok(blue)) = (
258            u8::from_str_radix(&s[1..3], 16),
259            u8::from_str_radix(&s[3..5], 16),
260            u8::from_str_radix(&s[5..7], 16),
261        ) {
262            Some((red, green, blue))
263        } else {
264            None
265        }
266    } else {
267        None
268    }
269}
270
271fn style_to_css(style: anstyle::Style) -> String {
272    let mut result = String::new();
273    let effects = style.get_effects();
274    if effects.contains(Effects::UNDERLINE) {
275        write!(&mut result, "text-decoration: underline;").unwrap();
276    }
277    if effects.contains(Effects::BOLD) {
278        write!(&mut result, "font-weight: bold;").unwrap();
279    }
280    if effects.contains(Effects::ITALIC) {
281        write!(&mut result, "font-style: italic;").unwrap();
282    }
283    if let Some(color) = style.get_fg_color() {
284        write_color(&mut result, color);
285    }
286    result
287}
288
289fn write_color(buffer: &mut String, color: Color) {
290    match color {
291        Color::Ansi(color) => match color {
292            AnsiColor::Black => write!(buffer, "color: black").unwrap(),
293            AnsiColor::Red => write!(buffer, "color: red").unwrap(),
294            AnsiColor::Green => write!(buffer, "color: green").unwrap(),
295            AnsiColor::Yellow => write!(buffer, "color: yellow").unwrap(),
296            AnsiColor::Blue => write!(buffer, "color: blue").unwrap(),
297            AnsiColor::Magenta => write!(buffer, "color: purple").unwrap(),
298            AnsiColor::Cyan => write!(buffer, "color: cyan").unwrap(),
299            AnsiColor::White => write!(buffer, "color: white").unwrap(),
300            _ => unreachable!(),
301        },
302        Color::Ansi256(Ansi256Color(n)) => {
303            let (r, g, b) = rgb_from_ansi256(n);
304            write!(buffer, "color: #{r:02x}{g:02x}{b:02x}").unwrap();
305        }
306        Color::Rgb(RgbColor(r, g, b)) => write!(buffer, "color: #{r:02x}{g:02x}{b:02x}").unwrap(),
307    }
308}
309
310fn terminal_supports_truecolor() -> bool {
311    std::env::var("COLORTERM")
312        .is_ok_and(|truecolor| truecolor == "truecolor" || truecolor == "24bit")
313}
314
315pub struct HighlightOptions {
316    pub theme: Theme,
317    pub check: bool,
318    pub captures_path: Option<PathBuf>,
319    pub inline_styles: bool,
320    pub html: bool,
321    pub quiet: bool,
322    pub print_time: bool,
323    pub cancellation_flag: Arc<AtomicUsize>,
324}
325
326pub fn highlight(
327    loader: &Loader,
328    path: &Path,
329    name: &str,
330    config: &HighlightConfiguration,
331    print_name: bool,
332    opts: &HighlightOptions,
333) -> Result<()> {
334    if opts.check {
335        let names = if let Some(path) = opts.captures_path.as_deref() {
336            let file = fs::read_to_string(path)?;
337            let capture_names = file
338                .lines()
339                .filter_map(|line| {
340                    if line.trim().is_empty() || line.trim().starts_with(';') {
341                        return None;
342                    }
343                    line.split(';').next().map(|s| s.trim().trim_matches('"'))
344                })
345                .collect::<HashSet<_>>();
346            config.nonconformant_capture_names(&capture_names)
347        } else {
348            config.nonconformant_capture_names(&HashSet::new())
349        };
350        if names.is_empty() {
351            eprintln!("All highlight captures conform to standards.");
352        } else {
353            eprintln!(
354                "Non-standard highlight {} detected:",
355                if names.len() > 1 {
356                    "captures"
357                } else {
358                    "capture"
359                }
360            );
361            for name in names {
362                eprintln!("* {name}");
363            }
364        }
365    }
366
367    let source = fs::read(path)?;
368    let stdout = io::stdout();
369    let mut stdout = stdout.lock();
370    let time = Instant::now();
371    let mut highlighter = Highlighter::new();
372    let events =
373        highlighter.highlight(config, &source, Some(&opts.cancellation_flag), |string| {
374            loader.highlight_config_for_injection_string(string)
375        })?;
376    let theme = &opts.theme;
377
378    if !opts.quiet && print_name {
379        writeln!(&mut stdout, "{name}")?;
380    }
381
382    if opts.html {
383        if !opts.quiet {
384            writeln!(&mut stdout, "{HTML_HEAD_HEADER}")?;
385            writeln!(&mut stdout, "  <style>")?;
386            let names = theme.highlight_names.iter();
387            let styles = theme.styles.iter();
388            for (name, style) in names.zip(styles) {
389                if let Some(css) = &style.css {
390                    writeln!(&mut stdout, "    .{name} {{ {css}; }}")?;
391                }
392            }
393            writeln!(&mut stdout, "  </style>")?;
394            writeln!(&mut stdout, "{HTML_BODY_HEADER}")?;
395        }
396
397        let mut renderer = HtmlRenderer::new();
398        renderer.render(events, &source, &move |highlight, output| {
399            if opts.inline_styles {
400                output.extend(b"style='");
401                output.extend(
402                    theme.styles[highlight.0]
403                        .css
404                        .as_ref()
405                        .map_or_else(|| "".as_bytes(), |css_style| css_style.as_bytes()),
406                );
407                output.extend(b"'");
408            } else {
409                output.extend(b"class='");
410                let mut parts = theme.highlight_names[highlight.0].split('.').peekable();
411                while let Some(part) = parts.next() {
412                    output.extend(part.as_bytes());
413                    if parts.peek().is_some() {
414                        output.extend(b" ");
415                    }
416                }
417                output.extend(b"'");
418            }
419        })?;
420
421        if !opts.quiet {
422            writeln!(&mut stdout, "<table>")?;
423            for (i, line) in renderer.lines().enumerate() {
424                writeln!(
425                    &mut stdout,
426                    "<tr><td class=line-number>{}</td><td class=line>{line}</td></tr>",
427                    i + 1,
428                )?;
429            }
430            writeln!(&mut stdout, "</table>")?;
431            writeln!(&mut stdout, "{HTML_FOOTER}")?;
432        }
433    } else {
434        let mut style_stack = vec![theme.default_style().ansi];
435        for event in events {
436            match event? {
437                HighlightEvent::HighlightStart(highlight) => {
438                    style_stack.push(theme.styles[highlight.0].ansi);
439                }
440                HighlightEvent::HighlightEnd => {
441                    style_stack.pop();
442                }
443                HighlightEvent::Source { start, end } => {
444                    let style = style_stack.last().unwrap();
445                    write!(&mut stdout, "{style}").unwrap();
446                    stdout.write_all(&source[start..end])?;
447                    write!(&mut stdout, "{style:#}").unwrap();
448                }
449            }
450        }
451    }
452
453    if opts.print_time {
454        eprintln!("Time: {}ms", time.elapsed().as_millis());
455    }
456
457    Ok(())
458}
459
460#[cfg(test)]
461mod tests {
462    use std::env;
463
464    use super::*;
465
466    const JUNGLE_GREEN: &str = "#26A69A";
467    const DARK_CYAN: &str = "#00AF87";
468
469    #[test]
470    fn test_parse_style() {
471        let original_environment_variable = env::var("COLORTERM");
472
473        let mut style = Style::default();
474        assert_eq!(style.ansi.get_fg_color(), None);
475        assert_eq!(style.css, None);
476
477        // darkcyan is an ANSI color and is preserved
478        env::set_var("COLORTERM", "");
479        parse_style(&mut style, Value::String(DARK_CYAN.to_string()));
480        assert_eq!(
481            style.ansi.get_fg_color(),
482            Some(Color::Ansi256(Ansi256Color(36)))
483        );
484        assert_eq!(style.css, Some("color: #00af87".to_string()));
485
486        // junglegreen is not an ANSI color and is preserved when the terminal supports it
487        env::set_var("COLORTERM", "truecolor");
488        parse_style(&mut style, Value::String(JUNGLE_GREEN.to_string()));
489        assert_eq!(
490            style.ansi.get_fg_color(),
491            Some(Color::Rgb(RgbColor(38, 166, 154)))
492        );
493        assert_eq!(style.css, Some("color: #26a69a".to_string()));
494
495        // junglegreen gets approximated as cadetblue when the terminal does not support it
496        env::set_var("COLORTERM", "");
497        parse_style(&mut style, Value::String(JUNGLE_GREEN.to_string()));
498        assert_eq!(
499            style.ansi.get_fg_color(),
500            Some(Color::Ansi256(Ansi256Color(72)))
501        );
502        assert_eq!(style.css, Some("color: #26a69a".to_string()));
503
504        if let Ok(environment_variable) = original_environment_variable {
505            env::set_var("COLORTERM", environment_variable);
506        } else {
507            env::remove_var("COLORTERM");
508        }
509    }
510}