Use the new Range class in more places

This commit is contained in:
James Westman 2023-07-25 20:01:41 -05:00
parent 56274d7c1f
commit 3bcc9f4cbd
10 changed files with 91 additions and 90 deletions

View file

@ -229,20 +229,23 @@ class AstNode:
error,
references=[
ErrorReference(
child.group.start,
child.group.end,
child.range,
"previous declaration was here",
)
],
)
def validate(token_name=None, end_token_name=None, skip_incomplete=False):
def validate(
token_name: T.Optional[str] = None,
end_token_name: T.Optional[str] = None,
skip_incomplete: bool = False,
):
"""Decorator for functions that validate an AST node. Exceptions raised
during validation are marked with range information from the tokens."""
def decorator(func):
def inner(self):
def inner(self: AstNode):
if skip_incomplete and self.incomplete:
return
@ -254,22 +257,14 @@ def validate(token_name=None, end_token_name=None, skip_incomplete=False):
if self.incomplete:
return
# This mess of code sets the error's start and end positions
# from the tokens passed to the decorator, if they have not
# already been set
if e.start is None:
if token := self.group.tokens.get(token_name):
e.start = token.start
else:
e.start = self.group.start
if e.end is None:
if token := self.group.tokens.get(end_token_name):
e.end = token.end
elif token := self.group.tokens.get(token_name):
e.end = token.end
else:
e.end = self.group.end
if e.range is None:
e.range = (
Range.join(
self.ranges[token_name],
self.ranges[end_token_name],
)
or self.range
)
# Re-raise the exception
raise e

View file

@ -23,6 +23,7 @@ import typing as T
from dataclasses import dataclass
from . import utils
from .tokenizer import Range
from .utils import Colors
@ -36,8 +37,7 @@ class PrintableError(Exception):
@dataclass
class ErrorReference:
start: int
end: int
range: Range
message: str
@ -50,8 +50,7 @@ class CompileError(PrintableError):
def __init__(
self,
message: str,
start: T.Optional[int] = None,
end: T.Optional[int] = None,
range: T.Optional[Range] = None,
did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None,
hints: T.Optional[T.List[str]] = None,
actions: T.Optional[T.List["CodeAction"]] = None,
@ -61,8 +60,7 @@ class CompileError(PrintableError):
super().__init__(message)
self.message = message
self.start = start
self.end = end
self.range = range
self.hints = hints or []
self.actions = actions or []
self.references = references or []
@ -92,9 +90,9 @@ class CompileError(PrintableError):
self.hint("Are your dependencies up to date?")
def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None:
assert self.start is not None
assert self.range is not None
line_num, col_num = utils.idx_to_pos(self.start + 1, code)
line_num, col_num = utils.idx_to_pos(self.range.start + 1, code)
line = code.splitlines(True)[line_num]
# Display 1-based line numbers
@ -110,7 +108,7 @@ at {filename} line {line_num} column {col_num}:
stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n")
for ref in self.references:
line_num, col_num = utils.idx_to_pos(ref.start + 1, code)
line_num, col_num = utils.idx_to_pos(ref.range.start + 1, code)
line = code.splitlines(True)[line_num]
line_num += 1
@ -138,14 +136,15 @@ class UpgradeWarning(CompileWarning):
class UnexpectedTokenError(CompileError):
def __init__(self, start, end) -> None:
super().__init__("Unexpected tokens", start, end)
def __init__(self, range: Range) -> None:
super().__init__("Unexpected tokens", range)
@dataclass
class CodeAction:
title: str
replace_with: str
edit_range: T.Optional[Range] = None
class MultipleErrors(PrintableError):

View file

@ -70,8 +70,7 @@ class ScopeCtx:
):
raise CompileError(
f"Duplicate object ID '{obj.tokens['id']}'",
token.start,
token.end,
token.range,
)
passed[obj.tokens["id"]] = obj

View file

@ -62,8 +62,7 @@ class UI(AstNode):
else:
gir_ctx.not_found_namespaces.add(i.namespace)
except CompileError as e:
e.start = i.group.tokens["namespace"].start
e.end = i.group.tokens["version"].end
e.range = i.range
self._gir_errors.append(e)
return gir_ctx

View file

@ -24,10 +24,12 @@ import traceback
import typing as T
from . import decompiler, parser, tokenizer, utils, xml_reader
from .ast_utils import AstNode
from .completions import complete
from .errors import CompileError, MultipleErrors, PrintableError
from .errors import CompileError, MultipleErrors
from .lsp_utils import *
from .outputs.xml import XmlOutput
from .tokenizer import Token
def printerr(*args, **kwargs):
@ -43,16 +45,16 @@ def command(json_method: str):
class OpenFile:
def __init__(self, uri: str, text: str, version: int):
def __init__(self, uri: str, text: str, version: int) -> None:
self.uri = uri
self.text = text
self.version = version
self.ast = None
self.tokens = None
self.ast: T.Optional[AstNode] = None
self.tokens: T.Optional[list[Token]] = None
self._update()
def apply_changes(self, changes):
def apply_changes(self, changes) -> None:
for change in changes:
if "range" not in change:
self.text = change["text"]
@ -70,8 +72,8 @@ class OpenFile:
self.text = self.text[:start] + change["text"] + self.text[end:]
self._update()
def _update(self):
self.diagnostics = []
def _update(self) -> None:
self.diagnostics: list[CompileError] = []
try:
self.tokens = tokenizer.tokenize(self.text)
self.ast, errors, warnings = parser.parse(self.tokens)
@ -327,31 +329,32 @@ class LanguageServer:
def code_actions(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]]
range_start = utils.pos_to_idx(
range = Range(
utils.pos_to_idx(
params["range"]["start"]["line"],
params["range"]["start"]["character"],
open_file.text,
)
range_end = utils.pos_to_idx(
),
utils.pos_to_idx(
params["range"]["end"]["line"],
params["range"]["end"]["character"],
open_file.text,
),
open_file.text,
)
actions = [
{
"title": action.title,
"kind": "quickfix",
"diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, diagnostic)
],
"diagnostics": [self._create_diagnostic(open_file.uri, diagnostic)],
"edit": {
"changes": {
open_file.uri: [
{
"range": utils.idxs_to_range(
diagnostic.start, diagnostic.end, open_file.text
),
"range": action.edit_range.to_json()
if action.edit_range
else diagnostic.range.to_json(),
"newText": action.replace_with,
}
]
@ -359,7 +362,7 @@ class LanguageServer:
},
}
for diagnostic in open_file.diagnostics
if not (diagnostic.end < range_start or diagnostic.start > range_end)
if range.overlaps(diagnostic.range)
for action in diagnostic.actions
]
@ -374,14 +377,8 @@ class LanguageServer:
result = {
"name": symbol.name,
"kind": symbol.kind,
"range": utils.idxs_to_range(
symbol.range.start, symbol.range.end, open_file.text
),
"selectionRange": utils.idxs_to_range(
symbol.selection_range.start,
symbol.selection_range.end,
open_file.text,
),
"range": symbol.range.to_json(),
"selectionRange": symbol.selection_range.to_json(),
"children": [to_json(child) for child in symbol.children],
}
if symbol.detail is not None:
@ -411,22 +408,22 @@ class LanguageServer:
{
"uri": open_file.uri,
"diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, err)
self._create_diagnostic(open_file.uri, err)
for err in open_file.diagnostics
],
},
)
def _create_diagnostic(self, text: str, uri: str, err: CompileError):
def _create_diagnostic(self, uri: str, err: CompileError):
message = err.message
assert err.start is not None and err.end is not None
assert err.range is not None
for hint in err.hints:
message += "\nhint: " + hint
result = {
"range": utils.idxs_to_range(err.start, err.end, text),
"range": err.range.to_json(),
"message": message,
"severity": DiagnosticSeverity.Warning
if isinstance(err, CompileWarning)
@ -441,7 +438,7 @@ class LanguageServer:
{
"location": {
"uri": uri,
"range": utils.idxs_to_range(ref.start, ref.end, text),
"range": ref.range.to_json(),
},
"message": ref.message,
}

View file

@ -24,8 +24,8 @@ import os
import sys
import typing as T
from . import decompiler, interactive_port, parser, tokenizer
from .errors import CompilerBugError, MultipleErrors, PrintableError, report_bug
from . import interactive_port, parser, tokenizer
from .errors import CompilerBugError, CompileError, PrintableError, report_bug
from .gir import add_typelib_search_path
from .lsp import LanguageServer
from .outputs import XmlOutput
@ -157,7 +157,7 @@ class BlueprintApp:
def cmd_port(self, opts):
interactive_port.run(opts)
def _compile(self, data: str) -> T.Tuple[str, T.List[PrintableError]]:
def _compile(self, data: str) -> T.Tuple[str, T.List[CompileError]]:
tokens = tokenizer.tokenize(data)
ast, errors, warnings = parser.parse(tokens)

View file

@ -224,11 +224,11 @@ class ParseContext:
if (
len(self.errors)
and isinstance((err := self.errors[-1]), UnexpectedTokenError)
and err.end == start
and err.range.end == start
):
err.end = end
err.range.end = end
else:
self.errors.append(UnexpectedTokenError(start, end))
self.errors.append(UnexpectedTokenError(Range(start, end, self.text)))
def is_eof(self) -> bool:
return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF
@ -281,10 +281,11 @@ class Err(ParseNode):
start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS:
start_idx += 1
start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end)
raise CompileError(
self.message, Range(start_token.start, start_token.start, ctx.text)
)
return True
@ -324,7 +325,9 @@ class Fail(ParseNode):
start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end)
raise CompileError(
self.message, Range.join(start_token.range, end_token.range)
)
return True
@ -373,7 +376,7 @@ class Statement(ParseNode):
token = ctx.peek_token()
if str(token) != ";":
ctx.errors.append(CompileError("Expected `;`", token.start, token.end))
ctx.errors.append(CompileError("Expected `;`", token.range))
else:
ctx.next_token()
return True

View file

@ -26,7 +26,7 @@ from .tokenizer import TokenType
def parse(
tokens: T.List[Token],
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[PrintableError]]:
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[CompileError]]:
"""Parses a list of tokens into an abstract syntax tree."""
try:

View file

@ -23,7 +23,6 @@ import typing as T
from dataclasses import dataclass
from enum import Enum
from .errors import CompileError, CompilerBugError
from . import utils
@ -69,6 +68,8 @@ class Token:
return Range(self.start, self.end, self.string)
def get_number(self) -> T.Union[int, float]:
from .errors import CompileError, CompilerBugError
if self.type != TokenType.NUMBER:
raise CompilerBugError()
@ -81,12 +82,12 @@ class Token:
else:
return int(string)
except:
raise CompileError(
f"{str(self)} is not a valid number literal", self.start, self.end
)
raise CompileError(f"{str(self)} is not a valid number literal", self.range)
def _tokenize(ui_ml: str):
from .errors import CompileError
i = 0
while i < len(ui_ml):
matched = False
@ -101,7 +102,8 @@ def _tokenize(ui_ml: str):
if not matched:
raise CompileError(
"Could not determine what kind of syntax is meant here", i, i
"Could not determine what kind of syntax is meant here",
Range(i, i, ui_ml),
)
yield Token(TokenType.EOF, i, i, ui_ml)
@ -117,6 +119,10 @@ class Range:
end: int
original_text: str
@property
def length(self) -> int:
return self.end - self.start
@property
def text(self) -> str:
return self.original_text[self.start : self.end]
@ -137,3 +143,6 @@ class Range:
def to_json(self):
return utils.idxs_to_range(self.start, self.end, self.original_text)
def overlaps(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)

View file

@ -113,9 +113,9 @@ class TestSamples(unittest.TestCase):
raise MultipleErrors(warnings)
except PrintableError as e:
def error_str(error):
line, col = utils.idx_to_pos(error.start + 1, blueprint)
len = error.end - error.start
def error_str(error: CompileError):
line, col = utils.idx_to_pos(error.range.start + 1, blueprint)
len = error.range.length
return ",".join([str(line + 1), str(col), str(len), error.message])
if isinstance(e, CompileError):