use std::fmt::{self, Debug, Display, Formatter};
use std::ops::{Deref, Range};
use std::rc::Rc;
use std::sync::Arc;
use ecow::{eco_format, eco_vec, EcoString, EcoVec};
use crate::ast::AstNode;
use crate::{FileId, Span, SyntaxKind};
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct SyntaxNode(Repr);
#[derive(Clone, Eq, PartialEq, Hash)]
enum Repr {
Leaf(LeafNode),
Inner(Arc<InnerNode>),
Error(Arc<ErrorNode>),
}
impl SyntaxNode {
pub fn leaf(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
Self(Repr::Leaf(LeafNode::new(kind, text)))
}
pub fn inner(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
Self(Repr::Inner(Arc::new(InnerNode::new(kind, children))))
}
pub fn error(error: SyntaxError, text: impl Into<EcoString>) -> Self {
Self(Repr::Error(Arc::new(ErrorNode::new(error, text))))
}
#[track_caller]
pub const fn placeholder(kind: SyntaxKind) -> Self {
if matches!(kind, SyntaxKind::Error) {
panic!("cannot create error placeholder");
}
Self(Repr::Leaf(LeafNode {
kind,
text: EcoString::new(),
span: Span::detached(),
}))
}
pub fn kind(&self) -> SyntaxKind {
match &self.0 {
Repr::Leaf(leaf) => leaf.kind,
Repr::Inner(inner) => inner.kind,
Repr::Error(_) => SyntaxKind::Error,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
match &self.0 {
Repr::Leaf(leaf) => leaf.len(),
Repr::Inner(inner) => inner.len,
Repr::Error(node) => node.len(),
}
}
pub fn span(&self) -> Span {
match &self.0 {
Repr::Leaf(leaf) => leaf.span,
Repr::Inner(inner) => inner.span,
Repr::Error(node) => node.error.span,
}
}
pub fn text(&self) -> &EcoString {
static EMPTY: EcoString = EcoString::new();
match &self.0 {
Repr::Leaf(leaf) => &leaf.text,
Repr::Inner(_) => &EMPTY,
Repr::Error(node) => &node.text,
}
}
pub fn into_text(self) -> EcoString {
match self.0 {
Repr::Leaf(leaf) => leaf.text,
Repr::Inner(inner) => {
inner.children.iter().cloned().map(Self::into_text).collect()
}
Repr::Error(node) => node.text.clone(),
}
}
pub fn children(&self) -> std::slice::Iter<'_, SyntaxNode> {
match &self.0 {
Repr::Leaf(_) | Repr::Error(_) => [].iter(),
Repr::Inner(inner) => inner.children.iter(),
}
}
pub fn is<'a, T: AstNode<'a>>(&'a self) -> bool {
self.cast::<T>().is_some()
}
pub fn cast<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
T::from_untyped(self)
}
pub fn cast_first_match<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
self.children().find_map(Self::cast)
}
pub fn cast_last_match<'a, T: AstNode<'a>>(&'a self) -> Option<T> {
self.children().rev().find_map(Self::cast)
}
pub fn erroneous(&self) -> bool {
match &self.0 {
Repr::Leaf(_) => false,
Repr::Inner(inner) => inner.erroneous,
Repr::Error(_) => true,
}
}
pub fn errors(&self) -> Vec<SyntaxError> {
if !self.erroneous() {
return vec![];
}
if let Repr::Error(node) = &self.0 {
vec![node.error.clone()]
} else {
self.children()
.filter(|node| node.erroneous())
.flat_map(|node| node.errors())
.collect()
}
}
pub fn hint(&mut self, hint: impl Into<EcoString>) {
if let Repr::Error(node) = &mut self.0 {
Arc::make_mut(node).hint(hint);
}
}
pub fn synthesize(&mut self, span: Span) {
match &mut self.0 {
Repr::Leaf(leaf) => leaf.span = span,
Repr::Inner(inner) => Arc::make_mut(inner).synthesize(span),
Repr::Error(node) => Arc::make_mut(node).error.span = span,
}
}
pub fn spanless_eq(&self, other: &Self) -> bool {
match (&self.0, &other.0) {
(Repr::Leaf(a), Repr::Leaf(b)) => a.spanless_eq(b),
(Repr::Inner(a), Repr::Inner(b)) => a.spanless_eq(b),
(Repr::Error(a), Repr::Error(b)) => a.spanless_eq(b),
_ => false,
}
}
}
impl SyntaxNode {
#[track_caller]
pub(super) fn convert_to_kind(&mut self, kind: SyntaxKind) {
debug_assert!(!kind.is_error());
match &mut self.0 {
Repr::Leaf(leaf) => leaf.kind = kind,
Repr::Inner(inner) => Arc::make_mut(inner).kind = kind,
Repr::Error(_) => panic!("cannot convert error"),
}
}
pub(super) fn convert_to_error(&mut self, message: impl Into<EcoString>) {
if !self.kind().is_error() {
let text = std::mem::take(self).into_text();
*self = SyntaxNode::error(SyntaxError::new(message), text);
}
}
pub(super) fn expected(&mut self, expected: &str) {
let kind = self.kind();
self.convert_to_error(eco_format!("expected {expected}, found {}", kind.name()));
if kind.is_keyword() && matches!(expected, "identifier" | "pattern") {
self.hint(eco_format!(
"keyword `{text}` is not allowed as an identifier; try `{text}_` instead",
text = self.text(),
));
}
}
pub(super) fn unexpected(&mut self) {
self.convert_to_error(eco_format!("unexpected {}", self.kind().name()));
}
pub(super) fn numberize(
&mut self,
id: FileId,
within: Range<u64>,
) -> NumberingResult {
if within.start >= within.end {
return Err(Unnumberable);
}
let mid = Span::new(id, (within.start + within.end) / 2).unwrap();
match &mut self.0 {
Repr::Leaf(leaf) => leaf.span = mid,
Repr::Inner(inner) => Arc::make_mut(inner).numberize(id, None, within)?,
Repr::Error(node) => Arc::make_mut(node).error.span = mid,
}
Ok(())
}
pub(super) fn is_leaf(&self) -> bool {
matches!(self.0, Repr::Leaf(_))
}
pub(super) fn descendants(&self) -> usize {
match &self.0 {
Repr::Leaf(_) | Repr::Error(_) => 1,
Repr::Inner(inner) => inner.descendants,
}
}
pub(super) fn children_mut(&mut self) -> &mut [SyntaxNode] {
match &mut self.0 {
Repr::Leaf(_) | Repr::Error(_) => &mut [],
Repr::Inner(inner) => &mut Arc::make_mut(inner).children,
}
}
pub(super) fn replace_children(
&mut self,
range: Range<usize>,
replacement: Vec<SyntaxNode>,
) -> NumberingResult {
if let Repr::Inner(inner) = &mut self.0 {
Arc::make_mut(inner).replace_children(range, replacement)?;
}
Ok(())
}
pub(super) fn update_parent(
&mut self,
prev_len: usize,
new_len: usize,
prev_descendants: usize,
new_descendants: usize,
) {
if let Repr::Inner(inner) = &mut self.0 {
Arc::make_mut(inner).update_parent(
prev_len,
new_len,
prev_descendants,
new_descendants,
);
}
}
pub(super) fn upper(&self) -> u64 {
match &self.0 {
Repr::Leaf(leaf) => leaf.span.number() + 1,
Repr::Inner(inner) => inner.upper,
Repr::Error(node) => node.error.span.number() + 1,
}
}
}
impl Debug for SyntaxNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match &self.0 {
Repr::Leaf(leaf) => leaf.fmt(f),
Repr::Inner(inner) => inner.fmt(f),
Repr::Error(node) => node.fmt(f),
}
}
}
impl Default for SyntaxNode {
fn default() -> Self {
Self::leaf(SyntaxKind::End, EcoString::new())
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
struct LeafNode {
kind: SyntaxKind,
text: EcoString,
span: Span,
}
impl LeafNode {
#[track_caller]
fn new(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
debug_assert!(!kind.is_error());
Self { kind, text: text.into(), span: Span::detached() }
}
fn len(&self) -> usize {
self.text.len()
}
fn spanless_eq(&self, other: &Self) -> bool {
self.kind == other.kind && self.text == other.text
}
}
impl Debug for LeafNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?}: {:?}", self.kind, self.text)
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
struct InnerNode {
kind: SyntaxKind,
len: usize,
span: Span,
descendants: usize,
erroneous: bool,
upper: u64,
children: Vec<SyntaxNode>,
}
impl InnerNode {
#[track_caller]
fn new(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
debug_assert!(!kind.is_error());
let mut len = 0;
let mut descendants = 1;
let mut erroneous = false;
for child in &children {
len += child.len();
descendants += child.descendants();
erroneous |= child.erroneous();
}
Self {
kind,
len,
span: Span::detached(),
descendants,
erroneous,
upper: 0,
children,
}
}
fn synthesize(&mut self, span: Span) {
self.span = span;
self.upper = span.number();
for child in &mut self.children {
child.synthesize(span);
}
}
fn numberize(
&mut self,
id: FileId,
range: Option<Range<usize>>,
within: Range<u64>,
) -> NumberingResult {
let descendants = match &range {
Some(range) if range.is_empty() => return Ok(()),
Some(range) => self.children[range.clone()]
.iter()
.map(SyntaxNode::descendants)
.sum::<usize>(),
None => self.descendants,
};
let space = within.end - within.start;
let mut stride = space / (2 * descendants as u64);
if stride == 0 {
stride = space / self.descendants as u64;
if stride == 0 {
return Err(Unnumberable);
}
}
let mut start = within.start;
if range.is_none() {
let end = start + stride;
self.span = Span::new(id, (start + end) / 2).unwrap();
self.upper = within.end;
start = end;
}
let len = self.children.len();
for child in &mut self.children[range.unwrap_or(0..len)] {
let end = start + child.descendants() as u64 * stride;
child.numberize(id, start..end)?;
start = end;
}
Ok(())
}
fn spanless_eq(&self, other: &Self) -> bool {
self.kind == other.kind
&& self.len == other.len
&& self.descendants == other.descendants
&& self.erroneous == other.erroneous
&& self.children.len() == other.children.len()
&& self
.children
.iter()
.zip(&other.children)
.all(|(a, b)| a.spanless_eq(b))
}
fn replace_children(
&mut self,
mut range: Range<usize>,
replacement: Vec<SyntaxNode>,
) -> NumberingResult {
let Some(id) = self.span.id() else { return Err(Unnumberable) };
let mut replacement_range = 0..replacement.len();
while range.start < range.end
&& replacement_range.start < replacement_range.end
&& self.children[range.start]
.spanless_eq(&replacement[replacement_range.start])
{
range.start += 1;
replacement_range.start += 1;
}
while range.start < range.end
&& replacement_range.start < replacement_range.end
&& self.children[range.end - 1]
.spanless_eq(&replacement[replacement_range.end - 1])
{
range.end -= 1;
replacement_range.end -= 1;
}
let mut replacement_vec = replacement;
let replacement = &replacement_vec[replacement_range.clone()];
let superseded = &self.children[range.clone()];
self.len = self.len + replacement.iter().map(SyntaxNode::len).sum::<usize>()
- superseded.iter().map(SyntaxNode::len).sum::<usize>();
self.descendants = self.descendants
+ replacement.iter().map(SyntaxNode::descendants).sum::<usize>()
- superseded.iter().map(SyntaxNode::descendants).sum::<usize>();
self.erroneous = replacement.iter().any(SyntaxNode::erroneous)
|| (self.erroneous
&& (self.children[..range.start].iter().any(SyntaxNode::erroneous))
|| self.children[range.end..].iter().any(SyntaxNode::erroneous));
self.children
.splice(range.clone(), replacement_vec.drain(replacement_range.clone()));
range.end = range.start + replacement_range.len();
let mut left = 0;
let mut right = 0;
let max_left = range.start;
let max_right = self.children.len() - range.end;
loop {
let renumber = range.start - left..range.end + right;
let start_number = renumber
.start
.checked_sub(1)
.and_then(|i| self.children.get(i))
.map_or(self.span.number() + 1, |child| child.upper());
let end_number = self
.children
.get(renumber.end)
.map_or(self.upper, |next| next.span().number());
let within = start_number..end_number;
if self.numberize(id, Some(renumber), within).is_ok() {
return Ok(());
}
if left == max_left && right == max_right {
return Err(Unnumberable);
}
left = (left + 1).next_power_of_two().min(max_left);
right = (right + 1).next_power_of_two().min(max_right);
}
}
fn update_parent(
&mut self,
prev_len: usize,
new_len: usize,
prev_descendants: usize,
new_descendants: usize,
) {
self.len = self.len + new_len - prev_len;
self.descendants = self.descendants + new_descendants - prev_descendants;
self.erroneous = self.children.iter().any(SyntaxNode::erroneous);
}
}
impl Debug for InnerNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?}: {}", self.kind, self.len)?;
if !self.children.is_empty() {
f.write_str(" ")?;
f.debug_list().entries(&self.children).finish()?;
}
Ok(())
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
struct ErrorNode {
text: EcoString,
error: SyntaxError,
}
impl ErrorNode {
fn new(error: SyntaxError, text: impl Into<EcoString>) -> Self {
Self { text: text.into(), error }
}
fn len(&self) -> usize {
self.text.len()
}
fn hint(&mut self, hint: impl Into<EcoString>) {
self.error.hints.push(hint.into());
}
fn spanless_eq(&self, other: &Self) -> bool {
self.text == other.text && self.error.spanless_eq(&other.error)
}
}
impl Debug for ErrorNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "Error: {:?} ({})", self.text, self.error.message)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct SyntaxError {
pub span: Span,
pub message: EcoString,
pub hints: EcoVec<EcoString>,
}
impl SyntaxError {
pub fn new(message: impl Into<EcoString>) -> Self {
Self {
span: Span::detached(),
message: message.into(),
hints: eco_vec![],
}
}
fn spanless_eq(&self, other: &Self) -> bool {
self.message == other.message && self.hints == other.hints
}
}
#[derive(Clone)]
pub struct LinkedNode<'a> {
node: &'a SyntaxNode,
parent: Option<Rc<Self>>,
index: usize,
offset: usize,
}
impl<'a> LinkedNode<'a> {
pub fn new(root: &'a SyntaxNode) -> Self {
Self { node: root, parent: None, index: 0, offset: 0 }
}
pub fn get(&self) -> &'a SyntaxNode {
self.node
}
pub fn index(&self) -> usize {
self.index
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn range(&self) -> Range<usize> {
self.offset..self.offset + self.node.len()
}
pub fn children(&self) -> LinkedChildren<'a> {
LinkedChildren {
parent: Rc::new(self.clone()),
iter: self.node.children().enumerate(),
front: self.offset,
back: self.offset + self.len(),
}
}
pub fn find(&self, span: Span) -> Option<LinkedNode<'a>> {
if self.span() == span {
return Some(self.clone());
}
if let Repr::Inner(inner) = &self.0 {
if span.number() < inner.span.number() {
return None;
}
let mut children = self.children().peekable();
while let Some(child) = children.next() {
if children
.peek()
.map_or(true, |next| next.span().number() > span.number())
{
if let Some(found) = child.find(span) {
return Some(found);
}
}
}
}
None
}
}
impl<'a> LinkedNode<'a> {
pub fn parent(&self) -> Option<&Self> {
self.parent.as_deref()
}
pub fn prev_sibling(&self) -> Option<Self> {
let parent = self.parent()?;
let index = self.index.checked_sub(1)?;
let node = parent.node.children().nth(index)?;
let offset = self.offset - node.len();
let prev = Self { node, parent: self.parent.clone(), index, offset };
if prev.kind().is_trivia() {
prev.prev_sibling()
} else {
Some(prev)
}
}
pub fn next_sibling(&self) -> Option<Self> {
let parent = self.parent()?;
let index = self.index.checked_add(1)?;
let node = parent.node.children().nth(index)?;
let offset = self.offset + self.node.len();
let next = Self { node, parent: self.parent.clone(), index, offset };
if next.kind().is_trivia() {
next.next_sibling()
} else {
Some(next)
}
}
pub fn parent_kind(&self) -> Option<SyntaxKind> {
Some(self.parent()?.node.kind())
}
pub fn prev_sibling_kind(&self) -> Option<SyntaxKind> {
Some(self.prev_sibling()?.node.kind())
}
pub fn next_sibling_kind(&self) -> Option<SyntaxKind> {
Some(self.next_sibling()?.node.kind())
}
}
#[derive(Debug, Clone)]
pub enum Side {
Before,
After,
}
impl<'a> LinkedNode<'a> {
pub fn prev_leaf(&self) -> Option<Self> {
let mut node = self.clone();
while let Some(prev) = node.prev_sibling() {
if let Some(leaf) = prev.rightmost_leaf() {
return Some(leaf);
}
node = prev;
}
self.parent()?.prev_leaf()
}
pub fn leftmost_leaf(&self) -> Option<Self> {
if self.is_leaf() && !self.kind().is_trivia() && !self.kind().is_error() {
return Some(self.clone());
}
for child in self.children() {
if let Some(leaf) = child.leftmost_leaf() {
return Some(leaf);
}
}
None
}
fn leaf_before(&self, cursor: usize) -> Option<Self> {
if self.node.children().len() == 0 && cursor <= self.offset + self.len() {
return Some(self.clone());
}
let mut offset = self.offset;
let count = self.node.children().len();
for (i, child) in self.children().enumerate() {
let len = child.len();
if (offset < cursor && cursor <= offset + len)
|| (offset == cursor && i + 1 == count)
{
return child.leaf_before(cursor);
}
offset += len;
}
None
}
fn leaf_after(&self, cursor: usize) -> Option<Self> {
if self.node.children().len() == 0 && cursor < self.offset + self.len() {
return Some(self.clone());
}
let mut offset = self.offset;
for child in self.children() {
let len = child.len();
if offset <= cursor && cursor < offset + len {
return child.leaf_after(cursor);
}
offset += len;
}
None
}
pub fn leaf_at(&self, cursor: usize, side: Side) -> Option<Self> {
match side {
Side::Before => self.leaf_before(cursor),
Side::After => self.leaf_after(cursor),
}
}
pub fn rightmost_leaf(&self) -> Option<Self> {
if self.is_leaf() && !self.kind().is_trivia() {
return Some(self.clone());
}
for child in self.children().rev() {
if let Some(leaf) = child.rightmost_leaf() {
return Some(leaf);
}
}
None
}
pub fn next_leaf(&self) -> Option<Self> {
let mut node = self.clone();
while let Some(next) = node.next_sibling() {
if let Some(leaf) = next.leftmost_leaf() {
return Some(leaf);
}
node = next;
}
self.parent()?.next_leaf()
}
}
impl Deref for LinkedNode<'_> {
type Target = SyntaxNode;
fn deref(&self) -> &Self::Target {
self.get()
}
}
impl Debug for LinkedNode<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.node.fmt(f)
}
}
pub struct LinkedChildren<'a> {
parent: Rc<LinkedNode<'a>>,
iter: std::iter::Enumerate<std::slice::Iter<'a, SyntaxNode>>,
front: usize,
back: usize,
}
impl<'a> Iterator for LinkedChildren<'a> {
type Item = LinkedNode<'a>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(index, node)| {
let offset = self.front;
self.front += node.len();
LinkedNode {
node,
parent: Some(self.parent.clone()),
index,
offset,
}
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl DoubleEndedIterator for LinkedChildren<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
self.iter.next_back().map(|(index, node)| {
self.back -= node.len();
LinkedNode {
node,
parent: Some(self.parent.clone()),
index,
offset: self.back,
}
})
}
}
impl ExactSizeIterator for LinkedChildren<'_> {}
pub(super) type NumberingResult = Result<(), Unnumberable>;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) struct Unnumberable;
impl Display for Unnumberable {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.pad("cannot number within this interval")
}
}
impl std::error::Error for Unnumberable {}
#[cfg(test)]
mod tests {
use super::*;
use crate::Source;
#[test]
fn test_linked_node() {
let source = Source::detached("#set text(12pt, red)");
let node = LinkedNode::new(source.root()).leaf_at(7, Side::Before).unwrap();
assert_eq!(node.offset(), 5);
assert_eq!(node.text(), "text");
let node = LinkedNode::new(source.root()).leaf_at(7, Side::After).unwrap();
assert_eq!(node.offset(), 5);
assert_eq!(node.text(), "text");
let prev = node.prev_sibling().unwrap();
assert_eq!(prev.offset(), 1);
assert_eq!(prev.text(), "set");
}
#[test]
fn test_linked_node_non_trivia_leaf() {
let source = Source::detached("#set fun(12pt, red)");
let leaf = LinkedNode::new(source.root()).leaf_at(6, Side::Before).unwrap();
let prev = leaf.prev_leaf().unwrap();
assert_eq!(leaf.text(), "fun");
assert_eq!(prev.text(), "set");
let source = Source::detached("#let x = 10");
let leaf = LinkedNode::new(source.root()).leaf_at(9, Side::Before).unwrap();
let prev = leaf.prev_leaf().unwrap();
let next = leaf.next_leaf().unwrap();
assert_eq!(prev.text(), "=");
assert_eq!(leaf.text(), " ");
assert_eq!(next.text(), "10");
let source = Source::detached("#let x = 10");
let leaf = LinkedNode::new(source.root()).leaf_at(9, Side::After).unwrap();
let prev = leaf.prev_leaf().unwrap();
assert!(leaf.next_leaf().is_none());
assert_eq!(prev.text(), "=");
assert_eq!(leaf.text(), "10");
}
}