From 0b7dbaf90d17073dd13249393b65dd130c9e5c58 Mon Sep 17 00:00:00 2001 From: James Westman Date: Sun, 25 Dec 2022 18:32:23 -0600 Subject: [PATCH] Add some type hints --- blueprintcompiler/decompiler.py | 42 ++++--- blueprintcompiler/errors.py | 30 ++--- blueprintcompiler/gir.py | 166 +++++++++++++------------- blueprintcompiler/interactive_port.py | 2 +- blueprintcompiler/lsp.py | 11 +- blueprintcompiler/main.py | 2 +- blueprintcompiler/parse_tree.py | 71 +++++------ blueprintcompiler/tokenizer.py | 10 +- blueprintcompiler/typelib.py | 26 ++-- blueprintcompiler/xml_reader.py | 6 +- 10 files changed, 193 insertions(+), 173 deletions(-) diff --git a/blueprintcompiler/decompiler.py b/blueprintcompiler/decompiler.py index 565d420..c068c93 100644 --- a/blueprintcompiler/decompiler.py +++ b/blueprintcompiler/decompiler.py @@ -51,17 +51,17 @@ class LineType(Enum): class DecompileCtx: - def __init__(self): - self._result = "" + def __init__(self) -> None: + self._result: str = "" self.gir = GirContext() - self._indent = 0 - self._blocks_need_end = [] - self._last_line_type = LineType.NONE + self._indent: int = 0 + self._blocks_need_end: T.List[str] = [] + self._last_line_type: LineType = LineType.NONE self.gir.add_namespace(get_namespace("Gtk", "4.0")) @property - def result(self): + def result(self) -> str: imports = "\n".join( [ f"using {ns} {namespace.version};" @@ -70,7 +70,7 @@ class DecompileCtx: ) return imports + "\n" + self._result - def type_by_cname(self, cname): + def type_by_cname(self, cname: str) -> T.Optional[GirType]: if type := self.gir.get_type_by_cname(cname): return type @@ -83,17 +83,19 @@ class DecompileCtx: except: pass - def start_block(self): - self._blocks_need_end.append(None) + return None - def end_block(self): + def start_block(self) -> None: + self._blocks_need_end.append("") + + def end_block(self) -> None: if close := self._blocks_need_end.pop(): self.print(close) - def end_block_with(self, text): + def end_block_with(self, text: str) -> None: self._blocks_need_end[-1] = text - def print(self, line, newline=True): + def print(self, line: str, newline: bool = True) -> None: if line == "}" or line == "]": self._indent -= 1 @@ -124,7 +126,7 @@ class DecompileCtx: self._blocks_need_end[-1] = _CLOSING[line[-1]] self._indent += 1 - def print_attribute(self, name, value, type): + def print_attribute(self, name: str, value: str, type: GirType) -> None: def get_enum_name(value): for member in type.members.values(): if ( @@ -169,13 +171,17 @@ class DecompileCtx: self.print(f'{name}: "{escape_quote(value)}";') -def _decompile_element(ctx: DecompileCtx, gir, xml): +def _decompile_element( + ctx: DecompileCtx, gir: T.Optional[GirContext], xml: Element +) -> None: try: decompiler = _DECOMPILERS.get(xml.tag) if decompiler is None: raise UnsupportedError(f"unsupported XML tag: <{xml.tag}>") - args = {canon(name): value for name, value in xml.attrs.items()} + args: T.Dict[str, T.Optional[str]] = { + canon(name): value for name, value in xml.attrs.items() + } if decompiler._cdata: if len(xml.children): args["cdata"] = None @@ -196,7 +202,7 @@ def _decompile_element(ctx: DecompileCtx, gir, xml): raise UnsupportedError(tag=xml.tag) -def decompile(data): +def decompile(data: str) -> str: ctx = DecompileCtx() xml = parse(data) @@ -216,11 +222,11 @@ def truthy(string: str) -> bool: return string.lower() in ["yes", "true", "t", "y", "1"] -def full_name(gir): +def full_name(gir) -> str: return gir.name if gir.full_name.startswith("Gtk.") else gir.full_name -def lookup_by_cname(gir, cname: str): +def lookup_by_cname(gir, cname: str) -> T.Optional[GirType]: if isinstance(gir, GirContext): return gir.get_type_by_cname(cname) else: diff --git a/blueprintcompiler/errors.py b/blueprintcompiler/errors.py index 01f9066..4a18589 100644 --- a/blueprintcompiler/errors.py +++ b/blueprintcompiler/errors.py @@ -47,15 +47,15 @@ class CompileError(PrintableError): def __init__( self, - message, - start=None, - end=None, - did_you_mean=None, - hints=None, - actions=None, - fatal=False, - references=None, - ): + message: str, + start: T.Optional[int] = None, + end: T.Optional[int] = None, + did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None, + hints: T.Optional[T.List[str]] = None, + actions: T.Optional[T.List["CodeAction"]] = None, + fatal: bool = False, + references: T.Optional[T.List[ErrorReference]] = None, + ) -> None: super().__init__(message) self.message = message @@ -69,11 +69,11 @@ class CompileError(PrintableError): if did_you_mean is not None: self._did_you_mean(*did_you_mean) - def hint(self, hint: str): + def hint(self, hint: str) -> "CompileError": self.hints.append(hint) return self - def _did_you_mean(self, word: str, options: T.List[str]): + def _did_you_mean(self, word: str, options: T.List[str]) -> None: if word.replace("_", "-") in options: self.hint(f"use '-', not '_': `{word.replace('_', '-')}`") return @@ -89,7 +89,9 @@ class CompileError(PrintableError): self.hint("Did you check your spelling?") self.hint("Are your dependencies up to date?") - def pretty_print(self, filename, code, stream=sys.stdout): + def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None: + assert self.start is not None + line_num, col_num = utils.idx_to_pos(self.start + 1, code) line = code.splitlines(True)[line_num] @@ -130,7 +132,7 @@ class UpgradeWarning(CompileWarning): class UnexpectedTokenError(CompileError): - def __init__(self, start, end): + def __init__(self, start, end) -> None: super().__init__("Unexpected tokens", start, end) @@ -145,7 +147,7 @@ class MultipleErrors(PrintableError): a list and re-thrown using the MultipleErrors exception. It will pretty-print all of the errors and a count of how many errors there are.""" - def __init__(self, errors: T.List[CompileError]): + def __init__(self, errors: T.List[CompileError]) -> None: super().__init__() self.errors = errors diff --git a/blueprintcompiler/gir.py b/blueprintcompiler/gir.py index 4ee2b1e..70d8e89 100644 --- a/blueprintcompiler/gir.py +++ b/blueprintcompiler/gir.py @@ -33,7 +33,7 @@ _namespace_cache: T.Dict[str, "Namespace"] = {} _xml_cache = {} -def get_namespace(namespace, version) -> "Namespace": +def get_namespace(namespace: str, version: str) -> "Namespace": search_paths = GIRepository.Repository.get_search_path() filename = f"{namespace}-{version}.typelib" @@ -58,10 +58,7 @@ def get_namespace(namespace, version) -> "Namespace": return _namespace_cache[filename] -def get_xml(namespace, version): - from .main import VERSION - from xml.etree import ElementTree - +def get_xml(namespace: str, version: str): search_paths = [] if data_paths := os.environ.get("XDG_DATA_DIRS"): @@ -90,12 +87,17 @@ def get_xml(namespace, version): class GirType: @property - def doc(self): + def doc(self) -> T.Optional[str]: return None def assignable_to(self, other: "GirType") -> bool: raise NotImplementedError() + @property + def name(self) -> str: + """The GIR name of the type, not including the namespace""" + raise NotImplementedError() + @property def full_name(self) -> str: """The GIR name of the type to use in diagnostics""" @@ -108,7 +110,7 @@ class GirType: class UncheckedType(GirType): - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: super().__init__() self._name = name @@ -136,7 +138,7 @@ class BoolType(BasicType): name = "bool" glib_type_name: str = "gboolean" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return isinstance(other, BoolType) @@ -144,7 +146,7 @@ class IntType(BasicType): name = "int" glib_type_name: str = "gint" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return ( isinstance(other, IntType) or isinstance(other, UIntType) @@ -156,7 +158,7 @@ class UIntType(BasicType): name = "uint" glib_type_name: str = "guint" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return ( isinstance(other, IntType) or isinstance(other, UIntType) @@ -168,7 +170,7 @@ class FloatType(BasicType): name = "float" glib_type_name: str = "gfloat" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return isinstance(other, FloatType) @@ -176,7 +178,7 @@ class StringType(BasicType): name = "string" glib_type_name: str = "gchararray" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return isinstance(other, StringType) @@ -184,7 +186,7 @@ class TypeType(BasicType): name = "GType" glib_type_name: str = "GType" - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: return isinstance(other, TypeType) @@ -208,14 +210,17 @@ _BASIC_TYPES = { } +TNode = T.TypeVar("TNode", bound="GirNode") + + class GirNode: - def __init__(self, container, tl): + def __init__(self, container: T.Optional["GirNode"], tl: typelib.Typelib) -> None: self.container = container self.tl = tl - def get_containing(self, container_type): + def get_containing(self, container_type: T.Type[TNode]) -> TNode: if self.container is None: - return None + raise CompilerBugError() elif isinstance(self.container, container_type): return self.container else: @@ -228,11 +233,11 @@ class GirNode: return el @cached_property - def glib_type_name(self): + def glib_type_name(self) -> str: return self.tl.OBJ_GTYPE_NAME @cached_property - def full_name(self): + def full_name(self) -> str: if self.container is None: return self.name else: @@ -273,20 +278,16 @@ class GirNode: return None @property - def type_name(self): - return self.type.name - - @property - def type(self): + def type(self) -> GirType: raise NotImplementedError() class Property(GirNode): - def __init__(self, klass, tl: typelib.Typelib): + def __init__(self, klass: T.Union["Class", "Interface"], tl: typelib.Typelib): super().__init__(klass, tl) @cached_property - def name(self): + def name(self) -> str: return self.tl.PROP_NAME @cached_property @@ -295,24 +296,26 @@ class Property(GirNode): @cached_property def signature(self): - return f"{self.type_name} {self.container.name}.{self.name}" + return f"{self.full_name} {self.container.name}.{self.name}" @property - def writable(self): + def writable(self) -> bool: return self.tl.PROP_WRITABLE == 1 @property - def construct_only(self): + def construct_only(self) -> bool: return self.tl.PROP_CONSTRUCT_ONLY == 1 class Parameter(GirNode): - def __init__(self, container: GirNode, tl: typelib.Typelib): + def __init__(self, container: GirNode, tl: typelib.Typelib) -> None: super().__init__(container, tl) class Signal(GirNode): - def __init__(self, klass, tl: typelib.Typelib): + def __init__( + self, klass: T.Union["Class", "Interface"], tl: typelib.Typelib + ) -> None: super().__init__(klass, tl) # if parameters := xml.get_elements('parameters'): # self.params = [Parameter(self, child) for child in parameters[0].get_elements('parameter')] @@ -328,11 +331,11 @@ class Signal(GirNode): class Interface(GirNode, GirType): - def __init__(self, ns, tl: typelib.Typelib): + def __init__(self, ns: "Namespace", tl: typelib.Typelib): super().__init__(ns, tl) @cached_property - def properties(self): + def properties(self) -> T.Mapping[str, Property]: n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES offset = self.tl.header.HEADER_INTERFACE_BLOB_SIZE offset += (n_prerequisites + n_prerequisites % 2) * 2 @@ -345,7 +348,7 @@ class Interface(GirNode, GirType): return result @cached_property - def signals(self): + def signals(self) -> T.Mapping[str, Signal]: n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES offset = self.tl.header.HEADER_INTERFACE_BLOB_SIZE offset += (n_prerequisites + n_prerequisites % 2) * 2 @@ -362,7 +365,7 @@ class Interface(GirNode, GirType): return result @cached_property - def prerequisites(self): + def prerequisites(self) -> T.List["Interface"]: n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES result = [] for i in range(n_prerequisites): @@ -370,7 +373,7 @@ class Interface(GirNode, GirType): result.append(self.get_containing(Repository)._resolve_dir_entry(entry)) return result - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: if self == other: return True for pre in self.prerequisites: @@ -380,15 +383,15 @@ class Interface(GirNode, GirType): class Class(GirNode, GirType): - def __init__(self, ns, tl: typelib.Typelib): + def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None: super().__init__(ns, tl) @property - def abstract(self): + def abstract(self) -> bool: return self.tl.OBJ_ABSTRACT == 1 @cached_property - def implements(self): + def implements(self) -> T.List[Interface]: n_interfaces = self.tl.OBJ_N_INTERFACES result = [] for i in range(n_interfaces): @@ -397,7 +400,7 @@ class Class(GirNode, GirType): return result @cached_property - def own_properties(self): + def own_properties(self) -> T.Mapping[str, Property]: n_interfaces = self.tl.OBJ_N_INTERFACES offset = self.tl.header.HEADER_OBJECT_BLOB_SIZE offset += (n_interfaces + n_interfaces % 2) * 2 @@ -414,7 +417,7 @@ class Class(GirNode, GirType): return result @cached_property - def own_signals(self): + def own_signals(self) -> T.Mapping[str, Signal]: n_interfaces = self.tl.OBJ_N_INTERFACES offset = self.tl.header.HEADER_OBJECT_BLOB_SIZE offset += (n_interfaces + n_interfaces % 2) * 2 @@ -433,16 +436,18 @@ class Class(GirNode, GirType): return result @cached_property - def parent(self): + def parent(self) -> T.Optional["Class"]: if entry := self.tl.OBJ_PARENT: return self.get_containing(Repository)._resolve_dir_entry(entry) else: return None @cached_property - def signature(self): + def signature(self) -> str: + assert self.container is not None result = f"class {self.container.name}.{self.name}" if self.parent is not None: + assert self.parent.container is not None result += f" : {self.parent.container.name}.{self.parent.name}" if len(self.implements): result += " implements " + ", ".join( @@ -451,14 +456,14 @@ class Class(GirNode, GirType): return result @cached_property - def properties(self): + def properties(self) -> T.Mapping[str, Property]: return {p.name: p for p in self._enum_properties()} @cached_property - def signals(self): + def signals(self) -> T.Mapping[str, Signal]: return {s.name: s for s in self._enum_signals()} - def assignable_to(self, other) -> bool: + def assignable_to(self, other: GirType) -> bool: if self == other: return True elif self.parent and self.parent.assignable_to(other): @@ -470,7 +475,7 @@ class Class(GirNode, GirType): return False - def _enum_properties(self): + def _enum_properties(self) -> T.Iterable[Property]: yield from self.own_properties.values() if self.parent is not None: @@ -479,7 +484,7 @@ class Class(GirNode, GirType): for impl in self.implements: yield from impl.properties.values() - def _enum_signals(self): + def _enum_signals(self) -> T.Iterable[Signal]: yield from self.own_signals.values() if self.parent is not None: @@ -490,8 +495,8 @@ class Class(GirNode, GirType): class EnumMember(GirNode): - def __init__(self, ns, tl: typelib.Typelib): - super().__init__(ns, tl) + def __init__(self, enum: "Enumeration", tl: typelib.Typelib) -> None: + super().__init__(enum, tl) @property def value(self) -> int: @@ -502,20 +507,20 @@ class EnumMember(GirNode): return self.tl.VALUE_NAME @cached_property - def nick(self): + def nick(self) -> str: return self.name.replace("_", "-") @property - def c_ident(self): + def c_ident(self) -> str: return self.tl.attr("c:identifier") @property - def signature(self): + def signature(self) -> str: return f"enum member {self.full_name} = {self.value}" class Enumeration(GirNode, GirType): - def __init__(self, ns, tl: typelib.Typelib): + def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None: super().__init__(ns, tl) @cached_property @@ -530,43 +535,43 @@ class Enumeration(GirNode, GirType): return members @property - def signature(self): + def signature(self) -> str: return f"enum {self.full_name}" - def assignable_to(self, type): + def assignable_to(self, type: GirType) -> bool: return type == self class Boxed(GirNode, GirType): - def __init__(self, ns, tl: typelib.Typelib): + def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None: super().__init__(ns, tl) @property - def signature(self): + def signature(self) -> str: return f"boxed {self.full_name}" - def assignable_to(self, type): + def assignable_to(self, type) -> bool: return type == self class Bitfield(Enumeration): - def __init__(self, ns, tl: typelib.Typelib): + def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None: super().__init__(ns, tl) class Namespace(GirNode): - def __init__(self, repo, tl: typelib.Typelib): + def __init__(self, repo: "Repository", tl: typelib.Typelib) -> None: super().__init__(repo, tl) - self.entries: T.Dict[str, GirNode] = {} + self.entries: T.Dict[str, GirType] = {} - n_local_entries = tl.HEADER_N_ENTRIES - directory = tl.HEADER_DIRECTORY + n_local_entries: int = tl.HEADER_N_ENTRIES + directory: typelib.Typelib = tl.HEADER_DIRECTORY for i in range(n_local_entries): entry = directory[i * tl.HEADER_ENTRY_BLOB_SIZE] - entry_name = entry.DIR_ENTRY_NAME - entry_type = entry.DIR_ENTRY_BLOB_TYPE - entry_blob = entry.DIR_ENTRY_OFFSET + entry_name: str = entry.DIR_ENTRY_NAME + entry_type: int = entry.DIR_ENTRY_BLOB_TYPE + entry_blob: typelib.Typelib = entry.DIR_ENTRY_OFFSET if entry_type == typelib.BLOB_TYPE_ENUM: self.entries[entry_name] = Enumeration(self, entry_blob) @@ -595,11 +600,11 @@ class Namespace(GirNode): return self.tl.HEADER_NSVERSION @property - def signature(self): + def signature(self) -> str: return f"namespace {self.name} {self.version}" @cached_property - def classes(self): + def classes(self) -> T.Mapping[str, Class]: return { name: entry for name, entry in self.entries.items() @@ -607,24 +612,25 @@ class Namespace(GirNode): } @cached_property - def interfaces(self): + def interfaces(self) -> T.Mapping[str, Interface]: return { name: entry for name, entry in self.entries.items() if isinstance(entry, Interface) } - def get_type(self, name): + def get_type(self, name) -> T.Optional[GirType]: """Gets a type (class, interface, enum, etc.) from this namespace.""" return self.entries.get(name) - def get_type_by_cname(self, cname: str): + def get_type_by_cname(self, cname: str) -> T.Optional[GirType]: """Gets a type from this namespace by its C name.""" for item in self.entries.values(): if hasattr(item, "cname") and item.cname == cname: return item + return None - def lookup_type(self, type_name: str): + def lookup_type(self, type_name: str) -> T.Optional[GirType]: """Looks up a type in the scope of this namespace (including in the namespace's dependencies).""" @@ -638,7 +644,7 @@ class Namespace(GirNode): class Repository(GirNode): - def __init__(self, tl: typelib.Typelib): + def __init__(self, tl: typelib.Typelib) -> None: super().__init__(None, tl) self.namespace = Namespace(self, tl) @@ -654,10 +660,10 @@ class Repository(GirNode): else: self.includes = {} - def get_type(self, name: str, ns: str) -> T.Optional[GirNode]: + def get_type(self, name: str, ns: str) -> T.Optional[GirType]: return self.lookup_namespace(ns).get_type(name) - def get_type_by_cname(self, name: str) -> T.Optional[GirNode]: + def get_type_by_cname(self, name: str) -> T.Optional[GirType]: for ns in [self.namespace, *self.includes.values()]: if type := ns.get_type_by_cname(name): return type @@ -679,7 +685,7 @@ class Repository(GirNode): ns = dir_entry.DIR_ENTRY_NAMESPACE return self.lookup_namespace(ns).get_type(dir_entry.DIR_ENTRY_NAME) - def _resolve_type_id(self, type_id: int): + def _resolve_type_id(self, type_id: int) -> GirType: if type_id & 0xFFFFFF == 0: type_id = (type_id >> 27) & 0x1F # simple type @@ -726,13 +732,13 @@ class GirContext: self.namespaces[namespace.name] = namespace - def get_type_by_cname(self, name: str) -> T.Optional[GirNode]: + def get_type_by_cname(self, name: str) -> T.Optional[GirType]: for ns in self.namespaces.values(): if type := ns.get_type_by_cname(name): return type return None - def get_type(self, name: str, ns: str) -> T.Optional[GirNode]: + def get_type(self, name: str, ns: str) -> T.Optional[GirType]: if ns is None and name in _BASIC_TYPES: return _BASIC_TYPES[name]() @@ -750,7 +756,7 @@ class GirContext: else: return None - def validate_ns(self, ns: str): + def validate_ns(self, ns: str) -> None: """Raises an exception if there is a problem looking up the given namespace.""" @@ -762,7 +768,7 @@ class GirContext: did_you_mean=(ns, self.namespaces.keys()), ) - def validate_type(self, name: str, ns: str): + def validate_type(self, name: str, ns: str) -> None: """Raises an exception if there is a problem looking up the given type.""" self.validate_ns(ns) diff --git a/blueprintcompiler/interactive_port.py b/blueprintcompiler/interactive_port.py index ddb5e28..ffc4292 100644 --- a/blueprintcompiler/interactive_port.py +++ b/blueprintcompiler/interactive_port.py @@ -32,7 +32,7 @@ from .utils import Colors class CouldNotPort: - def __init__(self, message): + def __init__(self, message: str): self.message = message diff --git a/blueprintcompiler/lsp.py b/blueprintcompiler/lsp.py index 890eff0..dd12905 100644 --- a/blueprintcompiler/lsp.py +++ b/blueprintcompiler/lsp.py @@ -31,7 +31,7 @@ def printerr(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) -def command(json_method): +def command(json_method: str): def decorator(func): func._json_method = json_method return func @@ -40,7 +40,7 @@ def command(json_method): class OpenFile: - def __init__(self, uri, text, version): + def __init__(self, uri: str, text: str, version: int): self.uri = uri self.text = text self.version = version @@ -81,6 +81,9 @@ class OpenFile: self.diagnostics.append(e) def calc_semantic_tokens(self) -> T.List[int]: + if self.ast is None: + return [] + tokens = list(self.ast.get_semantic_tokens()) token_lists = [ [ @@ -318,9 +321,11 @@ class LanguageServer: }, ) - def _create_diagnostic(self, text, uri, err): + def _create_diagnostic(self, text: str, uri: str, err: CompileError): message = err.message + assert err.start is not None and err.end is not None + for hint in err.hints: message += "\nhint: " + hint diff --git a/blueprintcompiler/main.py b/blueprintcompiler/main.py index 345f430..6127630 100644 --- a/blueprintcompiler/main.py +++ b/blueprintcompiler/main.py @@ -82,7 +82,7 @@ class BlueprintApp: except: report_bug() - def add_subcommand(self, name, help, func): + def add_subcommand(self, name: str, help: str, func): parser = self.subparsers.add_parser(name, help=help) parser.set_defaults(func=func) return parser diff --git a/blueprintcompiler/parse_tree.py b/blueprintcompiler/parse_tree.py index ef8586b..670c72e 100644 --- a/blueprintcompiler/parse_tree.py +++ b/blueprintcompiler/parse_tree.py @@ -23,6 +23,7 @@ import typing as T from collections import defaultdict from enum import Enum +from .ast_utils import AstNode from .errors import ( assert_true, @@ -64,19 +65,19 @@ class ParseGroup: be converted to AST nodes by passing the children and key=value pairs to the AST node constructor.""" - def __init__(self, ast_type, start: int): + def __init__(self, ast_type: T.Type[AstNode], start: int): self.ast_type = ast_type self.children: T.List[ParseGroup] = [] self.keys: T.Dict[str, T.Any] = {} - self.tokens: T.Dict[str, Token] = {} + self.tokens: T.Dict[str, T.Optional[Token]] = {} self.start = start - self.end = None + self.end: T.Optional[int] = None self.incomplete = False - def add_child(self, child): + def add_child(self, child: "ParseGroup"): self.children.append(child) - def set_val(self, key, val, token): + def set_val(self, key: str, val: T.Any, token: T.Optional[Token]): assert_true(key not in self.keys) self.keys[key] = val @@ -105,22 +106,22 @@ class ParseGroup: class ParseContext: """Contains the state of the parser.""" - def __init__(self, tokens, index=0): + def __init__(self, tokens: T.List[Token], index=0): self.tokens = list(tokens) self.binding_power = 0 self.index = index self.start = index - self.group = None - self.group_keys = {} - self.group_children = [] - self.last_group = None + self.group: T.Optional[ParseGroup] = None + self.group_keys: T.Dict[str, T.Tuple[T.Any, T.Optional[Token]]] = {} + self.group_children: T.List[ParseGroup] = [] + self.last_group: T.Optional[ParseGroup] = None self.group_incomplete = False - self.errors = [] - self.warnings = [] + self.errors: T.List[CompileError] = [] + self.warnings: T.List[CompileWarning] = [] - def create_child(self): + def create_child(self) -> "ParseContext": """Creates a new ParseContext at this context's position. The new context will be used to parse one node. If parsing is successful, the new context will be applied to "self". If parsing fails, the new @@ -131,7 +132,7 @@ class ParseContext: ctx.binding_power = self.binding_power return ctx - def apply_child(self, other): + def apply_child(self, other: "ParseContext"): """Applies a child context to this context.""" if other.group is not None: @@ -159,12 +160,12 @@ class ParseContext: elif other.last_group: self.last_group = other.last_group - def start_group(self, ast_type): + def start_group(self, ast_type: T.Type[AstNode]): """Sets this context to have its own match group.""" assert_true(self.group is None) self.group = ParseGroup(ast_type, self.tokens[self.index].start) - def set_group_val(self, key, value, token): + def set_group_val(self, key: str, value: T.Any, token: T.Optional[Token]): """Sets a matched key=value pair on the current match group.""" assert_true(key not in self.group_keys) self.group_keys[key] = (value, token) @@ -213,7 +214,7 @@ class ParseContext: else: self.errors.append(UnexpectedTokenError(start, end)) - def is_eof(self) -> Token: + def is_eof(self) -> bool: return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF @@ -237,17 +238,17 @@ class ParseNode: def _parse(self, ctx: ParseContext) -> bool: raise NotImplementedError() - def err(self, message): + def err(self, message: str) -> "Err": """Causes this ParseNode to raise an exception if it fails to parse. This prevents the parser from backtracking, so you should understand what it does and how the parser works before using it.""" return Err(self, message) - def expected(self, expect): + def expected(self, expect) -> "Err": """Convenience method for err().""" return self.err("Expected " + expect) - def warn(self, message): + def warn(self, message) -> "Warning": """Causes this ParseNode to emit a warning if it parses successfully.""" return Warning(self, message) @@ -255,11 +256,11 @@ class ParseNode: class Err(ParseNode): """ParseNode that emits a compile error if it fails to parse.""" - def __init__(self, child, message): + def __init__(self, child, message: str): self.child = to_parse_node(child) self.message = message - def _parse(self, ctx): + def _parse(self, ctx: ParseContext): if self.child.parse(ctx).failed(): start_idx = ctx.start while ctx.tokens[start_idx].type in SKIP_TOKENS: @@ -274,11 +275,11 @@ class Err(ParseNode): class Warning(ParseNode): """ParseNode that emits a compile warning if it parses successfully.""" - def __init__(self, child, message): + def __init__(self, child, message: str): self.child = to_parse_node(child) self.message = message - def _parse(self, ctx): + def _parse(self, ctx: ParseContext): ctx.skip() start_idx = ctx.index if self.child.parse(ctx).succeeded(): @@ -295,11 +296,11 @@ class Warning(ParseNode): class Fail(ParseNode): """ParseNode that emits a compile error if it parses successfully.""" - def __init__(self, child, message): + def __init__(self, child, message: str): self.child = to_parse_node(child) self.message = message - def _parse(self, ctx): + def _parse(self, ctx: ParseContext): if self.child.parse(ctx).succeeded(): start_idx = ctx.start while ctx.tokens[start_idx].type in SKIP_TOKENS: @@ -314,7 +315,7 @@ class Fail(ParseNode): class Group(ParseNode): """ParseNode that creates a match group.""" - def __init__(self, ast_type, child): + def __init__(self, ast_type: T.Type[AstNode], child): self.ast_type = ast_type self.child = to_parse_node(child) @@ -393,7 +394,7 @@ class Until(ParseNode): self.child = to_parse_node(child) self.delimiter = to_parse_node(delimiter) - def _parse(self, ctx): + def _parse(self, ctx: ParseContext): while not self.delimiter.parse(ctx).succeeded(): if ctx.is_eof(): return False @@ -463,7 +464,7 @@ class Eof(ParseNode): class Match(ParseNode): """ParseNode that matches the given literal token.""" - def __init__(self, op): + def __init__(self, op: str): self.op = op def _parse(self, ctx: ParseContext) -> bool: @@ -482,7 +483,7 @@ class UseIdent(ParseNode): """ParseNode that matches any identifier and sets it in a key=value pair on the containing match group.""" - def __init__(self, key): + def __init__(self, key: str): self.key = key def _parse(self, ctx: ParseContext): @@ -498,7 +499,7 @@ class UseNumber(ParseNode): """ParseNode that matches a number and sets it in a key=value pair on the containing match group.""" - def __init__(self, key): + def __init__(self, key: str): self.key = key def _parse(self, ctx: ParseContext): @@ -517,7 +518,7 @@ class UseNumberText(ParseNode): """ParseNode that matches a number, but sets its *original text* it in a key=value pair on the containing match group.""" - def __init__(self, key): + def __init__(self, key: str): self.key = key def _parse(self, ctx: ParseContext): @@ -533,7 +534,7 @@ class UseQuoted(ParseNode): """ParseNode that matches a quoted string and sets it in a key=value pair on the containing match group.""" - def __init__(self, key): + def __init__(self, key: str): self.key = key def _parse(self, ctx: ParseContext): @@ -557,7 +558,7 @@ class UseLiteral(ParseNode): pair on the containing group. Useful for, e.g., property and signal flags: `Sequence(Keyword("swapped"), UseLiteral("swapped", True))`""" - def __init__(self, key, literal): + def __init__(self, key: str, literal: T.Any): self.key = key self.literal = literal @@ -570,7 +571,7 @@ class Keyword(ParseNode): """Matches the given identifier and sets it as a named token, with the name being the identifier itself.""" - def __init__(self, kw): + def __init__(self, kw: str): self.kw = kw self.set_token = True diff --git a/blueprintcompiler/tokenizer.py b/blueprintcompiler/tokenizer.py index 516bc0b..170316c 100644 --- a/blueprintcompiler/tokenizer.py +++ b/blueprintcompiler/tokenizer.py @@ -22,7 +22,7 @@ import typing as T import re from enum import Enum -from .errors import CompileError +from .errors import CompileError, CompilerBugError class TokenType(Enum): @@ -53,18 +53,18 @@ _TOKENS = [(type, re.compile(regex)) for (type, regex) in _tokens] class Token: - def __init__(self, type, start, end, string): + def __init__(self, type: TokenType, start: int, end: int, string: str): self.type = type self.start = start self.end = end self.string = string - def __str__(self): + def __str__(self) -> str: return self.string[self.start : self.end] - def get_number(self): + def get_number(self) -> T.Union[int, float]: if self.type != TokenType.NUMBER: - return None + raise CompilerBugError() string = str(self).replace("_", "") try: diff --git a/blueprintcompiler/typelib.py b/blueprintcompiler/typelib.py index 88e7b57..48ec416 100644 --- a/blueprintcompiler/typelib.py +++ b/blueprintcompiler/typelib.py @@ -58,14 +58,14 @@ TYPE_UNICHAR = 21 class Field: - def __init__(self, offset, type, shift=0, mask=None): + def __init__(self, offset: int, type: str, shift=0, mask=None): self._offset = offset self._type = type self._shift = shift self._mask = (1 << mask) - 1 if mask else None self._name = f"{offset}__{type}__{shift}__{mask}" - def __get__(self, typelib, _objtype=None): + def __get__(self, typelib: "Typelib", _objtype=None): if typelib is None: return self @@ -181,47 +181,47 @@ class Typelib: VALUE_NAME = Field(0x4, "string") VALUE_VALUE = Field(0x8, "i32") - def __init__(self, typelib_file, offset): + def __init__(self, typelib_file, offset: int): self._typelib_file = typelib_file self._offset = offset - def __getitem__(self, index): + def __getitem__(self, index: int): return Typelib(self._typelib_file, self._offset + index) def attr(self, name): return self.header.attr(self._offset, name) @property - def header(self): + def header(self) -> "TypelibHeader": return TypelibHeader(self._typelib_file) @property - def u8(self): + def u8(self) -> int: """Gets the 8-bit unsigned int at this location.""" return self._int(1, False) @property - def u16(self): + def u16(self) -> int: """Gets the 16-bit unsigned int at this location.""" return self._int(2, False) @property - def u32(self): + def u32(self) -> int: """Gets the 32-bit unsigned int at this location.""" return self._int(4, False) @property - def i8(self): + def i8(self) -> int: """Gets the 8-bit unsigned int at this location.""" return self._int(1, True) @property - def i16(self): + def i16(self) -> int: """Gets the 16-bit unsigned int at this location.""" return self._int(2, True) @property - def i32(self): + def i32(self) -> int: """Gets the 32-bit unsigned int at this location.""" return self._int(4, True) @@ -240,7 +240,7 @@ class Typelib: end += 1 return self._typelib_file[loc:end].decode("utf-8") - def _int(self, size, signed): + def _int(self, size, signed) -> int: return int.from_bytes( self._typelib_file[self._offset : self._offset + size], sys.byteorder ) @@ -250,7 +250,7 @@ class TypelibHeader(Typelib): def __init__(self, typelib_file): super().__init__(typelib_file, 0) - def dir_entry(self, index): + def dir_entry(self, index) -> T.Optional[Typelib]: if index == 0: return None else: diff --git a/blueprintcompiler/xml_reader.py b/blueprintcompiler/xml_reader.py index c0552f5..5e31773 100644 --- a/blueprintcompiler/xml_reader.py +++ b/blueprintcompiler/xml_reader.py @@ -46,7 +46,7 @@ PARSE_GIR = set( class Element: - def __init__(self, tag, attrs: T.Dict[str, str]): + def __init__(self, tag: str, attrs: T.Dict[str, str]): self.tag = tag self.attrs = attrs self.children: T.List["Element"] = [] @@ -56,10 +56,10 @@ class Element: def cdata(self): return "".join(self.cdata_chunks) - def get_elements(self, name) -> T.List["Element"]: + def get_elements(self, name: str) -> T.List["Element"]: return [child for child in self.children if child.tag == name] - def __getitem__(self, key): + def __getitem__(self, key: str): return self.attrs.get(key)