Merge branch 'lsp-improvements' into 'main'

Add features to the language server

See merge request jwestman/blueprint-compiler!134
This commit is contained in:
James Westman 2023-07-26 01:09:54 +00:00
commit de2e7a5a5f
32 changed files with 671 additions and 117 deletions

View file

@ -22,7 +22,8 @@ from collections import ChainMap, defaultdict
from functools import cached_property from functools import cached_property
from .errors import * from .errors import *
from .lsp_utils import SemanticToken from .lsp_utils import DocumentSymbol, LocationLink, SemanticToken
from .tokenizer import Range
TType = T.TypeVar("TType") TType = T.TypeVar("TType")
@ -54,6 +55,18 @@ class Children:
return [child for child in self._children if isinstance(child, key)] return [child for child in self._children if isinstance(child, key)]
class Ranges:
def __init__(self, ranges: T.Dict[str, Range]):
self._ranges = ranges
def __getitem__(self, key: T.Union[str, tuple[str, str]]) -> T.Optional[Range]:
if isinstance(key, str):
return self._ranges.get(key)
elif isinstance(key, tuple):
start, end = key
return Range.join(self._ranges.get(start), self._ranges.get(end))
TCtx = T.TypeVar("TCtx") TCtx = T.TypeVar("TCtx")
TAttr = T.TypeVar("TAttr") TAttr = T.TypeVar("TAttr")
@ -102,6 +115,10 @@ class AstNode:
def context(self): def context(self):
return Ctx(self) return Ctx(self)
@cached_property
def ranges(self):
return Ranges(self.group.ranges)
@cached_property @cached_property
def root(self): def root(self):
if self.parent is None: if self.parent is None:
@ -109,6 +126,10 @@ class AstNode:
else: else:
return self.parent.root return self.parent.root
@property
def range(self):
return Range(self.group.start, self.group.end, self.group.text)
def parent_by_type(self, type: T.Type[TType]) -> TType: def parent_by_type(self, type: T.Type[TType]) -> TType:
if self.parent is None: if self.parent is None:
raise CompilerBugError() raise CompilerBugError()
@ -164,9 +185,8 @@ class AstNode:
return getattr(self, name) return getattr(self, name)
for child in self.children: for child in self.children:
if child.group.start <= idx < child.group.end: if idx in child.range:
docs = child.get_docs(idx) if docs := child.get_docs(idx):
if docs is not None:
return docs return docs
return None return None
@ -175,6 +195,27 @@ class AstNode:
for child in self.children: for child in self.children:
yield from child.get_semantic_tokens() yield from child.get_semantic_tokens()
def get_reference(self, idx: int) -> T.Optional[LocationLink]:
for child in self.children:
if idx in child.range:
if ref := child.get_reference(idx):
return ref
return None
@property
def document_symbol(self) -> T.Optional[DocumentSymbol]:
return None
def get_document_symbols(self) -> T.List[DocumentSymbol]:
result = []
for child in self.children:
if s := child.document_symbol:
s.children = child.get_document_symbols()
result.append(s)
else:
result.extend(child.get_document_symbols())
return result
def validate_unique_in_parent( def validate_unique_in_parent(
self, error: str, check: T.Optional[T.Callable[["AstNode"], bool]] = None self, error: str, check: T.Optional[T.Callable[["AstNode"], bool]] = None
): ):
@ -188,20 +229,23 @@ class AstNode:
error, error,
references=[ references=[
ErrorReference( ErrorReference(
child.group.start, child.range,
child.group.end,
"previous declaration was here", "previous declaration was here",
) )
], ],
) )
def validate(token_name=None, end_token_name=None, skip_incomplete=False): def validate(
token_name: T.Optional[str] = None,
end_token_name: T.Optional[str] = None,
skip_incomplete: bool = False,
):
"""Decorator for functions that validate an AST node. Exceptions raised """Decorator for functions that validate an AST node. Exceptions raised
during validation are marked with range information from the tokens.""" during validation are marked with range information from the tokens."""
def decorator(func): def decorator(func):
def inner(self): def inner(self: AstNode):
if skip_incomplete and self.incomplete: if skip_incomplete and self.incomplete:
return return
@ -213,22 +257,14 @@ def validate(token_name=None, end_token_name=None, skip_incomplete=False):
if self.incomplete: if self.incomplete:
return return
# This mess of code sets the error's start and end positions if e.range is None:
# from the tokens passed to the decorator, if they have not e.range = (
# already been set Range.join(
if e.start is None: self.ranges[token_name],
if token := self.group.tokens.get(token_name): self.ranges[end_token_name],
e.start = token.start )
else: or self.range
e.start = self.group.start )
if e.end is None:
if token := self.group.tokens.get(end_token_name):
e.end = token.end
elif token := self.group.tokens.get(token_name):
e.end = token.end
else:
e.end = self.group.end
# Re-raise the exception # Re-raise the exception
raise e raise e

View file

@ -158,7 +158,7 @@ def signal_completer(ast_node, match_variables):
yield Completion( yield Completion(
signal, signal,
CompletionItemKind.Property, CompletionItemKind.Property,
snippet=f"{signal} => ${{1:{name}_{signal.replace('-', '_')}}}()$0;", snippet=f"{signal} => \$${{1:{name}_{signal.replace('-', '_')}}}()$0;",
) )

View file

@ -23,6 +23,7 @@ import typing as T
from dataclasses import dataclass from dataclasses import dataclass
from . import utils from . import utils
from .tokenizer import Range
from .utils import Colors from .utils import Colors
@ -36,8 +37,7 @@ class PrintableError(Exception):
@dataclass @dataclass
class ErrorReference: class ErrorReference:
start: int range: Range
end: int
message: str message: str
@ -50,8 +50,7 @@ class CompileError(PrintableError):
def __init__( def __init__(
self, self,
message: str, message: str,
start: T.Optional[int] = None, range: T.Optional[Range] = None,
end: T.Optional[int] = None,
did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None, did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None,
hints: T.Optional[T.List[str]] = None, hints: T.Optional[T.List[str]] = None,
actions: T.Optional[T.List["CodeAction"]] = None, actions: T.Optional[T.List["CodeAction"]] = None,
@ -61,8 +60,7 @@ class CompileError(PrintableError):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
self.start = start self.range = range
self.end = end
self.hints = hints or [] self.hints = hints or []
self.actions = actions or [] self.actions = actions or []
self.references = references or [] self.references = references or []
@ -92,9 +90,9 @@ class CompileError(PrintableError):
self.hint("Are your dependencies up to date?") self.hint("Are your dependencies up to date?")
def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None: def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None:
assert self.start is not None assert self.range is not None
line_num, col_num = utils.idx_to_pos(self.start + 1, code) line_num, col_num = utils.idx_to_pos(self.range.start + 1, code)
line = code.splitlines(True)[line_num] line = code.splitlines(True)[line_num]
# Display 1-based line numbers # Display 1-based line numbers
@ -110,7 +108,7 @@ at {filename} line {line_num} column {col_num}:
stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n") stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n")
for ref in self.references: for ref in self.references:
line_num, col_num = utils.idx_to_pos(ref.start + 1, code) line_num, col_num = utils.idx_to_pos(ref.range.start + 1, code)
line = code.splitlines(True)[line_num] line = code.splitlines(True)[line_num]
line_num += 1 line_num += 1
@ -138,14 +136,15 @@ class UpgradeWarning(CompileWarning):
class UnexpectedTokenError(CompileError): class UnexpectedTokenError(CompileError):
def __init__(self, start, end) -> None: def __init__(self, range: Range) -> None:
super().__init__("Unexpected tokens", start, end) super().__init__("Unexpected tokens", range)
@dataclass @dataclass
class CodeAction: class CodeAction:
title: str title: str
replace_with: str replace_with: str
edit_range: T.Optional[Range] = None
class MultipleErrors(PrintableError): class MultipleErrors(PrintableError):

View file

@ -18,7 +18,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
import os import os
import sys
import typing as T import typing as T
from functools import cached_property from functools import cached_property
@ -29,6 +28,7 @@ from gi.repository import GIRepository # type: ignore
from . import typelib, xml_reader from . import typelib, xml_reader
from .errors import CompileError, CompilerBugError from .errors import CompileError, CompilerBugError
from .lsp_utils import CodeAction
_namespace_cache: T.Dict[str, "Namespace"] = {} _namespace_cache: T.Dict[str, "Namespace"] = {}
_xml_cache = {} _xml_cache = {}
@ -65,6 +65,27 @@ def get_namespace(namespace: str, version: str) -> "Namespace":
return _namespace_cache[filename] return _namespace_cache[filename]
_available_namespaces: list[tuple[str, str]] = []
def get_available_namespaces() -> T.List[T.Tuple[str, str]]:
if len(_available_namespaces):
return _available_namespaces
search_paths: list[str] = [
*GIRepository.Repository.get_search_path(),
*_user_search_paths,
]
for search_path in search_paths:
for filename in os.listdir(search_path):
if filename.endswith(".typelib"):
namespace, version = filename.removesuffix(".typelib").rsplit("-", 1)
_available_namespaces.append((namespace, version))
return _available_namespaces
def get_xml(namespace: str, version: str): def get_xml(namespace: str, version: str):
search_paths = [] search_paths = []
@ -1011,9 +1032,11 @@ class GirContext:
ns = ns or "Gtk" ns = ns or "Gtk"
if ns not in self.namespaces and ns not in self.not_found_namespaces: if ns not in self.namespaces and ns not in self.not_found_namespaces:
all_available = list(set(ns for ns, _version in get_available_namespaces()))
raise CompileError( raise CompileError(
f"Namespace {ns} was not imported", f"Namespace {ns} was not imported",
did_you_mean=(ns, self.namespaces.keys()), did_you_mean=(ns, all_available),
) )
def validate_type(self, name: str, ns: str) -> None: def validate_type(self, name: str, ns: str) -> None:

View file

@ -35,6 +35,16 @@ class AdwBreakpointCondition(AstNode):
def condition(self) -> str: def condition(self) -> str:
return self.tokens["condition"] return self.tokens["condition"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"condition",
SymbolKind.Property,
self.range,
self.group.tokens["kw"].range,
self.condition,
)
@docs("kw") @docs("kw")
def keyword_docs(self): def keyword_docs(self):
klass = self.root.gir.get_type("Breakpoint", "Adw") klass = self.root.gir.get_type("Breakpoint", "Adw")
@ -93,6 +103,16 @@ class AdwBreakpointSetter(AstNode):
else: else:
return None return None
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
f"{self.object_id}.{self.property_name}",
SymbolKind.Property,
self.range,
self.group.tokens["object"].range,
self.value.range.text,
)
@context(ValueTypeCtx) @context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx: def value_type(self) -> ValueTypeCtx:
if self.gir_property is not None: if self.gir_property is not None:
@ -147,12 +167,25 @@ class AdwBreakpointSetter(AstNode):
class AdwBreakpointSetters(AstNode): class AdwBreakpointSetters(AstNode):
grammar = ["setters", Match("{").expected(), Until(AdwBreakpointSetter, "}")] grammar = [
Keyword("setters"),
Match("{").expected(),
Until(AdwBreakpointSetter, "}"),
]
@property @property
def setters(self) -> T.List[AdwBreakpointSetter]: def setters(self) -> T.List[AdwBreakpointSetter]:
return self.children[AdwBreakpointSetter] return self.children[AdwBreakpointSetter]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"setters",
SymbolKind.Struct,
self.range,
self.group.tokens["setters"].range,
)
@validate() @validate()
def container_is_breakpoint(self): def container_is_breakpoint(self):
validate_parent_type(self, "Adw", "Breakpoint", "breakpoint setters") validate_parent_type(self, "Adw", "Breakpoint", "breakpoint setters")

View file

@ -84,6 +84,16 @@ class ExtAdwMessageDialogResponse(AstNode):
def value(self) -> StringValue: def value(self) -> StringValue:
return self.children[0] return self.children[0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.id,
SymbolKind.Field,
self.range,
self.group.tokens["id"].range,
self.value.range.text,
)
@context(ValueTypeCtx) @context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx: def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(StringType()) return ValueTypeCtx(StringType())
@ -108,6 +118,15 @@ class ExtAdwMessageDialog(AstNode):
def responses(self) -> T.List[ExtAdwMessageDialogResponse]: def responses(self) -> T.List[ExtAdwMessageDialogResponse]:
return self.children return self.children
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"responses",
SymbolKind.Array,
self.range,
self.group.tokens["responses"].range,
)
@validate("responses") @validate("responses")
def container_is_message_dialog(self): def container_is_message_dialog(self):
validate_parent_type(self, "Adw", "MessageDialog", "responses") validate_parent_type(self, "Adw", "MessageDialog", "responses")

View file

@ -46,7 +46,15 @@ from ..gir import (
IntType, IntType,
StringType, StringType,
) )
from ..lsp_utils import Completion, CompletionItemKind, SemanticToken, SemanticTokenType from ..lsp_utils import (
Completion,
CompletionItemKind,
DocumentSymbol,
LocationLink,
SemanticToken,
SemanticTokenType,
SymbolKind,
)
from ..parse_tree import * from ..parse_tree import *
OBJECT_CONTENT_HOOKS = AnyOf() OBJECT_CONTENT_HOOKS = AnyOf()

View file

@ -70,8 +70,7 @@ class ScopeCtx:
): ):
raise CompileError( raise CompileError(
f"Duplicate object ID '{obj.tokens['id']}'", f"Duplicate object ID '{obj.tokens['id']}'",
token.start, token.range,
token.end,
) )
passed[obj.tokens["id"]] = obj passed[obj.tokens["id"]] = obj

View file

@ -21,6 +21,9 @@
import typing as T import typing as T
from functools import cached_property from functools import cached_property
from blueprintcompiler.errors import T
from blueprintcompiler.lsp_utils import DocumentSymbol
from .common import * from .common import *
from .response_id import ExtResponse from .response_id import ExtResponse
from .types import ClassName, ConcreteClassName from .types import ClassName, ConcreteClassName
@ -59,8 +62,20 @@ class Object(AstNode):
def signature(self) -> str: def signature(self) -> str:
if self.id: if self.id:
return f"{self.class_name.gir_type.full_name} {self.id}" return f"{self.class_name.gir_type.full_name} {self.id}"
elif t := self.class_name.gir_type:
return f"{t.full_name}"
else: else:
return f"{self.class_name.gir_type.full_name}" return f"{self.class_name.as_string}"
@property
def document_symbol(self) -> T.Optional[DocumentSymbol]:
return DocumentSymbol(
self.class_name.as_string,
SymbolKind.Object,
self.range,
self.children[ClassName][0].range,
self.id,
)
@property @property
def gir_class(self) -> GirType: def gir_class(self) -> GirType:

View file

@ -47,6 +47,21 @@ class Property(AstNode):
else: else:
return None return None
@property
def document_symbol(self) -> DocumentSymbol:
if isinstance(self.value, ObjectValue):
detail = None
else:
detail = self.value.range.text
return DocumentSymbol(
self.name,
SymbolKind.Property,
self.range,
self.group.tokens["name"].range,
detail,
)
@validate() @validate()
def binding_valid(self): def binding_valid(self):
if ( if (

View file

@ -51,12 +51,14 @@ class Signal(AstNode):
] ]
), ),
"=>", "=>",
Mark("detail_start"),
Optional(["$", UseLiteral("extern", True)]), Optional(["$", UseLiteral("extern", True)]),
UseIdent("handler").expected("the name of a function to handle the signal"), UseIdent("handler").expected("the name of a function to handle the signal"),
Match("(").expected("argument list"), Match("(").expected("argument list"),
Optional(UseIdent("object")).expected("object identifier"), Optional(UseIdent("object")).expected("object identifier"),
Match(")").expected(), Match(")").expected(),
ZeroOrMore(SignalFlag), ZeroOrMore(SignalFlag),
Mark("detail_end"),
) )
@property @property
@ -105,6 +107,16 @@ class Signal(AstNode):
def gir_class(self): def gir_class(self):
return self.parent.parent.gir_class return self.parent.parent.gir_class
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.full_name,
SymbolKind.Event,
self.range,
self.group.tokens["name"].range,
self.ranges["detail_start", "detail_end"].text,
)
@validate("handler") @validate("handler")
def old_extern(self): def old_extern(self):
if not self.tokens["extern"]: if not self.tokens["extern"]:

View file

@ -139,6 +139,16 @@ class A11yProperty(BaseAttribute):
def value_type(self) -> ValueTypeCtx: def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(get_types(self.root.gir).get(self.tokens["name"])) return ValueTypeCtx(get_types(self.root.gir).get(self.tokens["name"]))
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
self.value.range.text,
)
@validate("name") @validate("name")
def is_valid_property(self): def is_valid_property(self):
types = get_types(self.root.gir) types = get_types(self.root.gir)
@ -172,6 +182,15 @@ class ExtAccessibility(AstNode):
def properties(self) -> T.List[A11yProperty]: def properties(self) -> T.List[A11yProperty]:
return self.children[A11yProperty] return self.children[A11yProperty]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"accessibility",
SymbolKind.Struct,
self.range,
self.group.tokens["accessibility"].range,
)
@validate("accessibility") @validate("accessibility")
def container_is_widget(self): def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "accessibility properties") validate_parent_type(self, "Gtk", "Widget", "accessibility properties")

View file

@ -31,13 +31,23 @@ class Item(AstNode):
] ]
@property @property
def name(self) -> str: def name(self) -> T.Optional[str]:
return self.tokens["name"] return self.tokens["name"]
@property @property
def value(self) -> StringValue: def value(self) -> StringValue:
return self.children[StringValue][0] return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.value.range.text,
SymbolKind.String,
self.range,
self.value.range,
self.name,
)
@validate("name") @validate("name")
def unique_in_parent(self): def unique_in_parent(self):
if self.name is not None: if self.name is not None:
@ -54,6 +64,15 @@ class ExtComboBoxItems(AstNode):
"]", "]",
] ]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"items",
SymbolKind.Array,
self.range,
self.group.tokens["items"].range,
)
@validate("items") @validate("items")
def container_is_combo_box_text(self): def container_is_combo_box_text(self):
validate_parent_type(self, "Gtk", "ComboBoxText", "combo box items") validate_parent_type(self, "Gtk", "ComboBoxText", "combo box items")

View file

@ -23,6 +23,15 @@ from .gobject_object import ObjectContent, validate_parent_type
class Filters(AstNode): class Filters(AstNode):
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.tokens["tag_name"],
SymbolKind.Array,
self.range,
self.group.tokens[self.tokens["tag_name"]].range,
)
@validate() @validate()
def container_is_file_filter(self): def container_is_file_filter(self):
validate_parent_type(self, "Gtk", "FileFilter", "file filter properties") validate_parent_type(self, "Gtk", "FileFilter", "file filter properties")
@ -46,6 +55,15 @@ class FilterString(AstNode):
def item(self) -> str: def item(self) -> str:
return self.tokens["name"] return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.item,
SymbolKind.String,
self.range,
self.group.tokens["name"].range,
)
@validate() @validate()
def unique_in_parent(self): def unique_in_parent(self):
self.validate_unique_in_parent( self.validate_unique_in_parent(

View file

@ -36,6 +36,16 @@ class LayoutProperty(AstNode):
def value(self) -> Value: def value(self) -> Value:
return self.children[Value][0] return self.children[Value][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
self.value.range.text,
)
@context(ValueTypeCtx) @context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx: def value_type(self) -> ValueTypeCtx:
# there isn't really a way to validate these # there isn't really a way to validate these
@ -56,6 +66,15 @@ class ExtLayout(AstNode):
Until(LayoutProperty, "}"), Until(LayoutProperty, "}"),
) )
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"layout",
SymbolKind.Struct,
self.range,
self.group.tokens["layout"].range,
)
@validate("layout") @validate("layout")
def container_is_widget(self): def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "layout properties") validate_parent_type(self, "Gtk", "Widget", "layout properties")

View file

@ -1,5 +1,9 @@
import typing as T
from blueprintcompiler.errors import T
from blueprintcompiler.lsp_utils import DocumentSymbol
from ..ast_utils import AstNode, validate from ..ast_utils import AstNode, validate
from ..parse_tree import Keyword
from .common import * from .common import *
from .contexts import ScopeCtx from .contexts import ScopeCtx
from .gobject_object import ObjectContent, validate_parent_type from .gobject_object import ObjectContent, validate_parent_type
@ -17,6 +21,15 @@ class ExtListItemFactory(AstNode):
def signature(self) -> str: def signature(self) -> str:
return f"template {self.gir_class.full_name}" return f"template {self.gir_class.full_name}"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.signature,
SymbolKind.Object,
self.range,
self.group.tokens["id"].range,
)
@property @property
def type_name(self) -> T.Optional[TypeName]: def type_name(self) -> T.Optional[TypeName]:
if len(self.children[TypeName]) == 1: if len(self.children[TypeName]) == 1:

View file

@ -42,6 +42,16 @@ class Menu(AstNode):
else: else:
return "Gio.Menu" return "Gio.Menu"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.tokens["tag"],
SymbolKind.Object,
self.range,
self.group.tokens[self.tokens["tag"]].range,
self.id,
)
@property @property
def tag(self) -> str: def tag(self) -> str:
return self.tokens["tag"] return self.tokens["tag"]
@ -72,6 +82,18 @@ class MenuAttribute(AstNode):
def value(self) -> StringValue: def value(self) -> StringValue:
return self.children[StringValue][0] return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range
if self.group.tokens["name"]
else self.range,
self.value.range.text,
)
@context(ValueTypeCtx) @context(ValueTypeCtx)
def value_type(self) -> ValueTypeCtx: def value_type(self) -> ValueTypeCtx:
return ValueTypeCtx(None) return ValueTypeCtx(None)
@ -98,7 +120,7 @@ menu_attribute = Group(
menu_section = Group( menu_section = Group(
Menu, Menu,
[ [
"section", Keyword("section"),
UseLiteral("tag", "section"), UseLiteral("tag", "section"),
Optional(UseIdent("id")), Optional(UseIdent("id")),
Match("{").expected(), Match("{").expected(),
@ -109,7 +131,7 @@ menu_section = Group(
menu_submenu = Group( menu_submenu = Group(
Menu, Menu,
[ [
"submenu", Keyword("submenu"),
UseLiteral("tag", "submenu"), UseLiteral("tag", "submenu"),
Optional(UseIdent("id")), Optional(UseIdent("id")),
Match("{").expected(), Match("{").expected(),
@ -120,7 +142,7 @@ menu_submenu = Group(
menu_item = Group( menu_item = Group(
Menu, Menu,
[ [
"item", Keyword("item"),
UseLiteral("tag", "item"), UseLiteral("tag", "item"),
Match("{").expected(), Match("{").expected(),
Until(menu_attribute, "}"), Until(menu_attribute, "}"),
@ -130,7 +152,7 @@ menu_item = Group(
menu_item_shorthand = Group( menu_item_shorthand = Group(
Menu, Menu,
[ [
"item", Keyword("item"),
UseLiteral("tag", "item"), UseLiteral("tag", "item"),
"(", "(",
Group( Group(

View file

@ -58,6 +58,24 @@ class ExtScaleMark(AstNode):
else: else:
return None return None
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
str(self.value),
SymbolKind.Field,
self.range,
self.group.tokens["mark"].range,
self.label.string if self.label else None,
)
def get_semantic_tokens(self) -> T.Iterator[SemanticToken]:
if range := self.ranges["position"]:
yield SemanticToken(
range.start,
range.end,
SemanticTokenType.EnumMember,
)
@docs("position") @docs("position")
def position_docs(self) -> T.Optional[str]: def position_docs(self) -> T.Optional[str]:
if member := self.root.gir.get_type("PositionType", "Gtk").members.get( if member := self.root.gir.get_type("PositionType", "Gtk").members.get(
@ -88,6 +106,15 @@ class ExtScaleMarks(AstNode):
def marks(self) -> T.List[ExtScaleMark]: def marks(self) -> T.List[ExtScaleMark]:
return self.children return self.children
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"marks",
SymbolKind.Array,
self.range,
self.group.tokens["marks"].range,
)
@validate("marks") @validate("marks")
def container_is_size_group(self): def container_is_size_group(self):
validate_parent_type(self, "Gtk", "Scale", "scale marks") validate_parent_type(self, "Gtk", "Scale", "scale marks")

View file

@ -30,6 +30,15 @@ class Widget(AstNode):
def name(self) -> str: def name(self) -> str:
return self.tokens["name"] return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.Field,
self.range,
self.group.tokens["name"].range,
)
@validate("name") @validate("name")
def obj_widget(self): def obj_widget(self):
object = self.context[ScopeCtx].objects.get(self.tokens["name"]) object = self.context[ScopeCtx].objects.get(self.tokens["name"])
@ -62,6 +71,15 @@ class ExtSizeGroupWidgets(AstNode):
"]", "]",
] ]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"widgets",
SymbolKind.Array,
self.range,
self.group.tokens["widgets"].range,
)
@validate("widgets") @validate("widgets")
def container_is_size_group(self): def container_is_size_group(self):
validate_parent_type(self, "Gtk", "SizeGroup", "size group properties") validate_parent_type(self, "Gtk", "SizeGroup", "size group properties")

View file

@ -30,6 +30,15 @@ class Item(AstNode):
def child(self) -> StringValue: def child(self) -> StringValue:
return self.children[StringValue][0] return self.children[StringValue][0]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.child.range.text,
SymbolKind.String,
self.range,
self.range,
)
class ExtStringListStrings(AstNode): class ExtStringListStrings(AstNode):
grammar = [ grammar = [
@ -39,6 +48,15 @@ class ExtStringListStrings(AstNode):
"]", "]",
] ]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"strings",
SymbolKind.Array,
self.range,
self.group.tokens["strings"].range,
)
@validate("items") @validate("items")
def container_is_string_list(self): def container_is_string_list(self):
validate_parent_type(self, "Gtk", "StringList", "StringList items") validate_parent_type(self, "Gtk", "StringList", "StringList items")

View file

@ -29,6 +29,15 @@ class StyleClass(AstNode):
def name(self) -> str: def name(self) -> str:
return self.tokens["name"] return self.tokens["name"]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.name,
SymbolKind.String,
self.range,
self.range,
)
@validate("name") @validate("name")
def unique_in_parent(self): def unique_in_parent(self):
self.validate_unique_in_parent( self.validate_unique_in_parent(
@ -44,6 +53,15 @@ class ExtStyles(AstNode):
"]", "]",
] ]
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
"styles",
SymbolKind.Array,
self.range,
self.group.tokens["styles"].range,
)
@validate("styles") @validate("styles")
def container_is_widget(self): def container_is_widget(self):
validate_parent_type(self, "Gtk", "Widget", "style classes") validate_parent_type(self, "Gtk", "Widget", "style classes")

View file

@ -46,10 +46,19 @@ class Template(Object):
@property @property
def signature(self) -> str: def signature(self) -> str:
if self.parent_type: if self.parent_type and self.parent_type.gir_type:
return f"template {self.gir_class.full_name} : {self.parent_type.gir_type.full_name}" return f"template {self.class_name.as_string} : {self.parent_type.gir_type.full_name}"
else: else:
return f"template {self.gir_class.full_name}" return f"template {self.class_name.as_string}"
@property
def document_symbol(self) -> DocumentSymbol:
return DocumentSymbol(
self.signature,
SymbolKind.Object,
self.range,
self.group.tokens["id"].range,
)
@property @property
def gir_class(self) -> GirType: def gir_class(self) -> GirType:

View file

@ -55,7 +55,16 @@ class TypeName(AstNode):
@validate("namespace") @validate("namespace")
def gir_ns_exists(self): def gir_ns_exists(self):
if not self.tokens["extern"]: if not self.tokens["extern"]:
self.root.gir.validate_ns(self.tokens["namespace"]) try:
self.root.gir.validate_ns(self.tokens["namespace"])
except CompileError as e:
ns = self.tokens["namespace"]
e.actions = [
self.root.import_code_action(n, version)
for n, version in gir.get_available_namespaces()
if n == ns
]
raise e
@validate() @validate()
def deprecated(self) -> None: def deprecated(self) -> None:

View file

@ -62,8 +62,7 @@ class UI(AstNode):
else: else:
gir_ctx.not_found_namespaces.add(i.namespace) gir_ctx.not_found_namespaces.add(i.namespace)
except CompileError as e: except CompileError as e:
e.start = i.group.tokens["namespace"].start e.range = i.range
e.end = i.group.tokens["version"].end
self._gir_errors.append(e) self._gir_errors.append(e)
return gir_ctx return gir_ctx
@ -100,6 +99,18 @@ class UI(AstNode):
and self.template.class_name.glib_type_name == id and self.template.class_name.glib_type_name == id
) )
def import_code_action(self, ns: str, version: str) -> CodeAction:
if len(self.children[Import]):
pos = self.children[Import][-1].range.end
else:
pos = self.children[GtkDirective][0].range.end
return CodeAction(
f"Import {ns} {version}",
f"\nusing {ns} {version};",
Range(pos, pos, self.group.text),
)
@context(ScopeCtx) @context(ScopeCtx)
def scope_ctx(self) -> ScopeCtx: def scope_ctx(self) -> ScopeCtx:
return ScopeCtx(node=self) return ScopeCtx(node=self)

View file

@ -339,6 +339,16 @@ class IdentLiteral(AstNode):
token = self.group.tokens["value"] token = self.group.tokens["value"]
yield SemanticToken(token.start, token.end, SemanticTokenType.EnumMember) yield SemanticToken(token.start, token.end, SemanticTokenType.EnumMember)
def get_reference(self, _idx: int) -> T.Optional[LocationLink]:
ref = self.context[ScopeCtx].objects.get(self.ident)
if ref is None and self.root.is_legacy_template(self.ident):
ref = self.root.template
if ref:
return LocationLink(self.range, ref.range, ref.ranges["id"])
else:
return None
class Literal(AstNode): class Literal(AstNode):
grammar = AnyOf( grammar = AnyOf(

View file

@ -24,10 +24,12 @@ import traceback
import typing as T import typing as T
from . import decompiler, parser, tokenizer, utils, xml_reader from . import decompiler, parser, tokenizer, utils, xml_reader
from .ast_utils import AstNode
from .completions import complete from .completions import complete
from .errors import CompileError, MultipleErrors, PrintableError from .errors import CompileError, MultipleErrors
from .lsp_utils import * from .lsp_utils import *
from .outputs.xml import XmlOutput from .outputs.xml import XmlOutput
from .tokenizer import Token
def printerr(*args, **kwargs): def printerr(*args, **kwargs):
@ -43,16 +45,16 @@ def command(json_method: str):
class OpenFile: class OpenFile:
def __init__(self, uri: str, text: str, version: int): def __init__(self, uri: str, text: str, version: int) -> None:
self.uri = uri self.uri = uri
self.text = text self.text = text
self.version = version self.version = version
self.ast = None self.ast: T.Optional[AstNode] = None
self.tokens = None self.tokens: T.Optional[list[Token]] = None
self._update() self._update()
def apply_changes(self, changes): def apply_changes(self, changes) -> None:
for change in changes: for change in changes:
if "range" not in change: if "range" not in change:
self.text = change["text"] self.text = change["text"]
@ -70,8 +72,8 @@ class OpenFile:
self.text = self.text[:start] + change["text"] + self.text[end:] self.text = self.text[:start] + change["text"] + self.text[end:]
self._update() self._update()
def _update(self): def _update(self) -> None:
self.diagnostics = [] self.diagnostics: list[CompileError] = []
try: try:
self.tokens = tokenizer.tokenize(self.text) self.tokens = tokenizer.tokenize(self.text)
self.ast, errors, warnings = parser.parse(self.tokens) self.ast, errors, warnings = parser.parse(self.tokens)
@ -203,6 +205,8 @@ class LanguageServer:
"completionProvider": {}, "completionProvider": {},
"codeActionProvider": {}, "codeActionProvider": {},
"hoverProvider": True, "hoverProvider": True,
"documentSymbolProvider": True,
"definitionProvider": True,
}, },
"serverInfo": { "serverInfo": {
"name": "Blueprint", "name": "Blueprint",
@ -325,14 +329,17 @@ class LanguageServer:
def code_actions(self, id, params): def code_actions(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]] open_file = self._open_files[params["textDocument"]["uri"]]
range_start = utils.pos_to_idx( range = Range(
params["range"]["start"]["line"], utils.pos_to_idx(
params["range"]["start"]["character"], params["range"]["start"]["line"],
open_file.text, params["range"]["start"]["character"],
) open_file.text,
range_end = utils.pos_to_idx( ),
params["range"]["end"]["line"], utils.pos_to_idx(
params["range"]["end"]["character"], params["range"]["end"]["line"],
params["range"]["end"]["character"],
open_file.text,
),
open_file.text, open_file.text,
) )
@ -340,16 +347,14 @@ class LanguageServer:
{ {
"title": action.title, "title": action.title,
"kind": "quickfix", "kind": "quickfix",
"diagnostics": [ "diagnostics": [self._create_diagnostic(open_file.uri, diagnostic)],
self._create_diagnostic(open_file.text, open_file.uri, diagnostic)
],
"edit": { "edit": {
"changes": { "changes": {
open_file.uri: [ open_file.uri: [
{ {
"range": utils.idxs_to_range( "range": action.edit_range.to_json()
diagnostic.start, diagnostic.end, open_file.text if action.edit_range
), else diagnostic.range.to_json(),
"newText": action.replace_with, "newText": action.replace_with,
} }
] ]
@ -357,34 +362,68 @@ class LanguageServer:
}, },
} }
for diagnostic in open_file.diagnostics for diagnostic in open_file.diagnostics
if not (diagnostic.end < range_start or diagnostic.start > range_end) if range.overlaps(diagnostic.range)
for action in diagnostic.actions for action in diagnostic.actions
] ]
self._send_response(id, actions) self._send_response(id, actions)
@command("textDocument/documentSymbol")
def document_symbols(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
symbols = open_file.ast.get_document_symbols()
def to_json(symbol: DocumentSymbol):
result = {
"name": symbol.name,
"kind": symbol.kind,
"range": symbol.range.to_json(),
"selectionRange": symbol.selection_range.to_json(),
"children": [to_json(child) for child in symbol.children],
}
if symbol.detail is not None:
result["detail"] = symbol.detail
return result
self._send_response(id, [to_json(symbol) for symbol in symbols])
@command("textDocument/definition")
def definition(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
idx = utils.pos_to_idx(
params["position"]["line"], params["position"]["character"], open_file.text
)
definition = open_file.ast.get_reference(idx)
if definition is None:
self._send_response(id, None)
else:
self._send_response(
id,
definition.to_json(open_file.uri),
)
def _send_file_updates(self, open_file: OpenFile): def _send_file_updates(self, open_file: OpenFile):
self._send_notification( self._send_notification(
"textDocument/publishDiagnostics", "textDocument/publishDiagnostics",
{ {
"uri": open_file.uri, "uri": open_file.uri,
"diagnostics": [ "diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, err) self._create_diagnostic(open_file.uri, err)
for err in open_file.diagnostics for err in open_file.diagnostics
], ],
}, },
) )
def _create_diagnostic(self, text: str, uri: str, err: CompileError): def _create_diagnostic(self, uri: str, err: CompileError):
message = err.message message = err.message
assert err.start is not None and err.end is not None assert err.range is not None
for hint in err.hints: for hint in err.hints:
message += "\nhint: " + hint message += "\nhint: " + hint
result = { result = {
"range": utils.idxs_to_range(err.start, err.end, text), "range": err.range.to_json(),
"message": message, "message": message,
"severity": DiagnosticSeverity.Warning "severity": DiagnosticSeverity.Warning
if isinstance(err, CompileWarning) if isinstance(err, CompileWarning)
@ -399,7 +438,7 @@ class LanguageServer:
{ {
"location": { "location": {
"uri": uri, "uri": uri,
"range": utils.idxs_to_range(ref.start, ref.end, text), "range": ref.range.to_json(),
}, },
"message": ref.message, "message": ref.message,
} }

View file

@ -20,9 +20,10 @@
import enum import enum
import typing as T import typing as T
from dataclasses import dataclass from dataclasses import dataclass, field
from .errors import * from .errors import *
from .tokenizer import Range
from .utils import * from .utils import *
@ -129,3 +130,57 @@ class SemanticToken:
start: int start: int
end: int end: int
type: SemanticTokenType type: SemanticTokenType
class SymbolKind(enum.IntEnum):
File = 1
Module = 2
Namespace = 3
Package = 4
Class = 5
Method = 6
Property = 7
Field = 8
Constructor = 9
Enum = 10
Interface = 11
Function = 12
Variable = 13
Constant = 14
String = 15
Number = 16
Boolean = 17
Array = 18
Object = 19
Key = 20
Null = 21
EnumMember = 22
Struct = 23
Event = 24
Operator = 25
TypeParameter = 26
@dataclass
class DocumentSymbol:
name: str
kind: SymbolKind
range: Range
selection_range: Range
detail: T.Optional[str] = None
children: T.List["DocumentSymbol"] = field(default_factory=list)
@dataclass
class LocationLink:
origin_selection_range: Range
target_range: Range
target_selection_range: Range
def to_json(self, target_uri: str):
return {
"originSelectionRange": self.origin_selection_range.to_json(),
"targetUri": target_uri,
"targetRange": self.target_range.to_json(),
"targetSelectionRange": self.target_selection_range.to_json(),
}

View file

@ -24,8 +24,8 @@ import os
import sys import sys
import typing as T import typing as T
from . import decompiler, interactive_port, parser, tokenizer from . import interactive_port, parser, tokenizer
from .errors import CompilerBugError, MultipleErrors, PrintableError, report_bug from .errors import CompileError, CompilerBugError, PrintableError, report_bug
from .gir import add_typelib_search_path from .gir import add_typelib_search_path
from .lsp import LanguageServer from .lsp import LanguageServer
from .outputs import XmlOutput from .outputs import XmlOutput
@ -157,7 +157,7 @@ class BlueprintApp:
def cmd_port(self, opts): def cmd_port(self, opts):
interactive_port.run(opts) interactive_port.run(opts)
def _compile(self, data: str) -> T.Tuple[str, T.List[PrintableError]]: def _compile(self, data: str) -> T.Tuple[str, T.List[CompileError]]:
tokens = tokenizer.tokenize(data) tokens = tokenizer.tokenize(data)
ast, errors, warnings = parser.parse(tokens) ast, errors, warnings = parser.parse(tokens)

View file

@ -20,7 +20,6 @@
""" Utilities for parsing an AST from a token stream. """ """ Utilities for parsing an AST from a token stream. """
import typing as T import typing as T
from collections import defaultdict
from enum import Enum from enum import Enum
from .ast_utils import AstNode from .ast_utils import AstNode
@ -31,7 +30,7 @@ from .errors import (
UnexpectedTokenError, UnexpectedTokenError,
assert_true, assert_true,
) )
from .tokenizer import Token, TokenType from .tokenizer import Range, Token, TokenType
SKIP_TOKENS = [TokenType.COMMENT, TokenType.WHITESPACE] SKIP_TOKENS = [TokenType.COMMENT, TokenType.WHITESPACE]
@ -63,14 +62,16 @@ class ParseGroup:
be converted to AST nodes by passing the children and key=value pairs to be converted to AST nodes by passing the children and key=value pairs to
the AST node constructor.""" the AST node constructor."""
def __init__(self, ast_type: T.Type[AstNode], start: int): def __init__(self, ast_type: T.Type[AstNode], start: int, text: str):
self.ast_type = ast_type self.ast_type = ast_type
self.children: T.List[ParseGroup] = [] self.children: T.List[ParseGroup] = []
self.keys: T.Dict[str, T.Any] = {} self.keys: T.Dict[str, T.Any] = {}
self.tokens: T.Dict[str, T.Optional[Token]] = {} self.tokens: T.Dict[str, T.Optional[Token]] = {}
self.ranges: T.Dict[str, Range] = {}
self.start = start self.start = start
self.end: T.Optional[int] = None self.end: T.Optional[int] = None
self.incomplete = False self.incomplete = False
self.text = text
def add_child(self, child: "ParseGroup"): def add_child(self, child: "ParseGroup"):
self.children.append(child) self.children.append(child)
@ -80,6 +81,12 @@ class ParseGroup:
self.keys[key] = val self.keys[key] = val
self.tokens[key] = token self.tokens[key] = token
if token:
self.set_range(key, token.range)
def set_range(self, key: str, range: Range):
assert_true(key not in self.ranges)
self.ranges[key] = range
def to_ast(self): def to_ast(self):
"""Creates an AST node from the match group.""" """Creates an AST node from the match group."""
@ -104,8 +111,9 @@ class ParseGroup:
class ParseContext: class ParseContext:
"""Contains the state of the parser.""" """Contains the state of the parser."""
def __init__(self, tokens: T.List[Token], index=0): def __init__(self, tokens: T.List[Token], text: str, index=0):
self.tokens = tokens self.tokens = tokens
self.text = text
self.binding_power = 0 self.binding_power = 0
self.index = index self.index = index
@ -113,6 +121,7 @@ class ParseContext:
self.group: T.Optional[ParseGroup] = None self.group: T.Optional[ParseGroup] = None
self.group_keys: T.Dict[str, T.Tuple[T.Any, T.Optional[Token]]] = {} self.group_keys: T.Dict[str, T.Tuple[T.Any, T.Optional[Token]]] = {}
self.group_children: T.List[ParseGroup] = [] self.group_children: T.List[ParseGroup] = []
self.group_ranges: T.Dict[str, Range] = {}
self.last_group: T.Optional[ParseGroup] = None self.last_group: T.Optional[ParseGroup] = None
self.group_incomplete = False self.group_incomplete = False
@ -124,7 +133,7 @@ class ParseContext:
context will be used to parse one node. If parsing is successful, the context will be used to parse one node. If parsing is successful, the
new context will be applied to "self". If parsing fails, the new new context will be applied to "self". If parsing fails, the new
context will be discarded.""" context will be discarded."""
ctx = ParseContext(self.tokens, self.index) ctx = ParseContext(self.tokens, self.text, self.index)
ctx.errors = self.errors ctx.errors = self.errors
ctx.warnings = self.warnings ctx.warnings = self.warnings
ctx.binding_power = self.binding_power ctx.binding_power = self.binding_power
@ -140,6 +149,8 @@ class ParseContext:
other.group.set_val(key, val, token) other.group.set_val(key, val, token)
for child in other.group_children: for child in other.group_children:
other.group.add_child(child) other.group.add_child(child)
for key, range in other.group_ranges.items():
other.group.set_range(key, range)
other.group.end = other.tokens[other.index - 1].end other.group.end = other.tokens[other.index - 1].end
other.group.incomplete = other.group_incomplete other.group.incomplete = other.group_incomplete
self.group_children.append(other.group) self.group_children.append(other.group)
@ -148,6 +159,7 @@ class ParseContext:
# its matched values # its matched values
self.group_keys = {**self.group_keys, **other.group_keys} self.group_keys = {**self.group_keys, **other.group_keys}
self.group_children += other.group_children self.group_children += other.group_children
self.group_ranges = {**self.group_ranges, **other.group_ranges}
self.group_incomplete |= other.group_incomplete self.group_incomplete |= other.group_incomplete
self.index = other.index self.index = other.index
@ -161,13 +173,19 @@ class ParseContext:
def start_group(self, ast_type: T.Type[AstNode]): def start_group(self, ast_type: T.Type[AstNode]):
"""Sets this context to have its own match group.""" """Sets this context to have its own match group."""
assert_true(self.group is None) assert_true(self.group is None)
self.group = ParseGroup(ast_type, self.tokens[self.index].start) self.group = ParseGroup(ast_type, self.tokens[self.index].start, self.text)
def set_group_val(self, key: str, value: T.Any, token: T.Optional[Token]): def set_group_val(self, key: str, value: T.Any, token: T.Optional[Token]):
"""Sets a matched key=value pair on the current match group.""" """Sets a matched key=value pair on the current match group."""
assert_true(key not in self.group_keys) assert_true(key not in self.group_keys)
self.group_keys[key] = (value, token) self.group_keys[key] = (value, token)
def set_mark(self, key: str):
"""Sets a zero-length range on the current match group at the current position."""
self.group_ranges[key] = Range(
self.tokens[self.index].start, self.tokens[self.index].start, self.text
)
def set_group_incomplete(self): def set_group_incomplete(self):
"""Marks the current match group as incomplete (it could not be fully """Marks the current match group as incomplete (it could not be fully
parsed, but the parser recovered).""" parsed, but the parser recovered)."""
@ -206,11 +224,11 @@ class ParseContext:
if ( if (
len(self.errors) len(self.errors)
and isinstance((err := self.errors[-1]), UnexpectedTokenError) and isinstance((err := self.errors[-1]), UnexpectedTokenError)
and err.end == start and err.range.end == start
): ):
err.end = end err.range.end = end
else: else:
self.errors.append(UnexpectedTokenError(start, end)) self.errors.append(UnexpectedTokenError(Range(start, end, self.text)))
def is_eof(self) -> bool: def is_eof(self) -> bool:
return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF
@ -263,10 +281,11 @@ class Err(ParseNode):
start_idx = ctx.start start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS: while ctx.tokens[start_idx].type in SKIP_TOKENS:
start_idx += 1 start_idx += 1
start_token = ctx.tokens[start_idx] start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end) raise CompileError(
self.message, Range(start_token.start, start_token.start, ctx.text)
)
return True return True
@ -306,7 +325,9 @@ class Fail(ParseNode):
start_token = ctx.tokens[start_idx] start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index] end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end) raise CompileError(
self.message, Range.join(start_token.range, end_token.range)
)
return True return True
@ -355,7 +376,7 @@ class Statement(ParseNode):
token = ctx.peek_token() token = ctx.peek_token()
if str(token) != ";": if str(token) != ";":
ctx.errors.append(CompileError("Expected `;`", token.start, token.end)) ctx.errors.append(CompileError("Expected `;`", token.range))
else: else:
ctx.next_token() ctx.next_token()
return True return True
@ -604,6 +625,15 @@ class Keyword(ParseNode):
return str(token) == self.kw return str(token) == self.kw
class Mark(ParseNode):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
ctx.set_mark(self.key)
return True
def to_parse_node(value) -> ParseNode: def to_parse_node(value) -> ParseNode:
if isinstance(value, str): if isinstance(value, str):
return Match(value) return Match(value)

View file

@ -26,11 +26,12 @@ from .tokenizer import TokenType
def parse( def parse(
tokens: T.List[Token], tokens: T.List[Token],
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[PrintableError]]: ) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[CompileError]]:
"""Parses a list of tokens into an abstract syntax tree.""" """Parses a list of tokens into an abstract syntax tree."""
try: try:
ctx = ParseContext(tokens) original_text = tokens[0].string if len(tokens) else ""
ctx = ParseContext(tokens, original_text)
AnyOf(UI).parse(ctx) AnyOf(UI).parse(ctx)
ast_node = ctx.last_group.to_ast() if ctx.last_group else None ast_node = ctx.last_group.to_ast() if ctx.last_group else None

View file

@ -20,9 +20,10 @@
import re import re
import typing as T import typing as T
from dataclasses import dataclass
from enum import Enum from enum import Enum
from .errors import CompileError, CompilerBugError from . import utils
class TokenType(Enum): class TokenType(Enum):
@ -62,7 +63,13 @@ class Token:
def __str__(self) -> str: def __str__(self) -> str:
return self.string[self.start : self.end] return self.string[self.start : self.end]
@property
def range(self) -> "Range":
return Range(self.start, self.end, self.string)
def get_number(self) -> T.Union[int, float]: def get_number(self) -> T.Union[int, float]:
from .errors import CompileError, CompilerBugError
if self.type != TokenType.NUMBER: if self.type != TokenType.NUMBER:
raise CompilerBugError() raise CompilerBugError()
@ -75,12 +82,12 @@ class Token:
else: else:
return int(string) return int(string)
except: except:
raise CompileError( raise CompileError(f"{str(self)} is not a valid number literal", self.range)
f"{str(self)} is not a valid number literal", self.start, self.end
)
def _tokenize(ui_ml: str): def _tokenize(ui_ml: str):
from .errors import CompileError
i = 0 i = 0
while i < len(ui_ml): while i < len(ui_ml):
matched = False matched = False
@ -95,7 +102,8 @@ def _tokenize(ui_ml: str):
if not matched: if not matched:
raise CompileError( raise CompileError(
"Could not determine what kind of syntax is meant here", i, i "Could not determine what kind of syntax is meant here",
Range(i, i, ui_ml),
) )
yield Token(TokenType.EOF, i, i, ui_ml) yield Token(TokenType.EOF, i, i, ui_ml)
@ -103,3 +111,38 @@ def _tokenize(ui_ml: str):
def tokenize(data: str) -> T.List[Token]: def tokenize(data: str) -> T.List[Token]:
return list(_tokenize(data)) return list(_tokenize(data))
@dataclass
class Range:
start: int
end: int
original_text: str
@property
def length(self) -> int:
return self.end - self.start
@property
def text(self) -> str:
return self.original_text[self.start : self.end]
@staticmethod
def join(a: T.Optional["Range"], b: T.Optional["Range"]) -> T.Optional["Range"]:
if a is None:
return b
if b is None:
return a
return Range(min(a.start, b.start), max(a.end, b.end), a.original_text)
def __contains__(self, other: T.Union[int, "Range"]) -> bool:
if isinstance(other, int):
return self.start <= other <= self.end
else:
return self.start <= other.start and self.end >= other.end
def to_json(self):
return utils.idxs_to_range(self.start, self.end, self.original_text)
def overlaps(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)

View file

@ -40,13 +40,12 @@ from blueprintcompiler.tokenizer import Token, TokenType, tokenize
class TestSamples(unittest.TestCase): class TestSamples(unittest.TestCase):
def assert_docs_dont_crash(self, text, ast): def assert_ast_doesnt_crash(self, text, tokens, ast):
for i in range(len(text)): for i in range(len(text)):
ast.get_docs(i) ast.get_docs(i)
def assert_completions_dont_crash(self, text, ast, tokens):
for i in range(len(text)): for i in range(len(text)):
list(complete(ast, tokens, i)) list(complete(ast, tokens, i))
ast.get_document_symbols()
def assert_sample(self, name, skip_run=False): def assert_sample(self, name, skip_run=False):
print(f'assert_sample("{name}", skip_run={skip_run})') print(f'assert_sample("{name}", skip_run={skip_run})')
@ -79,8 +78,7 @@ class TestSamples(unittest.TestCase):
print("\n".join(diff)) print("\n".join(diff))
raise AssertionError() raise AssertionError()
self.assert_docs_dont_crash(blueprint, ast) self.assert_ast_doesnt_crash(blueprint, tokens, ast)
self.assert_completions_dont_crash(blueprint, ast, tokens)
except PrintableError as e: # pragma: no cover except PrintableError as e: # pragma: no cover
e.pretty_print(name + ".blp", blueprint) e.pretty_print(name + ".blp", blueprint)
raise AssertionError() raise AssertionError()
@ -105,8 +103,7 @@ class TestSamples(unittest.TestCase):
ast, errors, warnings = parser.parse(tokens) ast, errors, warnings = parser.parse(tokens)
if ast is not None: if ast is not None:
self.assert_docs_dont_crash(blueprint, ast) self.assert_ast_doesnt_crash(blueprint, tokens, ast)
self.assert_completions_dont_crash(blueprint, ast, tokens)
if errors: if errors:
raise errors raise errors
@ -116,9 +113,9 @@ class TestSamples(unittest.TestCase):
raise MultipleErrors(warnings) raise MultipleErrors(warnings)
except PrintableError as e: except PrintableError as e:
def error_str(error): def error_str(error: CompileError):
line, col = utils.idx_to_pos(error.start + 1, blueprint) line, col = utils.idx_to_pos(error.range.start + 1, blueprint)
len = error.end - error.start len = error.range.length
return ",".join([str(line + 1), str(col), str(len), error.message]) return ",".join([str(line + 1), str(col), str(len), error.message])
if isinstance(e, CompileError): if isinstance(e, CompileError):