Merge branch 'lsp-improvements' into 'main'

Add features to the language server

See merge request jwestman/blueprint-compiler!134
This commit is contained in:
James Westman 2023-07-26 01:09:54 +00:00
commit de2e7a5a5f
32 changed files with 671 additions and 117 deletions

View file

@ -22,7 +22,8 @@ from collections import ChainMap, defaultdict
from functools import cached_property
from .errors import *
from .lsp_utils import SemanticToken
from .lsp_utils import DocumentSymbol, LocationLink, SemanticToken
from .tokenizer import Range
TType = T.TypeVar("TType")
@ -54,6 +55,18 @@ class Children:
return [child for child in self._children if isinstance(child, key)]
class Ranges:
def __init__(self, ranges: T.Dict[str, Range]):
self._ranges = ranges
def __getitem__(self, key: T.Union[str, tuple[str, str]]) -> T.Optional[Range]:
if isinstance(key, str):
return self._ranges.get(key)
elif isinstance(key, tuple):
start, end = key
return Range.join(self._ranges.get(start), self._ranges.get(end))
TCtx = T.TypeVar("TCtx")
TAttr = T.TypeVar("TAttr")
@ -102,6 +115,10 @@ class AstNode:
def context(self):
return Ctx(self)
@cached_property
def ranges(self):
return Ranges(self.group.ranges)
@cached_property
def root(self):
if self.parent is None:
@ -109,6 +126,10 @@ class AstNode:
else:
return self.parent.root
@property
def range(self):
return Range(self.group.start, self.group.end, self.group.text)
def parent_by_type(self, type: T.Type[TType]) -> TType:
if self.parent is None:
raise CompilerBugError()
@ -164,9 +185,8 @@ class AstNode:
return getattr(self, name)
for child in self.children:
if child.group.start <= idx < child.group.end:
docs = child.get_docs(idx)
if docs is not None:
if idx in child.range:
if docs := child.get_docs(idx):
return docs
return None
@ -175,6 +195,27 @@ class AstNode:
for child in self.children:
yield from child.get_semantic_tokens()
def get_reference(self, idx: int) -> T.Optional[LocationLink]:
for child in self.children:
if idx in child.range:
if ref := child.get_reference(idx):
return ref
return None
@property
def document_symbol(self) -> T.Optional[DocumentSymbol]:
return None
def get_document_symbols(self) -> T.List[DocumentSymbol]:
result = []
for child in self.children:
if s := child.document_symbol:
s.children = child.get_document_symbols()
result.append(s)
else:
result.extend(child.get_document_symbols())
return result
def validate_unique_in_parent(
self, error: str, check: T.Optional[T.Callable[["AstNode"], bool]] = None
):
@ -188,20 +229,23 @@ class AstNode:
error,
references=[
ErrorReference(
child.group.start,
child.group.end,
child.range,
"previous declaration was here",
)
],
)
def validate(token_name=None, end_token_name=None, skip_incomplete=False):
def validate(
token_name: T.Optional[str] = None,
end_token_name: T.Optional[str] = None,
skip_incomplete: bool = False,
):
"""Decorator for functions that validate an AST node. Exceptions raised
during validation are marked with range information from the tokens."""
def decorator(func):
def inner(self):
def inner(self: AstNode):
if skip_incomplete and self.incomplete:
return
@ -213,22 +257,14 @@ def validate(token_name=None, end_token_name=None, skip_incomplete=False):
if self.incomplete:
return
# This mess of code sets the error's start and end positions
# from the tokens passed to the decorator, if they have not
# already been set
if e.start is None:
if token := self.group.tokens.get(token_name):
e.start = token.start
else:
e.start = self.group.start
if e.end is None:
if token := self.group.tokens.get(end_token_name):
e.end = token.end
elif token := self.group.tokens.get(token_name):
e.end = token.end
else:
e.end = self.group.end
if e.range is None:
e.range = (
Range.join(
self.ranges[token_name],
self.ranges[end_token_name],
)
or self.range
)
# Re-raise the exception
raise e

View file

@ -158,7 +158,7 @@ def signal_completer(ast_node, match_variables):
yield Completion(
signal,
CompletionItemKind.Property,
snippet=f"{signal} => ${{1:{name}_{signal.replace('-', '_')}}}()$0;",
snippet=f"{signal} => \$${{1:{name}_{signal.replace('-', '_')}}}()$0;",
)

View file

@ -23,6 +23,7 @@ import typing as T
from dataclasses import dataclass
from . import utils
from .tokenizer import Range
from .utils import Colors
@ -36,8 +37,7 @@ class PrintableError(Exception):
@dataclass
class ErrorReference:
start: int
end: int
range: Range
message: str
@ -50,8 +50,7 @@ class CompileError(PrintableError):
def __init__(
self,
message: str,
start: T.Optional[int] = None,
end: T.Optional[int] = None,
range: T.Optional[Range] = None,
did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None,
hints: T.Optional[T.List[str]] = None,
actions: T.Optional[T.List["CodeAction"]] = None,
@ -61,8 +60,7 @@ class CompileError(PrintableError):
super().__init__(message)
self.message = message
self.start = start
self.end = end
self.range = range
self.hints = hints or []
self.actions = actions or []
self.references = references or []
@ -92,9 +90,9 @@ class CompileError(PrintableError):
self.hint("Are your dependencies up to date?")
def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None:
assert self.start is not None
assert self.range is not None
line_num, col_num = utils.idx_to_pos(self.start + 1, code)
line_num, col_num = utils.idx_to_pos(self.range.start + 1, code)
line = code.splitlines(True)[line_num]
# Display 1-based line numbers
@ -110,7 +108,7 @@ at {filename} line {line_num} column {col_num}:
stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n")
for ref in self.references:
line_num, col_num = utils.idx_to_pos(ref.start + 1, code)
line_num, col_num = utils.idx_to_pos(ref.range.start + 1, code)
line = code.splitlines(True)[line_num]
line_num += 1
@ -138,14 +136,15 @@ class UpgradeWarning(CompileWarning):
class UnexpectedTokenError(CompileError):
def __init__(self, start, end) -> None:
super().__init__("Unexpected tokens", start, end)
def __init__(self, range: Range) -> None:
super().__init__("Unexpected tokens", range)
@dataclass
class CodeAction:
title: str
replace_with: str
edit_range: T.Optional[Range] = None
class MultipleErrors(PrintableError):

View file

@ -18,7 +18,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
import sys
import typing as T
from functools import cached_property
@ -29,6 +28,7 @@ from gi.repository import GIRepository # type: ignore
from . import typelib, xml_reader
from .errors import CompileError, CompilerBugError
from .lsp_utils import CodeAction
_namespace_cache: T.Dict[str, "Namespace"] = {}
_xml_cache = {}
@ -65,6 +65,27 @@ def get_namespace(namespace: str, version: str) -> "Namespace":
return _namespace_cache[filename]
_available_namespaces: list[tuple[str, str]] = []
def get_available_namespaces() -> T.List[T.Tuple[str, str]]:
if len(_available_namespaces):
return _available_namespaces
search_paths: list[str] = [
*GIRepository.Repository.get_search_path(),
*_user_search_paths,
]
for search_path in search_paths:
for filename in os.listdir(search_path):
if filename.endswith(".typelib"):
namespace, version = filename.removesuffix(".typelib").rsplit("-", 1)
_available_namespaces.append((namespace, version))
return _available_namespaces
def get_xml(namespace: str, version: str):
search_paths = []
@ -1011,9 +1032,11 @@ class GirContext:
ns = ns or "Gtk"
if ns not in self.namespaces and ns not in self.not_found_namespaces:
all_available = list(set(ns for ns, _version in get_available_namespaces()))
raise CompileError(
f"Namespace {ns} was not imported",
did_you_mean=(ns, self.namespaces.keys()),
did_you_mean=(ns, all_available),
)
def validate_type(self, name: str, ns: str) -> None:

View file

@ -35,6 +35,16 @@ class AdwBreakpointCondition(AstNode):
def condition(self) -> str:
return self.tokens["condition"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"condition",
SymbolKind.Property,
self.range,
self.group.tokens["kw"].range,
self.condition,
)
@docs("kw")
def keyword_docs(self):
klass = self.root.gir.get_type("Breakpoint", "Adw")
@ -93,6 +103,16 @@ class AdwBreakpointSetter(AstNode):
else:
return None
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
f"{self.object_id}.{self.property_name}",
SymbolKind.Property,
self.range,
self.group.tokens["object"].range,
self.value.range.text,
)
@context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx:
if self.gir_property is not None:
@ -147,12 +167,25 @@ class AdwBreakpointSetter(AstNode):
class AdwBreakpointSetters(AstNode):
grammar = ["setters", Match("{").expected(), Until(AdwBreakpointSetter, "}")]
grammar = [
Keyword("setters"),
Match("{").expected(),
Until(AdwBreakpointSetter, "}"),
]
@property
def setters(self) -> T.List[AdwBreakpointSetter]:
return self.children[AdwBreakpointSetter]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"setters",
SymbolKind.Struct,
self.range,
self.group.tokens["setters"].range,
)
@validate()
def container_is_breakpoint(self):
validate_parent_type(self, "Adw", "Breakpoint", "breakpoint setters")

View file

@ -84,6 +84,16 @@ class ExtAdwMessageDialogResponse(AstNode):
def value(self) -> StringValue:
return self.children[0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.id,
SymbolKind.Field,
self.range,
self.group.tokens["id"].range,
self.value.range.text,
)
@context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(StringType())
@ -108,6 +118,15 @@ class ExtAdwMessageDialog(AstNode):
def responses(self) -> T.List[ExtAdwMessageDialogResponse]:
return self.children
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"responses",
SymbolKind.Array,
self.range,
self.group.tokens["responses"].range,
)
@validate("responses")
def container_is_message_dialog(self):
validate_parent_type(self, "Adw", "MessageDialog", "responses")

View file

@ -46,7 +46,15 @@ from ..gir import (
IntType,
StringType,
)
from ..lsp_utils import Completion, CompletionItemKind, SemanticToken, SemanticTokenType
from ..lsp_utils import (
Completion,
CompletionItemKind,
DocumentSymbol,
LocationLink,
SemanticToken,
SemanticTokenType,
SymbolKind,
)
from ..parse_tree import *
OBJECT_CONTENT_HOOKS = AnyOf()

View file

@ -70,8 +70,7 @@ class ScopeCtx:
):
raise CompileError(
f"Duplicate object ID '{obj.tokens['id']}'",
token.start,
token.end,
token.range,
)
passed[obj.tokens["id"]] = obj

View file

@ -21,6 +21,9 @@
import typing as T
from functools import cached_property
from blueprintcompiler.errors import T
from blueprintcompiler.lsp_utils import DocumentSymbol
from .common import *
from .response_id import ExtResponse
from .types import ClassName, ConcreteClassName
@ -59,8 +62,20 @@ class Object(AstNode):
def signature(self) -> str:
if self.id:
return f"{self.class_name.gir_type.full_name} {self.id}"
elif t := self.class_name.gir_type:
return f"{t.full_name}"
else:
return f"{self.class_name.gir_type.full_name}"
return f"{self.class_name.as_string}"
@property
def document_symbol(self) -> T.Optional[DocumentSymbol]:
return DocumentSymbol(
self.class_name.as_string,
SymbolKind.Object,
self.range,
self.children[ClassName][0].range,
self.id,
)
@property
def gir_class(self) -> GirType:

View file

@ -47,6 +47,21 @@ class Property(AstNode):
else:
return None
@property
def document_symbol(self) -> DocumentSymbol:
if isinstance(self.value, ObjectValue):
detail = None
else:
detail = self.value.range.text
return DocumentSymbol(
self.name,
SymbolKind.Property,
self.range,
self.group.tokens["name"].range,
detail,
)
@validate()
def binding_valid(self):
if (

View file

@ -51,12 +51,14 @@ class Signal(AstNode):
]
),
"=>",
Mark("detail_start"),
Optional(["$", UseLiteral("extern", True)]),
UseIdent("handler").expected("the name of a function to handle the signal"),
Match("(").expected("argument list"),
Optional(UseIdent("object")).expected("object identifier"),
Match(")").expected(),
ZeroOrMore(SignalFlag),
Mark("detail_end"),
)
@property
@ -105,6 +107,16 @@ class Signal(AstNode):
def gir_class(self):
return self.parent.parent.gir_class
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.full_name,
SymbolKind.Event,
self.range,
self.group.tokens["name"].range,
self.ranges["detail_start", "detail_end"].text,
)
@validate("handler")
def old_extern(self):
if not self.tokens["extern"]:

View file

@ -139,6 +139,16 @@ class A11yProperty(BaseAttribute):
def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(get_types(self.root.gir).get(self.tokens["name"]))
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
self.value.range.text,
)
@validate("name")
def is_valid_property(self):
types = get_types(self.root.gir)
@ -172,6 +182,15 @@ class ExtAccessibility(AstNode):
def properties(self) -> T.List[A11yProperty]:
return self.children[A11yProperty]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"accessibility",
SymbolKind.Struct,
self.range,
self.group.tokens["accessibility"].range,
)
@validate("accessibility")
def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "accessibility properties")

View file

@ -31,13 +31,23 @@ class Item(AstNode):
]
@property
def name(self) -> str:
def name(self) -> T.Optional[str]:
return self.tokens["name"]
@property
def value(self) -> StringValue:
return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.value.range.text,
SymbolKind.String,
self.range,
self.value.range,
self.name,
)
@validate("name")
def unique_in_parent(self):
if self.name is not None:
@ -54,6 +64,15 @@ class ExtComboBoxItems(AstNode):
"]",
]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"items",
SymbolKind.Array,
self.range,
self.group.tokens["items"].range,
)
@validate("items")
def container_is_combo_box_text(self):
validate_parent_type(self, "Gtk", "ComboBoxText", "combo box items")

View file

@ -23,6 +23,15 @@ from .gobject_object import ObjectContent, validate_parent_type
class Filters(AstNode):
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.tokens["tag_name"],
SymbolKind.Array,
self.range,
self.group.tokens[self.tokens["tag_name"]].range,
)
@validate()
def container_is_file_filter(self):
validate_parent_type(self, "Gtk", "FileFilter", "file filter properties")
@ -46,6 +55,15 @@ class FilterString(AstNode):
def item(self) -> str:
return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.item,
SymbolKind.String,
self.range,
self.group.tokens["name"].range,
)
@validate()
def unique_in_parent(self):
self.validate_unique_in_parent(

View file

@ -36,6 +36,16 @@ class LayoutProperty(AstNode):
def value(self) -> Value:
return self.children[Value][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
self.value.range.text,
)
@context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx:
# there isn't really a way to validate these
@ -56,6 +66,15 @@ class ExtLayout(AstNode):
Until(LayoutProperty, "}"),
)
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"layout",
SymbolKind.Struct,
self.range,
self.group.tokens["layout"].range,
)
@validate("layout")
def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "layout properties")

View file

@ -1,5 +1,9 @@
import typing as T
from blueprintcompiler.errors import T
from blueprintcompiler.lsp_utils import DocumentSymbol
from ..ast_utils import AstNode, validate
from ..parse_tree import Keyword
from .common import *
from .contexts import ScopeCtx
from .gobject_object import ObjectContent, validate_parent_type
@ -17,6 +21,15 @@ class ExtListItemFactory(AstNode):
def signature(self) -> str:
return f"template {self.gir_class.full_name}"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.signature,
SymbolKind.Object,
self.range,
self.group.tokens["id"].range,
)
@property
def type_name(self) -> T.Optional[TypeName]:
if len(self.children[TypeName]) == 1:

View file

@ -42,6 +42,16 @@ class Menu(AstNode):
else:
return "Gio.Menu"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.tokens["tag"],
SymbolKind.Object,
self.range,
self.group.tokens[self.tokens["tag"]].range,
self.id,
)
@property
def tag(self) -> str:
return self.tokens["tag"]
@ -72,6 +82,18 @@ class MenuAttribute(AstNode):
def value(self) -> StringValue:
return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range
if self.group.tokens["name"]
else self.range,
self.value.range.text,
)
@context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(None)
@ -98,7 +120,7 @@ menu_attribute = Group(
menu_section = Group(
Menu,
[
"section",
Keyword("section"),
UseLiteral("tag", "section"),
Optional(UseIdent("id")),
Match("{").expected(),
@ -109,7 +131,7 @@ menu_section = Group(
menu_submenu = Group(
Menu,
[
"submenu",
Keyword("submenu"),
UseLiteral("tag", "submenu"),
Optional(UseIdent("id")),
Match("{").expected(),
@ -120,7 +142,7 @@ menu_submenu = Group(
menu_item = Group(
Menu,
[
"item",
Keyword("item"),
UseLiteral("tag", "item"),
Match("{").expected(),
Until(menu_attribute, "}"),
@ -130,7 +152,7 @@ menu_item = Group(
menu_item_shorthand = Group(
Menu,
[
"item",
Keyword("item"),
UseLiteral("tag", "item"),
"(",
Group(

View file

@ -58,6 +58,24 @@ class ExtScaleMark(AstNode):
else:
return None
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
str(self.value),
SymbolKind.Field,
self.range,
self.group.tokens["mark"].range,
self.label.string if self.label else None,
)
def get_semantic_tokens(self) -> T.Iterator[SemanticToken]:
if range := self.ranges["position"]:
yield SemanticToken(
range.start,
range.end,
SemanticTokenType.EnumMember,
)
@docs("position")
def position_docs(self) -> T.Optional[str]:
if member := self.root.gir.get_type("PositionType", "Gtk").members.get(
@ -88,6 +106,15 @@ class ExtScaleMarks(AstNode):
def marks(self) -> T.List[ExtScaleMark]:
return self.children
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"marks",
SymbolKind.Array,
self.range,
self.group.tokens["marks"].range,
)
@validate("marks")
def container_is_size_group(self):
validate_parent_type(self, "Gtk", "Scale", "scale marks")

View file

@ -30,6 +30,15 @@ class Widget(AstNode):
def name(self) -> str:
return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
)
@validate("name")
def obj_widget(self):
object = self.context[ScopeCtx].objects.get(self.tokens["name"])
@ -62,6 +71,15 @@ class ExtSizeGroupWidgets(AstNode):
"]",
]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"widgets",
SymbolKind.Array,
self.range,
self.group.tokens["widgets"].range,
)
@validate("widgets")
def container_is_size_group(self):
validate_parent_type(self, "Gtk", "SizeGroup", "size group properties")

View file

@ -30,6 +30,15 @@ class Item(AstNode):
def child(self) -> StringValue:
return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.child.range.text,
SymbolKind.String,
self.range,
self.range,
)
class ExtStringListStrings(AstNode):
grammar = [
@ -39,6 +48,15 @@ class ExtStringListStrings(AstNode):
"]",
]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"strings",
SymbolKind.Array,
self.range,
self.group.tokens["strings"].range,
)
@validate("items")
def container_is_string_list(self):
validate_parent_type(self, "Gtk", "StringList", "StringList items")

View file

@ -29,6 +29,15 @@ class StyleClass(AstNode):
def name(self) -> str:
return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.String,
self.range,
self.range,
)
@validate("name")
def unique_in_parent(self):
self.validate_unique_in_parent(
@ -44,6 +53,15 @@ class ExtStyles(AstNode):
"]",
]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"styles",
SymbolKind.Array,
self.range,
self.group.tokens["styles"].range,
)
@validate("styles")
def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "style classes")

View file

@ -46,10 +46,19 @@ class Template(Object):
@property
def signature(self) -> str:
if self.parent_type:
return f"template {self.gir_class.full_name} : {self.parent_type.gir_type.full_name}"
if self.parent_type and self.parent_type.gir_type:
return f"template {self.class_name.as_string} : {self.parent_type.gir_type.full_name}"
else:
return f"template {self.gir_class.full_name}"
return f"template {self.class_name.as_string}"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.signature,
SymbolKind.Object,
self.range,
self.group.tokens["id"].range,
)
@property
def gir_class(self) -> GirType:

View file

@ -55,7 +55,16 @@ class TypeName(AstNode):
@validate("namespace")
def gir_ns_exists(self):
if not self.tokens["extern"]:
try:
self.root.gir.validate_ns(self.tokens["namespace"])
except CompileError as e:
ns = self.tokens["namespace"]
e.actions = [
self.root.import_code_action(n, version)
for n, version in gir.get_available_namespaces()
if n == ns
]
raise e
@validate()
def deprecated(self) -> None:

View file

@ -62,8 +62,7 @@ class UI(AstNode):
else:
gir_ctx.not_found_namespaces.add(i.namespace)
except CompileError as e:
e.start = i.group.tokens["namespace"].start
e.end = i.group.tokens["version"].end
e.range = i.range
self._gir_errors.append(e)
return gir_ctx
@ -100,6 +99,18 @@ class UI(AstNode):
and self.template.class_name.glib_type_name == id
)
def import_code_action(self, ns: str, version: str) -> CodeAction:
if len(self.children[Import]):
pos = self.children[Import][-1].range.end
else:
pos = self.children[GtkDirective][0].range.end
return CodeAction(
f"Import {ns} {version}",
f"\nusing {ns} {version};",
Range(pos, pos, self.group.text),
)
@context(ScopeCtx)
def scope_ctx(self) -> ScopeCtx:
return ScopeCtx(node=self)

View file

@ -339,6 +339,16 @@ class IdentLiteral(AstNode):
token = self.group.tokens["value"]
yield SemanticToken(token.start, token.end, SemanticTokenType.EnumMember)
def get_reference(self, _idx: int) -> T.Optional[LocationLink]:
ref = self.context[ScopeCtx].objects.get(self.ident)
if ref is None and self.root.is_legacy_template(self.ident):
ref = self.root.template
if ref:
return LocationLink(self.range, ref.range, ref.ranges["id"])
else:
return None
class Literal(AstNode):
grammar = AnyOf(

View file

@ -24,10 +24,12 @@ import traceback
import typing as T
from . import decompiler, parser, tokenizer, utils, xml_reader
from .ast_utils import AstNode
from .completions import complete
from .errors import CompileError, MultipleErrors, PrintableError
from .errors import CompileError, MultipleErrors
from .lsp_utils import *
from .outputs.xml import XmlOutput
from .tokenizer import Token
def printerr(*args, **kwargs):
@ -43,16 +45,16 @@ def command(json_method: str):
class OpenFile:
def __init__(self, uri: str, text: str, version: int):
def __init__(self, uri: str, text: str, version: int) -> None:
self.uri = uri
self.text = text
self.version = version
self.ast = None
self.tokens = None
self.ast: T.Optional[AstNode] = None
self.tokens: T.Optional[list[Token]] = None
self._update()
def apply_changes(self, changes):
def apply_changes(self, changes) -> None:
for change in changes:
if "range" not in change:
self.text = change["text"]
@ -70,8 +72,8 @@ class OpenFile:
self.text = self.text[:start] + change["text"] + self.text[end:]
self._update()
def _update(self):
self.diagnostics = []
def _update(self) -> None:
self.diagnostics: list[CompileError] = []
try:
self.tokens = tokenizer.tokenize(self.text)
self.ast, errors, warnings = parser.parse(self.tokens)
@ -203,6 +205,8 @@ class LanguageServer:
"completionProvider": {},
"codeActionProvider": {},
"hoverProvider": True,
"documentSymbolProvider": True,
"definitionProvider": True,
},
"serverInfo": {
"name": "Blueprint",
@ -325,31 +329,32 @@ class LanguageServer:
def code_actions(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
range_start = utils.pos_to_idx(
range = Range(
utils.pos_to_idx(
params["range"]["start"]["line"],
params["range"]["start"]["character"],
open_file.text,
)
range_end = utils.pos_to_idx(
),
utils.pos_to_idx(
params["range"]["end"]["line"],
params["range"]["end"]["character"],
open_file.text,
),
open_file.text,
)
actions = [
{
"title": action.title,
"kind": "quickfix",
"diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, diagnostic)
],
"diagnostics": [self._create_diagnostic(open_file.uri, diagnostic)],
"edit": {
"changes": {
open_file.uri: [
{
"range": utils.idxs_to_range(
diagnostic.start, diagnostic.end, open_file.text
),
"range": action.edit_range.to_json()
if action.edit_range
else diagnostic.range.to_json(),
"newText": action.replace_with,
}
]
@ -357,34 +362,68 @@ class LanguageServer:
},
}
for diagnostic in open_file.diagnostics
if not (diagnostic.end < range_start or diagnostic.start > range_end)
if range.overlaps(diagnostic.range)
for action in diagnostic.actions
]
self._send_response(id, actions)
@command("textDocument/documentSymbol")
def document_symbols(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
symbols = open_file.ast.get_document_symbols()
def to_json(symbol: DocumentSymbol):
result = {
"name": symbol.name,
"kind": symbol.kind,
"range": symbol.range.to_json(),
"selectionRange": symbol.selection_range.to_json(),
"children": [to_json(child) for child in symbol.children],
}
if symbol.detail is not None:
result["detail"] = symbol.detail
return result
self._send_response(id, [to_json(symbol) for symbol in symbols])
@command("textDocument/definition")
def definition(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
idx = utils.pos_to_idx(
params["position"]["line"], params["position"]["character"], open_file.text
)
definition = open_file.ast.get_reference(idx)
if definition is None:
self._send_response(id, None)
else:
self._send_response(
id,
definition.to_json(open_file.uri),
)
def _send_file_updates(self, open_file: OpenFile):
self._send_notification(
"textDocument/publishDiagnostics",
{
"uri": open_file.uri,
"diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, err)
self._create_diagnostic(open_file.uri, err)
for err in open_file.diagnostics
],
},
)
def _create_diagnostic(self, text: str, uri: str, err: CompileError):
def _create_diagnostic(self, uri: str, err: CompileError):
message = err.message
assert err.start is not None and err.end is not None
assert err.range is not None
for hint in err.hints:
message += "\nhint: " + hint
result = {
"range": utils.idxs_to_range(err.start, err.end, text),
"range": err.range.to_json(),
"message": message,
"severity": DiagnosticSeverity.Warning
if isinstance(err, CompileWarning)
@ -399,7 +438,7 @@ class LanguageServer:
{
"location": {
"uri": uri,
"range": utils.idxs_to_range(ref.start, ref.end, text),
"range": ref.range.to_json(),
},
"message": ref.message,
}

View file

@ -20,9 +20,10 @@
import enum
import typing as T
from dataclasses import dataclass
from dataclasses import dataclass, field
from .errors import *
from .tokenizer import Range
from .utils import *
@ -129,3 +130,57 @@ class SemanticToken:
start: int
end: int
type: SemanticTokenType
class SymbolKind(enum.IntEnum):
File = 1
Module = 2
Namespace = 3
Package = 4
Class = 5
Method = 6
Property = 7
Field = 8
Constructor = 9
Enum = 10
Interface = 11
Function = 12
Variable = 13
Constant = 14
String = 15
Number = 16
Boolean = 17
Array = 18
Object = 19
Key = 20
Null = 21
EnumMember = 22
Struct = 23
Event = 24
Operator = 25
TypeParameter = 26
@dataclass
class DocumentSymbol:
name: str
kind: SymbolKind
range: Range
selection_range: Range
detail: T.Optional[str] = None
children: T.List["DocumentSymbol"] = field(default_factory=list)
@dataclass
class LocationLink:
origin_selection_range: Range
target_range: Range
target_selection_range: Range
def to_json(self, target_uri: str):
return {
"originSelectionRange": self.origin_selection_range.to_json(),
"targetUri": target_uri,
"targetRange": self.target_range.to_json(),
"targetSelectionRange": self.target_selection_range.to_json(),
}

View file

@ -24,8 +24,8 @@ import os
import sys
import typing as T
from . import decompiler, interactive_port, parser, tokenizer
from .errors import CompilerBugError, MultipleErrors, PrintableError, report_bug
from . import interactive_port, parser, tokenizer
from .errors import CompileError, CompilerBugError, PrintableError, report_bug
from .gir import add_typelib_search_path
from .lsp import LanguageServer
from .outputs import XmlOutput
@ -157,7 +157,7 @@ class BlueprintApp:
def cmd_port(self, opts):
interactive_port.run(opts)
def _compile(self, data: str) -> T.Tuple[str, T.List[PrintableError]]:
def _compile(self, data: str) -> T.Tuple[str, T.List[CompileError]]:
tokens = tokenizer.tokenize(data)
ast, errors, warnings = parser.parse(tokens)

View file

@ -20,7 +20,6 @@
""" Utilities for parsing an AST from a token stream. """
import typing as T
from collections import defaultdict
from enum import Enum
from .ast_utils import AstNode
@ -31,7 +30,7 @@ from .errors import (
UnexpectedTokenError,
assert_true,
)
from .tokenizer import Token, TokenType
from .tokenizer import Range, Token, TokenType
SKIP_TOKENS = [TokenType.COMMENT, TokenType.WHITESPACE]
@ -63,14 +62,16 @@ class ParseGroup:
be converted to AST nodes by passing the children and key=value pairs to
the AST node constructor."""
def __init__(self, ast_type: T.Type[AstNode], start: int):
def __init__(self, ast_type: T.Type[AstNode], start: int, text: str):
self.ast_type = ast_type
self.children: T.List[ParseGroup] = []
self.keys: T.Dict[str, T.Any] = {}
self.tokens: T.Dict[str, T.Optional[Token]] = {}
self.ranges: T.Dict[str, Range] = {}
self.start = start
self.end: T.Optional[int] = None
self.incomplete = False
self.text = text
def add_child(self, child: "ParseGroup"):
self.children.append(child)
@ -80,6 +81,12 @@ class ParseGroup:
self.keys[key] = val
self.tokens[key] = token
if token:
self.set_range(key, token.range)
def set_range(self, key: str, range: Range):
assert_true(key not in self.ranges)
self.ranges[key] = range
def to_ast(self):
"""Creates an AST node from the match group."""
@ -104,8 +111,9 @@ class ParseGroup:
class ParseContext:
"""Contains the state of the parser."""
def __init__(self, tokens: T.List[Token], index=0):
def __init__(self, tokens: T.List[Token], text: str, index=0):
self.tokens = tokens
self.text = text
self.binding_power = 0
self.index = index
@ -113,6 +121,7 @@ class ParseContext:
self.group: T.Optional[ParseGroup] = None
self.group_keys: T.Dict[str, T.Tuple[T.Any, T.Optional[Token]]] = {}
self.group_children: T.List[ParseGroup] = []
self.group_ranges: T.Dict[str, Range] = {}
self.last_group: T.Optional[ParseGroup] = None
self.group_incomplete = False
@ -124,7 +133,7 @@ class ParseContext:
context will be used to parse one node. If parsing is successful, the
new context will be applied to "self". If parsing fails, the new
context will be discarded."""
ctx = ParseContext(self.tokens, self.index)
ctx = ParseContext(self.tokens, self.text, self.index)
ctx.errors = self.errors
ctx.warnings = self.warnings
ctx.binding_power = self.binding_power
@ -140,6 +149,8 @@ class ParseContext:
other.group.set_val(key, val, token)
for child in other.group_children:
other.group.add_child(child)
for key, range in other.group_ranges.items():
other.group.set_range(key, range)
other.group.end = other.tokens[other.index - 1].end
other.group.incomplete = other.group_incomplete
self.group_children.append(other.group)
@ -148,6 +159,7 @@ class ParseContext:
# its matched values
self.group_keys = {**self.group_keys, **other.group_keys}
self.group_children += other.group_children
self.group_ranges = {**self.group_ranges, **other.group_ranges}
self.group_incomplete |= other.group_incomplete
self.index = other.index
@ -161,13 +173,19 @@ class ParseContext:
def start_group(self, ast_type: T.Type[AstNode]):
"""Sets this context to have its own match group."""
assert_true(self.group is None)
self.group = ParseGroup(ast_type, self.tokens[self.index].start)
self.group = ParseGroup(ast_type, self.tokens[self.index].start, self.text)
def set_group_val(self, key: str, value: T.Any, token: T.Optional[Token]):
"""Sets a matched key=value pair on the current match group."""
assert_true(key not in self.group_keys)
self.group_keys[key] = (value, token)
def set_mark(self, key: str):
"""Sets a zero-length range on the current match group at the current position."""
self.group_ranges[key] = Range(
self.tokens[self.index].start, self.tokens[self.index].start, self.text
)
def set_group_incomplete(self):
"""Marks the current match group as incomplete (it could not be fully
parsed, but the parser recovered)."""
@ -206,11 +224,11 @@ class ParseContext:
if (
len(self.errors)
and isinstance((err := self.errors[-1]), UnexpectedTokenError)
and err.end == start
and err.range.end == start
):
err.end = end
err.range.end = end
else:
self.errors.append(UnexpectedTokenError(start, end))
self.errors.append(UnexpectedTokenError(Range(start, end, self.text)))
def is_eof(self) -> bool:
return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF
@ -263,10 +281,11 @@ class Err(ParseNode):
start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS:
start_idx += 1
start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end)
raise CompileError(
self.message, Range(start_token.start, start_token.start, ctx.text)
)
return True
@ -306,7 +325,9 @@ class Fail(ParseNode):
start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end)
raise CompileError(
self.message, Range.join(start_token.range, end_token.range)
)
return True
@ -355,7 +376,7 @@ class Statement(ParseNode):
token = ctx.peek_token()
if str(token) != ";":
ctx.errors.append(CompileError("Expected `;`", token.start, token.end))
ctx.errors.append(CompileError("Expected `;`", token.range))
else:
ctx.next_token()
return True
@ -604,6 +625,15 @@ class Keyword(ParseNode):
return str(token) == self.kw
class Mark(ParseNode):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
ctx.set_mark(self.key)
return True
def to_parse_node(value) -> ParseNode:
if isinstance(value, str):
return Match(value)

View file

@ -26,11 +26,12 @@ from .tokenizer import TokenType
def parse(
tokens: T.List[Token],
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[PrintableError]]:
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[CompileError]]:
"""Parses a list of tokens into an abstract syntax tree."""
try:
ctx = ParseContext(tokens)
original_text = tokens[0].string if len(tokens) else ""
ctx = ParseContext(tokens, original_text)
AnyOf(UI).parse(ctx)
ast_node = ctx.last_group.to_ast() if ctx.last_group else None

View file

@ -20,9 +20,10 @@
import re
import typing as T
from dataclasses import dataclass
from enum import Enum
from .errors import CompileError, CompilerBugError
from . import utils
class TokenType(Enum):
@ -62,7 +63,13 @@ class Token:
def __str__(self) -> str:
return self.string[self.start : self.end]
@property
def range(self) -> "Range":
return Range(self.start, self.end, self.string)
def get_number(self) -> T.Union[int, float]:
from .errors import CompileError, CompilerBugError
if self.type != TokenType.NUMBER:
raise CompilerBugError()
@ -75,12 +82,12 @@ class Token:
else:
return int(string)
except:
raise CompileError(
f"{str(self)} is not a valid number literal", self.start, self.end
)
raise CompileError(f"{str(self)} is not a valid number literal", self.range)
def _tokenize(ui_ml: str):
from .errors import CompileError
i = 0
while i < len(ui_ml):
matched = False
@ -95,7 +102,8 @@ def _tokenize(ui_ml: str):
if not matched:
raise CompileError(
"Could not determine what kind of syntax is meant here", i, i
"Could not determine what kind of syntax is meant here",
Range(i, i, ui_ml),
)
yield Token(TokenType.EOF, i, i, ui_ml)
@ -103,3 +111,38 @@ def _tokenize(ui_ml: str):
def tokenize(data: str) -> T.List[Token]:
return list(_tokenize(data))
@dataclass
class Range:
start: int
end: int
original_text: str
@property
def length(self) -> int:
return self.end - self.start
@property
def text(self) -> str:
return self.original_text[self.start : self.end]
@staticmethod
def join(a: T.Optional["Range"], b: T.Optional["Range"]) -> T.Optional["Range"]:
if a is None:
return b
if b is None:
return a
return Range(min(a.start, b.start), max(a.end, b.end), a.original_text)
def __contains__(self, other: T.Union[int, "Range"]) -> bool:
if isinstance(other, int):
return self.start <= other <= self.end
else:
return self.start <= other.start and self.end >= other.end
def to_json(self):
return utils.idxs_to_range(self.start, self.end, self.original_text)
def overlaps(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)

View file

@ -40,13 +40,12 @@ from blueprintcompiler.tokenizer import Token, TokenType, tokenize
class TestSamples(unittest.TestCase):
def assert_docs_dont_crash(self, text, ast):
def assert_ast_doesnt_crash(self, text, tokens, ast):
for i in range(len(text)):
ast.get_docs(i)
def assert_completions_dont_crash(self, text, ast, tokens):
for i in range(len(text)):
list(complete(ast, tokens, i))
ast.get_document_symbols()
def assert_sample(self, name, skip_run=False):
print(f'assert_sample("{name}", skip_run={skip_run})')
@ -79,8 +78,7 @@ class TestSamples(unittest.TestCase):
print("\n".join(diff))
raise AssertionError()
self.assert_docs_dont_crash(blueprint, ast)
self.assert_completions_dont_crash(blueprint, ast, tokens)
self.assert_ast_doesnt_crash(blueprint, tokens, ast)
except PrintableError as e: # pragma: no cover
e.pretty_print(name + ".blp", blueprint)
raise AssertionError()
@ -105,8 +103,7 @@ class TestSamples(unittest.TestCase):
ast, errors, warnings = parser.parse(tokens)
if ast is not None:
self.assert_docs_dont_crash(blueprint, ast)
self.assert_completions_dont_crash(blueprint, ast, tokens)
self.assert_ast_doesnt_crash(blueprint, tokens, ast)
if errors:
raise errors
@ -116,9 +113,9 @@ class TestSamples(unittest.TestCase):
raise MultipleErrors(warnings)
except PrintableError as e:
def error_str(error):
line, col = utils.idx_to_pos(error.start + 1, blueprint)
len = error.end - error.start
def error_str(error: CompileError):
line, col = utils.idx_to_pos(error.range.start + 1, blueprint)
len = error.range.length
return ",".join([str(line + 1), str(col), str(len), error.message])
if isinstance(e, CompileError):