mirror of
https://gitlab.gnome.org/jwestman/blueprint-compiler.git
synced 2025-05-04 15:59:08 -04:00
- Added tests for more error messages - Test the "go to reference" feature at every character index of every test case - Delete unused code and imports - Fix some bugs I found along the way
336 lines
9.9 KiB
Python
336 lines
9.9 KiB
Python
# ast_utils.py
|
|
#
|
|
# Copyright 2021 James Westman <james@jwestman.net>
|
|
#
|
|
# This file is free software; you can redistribute it and/or modify it
|
|
# under the terms of the GNU Lesser General Public License as
|
|
# published by the Free Software Foundation; either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# This file is distributed in the hope that it will be useful, but
|
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
# Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public
|
|
# License along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
# SPDX-License-Identifier: LGPL-3.0-or-later
|
|
|
|
import typing as T
|
|
from collections import ChainMap, defaultdict
|
|
from functools import cached_property
|
|
|
|
from .errors import *
|
|
from .lsp_utils import DocumentSymbol, LocationLink, SemanticToken
|
|
from .tokenizer import Range
|
|
|
|
TType = T.TypeVar("TType")
|
|
|
|
|
|
class Children:
|
|
"""Allows accessing children by type using array syntax."""
|
|
|
|
def __init__(self, children):
|
|
self._children = children
|
|
|
|
def __iter__(self) -> T.Iterator["AstNode"]:
|
|
return iter(self._children)
|
|
|
|
@T.overload
|
|
def __getitem__(self, key: T.Type[TType]) -> T.List[TType]: ...
|
|
|
|
@T.overload
|
|
def __getitem__(self, key: int) -> "AstNode": ...
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, int):
|
|
if key >= len(self._children):
|
|
return None
|
|
else:
|
|
return self._children[key]
|
|
else:
|
|
return [child for child in self._children if isinstance(child, key)]
|
|
|
|
|
|
class Ranges:
|
|
def __init__(self, ranges: T.Dict[str, Range]):
|
|
self._ranges = ranges
|
|
|
|
def __getitem__(self, key: T.Union[str, tuple[str, str]]) -> T.Optional[Range]:
|
|
if isinstance(key, str):
|
|
return self._ranges.get(key)
|
|
elif isinstance(key, tuple):
|
|
start, end = key
|
|
return Range.join(self._ranges.get(start), self._ranges.get(end))
|
|
|
|
|
|
TCtx = T.TypeVar("TCtx")
|
|
TAttr = T.TypeVar("TAttr")
|
|
|
|
|
|
class Ctx:
|
|
"""Allows accessing values from higher in the syntax tree."""
|
|
|
|
def __init__(self, node: "AstNode") -> None:
|
|
self.node = node
|
|
|
|
def __getitem__(self, key: T.Type[TCtx]) -> T.Optional[TCtx]:
|
|
attrs = self.node._attrs_by_type(Context)
|
|
for name, attr in attrs:
|
|
if attr.type == key:
|
|
return getattr(self.node, name)
|
|
if self.node.parent is not None:
|
|
return self.node.parent.context[key]
|
|
else:
|
|
return None
|
|
|
|
|
|
class AstNode:
|
|
"""Base class for nodes in the abstract syntax tree."""
|
|
|
|
completers: T.List = []
|
|
attrs_by_type: T.Dict[T.Type, T.List] = {}
|
|
|
|
def __init__(self, group, children, tokens, incomplete=False):
|
|
self.group = group
|
|
self.children = Children(children)
|
|
self.tokens = ChainMap(tokens, defaultdict(lambda: None))
|
|
self.incomplete = incomplete
|
|
|
|
self.parent = None
|
|
for child in self.children:
|
|
child.parent = self
|
|
|
|
def __init_subclass__(cls):
|
|
cls.completers = []
|
|
cls.validators = [
|
|
getattr(cls, f) for f in dir(cls) if hasattr(getattr(cls, f), "_validator")
|
|
]
|
|
cls.attrs_by_type = {}
|
|
|
|
@cached_property
|
|
def context(self):
|
|
return Ctx(self)
|
|
|
|
@cached_property
|
|
def ranges(self):
|
|
return Ranges(self.group.ranges)
|
|
|
|
@cached_property
|
|
def root(self):
|
|
if self.parent is None:
|
|
return self
|
|
else:
|
|
return self.parent.root
|
|
|
|
@property
|
|
def range(self) -> Range:
|
|
return Range(self.group.start, self.group.end, self.group.text)
|
|
|
|
def parent_by_type(self, type: T.Type[TType]) -> TType:
|
|
if self.parent is None:
|
|
raise CompilerBugError()
|
|
elif isinstance(self.parent, type):
|
|
return self.parent
|
|
else:
|
|
return self.parent.parent_by_type(type)
|
|
|
|
@cached_property
|
|
def errors(self):
|
|
return list(
|
|
error
|
|
for error in self._get_errors()
|
|
if not isinstance(error, CompileWarning)
|
|
)
|
|
|
|
@cached_property
|
|
def warnings(self):
|
|
return list(
|
|
warning
|
|
for warning in self._get_errors()
|
|
if isinstance(warning, CompileWarning)
|
|
)
|
|
|
|
def _get_errors(self):
|
|
for validator in self.validators:
|
|
try:
|
|
validator(self)
|
|
except CompileError as e:
|
|
yield e
|
|
if e.fatal:
|
|
return
|
|
except MultipleErrors as e:
|
|
for error in e.errors:
|
|
yield error
|
|
if error.fatal:
|
|
return
|
|
|
|
for child in self.children:
|
|
yield from child._get_errors()
|
|
|
|
def _attrs_by_type(self, attr_type: T.Type[TAttr]) -> T.List[T.Tuple[str, TAttr]]:
|
|
if attr_type not in self.attrs_by_type:
|
|
self.attrs_by_type[attr_type] = []
|
|
for name in dir(type(self)):
|
|
item = getattr(type(self), name)
|
|
if isinstance(item, attr_type):
|
|
self.attrs_by_type[attr_type].append((name, item))
|
|
return self.attrs_by_type[attr_type]
|
|
|
|
def get_docs(self, idx: int) -> T.Optional[str]:
|
|
for name, attr in self._attrs_by_type(Docs):
|
|
if attr.token_name:
|
|
token = self.group.tokens.get(attr.token_name)
|
|
if token and token.start <= idx < token.end:
|
|
return getattr(self, name)
|
|
|
|
for child in self.children:
|
|
if idx in child.range:
|
|
if docs := child.get_docs(idx):
|
|
return docs
|
|
|
|
for name, attr in self._attrs_by_type(Docs):
|
|
if not attr.token_name:
|
|
return getattr(self, name)
|
|
|
|
return None
|
|
|
|
def get_semantic_tokens(self) -> T.Iterator[SemanticToken]:
|
|
for child in self.children:
|
|
yield from child.get_semantic_tokens()
|
|
|
|
def get_reference(self, idx: int) -> T.Optional[LocationLink]:
|
|
for child in self.children:
|
|
if idx in child.range:
|
|
if ref := child.get_reference(idx):
|
|
return ref
|
|
return None
|
|
|
|
@property
|
|
def document_symbol(self) -> T.Optional[DocumentSymbol]:
|
|
return None
|
|
|
|
def get_document_symbols(self) -> T.List[DocumentSymbol]:
|
|
result = []
|
|
for child in self.children:
|
|
if s := child.document_symbol:
|
|
s.children = child.get_document_symbols()
|
|
result.append(s)
|
|
else:
|
|
result.extend(child.get_document_symbols())
|
|
return result
|
|
|
|
def validate_unique_in_parent(
|
|
self, error: str, check: T.Optional[T.Callable[["AstNode"], bool]] = None
|
|
):
|
|
for child in self.parent.children:
|
|
if child is self:
|
|
break
|
|
|
|
if type(child) is type(self):
|
|
if check is None or check(child):
|
|
raise CompileError(
|
|
error,
|
|
references=[
|
|
ErrorReference(
|
|
child.range,
|
|
"previous declaration was here",
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def validate(
|
|
token_name: T.Optional[str] = None,
|
|
end_token_name: T.Optional[str] = None,
|
|
skip_incomplete: bool = False,
|
|
):
|
|
"""Decorator for functions that validate an AST node. Exceptions raised
|
|
during validation are marked with range information from the tokens."""
|
|
|
|
def decorator(func):
|
|
def inner(self: AstNode):
|
|
if skip_incomplete and self.incomplete:
|
|
return
|
|
|
|
def fill_error(e: CompileError):
|
|
if e.range is None:
|
|
e.range = (
|
|
Range.join(
|
|
self.ranges[token_name],
|
|
self.ranges[end_token_name],
|
|
)
|
|
or self.range
|
|
)
|
|
|
|
try:
|
|
func(self)
|
|
except CompileError as e:
|
|
# If the node is only partially complete, then an error must
|
|
# have already been reported at the parsing stage
|
|
if self.incomplete:
|
|
return
|
|
|
|
fill_error(e)
|
|
|
|
# Re-raise the exception
|
|
raise e
|
|
except MultipleErrors as e:
|
|
if self.incomplete:
|
|
return
|
|
|
|
for error in e.errors:
|
|
fill_error(error)
|
|
|
|
raise e
|
|
|
|
inner._validator = True
|
|
return inner
|
|
|
|
return decorator
|
|
|
|
|
|
class Docs:
|
|
def __init__(self, func, token_name=None):
|
|
self.func = func
|
|
self.token_name = token_name
|
|
|
|
def __get__(self, instance, owner):
|
|
if instance is None:
|
|
return self
|
|
return self.func(instance)
|
|
|
|
|
|
def docs(*args, **kwargs):
|
|
"""Decorator for functions that return documentation for tokens."""
|
|
|
|
def decorator(func):
|
|
return Docs(func, *args, **kwargs)
|
|
|
|
return decorator
|
|
|
|
|
|
class Context:
|
|
def __init__(self, type: T.Type[TCtx], func: T.Callable[[AstNode], TCtx]) -> None:
|
|
self.type = type
|
|
self.func = func
|
|
|
|
def __get__(self, instance, owner):
|
|
if instance is None:
|
|
return self
|
|
if ctx := getattr(instance, "_context_" + self.type.__name__, None):
|
|
return ctx
|
|
else:
|
|
ctx = self.func(instance)
|
|
setattr(instance, "_context_" + self.type.__name__, ctx)
|
|
return ctx
|
|
|
|
|
|
def context(type: T.Type[TCtx]):
|
|
"""Decorator for functions that return a context object, which is passed down to ."""
|
|
|
|
def decorator(func: T.Callable[[AstNode], TCtx]) -> Context:
|
|
return Context(type, func)
|
|
|
|
return decorator
|