From 3bcc9f4cbd2158fa7b90fb3280dbb1b7e03ddc33 Mon Sep 17 00:00:00 2001 From: James Westman Date: Tue, 25 Jul 2023 20:01:41 -0500 Subject: [PATCH] Use the new Range class in more places --- blueprintcompiler/ast_utils.py | 35 ++++++-------- blueprintcompiler/errors.py | 21 ++++---- blueprintcompiler/language/contexts.py | 3 +- blueprintcompiler/language/ui.py | 3 +- blueprintcompiler/lsp.py | 67 ++++++++++++-------------- blueprintcompiler/main.py | 6 +-- blueprintcompiler/parse_tree.py | 19 +++++--- blueprintcompiler/parser.py | 2 +- blueprintcompiler/tokenizer.py | 19 ++++++-- tests/test_samples.py | 6 +-- 10 files changed, 91 insertions(+), 90 deletions(-) diff --git a/blueprintcompiler/ast_utils.py b/blueprintcompiler/ast_utils.py index 81958aa..56501e7 100644 --- a/blueprintcompiler/ast_utils.py +++ b/blueprintcompiler/ast_utils.py @@ -229,20 +229,23 @@ class AstNode: error, references=[ ErrorReference( - child.group.start, - child.group.end, + child.range, "previous declaration was here", ) ], ) -def validate(token_name=None, end_token_name=None, skip_incomplete=False): +def validate( + token_name: T.Optional[str] = None, + end_token_name: T.Optional[str] = None, + skip_incomplete: bool = False, +): """Decorator for functions that validate an AST node. Exceptions raised during validation are marked with range information from the tokens.""" def decorator(func): - def inner(self): + def inner(self: AstNode): if skip_incomplete and self.incomplete: return @@ -254,22 +257,14 @@ def validate(token_name=None, end_token_name=None, skip_incomplete=False): if self.incomplete: return - # This mess of code sets the error's start and end positions - # from the tokens passed to the decorator, if they have not - # already been set - if e.start is None: - if token := self.group.tokens.get(token_name): - e.start = token.start - else: - e.start = self.group.start - - if e.end is None: - if token := self.group.tokens.get(end_token_name): - e.end = token.end - elif token := self.group.tokens.get(token_name): - e.end = token.end - else: - e.end = self.group.end + if e.range is None: + e.range = ( + Range.join( + self.ranges[token_name], + self.ranges[end_token_name], + ) + or self.range + ) # Re-raise the exception raise e diff --git a/blueprintcompiler/errors.py b/blueprintcompiler/errors.py index e89ec31..773122a 100644 --- a/blueprintcompiler/errors.py +++ b/blueprintcompiler/errors.py @@ -23,6 +23,7 @@ import typing as T from dataclasses import dataclass from . import utils +from .tokenizer import Range from .utils import Colors @@ -36,8 +37,7 @@ class PrintableError(Exception): @dataclass class ErrorReference: - start: int - end: int + range: Range message: str @@ -50,8 +50,7 @@ class CompileError(PrintableError): def __init__( self, message: str, - start: T.Optional[int] = None, - end: T.Optional[int] = None, + range: T.Optional[Range] = None, did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None, hints: T.Optional[T.List[str]] = None, actions: T.Optional[T.List["CodeAction"]] = None, @@ -61,8 +60,7 @@ class CompileError(PrintableError): super().__init__(message) self.message = message - self.start = start - self.end = end + self.range = range self.hints = hints or [] self.actions = actions or [] self.references = references or [] @@ -92,9 +90,9 @@ class CompileError(PrintableError): self.hint("Are your dependencies up to date?") def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None: - assert self.start is not None + assert self.range is not None - line_num, col_num = utils.idx_to_pos(self.start + 1, code) + line_num, col_num = utils.idx_to_pos(self.range.start + 1, code) line = code.splitlines(True)[line_num] # Display 1-based line numbers @@ -110,7 +108,7 @@ at {filename} line {line_num} column {col_num}: stream.write(f"{Colors.FAINT}hint: {hint}{Colors.CLEAR}\n") for ref in self.references: - line_num, col_num = utils.idx_to_pos(ref.start + 1, code) + line_num, col_num = utils.idx_to_pos(ref.range.start + 1, code) line = code.splitlines(True)[line_num] line_num += 1 @@ -138,14 +136,15 @@ class UpgradeWarning(CompileWarning): class UnexpectedTokenError(CompileError): - def __init__(self, start, end) -> None: - super().__init__("Unexpected tokens", start, end) + def __init__(self, range: Range) -> None: + super().__init__("Unexpected tokens", range) @dataclass class CodeAction: title: str replace_with: str + edit_range: T.Optional[Range] = None class MultipleErrors(PrintableError): diff --git a/blueprintcompiler/language/contexts.py b/blueprintcompiler/language/contexts.py index 2f8e22e..c5e97b3 100644 --- a/blueprintcompiler/language/contexts.py +++ b/blueprintcompiler/language/contexts.py @@ -70,8 +70,7 @@ class ScopeCtx: ): raise CompileError( f"Duplicate object ID '{obj.tokens['id']}'", - token.start, - token.end, + token.range, ) passed[obj.tokens["id"]] = obj diff --git a/blueprintcompiler/language/ui.py b/blueprintcompiler/language/ui.py index 1b7e6e9..3ce23da 100644 --- a/blueprintcompiler/language/ui.py +++ b/blueprintcompiler/language/ui.py @@ -62,8 +62,7 @@ class UI(AstNode): else: gir_ctx.not_found_namespaces.add(i.namespace) except CompileError as e: - e.start = i.group.tokens["namespace"].start - e.end = i.group.tokens["version"].end + e.range = i.range self._gir_errors.append(e) return gir_ctx diff --git a/blueprintcompiler/lsp.py b/blueprintcompiler/lsp.py index a7e5f9b..19c02e5 100644 --- a/blueprintcompiler/lsp.py +++ b/blueprintcompiler/lsp.py @@ -24,10 +24,12 @@ import traceback import typing as T from . import decompiler, parser, tokenizer, utils, xml_reader +from .ast_utils import AstNode from .completions import complete -from .errors import CompileError, MultipleErrors, PrintableError +from .errors import CompileError, MultipleErrors from .lsp_utils import * from .outputs.xml import XmlOutput +from .tokenizer import Token def printerr(*args, **kwargs): @@ -43,16 +45,16 @@ def command(json_method: str): class OpenFile: - def __init__(self, uri: str, text: str, version: int): + def __init__(self, uri: str, text: str, version: int) -> None: self.uri = uri self.text = text self.version = version - self.ast = None - self.tokens = None + self.ast: T.Optional[AstNode] = None + self.tokens: T.Optional[list[Token]] = None self._update() - def apply_changes(self, changes): + def apply_changes(self, changes) -> None: for change in changes: if "range" not in change: self.text = change["text"] @@ -70,8 +72,8 @@ class OpenFile: self.text = self.text[:start] + change["text"] + self.text[end:] self._update() - def _update(self): - self.diagnostics = [] + def _update(self) -> None: + self.diagnostics: list[CompileError] = [] try: self.tokens = tokenizer.tokenize(self.text) self.ast, errors, warnings = parser.parse(self.tokens) @@ -327,14 +329,17 @@ class LanguageServer: def code_actions(self, id, params): open_file = self._open_files[params["textDocument"]["uri"]] - range_start = utils.pos_to_idx( - params["range"]["start"]["line"], - params["range"]["start"]["character"], - open_file.text, - ) - range_end = utils.pos_to_idx( - params["range"]["end"]["line"], - params["range"]["end"]["character"], + range = Range( + utils.pos_to_idx( + params["range"]["start"]["line"], + params["range"]["start"]["character"], + open_file.text, + ), + utils.pos_to_idx( + params["range"]["end"]["line"], + params["range"]["end"]["character"], + open_file.text, + ), open_file.text, ) @@ -342,16 +347,14 @@ class LanguageServer: { "title": action.title, "kind": "quickfix", - "diagnostics": [ - self._create_diagnostic(open_file.text, open_file.uri, diagnostic) - ], + "diagnostics": [self._create_diagnostic(open_file.uri, diagnostic)], "edit": { "changes": { open_file.uri: [ { - "range": utils.idxs_to_range( - diagnostic.start, diagnostic.end, open_file.text - ), + "range": action.edit_range.to_json() + if action.edit_range + else diagnostic.range.to_json(), "newText": action.replace_with, } ] @@ -359,7 +362,7 @@ class LanguageServer: }, } for diagnostic in open_file.diagnostics - if not (diagnostic.end < range_start or diagnostic.start > range_end) + if range.overlaps(diagnostic.range) for action in diagnostic.actions ] @@ -374,14 +377,8 @@ class LanguageServer: result = { "name": symbol.name, "kind": symbol.kind, - "range": utils.idxs_to_range( - symbol.range.start, symbol.range.end, open_file.text - ), - "selectionRange": utils.idxs_to_range( - symbol.selection_range.start, - symbol.selection_range.end, - open_file.text, - ), + "range": symbol.range.to_json(), + "selectionRange": symbol.selection_range.to_json(), "children": [to_json(child) for child in symbol.children], } if symbol.detail is not None: @@ -411,22 +408,22 @@ class LanguageServer: { "uri": open_file.uri, "diagnostics": [ - self._create_diagnostic(open_file.text, open_file.uri, err) + self._create_diagnostic(open_file.uri, err) for err in open_file.diagnostics ], }, ) - def _create_diagnostic(self, text: str, uri: str, err: CompileError): + def _create_diagnostic(self, uri: str, err: CompileError): message = err.message - assert err.start is not None and err.end is not None + assert err.range is not None for hint in err.hints: message += "\nhint: " + hint result = { - "range": utils.idxs_to_range(err.start, err.end, text), + "range": err.range.to_json(), "message": message, "severity": DiagnosticSeverity.Warning if isinstance(err, CompileWarning) @@ -441,7 +438,7 @@ class LanguageServer: { "location": { "uri": uri, - "range": utils.idxs_to_range(ref.start, ref.end, text), + "range": ref.range.to_json(), }, "message": ref.message, } diff --git a/blueprintcompiler/main.py b/blueprintcompiler/main.py index db9fb65..416db47 100644 --- a/blueprintcompiler/main.py +++ b/blueprintcompiler/main.py @@ -24,8 +24,8 @@ import os import sys import typing as T -from . import decompiler, interactive_port, parser, tokenizer -from .errors import CompilerBugError, MultipleErrors, PrintableError, report_bug +from . import interactive_port, parser, tokenizer +from .errors import CompilerBugError, CompileError, PrintableError, report_bug from .gir import add_typelib_search_path from .lsp import LanguageServer from .outputs import XmlOutput @@ -157,7 +157,7 @@ class BlueprintApp: def cmd_port(self, opts): interactive_port.run(opts) - def _compile(self, data: str) -> T.Tuple[str, T.List[PrintableError]]: + def _compile(self, data: str) -> T.Tuple[str, T.List[CompileError]]: tokens = tokenizer.tokenize(data) ast, errors, warnings = parser.parse(tokens) diff --git a/blueprintcompiler/parse_tree.py b/blueprintcompiler/parse_tree.py index a8efddb..8f3ef31 100644 --- a/blueprintcompiler/parse_tree.py +++ b/blueprintcompiler/parse_tree.py @@ -224,11 +224,11 @@ class ParseContext: if ( len(self.errors) and isinstance((err := self.errors[-1]), UnexpectedTokenError) - and err.end == start + and err.range.end == start ): - err.end = end + err.range.end = end else: - self.errors.append(UnexpectedTokenError(start, end)) + self.errors.append(UnexpectedTokenError(Range(start, end, self.text))) def is_eof(self) -> bool: return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF @@ -281,10 +281,11 @@ class Err(ParseNode): start_idx = ctx.start while ctx.tokens[start_idx].type in SKIP_TOKENS: start_idx += 1 - start_token = ctx.tokens[start_idx] - end_token = ctx.tokens[ctx.index] - raise CompileError(self.message, start_token.start, end_token.end) + + raise CompileError( + self.message, Range(start_token.start, start_token.start, ctx.text) + ) return True @@ -324,7 +325,9 @@ class Fail(ParseNode): start_token = ctx.tokens[start_idx] end_token = ctx.tokens[ctx.index] - raise CompileError(self.message, start_token.start, end_token.end) + raise CompileError( + self.message, Range.join(start_token.range, end_token.range) + ) return True @@ -373,7 +376,7 @@ class Statement(ParseNode): token = ctx.peek_token() if str(token) != ";": - ctx.errors.append(CompileError("Expected `;`", token.start, token.end)) + ctx.errors.append(CompileError("Expected `;`", token.range)) else: ctx.next_token() return True diff --git a/blueprintcompiler/parser.py b/blueprintcompiler/parser.py index a9cc0ae..89e1533 100644 --- a/blueprintcompiler/parser.py +++ b/blueprintcompiler/parser.py @@ -26,7 +26,7 @@ from .tokenizer import TokenType def parse( tokens: T.List[Token], -) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[PrintableError]]: +) -> T.Tuple[T.Optional[UI], T.Optional[MultipleErrors], T.List[CompileError]]: """Parses a list of tokens into an abstract syntax tree.""" try: diff --git a/blueprintcompiler/tokenizer.py b/blueprintcompiler/tokenizer.py index e1066ac..1ab6def 100644 --- a/blueprintcompiler/tokenizer.py +++ b/blueprintcompiler/tokenizer.py @@ -23,7 +23,6 @@ import typing as T from dataclasses import dataclass from enum import Enum -from .errors import CompileError, CompilerBugError from . import utils @@ -69,6 +68,8 @@ class Token: return Range(self.start, self.end, self.string) def get_number(self) -> T.Union[int, float]: + from .errors import CompileError, CompilerBugError + if self.type != TokenType.NUMBER: raise CompilerBugError() @@ -81,12 +82,12 @@ class Token: else: return int(string) except: - raise CompileError( - f"{str(self)} is not a valid number literal", self.start, self.end - ) + raise CompileError(f"{str(self)} is not a valid number literal", self.range) def _tokenize(ui_ml: str): + from .errors import CompileError + i = 0 while i < len(ui_ml): matched = False @@ -101,7 +102,8 @@ def _tokenize(ui_ml: str): if not matched: raise CompileError( - "Could not determine what kind of syntax is meant here", i, i + "Could not determine what kind of syntax is meant here", + Range(i, i, ui_ml), ) yield Token(TokenType.EOF, i, i, ui_ml) @@ -117,6 +119,10 @@ class Range: end: int original_text: str + @property + def length(self) -> int: + return self.end - self.start + @property def text(self) -> str: return self.original_text[self.start : self.end] @@ -137,3 +143,6 @@ class Range: def to_json(self): return utils.idxs_to_range(self.start, self.end, self.original_text) + + def overlaps(self, other: "Range") -> bool: + return not (self.end < other.start or self.start > other.end) diff --git a/tests/test_samples.py b/tests/test_samples.py index 9ad8cc0..9d891f6 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -113,9 +113,9 @@ class TestSamples(unittest.TestCase): raise MultipleErrors(warnings) except PrintableError as e: - def error_str(error): - line, col = utils.idx_to_pos(error.start + 1, blueprint) - len = error.end - error.start + def error_str(error: CompileError): + line, col = utils.idx_to_pos(error.range.start + 1, blueprint) + len = error.range.length return ",".join([str(line + 1), str(col), str(len), error.message]) if isinstance(e, CompileError):