Use the new Range class in more places

This commit is contained in:
James Westman 2023-07-25 20:01:41 -05:00
parent 56274d7c1f
commit 3bcc9f4cbd
10 changed files with 91 additions and 90 deletions

View file

@ -229,20 +229,23 @@ class AstNode:
error, error,
references=[ references=[
ErrorReference( ErrorReference(
child.group.start, child.range,
child.group.end,
"previous declaration was here", "previous declaration was here",
) )
], ],
) )
def validate(token_name=None, end_token_name=None, skip_incomplete=False): def validate(
token_name: T.Optional[str] = None,
end_token_name: T.Optional[str] = None,
skip_incomplete: bool = False,
):
"""Decorator for functions that validate an AST node. Exceptions raised """Decorator for functions that validate an AST node. Exceptions raised
during validation are marked with range information from the tokens.""" during validation are marked with range information from the tokens."""
def decorator(func): def decorator(func):
def inner(self): def inner(self: AstNode):
if skip_incomplete and self.incomplete: if skip_incomplete and self.incomplete:
return return
@ -254,22 +257,14 @@ def validate(token_name=None, end_token_name=None, skip_incomplete=False):
if self.incomplete: if self.incomplete:
return return
# This mess of code sets the error's start and end positions if e.range is None:
# from the tokens passed to the decorator, if they have not e.range = (
# already been set Range.join(
if e.start is None: self.ranges[token_name],
if token := self.group.tokens.get(token_name): self.ranges[end_token_name],
e.start = token.start )
else: or self.range
e.start = self.group.start )
if e.end is None:
if token := self.group.tokens.get(end_token_name):
e.end = token.end
elif token := self.group.tokens.get(token_name):
e.end = token.end
else:
e.end = self.group.end
# Re-raise the exception # Re-raise the exception
raise e raise e

View file

@ -23,6 +23,7 @@ import typing as T
from dataclasses import dataclass from dataclasses import dataclass
from . import utils from . import utils
from .tokenizer import Range
from .utils import Colors from .utils import Colors
@ -36,8 +37,7 @@ class PrintableError(Exception):
@dataclass @dataclass
class ErrorReference: class ErrorReference:
start: int range: Range
end: int
message: str message: str
@ -50,8 +50,7 @@ class CompileError(PrintableError):
def __init__( def __init__(
self, self,
message: str, message: str,
start: T.Optional[int] = None, range: T.Optional[Range] = None,
end: T.Optional[int] = None,
did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None, did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None,
hints: T.Optional[T.List[str]] = None, hints: T.Optional[T.List[str]] = None,
actions: T.Optional[T.List["CodeAction"]] = None, actions: T.Optional[T.List["CodeAction"]] = None,
@ -61,8 +60,7 @@ class CompileError(PrintableError):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
self.start = start self.range = range
self.end = end
self.hints = hints or [] self.hints = hints or []
self.actions = actions or [] self.actions = actions or []
self.references = references or [] self.references = references or []
@ -92,9 +90,9 @@ class CompileError(PrintableError):
self.hint("Are your dependencies up to date?") self.hint("Are your dependencies up to date?")
def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None: def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None:
assert self.start is not None assert self.range is not None
line_num, col_num = utils.idx_to_pos(self.start + 1, code) line_num, col_num = utils.idx_to_pos(self.range.start + 1, code)
line = code.splitlines(True)[line_num] line = code.splitlines(True)[line_num]
# Display 1-based line numbers # Display 1-based line numbers
@ -110,7 +108,7 @@ at {filename} line {line_num} column {col_num}:
stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n") stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n")
for ref in self.references: for ref in self.references:
line_num, col_num = utils.idx_to_pos(ref.start + 1, code) line_num, col_num = utils.idx_to_pos(ref.range.start + 1, code)
line = code.splitlines(True)[line_num] line = code.splitlines(True)[line_num]
line_num += 1 line_num += 1
@ -138,14 +136,15 @@ class UpgradeWarning(CompileWarning):
class UnexpectedTokenError(CompileError): class UnexpectedTokenError(CompileError):
def __init__(self, start, end) -> None: def __init__(self, range: Range) -> None:
super().__init__("Unexpected tokens", start, end) super().__init__("Unexpected tokens", range)
@dataclass @dataclass
class CodeAction: class CodeAction:
title: str title: str
replace_with: str replace_with: str
edit_range: T.Optional[Range] = None
class MultipleErrors(PrintableError): class MultipleErrors(PrintableError):

View file

@ -70,8 +70,7 @@ class ScopeCtx:
): ):
raise CompileError( raise CompileError(
f"Duplicate object ID '{obj.tokens['id']}'", f"Duplicate object ID '{obj.tokens['id']}'",
token.start, token.range,
token.end,
) )
passed[obj.tokens["id"]] = obj passed[obj.tokens["id"]] = obj

View file

@ -62,8 +62,7 @@ class UI(AstNode):
else: else:
gir_ctx.not_found_namespaces.add(i.namespace) gir_ctx.not_found_namespaces.add(i.namespace)
except CompileError as e: except CompileError as e:
e.start = i.group.tokens["namespace"].start e.range = i.range
e.end = i.group.tokens["version"].end
self._gir_errors.append(e) self._gir_errors.append(e)
return gir_ctx return gir_ctx

View file

@ -24,10 +24,12 @@ import traceback
import typing as T import typing as T
from . import decompiler, parser, tokenizer, utils, xml_reader from . import decompiler, parser, tokenizer, utils, xml_reader
from .ast_utils import AstNode
from .completions import complete from .completions import complete
from .errors import CompileError, MultipleErrors, PrintableError from .errors import CompileError, MultipleErrors
from .lsp_utils import * from .lsp_utils import *
from .outputs.xml import XmlOutput from .outputs.xml import XmlOutput
from .tokenizer import Token
def printerr(*args, **kwargs): def printerr(*args, **kwargs):
@ -43,16 +45,16 @@ def command(json_method: str):
class OpenFile: class OpenFile:
def __init__(self, uri: str, text: str, version: int): def __init__(self, uri: str, text: str, version: int) -> None:
self.uri = uri self.uri = uri
self.text = text self.text = text
self.version = version self.version = version
self.ast = None self.ast: T.Optional[AstNode] = None
self.tokens = None self.tokens: T.Optional[list[Token]] = None
self._update() self._update()
def apply_changes(self, changes): def apply_changes(self, changes) -> None:
for change in changes: for change in changes:
if "range" not in change: if "range" not in change:
self.text = change["text"] self.text = change["text"]
@ -70,8 +72,8 @@ class OpenFile:
self.text = self.text[:start] + change["text"] + self.text[end:] self.text = self.text[:start] + change["text"] + self.text[end:]
self._update() self._update()
def _update(self): def _update(self) -> None:
self.diagnostics = [] self.diagnostics: list[CompileError] = []
try: try:
self.tokens = tokenizer.tokenize(self.text) self.tokens = tokenizer.tokenize(self.text)
self.ast, errors, warnings = parser.parse(self.tokens) self.ast, errors, warnings = parser.parse(self.tokens)
@ -327,31 +329,32 @@ class LanguageServer:
def code_actions(self, id, params): def code_actions(self, id, params):
open_file = self._open_files[params["textDocument"]["uri"]] open_file = self._open_files[params["textDocument"]["uri"]]
range_start = utils.pos_to_idx( range = Range(
utils.pos_to_idx(
params["range"]["start"]["line"], params["range"]["start"]["line"],
params["range"]["start"]["character"], params["range"]["start"]["character"],
open_file.text, open_file.text,
) ),
range_end = utils.pos_to_idx( utils.pos_to_idx(
params["range"]["end"]["line"], params["range"]["end"]["line"],
params["range"]["end"]["character"], params["range"]["end"]["character"],
open_file.text, open_file.text,
),
open_file.text,
) )
actions = [ actions = [
{ {
"title": action.title, "title": action.title,
"kind": "quickfix", "kind": "quickfix",
"diagnostics": [ "diagnostics": [self._create_diagnostic(open_file.uri, diagnostic)],
self._create_diagnostic(open_file.text, open_file.uri, diagnostic)
],
"edit": { "edit": {
"changes": { "changes": {
open_file.uri: [ open_file.uri: [
{ {
"range": utils.idxs_to_range( "range": action.edit_range.to_json()
diagnostic.start, diagnostic.end, open_file.text if action.edit_range
), else diagnostic.range.to_json(),
"newText": action.replace_with, "newText": action.replace_with,
} }
] ]
@ -359,7 +362,7 @@ class LanguageServer:
}, },
} }
for diagnostic in open_file.diagnostics for diagnostic in open_file.diagnostics
if not (diagnostic.end < range_start or diagnostic.start > range_end) if range.overlaps(diagnostic.range)
for action in diagnostic.actions for action in diagnostic.actions
] ]
@ -374,14 +377,8 @@ class LanguageServer:
result = { result = {
"name": symbol.name, "name": symbol.name,
"kind": symbol.kind, "kind": symbol.kind,
"range": utils.idxs_to_range( "range": symbol.range.to_json(),
symbol.range.start, symbol.range.end, open_file.text "selectionRange": symbol.selection_range.to_json(),
),
"selectionRange": utils.idxs_to_range(
symbol.selection_range.start,
symbol.selection_range.end,
open_file.text,
),
"children": [to_json(child) for child in symbol.children], "children": [to_json(child) for child in symbol.children],
} }
if symbol.detail is not None: if symbol.detail is not None:
@ -411,22 +408,22 @@ class LanguageServer:
{ {
"uri": open_file.uri, "uri": open_file.uri,
"diagnostics": [ "diagnostics": [
self._create_diagnostic(open_file.text, open_file.uri, err) self._create_diagnostic(open_file.uri, err)
for err in open_file.diagnostics for err in open_file.diagnostics
], ],
}, },
) )
def _create_diagnostic(self, text: str, uri: str, err: CompileError): def _create_diagnostic(self, uri: str, err: CompileError):
message = err.message message = err.message
assert err.start is not None and err.end is not None assert err.range is not None
for hint in err.hints: for hint in err.hints:
message += "\nhint: " + hint message += "\nhint: " + hint
result = { result = {
"range": utils.idxs_to_range(err.start, err.end, text), "range": err.range.to_json(),
"message": message, "message": message,
"severity": DiagnosticSeverity.Warning "severity": DiagnosticSeverity.Warning
if isinstance(err, CompileWarning) if isinstance(err, CompileWarning)
@ -441,7 +438,7 @@ class LanguageServer:
{ {
"location": { "location": {
"uri": uri, "uri": uri,
"range": utils.idxs_to_range(ref.start, ref.end, text), "range": ref.range.to_json(),
}, },
"message": ref.message, "message": ref.message,
} }

View file

@ -24,8 +24,8 @@ import os
import sys import sys
import typing as T import typing as T
from . import decompiler, interactive_port, parser, tokenizer from . import interactive_port, parser, tokenizer
from .errors import CompilerBugError, MultipleErrors, PrintableError, report_bug from .errors import CompilerBugError, CompileError, PrintableError, report_bug
from .gir import add_typelib_search_path from .gir import add_typelib_search_path
from .lsp import LanguageServer from .lsp import LanguageServer
from .outputs import XmlOutput from .outputs import XmlOutput
@ -157,7 +157,7 @@ class BlueprintApp:
def cmd_port(self, opts): def cmd_port(self, opts):
interactive_port.run(opts) interactive_port.run(opts)
def _compile(self, data: str) -> T.Tuple[str, T.List[PrintableError]]: def _compile(self, data: str) -> T.Tuple[str, T.List[CompileError]]:
tokens = tokenizer.tokenize(data) tokens = tokenizer.tokenize(data)
ast, errors, warnings = parser.parse(tokens) ast, errors, warnings = parser.parse(tokens)

View file

@ -224,11 +224,11 @@ class ParseContext:
if ( if (
len(self.errors) len(self.errors)
and isinstance((err := self.errors[-1]), UnexpectedTokenError) and isinstance((err := self.errors[-1]), UnexpectedTokenError)
and err.end == start and err.range.end == start
): ):
err.end = end err.range.end = end
else: else:
self.errors.append(UnexpectedTokenError(start, end)) self.errors.append(UnexpectedTokenError(Range(start, end, self.text)))
def is_eof(self) -> bool: def is_eof(self) -> bool:
return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF
@ -281,10 +281,11 @@ class Err(ParseNode):
start_idx = ctx.start start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS: while ctx.tokens[start_idx].type in SKIP_TOKENS:
start_idx += 1 start_idx += 1
start_token = ctx.tokens[start_idx] start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end) raise CompileError(
self.message, Range(start_token.start, start_token.start, ctx.text)
)
return True return True
@ -324,7 +325,9 @@ class Fail(ParseNode):
start_token = ctx.tokens[start_idx] start_token = ctx.tokens[start_idx]
end_token = ctx.tokens[ctx.index] end_token = ctx.tokens[ctx.index]
raise CompileError(self.message, start_token.start, end_token.end) raise CompileError(
self.message, Range.join(start_token.range, end_token.range)
)
return True return True
@ -373,7 +376,7 @@ class Statement(ParseNode):
token = ctx.peek_token() token = ctx.peek_token()
if str(token) != ";": if str(token) != ";":
ctx.errors.append(CompileError("Expected `;`", token.start, token.end)) ctx.errors.append(CompileError("Expected `;`", token.range))
else: else:
ctx.next_token() ctx.next_token()
return True return True

View file

@ -26,7 +26,7 @@ from .tokenizer import TokenType
def parse( def parse(
tokens: T.List[Token], tokens: T.List[Token],
) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[PrintableError]]: ) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[CompileError]]:
"""Parses a list of tokens into an abstract syntax tree.""" """Parses a list of tokens into an abstract syntax tree."""
try: try:

View file

@ -23,7 +23,6 @@ import typing as T
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from .errors import CompileError, CompilerBugError
from . import utils from . import utils
@ -69,6 +68,8 @@ class Token:
return Range(self.start, self.end, self.string) return Range(self.start, self.end, self.string)
def get_number(self) -> T.Union[int, float]: def get_number(self) -> T.Union[int, float]:
from .errors import CompileError, CompilerBugError
if self.type != TokenType.NUMBER: if self.type != TokenType.NUMBER:
raise CompilerBugError() raise CompilerBugError()
@ -81,12 +82,12 @@ class Token:
else: else:
return int(string) return int(string)
except: except:
raise CompileError( raise CompileError(f"{str(self)} is not a valid number literal", self.range)
f"{str(self)} is not a valid number literal", self.start, self.end
)
def _tokenize(ui_ml: str): def _tokenize(ui_ml: str):
from .errors import CompileError
i = 0 i = 0
while i < len(ui_ml): while i < len(ui_ml):
matched = False matched = False
@ -101,7 +102,8 @@ def _tokenize(ui_ml: str):
if not matched: if not matched:
raise CompileError( raise CompileError(
"Could not determine what kind of syntax is meant here", i, i "Could not determine what kind of syntax is meant here",
Range(i, i, ui_ml),
) )
yield Token(TokenType.EOF, i, i, ui_ml) yield Token(TokenType.EOF, i, i, ui_ml)
@ -117,6 +119,10 @@ class Range:
end: int end: int
original_text: str original_text: str
@property
def length(self) -> int:
return self.end - self.start
@property @property
def text(self) -> str: def text(self) -> str:
return self.original_text[self.start : self.end] return self.original_text[self.start : self.end]
@ -137,3 +143,6 @@ class Range:
def to_json(self): def to_json(self):
return utils.idxs_to_range(self.start, self.end, self.original_text) return utils.idxs_to_range(self.start, self.end, self.original_text)
def overlaps(self, other: "Range") -> bool:
return not (self.end < other.start or self.start > other.end)

View file

@ -113,9 +113,9 @@ class TestSamples(unittest.TestCase):
raise MultipleErrors(warnings) raise MultipleErrors(warnings)
except PrintableError as e: except PrintableError as e:
def error_str(error): def error_str(error: CompileError):
line, col = utils.idx_to_pos(error.start + 1, blueprint) line, col = utils.idx_to_pos(error.range.start + 1, blueprint)
len = error.end - error.start len = error.range.length
return ",".join([str(line + 1), str(col), str(len), error.message]) return ",".join([str(line + 1), str(col), str(len), error.message])
if isinstance(e, CompileError): if isinstance(e, CompileError):