Add some type hints

This commit is contained in:
James Westman 2022-12-25 18:32:23 -06:00
parent b6ee649458
commit 0b7dbaf90d
No known key found for this signature in database
GPG key ID: CE2DBA0ADB654EA6
10 changed files with 193 additions and 173 deletions

View file

@ -51,17 +51,17 @@ class LineType(Enum):
class DecompileCtx:
def __init__(self):
self._result = ""
def __init__(self) -> None:
self._result: str = ""
self.gir = GirContext()
self._indent = 0
self._blocks_need_end = []
self._last_line_type = LineType.NONE
self._indent: int = 0
self._blocks_need_end: T.List[str] = []
self._last_line_type: LineType = LineType.NONE
self.gir.add_namespace(get_namespace("Gtk", "4.0"))
@property
def result(self):
def result(self) -> str:
imports = "\n".join(
[
f"using {ns} {namespace.version};"
@ -70,7 +70,7 @@ class DecompileCtx:
)
return imports + "\n" + self._result
def type_by_cname(self, cname):
def type_by_cname(self, cname: str) -> T.Optional[GirType]:
if type := self.gir.get_type_by_cname(cname):
return type
@ -83,17 +83,19 @@ class DecompileCtx:
except:
pass
def start_block(self):
self._blocks_need_end.append(None)
return None
def end_block(self):
def start_block(self) -> None:
self._blocks_need_end.append("")
def end_block(self) -> None:
if close := self._blocks_need_end.pop():
self.print(close)
def end_block_with(self, text):
def end_block_with(self, text: str) -> None:
self._blocks_need_end[-1] = text
def print(self, line, newline=True):
def print(self, line: str, newline: bool = True) -> None:
if line == "}" or line == "]":
self._indent -= 1
@ -124,7 +126,7 @@ class DecompileCtx:
self._blocks_need_end[-1] = _CLOSING[line[-1]]
self._indent += 1
def print_attribute(self, name, value, type):
def print_attribute(self, name: str, value: str, type: GirType) -> None:
def get_enum_name(value):
for member in type.members.values():
if (
@ -169,13 +171,17 @@ class DecompileCtx:
self.print(f'{name}: "{escape_quote(value)}";')
def _decompile_element(ctx: DecompileCtx, gir, xml):
def _decompile_element(
ctx: DecompileCtx, gir: T.Optional[GirContext], xml: Element
) -> None:
try:
decompiler = _DECOMPILERS.get(xml.tag)
if decompiler is None:
raise UnsupportedError(f"unsupported XML tag: <{xml.tag}>")
args = {canon(name): value for name, value in xml.attrs.items()}
args: T.Dict[str, T.Optional[str]] = {
canon(name): value for name, value in xml.attrs.items()
}
if decompiler._cdata:
if len(xml.children):
args["cdata"] = None
@ -196,7 +202,7 @@ def _decompile_element(ctx: DecompileCtx, gir, xml):
raise UnsupportedError(tag=xml.tag)
def decompile(data):
def decompile(data: str) -> str:
ctx = DecompileCtx()
xml = parse(data)
@ -216,11 +222,11 @@ def truthy(string: str) -> bool:
return string.lower() in ["yes", "true", "t", "y", "1"]
def full_name(gir):
def full_name(gir) -> str:
return gir.name if gir.full_name.startswith("Gtk.") else gir.full_name
def lookup_by_cname(gir, cname: str):
def lookup_by_cname(gir, cname: str) -> T.Optional[GirType]:
if isinstance(gir, GirContext):
return gir.get_type_by_cname(cname)
else:

View file

@ -47,15 +47,15 @@ class CompileError(PrintableError):
def __init__(
self,
message,
start=None,
end=None,
did_you_mean=None,
hints=None,
actions=None,
fatal=False,
references=None,
):
message: str,
start: T.Optional[int] = None,
end: T.Optional[int] = None,
did_you_mean: T.Optional[T.Tuple[str, T.List[str]]] = None,
hints: T.Optional[T.List[str]] = None,
actions: T.Optional[T.List["CodeAction"]] = None,
fatal: bool = False,
references: T.Optional[T.List[ErrorReference]] = None,
) -> None:
super().__init__(message)
self.message = message
@ -69,11 +69,11 @@ class CompileError(PrintableError):
if did_you_mean is not None:
self._did_you_mean(*did_you_mean)
def hint(self, hint: str):
def hint(self, hint: str) -> "CompileError":
self.hints.append(hint)
return self
def _did_you_mean(self, word: str, options: T.List[str]):
def _did_you_mean(self, word: str, options: T.List[str]) -> None:
if word.replace("_", "-") in options:
self.hint(f"use '-', not '_': `{word.replace('_', '-')}`")
return
@ -89,7 +89,9 @@ class CompileError(PrintableError):
self.hint("Did you check your spelling?")
self.hint("Are your dependencies up to date?")
def pretty_print(self, filename, code, stream=sys.stdout):
def pretty_print(self, filename: str, code: str, stream=sys.stdout) -> None:
assert self.start is not None
line_num, col_num = utils.idx_to_pos(self.start + 1, code)
line = code.splitlines(True)[line_num]
@ -130,7 +132,7 @@ class UpgradeWarning(CompileWarning):
class UnexpectedTokenError(CompileError):
def __init__(self, start, end):
def __init__(self, start, end) -> None:
super().__init__("Unexpected tokens", start, end)
@ -145,7 +147,7 @@ class MultipleErrors(PrintableError):
a list and re-thrown using the MultipleErrors exception. It will
pretty-print all of the errors and a count of how many errors there are."""
def __init__(self, errors: T.List[CompileError]):
def __init__(self, errors: T.List[CompileError]) -> None:
super().__init__()
self.errors = errors

View file

@ -33,7 +33,7 @@ _namespace_cache: T.Dict[str, "Namespace"] = {}
_xml_cache = {}
def get_namespace(namespace, version) -> "Namespace":
def get_namespace(namespace: str, version: str) -> "Namespace":
search_paths = GIRepository.Repository.get_search_path()
filename = f"{namespace}-{version}.typelib"
@ -58,10 +58,7 @@ def get_namespace(namespace, version) -> "Namespace":
return _namespace_cache[filename]
def get_xml(namespace, version):
from .main import VERSION
from xml.etree import ElementTree
def get_xml(namespace: str, version: str):
search_paths = []
if data_paths := os.environ.get("XDG_DATA_DIRS"):
@ -90,12 +87,17 @@ def get_xml(namespace, version):
class GirType:
@property
def doc(self):
def doc(self) -> T.Optional[str]:
return None
def assignable_to(self, other: "GirType") -> bool:
raise NotImplementedError()
@property
def name(self) -> str:
"""The GIR name of the type, not including the namespace"""
raise NotImplementedError()
@property
def full_name(self) -> str:
"""The GIR name of the type to use in diagnostics"""
@ -108,7 +110,7 @@ class GirType:
class UncheckedType(GirType):
def __init__(self, name) -> None:
def __init__(self, name: str) -> None:
super().__init__()
self._name = name
@ -136,7 +138,7 @@ class BoolType(BasicType):
name = "bool"
glib_type_name: str = "gboolean"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return isinstance(other, BoolType)
@ -144,7 +146,7 @@ class IntType(BasicType):
name = "int"
glib_type_name: str = "gint"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return (
isinstance(other, IntType)
or isinstance(other, UIntType)
@ -156,7 +158,7 @@ class UIntType(BasicType):
name = "uint"
glib_type_name: str = "guint"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return (
isinstance(other, IntType)
or isinstance(other, UIntType)
@ -168,7 +170,7 @@ class FloatType(BasicType):
name = "float"
glib_type_name: str = "gfloat"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return isinstance(other, FloatType)
@ -176,7 +178,7 @@ class StringType(BasicType):
name = "string"
glib_type_name: str = "gchararray"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return isinstance(other, StringType)
@ -184,7 +186,7 @@ class TypeType(BasicType):
name = "GType"
glib_type_name: str = "GType"
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
return isinstance(other, TypeType)
@ -208,14 +210,17 @@ _BASIC_TYPES = {
}
TNode = T.TypeVar("TNode", bound="GirNode")
class GirNode:
def __init__(self, container, tl):
def __init__(self, container: T.Optional["GirNode"], tl: typelib.Typelib) -> None:
self.container = container
self.tl = tl
def get_containing(self, container_type):
def get_containing(self, container_type: T.Type[TNode]) -> TNode:
if self.container is None:
return None
raise CompilerBugError()
elif isinstance(self.container, container_type):
return self.container
else:
@ -228,11 +233,11 @@ class GirNode:
return el
@cached_property
def glib_type_name(self):
def glib_type_name(self) -> str:
return self.tl.OBJ_GTYPE_NAME
@cached_property
def full_name(self):
def full_name(self) -> str:
if self.container is None:
return self.name
else:
@ -273,20 +278,16 @@ class GirNode:
return None
@property
def type_name(self):
return self.type.name
@property
def type(self):
def type(self) -> GirType:
raise NotImplementedError()
class Property(GirNode):
def __init__(self, klass, tl: typelib.Typelib):
def __init__(self, klass: T.Union["Class", "Interface"], tl: typelib.Typelib):
super().__init__(klass, tl)
@cached_property
def name(self):
def name(self) -> str:
return self.tl.PROP_NAME
@cached_property
@ -295,24 +296,26 @@ class Property(GirNode):
@cached_property
def signature(self):
return f"{self.type_name} {self.container.name}.{self.name}"
return f"{self.full_name} {self.container.name}.{self.name}"
@property
def writable(self):
def writable(self) -> bool:
return self.tl.PROP_WRITABLE == 1
@property
def construct_only(self):
def construct_only(self) -> bool:
return self.tl.PROP_CONSTRUCT_ONLY == 1
class Parameter(GirNode):
def __init__(self, container: GirNode, tl: typelib.Typelib):
def __init__(self, container: GirNode, tl: typelib.Typelib) -> None:
super().__init__(container, tl)
class Signal(GirNode):
def __init__(self, klass, tl: typelib.Typelib):
def __init__(
self, klass: T.Union["Class", "Interface"], tl: typelib.Typelib
) -> None:
super().__init__(klass, tl)
# if parameters := xml.get_elements('parameters'):
# self.params = [Parameter(self, child) for child in parameters[0].get_elements('parameter')]
@ -328,11 +331,11 @@ class Signal(GirNode):
class Interface(GirNode, GirType):
def __init__(self, ns, tl: typelib.Typelib):
def __init__(self, ns: "Namespace", tl: typelib.Typelib):
super().__init__(ns, tl)
@cached_property
def properties(self):
def properties(self) -> T.Mapping[str, Property]:
n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES
offset = self.tl.header.HEADER_INTERFACE_BLOB_SIZE
offset += (n_prerequisites + n_prerequisites % 2) * 2
@ -345,7 +348,7 @@ class Interface(GirNode, GirType):
return result
@cached_property
def signals(self):
def signals(self) -> T.Mapping[str, Signal]:
n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES
offset = self.tl.header.HEADER_INTERFACE_BLOB_SIZE
offset += (n_prerequisites + n_prerequisites % 2) * 2
@ -362,7 +365,7 @@ class Interface(GirNode, GirType):
return result
@cached_property
def prerequisites(self):
def prerequisites(self) -> T.List["Interface"]:
n_prerequisites = self.tl.INTERFACE_N_PREREQUISITES
result = []
for i in range(n_prerequisites):
@ -370,7 +373,7 @@ class Interface(GirNode, GirType):
result.append(self.get_containing(Repository)._resolve_dir_entry(entry))
return result
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
if self == other:
return True
for pre in self.prerequisites:
@ -380,15 +383,15 @@ class Interface(GirNode, GirType):
class Class(GirNode, GirType):
def __init__(self, ns, tl: typelib.Typelib):
def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None:
super().__init__(ns, tl)
@property
def abstract(self):
def abstract(self) -> bool:
return self.tl.OBJ_ABSTRACT == 1
@cached_property
def implements(self):
def implements(self) -> T.List[Interface]:
n_interfaces = self.tl.OBJ_N_INTERFACES
result = []
for i in range(n_interfaces):
@ -397,7 +400,7 @@ class Class(GirNode, GirType):
return result
@cached_property
def own_properties(self):
def own_properties(self) -> T.Mapping[str, Property]:
n_interfaces = self.tl.OBJ_N_INTERFACES
offset = self.tl.header.HEADER_OBJECT_BLOB_SIZE
offset += (n_interfaces + n_interfaces % 2) * 2
@ -414,7 +417,7 @@ class Class(GirNode, GirType):
return result
@cached_property
def own_signals(self):
def own_signals(self) -> T.Mapping[str, Signal]:
n_interfaces = self.tl.OBJ_N_INTERFACES
offset = self.tl.header.HEADER_OBJECT_BLOB_SIZE
offset += (n_interfaces + n_interfaces % 2) * 2
@ -433,16 +436,18 @@ class Class(GirNode, GirType):
return result
@cached_property
def parent(self):
def parent(self) -> T.Optional["Class"]:
if entry := self.tl.OBJ_PARENT:
return self.get_containing(Repository)._resolve_dir_entry(entry)
else:
return None
@cached_property
def signature(self):
def signature(self) -> str:
assert self.container is not None
result = f"class {self.container.name}.{self.name}"
if self.parent is not None:
assert self.parent.container is not None
result += f" : {self.parent.container.name}.{self.parent.name}"
if len(self.implements):
result += " implements " + ", ".join(
@ -451,14 +456,14 @@ class Class(GirNode, GirType):
return result
@cached_property
def properties(self):
def properties(self) -> T.Mapping[str, Property]:
return {p.name: p for p in self._enum_properties()}
@cached_property
def signals(self):
def signals(self) -> T.Mapping[str, Signal]:
return {s.name: s for s in self._enum_signals()}
def assignable_to(self, other) -> bool:
def assignable_to(self, other: GirType) -> bool:
if self == other:
return True
elif self.parent and self.parent.assignable_to(other):
@ -470,7 +475,7 @@ class Class(GirNode, GirType):
return False
def _enum_properties(self):
def _enum_properties(self) -> T.Iterable[Property]:
yield from self.own_properties.values()
if self.parent is not None:
@ -479,7 +484,7 @@ class Class(GirNode, GirType):
for impl in self.implements:
yield from impl.properties.values()
def _enum_signals(self):
def _enum_signals(self) -> T.Iterable[Signal]:
yield from self.own_signals.values()
if self.parent is not None:
@ -490,8 +495,8 @@ class Class(GirNode, GirType):
class EnumMember(GirNode):
def __init__(self, ns, tl: typelib.Typelib):
super().__init__(ns, tl)
def __init__(self, enum: "Enumeration", tl: typelib.Typelib) -> None:
super().__init__(enum, tl)
@property
def value(self) -> int:
@ -502,20 +507,20 @@ class EnumMember(GirNode):
return self.tl.VALUE_NAME
@cached_property
def nick(self):
def nick(self) -> str:
return self.name.replace("_", "-")
@property
def c_ident(self):
def c_ident(self) -> str:
return self.tl.attr("c:identifier")
@property
def signature(self):
def signature(self) -> str:
return f"enum member {self.full_name} = {self.value}"
class Enumeration(GirNode, GirType):
def __init__(self, ns, tl: typelib.Typelib):
def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None:
super().__init__(ns, tl)
@cached_property
@ -530,43 +535,43 @@ class Enumeration(GirNode, GirType):
return members
@property
def signature(self):
def signature(self) -> str:
return f"enum {self.full_name}"
def assignable_to(self, type):
def assignable_to(self, type: GirType) -> bool:
return type == self
class Boxed(GirNode, GirType):
def __init__(self, ns, tl: typelib.Typelib):
def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None:
super().__init__(ns, tl)
@property
def signature(self):
def signature(self) -> str:
return f"boxed {self.full_name}"
def assignable_to(self, type):
def assignable_to(self, type) -> bool:
return type == self
class Bitfield(Enumeration):
def __init__(self, ns, tl: typelib.Typelib):
def __init__(self, ns: "Namespace", tl: typelib.Typelib) -> None:
super().__init__(ns, tl)
class Namespace(GirNode):
def __init__(self, repo, tl: typelib.Typelib):
def __init__(self, repo: "Repository", tl: typelib.Typelib) -> None:
super().__init__(repo, tl)
self.entries: T.Dict[str, GirNode] = {}
self.entries: T.Dict[str, GirType] = {}
n_local_entries = tl.HEADER_N_ENTRIES
directory = tl.HEADER_DIRECTORY
n_local_entries: int = tl.HEADER_N_ENTRIES
directory: typelib.Typelib = tl.HEADER_DIRECTORY
for i in range(n_local_entries):
entry = directory[i * tl.HEADER_ENTRY_BLOB_SIZE]
entry_name = entry.DIR_ENTRY_NAME
entry_type = entry.DIR_ENTRY_BLOB_TYPE
entry_blob = entry.DIR_ENTRY_OFFSET
entry_name: str = entry.DIR_ENTRY_NAME
entry_type: int = entry.DIR_ENTRY_BLOB_TYPE
entry_blob: typelib.Typelib = entry.DIR_ENTRY_OFFSET
if entry_type == typelib.BLOB_TYPE_ENUM:
self.entries[entry_name] = Enumeration(self, entry_blob)
@ -595,11 +600,11 @@ class Namespace(GirNode):
return self.tl.HEADER_NSVERSION
@property
def signature(self):
def signature(self) -> str:
return f"namespace {self.name} {self.version}"
@cached_property
def classes(self):
def classes(self) -> T.Mapping[str, Class]:
return {
name: entry
for name, entry in self.entries.items()
@ -607,24 +612,25 @@ class Namespace(GirNode):
}
@cached_property
def interfaces(self):
def interfaces(self) -> T.Mapping[str, Interface]:
return {
name: entry
for name, entry in self.entries.items()
if isinstance(entry, Interface)
}
def get_type(self, name):
def get_type(self, name) -> T.Optional[GirType]:
"""Gets a type (class, interface, enum, etc.) from this namespace."""
return self.entries.get(name)
def get_type_by_cname(self, cname: str):
def get_type_by_cname(self, cname: str) -> T.Optional[GirType]:
"""Gets a type from this namespace by its C name."""
for item in self.entries.values():
if hasattr(item, "cname") and item.cname == cname:
return item
return None
def lookup_type(self, type_name: str):
def lookup_type(self, type_name: str) -> T.Optional[GirType]:
"""Looks up a type in the scope of this namespace (including in the
namespace's dependencies)."""
@ -638,7 +644,7 @@ class Namespace(GirNode):
class Repository(GirNode):
def __init__(self, tl: typelib.Typelib):
def __init__(self, tl: typelib.Typelib) -> None:
super().__init__(None, tl)
self.namespace = Namespace(self, tl)
@ -654,10 +660,10 @@ class Repository(GirNode):
else:
self.includes = {}
def get_type(self, name: str, ns: str) -> T.Optional[GirNode]:
def get_type(self, name: str, ns: str) -> T.Optional[GirType]:
return self.lookup_namespace(ns).get_type(name)
def get_type_by_cname(self, name: str) -> T.Optional[GirNode]:
def get_type_by_cname(self, name: str) -> T.Optional[GirType]:
for ns in [self.namespace, *self.includes.values()]:
if type := ns.get_type_by_cname(name):
return type
@ -679,7 +685,7 @@ class Repository(GirNode):
ns = dir_entry.DIR_ENTRY_NAMESPACE
return self.lookup_namespace(ns).get_type(dir_entry.DIR_ENTRY_NAME)
def _resolve_type_id(self, type_id: int):
def _resolve_type_id(self, type_id: int) -> GirType:
if type_id & 0xFFFFFF == 0:
type_id = (type_id >> 27) & 0x1F
# simple type
@ -726,13 +732,13 @@ class GirContext:
self.namespaces[namespace.name] = namespace
def get_type_by_cname(self, name: str) -> T.Optional[GirNode]:
def get_type_by_cname(self, name: str) -> T.Optional[GirType]:
for ns in self.namespaces.values():
if type := ns.get_type_by_cname(name):
return type
return None
def get_type(self, name: str, ns: str) -> T.Optional[GirNode]:
def get_type(self, name: str, ns: str) -> T.Optional[GirType]:
if ns is None and name in _BASIC_TYPES:
return _BASIC_TYPES[name]()
@ -750,7 +756,7 @@ class GirContext:
else:
return None
def validate_ns(self, ns: str):
def validate_ns(self, ns: str) -> None:
"""Raises an exception if there is a problem looking up the given
namespace."""
@ -762,7 +768,7 @@ class GirContext:
did_you_mean=(ns, self.namespaces.keys()),
)
def validate_type(self, name: str, ns: str):
def validate_type(self, name: str, ns: str) -> None:
"""Raises an exception if there is a problem looking up the given type."""
self.validate_ns(ns)

View file

@ -32,7 +32,7 @@ from .utils import Colors
class CouldNotPort:
def __init__(self, message):
def __init__(self, message: str):
self.message = message

View file

@ -31,7 +31,7 @@ def printerr(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def command(json_method):
def command(json_method: str):
def decorator(func):
func._json_method = json_method
return func
@ -40,7 +40,7 @@ def command(json_method):
class OpenFile:
def __init__(self, uri, text, version):
def __init__(self, uri: str, text: str, version: int):
self.uri = uri
self.text = text
self.version = version
@ -81,6 +81,9 @@ class OpenFile:
self.diagnostics.append(e)
def calc_semantic_tokens(self) -> T.List[int]:
if self.ast is None:
return []
tokens = list(self.ast.get_semantic_tokens())
token_lists = [
[
@ -318,9 +321,11 @@ class LanguageServer:
},
)
def _create_diagnostic(self, text, uri, err):
def _create_diagnostic(self, text: str, uri: str, err: CompileError):
message = err.message
assert err.start is not None and err.end is not None
for hint in err.hints:
message += "\nhint: " + hint

View file

@ -82,7 +82,7 @@ class BlueprintApp:
except:
report_bug()
def add_subcommand(self, name, help, func):
def add_subcommand(self, name: str, help: str, func):
parser = self.subparsers.add_parser(name, help=help)
parser.set_defaults(func=func)
return parser

View file

@ -23,6 +23,7 @@ import typing as T
from collections import defaultdict
from enum import Enum
from .ast_utils import AstNode
from .errors import (
assert_true,
@ -64,19 +65,19 @@ class ParseGroup:
be converted to AST nodes by passing the children and key=value pairs to
the AST node constructor."""
def __init__(self, ast_type, start: int):
def __init__(self, ast_type: T.Type[AstNode], start: int):
self.ast_type = ast_type
self.children: T.List[ParseGroup] = []
self.keys: T.Dict[str, T.Any] = {}
self.tokens: T.Dict[str, Token] = {}
self.tokens: T.Dict[str, T.Optional[Token]] = {}
self.start = start
self.end = None
self.end: T.Optional[int] = None
self.incomplete = False
def add_child(self, child):
def add_child(self, child: "ParseGroup"):
self.children.append(child)
def set_val(self, key, val, token):
def set_val(self, key: str, val: T.Any, token: T.Optional[Token]):
assert_true(key not in self.keys)
self.keys[key] = val
@ -105,22 +106,22 @@ class ParseGroup:
class ParseContext:
"""Contains the state of the parser."""
def __init__(self, tokens, index=0):
def __init__(self, tokens: T.List[Token], index=0):
self.tokens = list(tokens)
self.binding_power = 0
self.index = index
self.start = index
self.group = None
self.group_keys = {}
self.group_children = []
self.last_group = None
self.group: T.Optional[ParseGroup] = None
self.group_keys: T.Dict[str, T.Tuple[T.Any, T.Optional[Token]]] = {}
self.group_children: T.List[ParseGroup] = []
self.last_group: T.Optional[ParseGroup] = None
self.group_incomplete = False
self.errors = []
self.warnings = []
self.errors: T.List[CompileError] = []
self.warnings: T.List[CompileWarning] = []
def create_child(self):
def create_child(self) -> "ParseContext":
"""Creates a new ParseContext at this context's position. The new
context will be used to parse one node. If parsing is successful, the
new context will be applied to "self". If parsing fails, the new
@ -131,7 +132,7 @@ class ParseContext:
ctx.binding_power = self.binding_power
return ctx
def apply_child(self, other):
def apply_child(self, other: "ParseContext"):
"""Applies a child context to this context."""
if other.group is not None:
@ -159,12 +160,12 @@ class ParseContext:
elif other.last_group:
self.last_group = other.last_group
def start_group(self, ast_type):
def start_group(self, ast_type: T.Type[AstNode]):
"""Sets this context to have its own match group."""
assert_true(self.group is None)
self.group = ParseGroup(ast_type, self.tokens[self.index].start)
def set_group_val(self, key, value, token):
def set_group_val(self, key: str, value: T.Any, token: T.Optional[Token]):
"""Sets a matched key=value pair on the current match group."""
assert_true(key not in self.group_keys)
self.group_keys[key] = (value, token)
@ -213,7 +214,7 @@ class ParseContext:
else:
self.errors.append(UnexpectedTokenError(start, end))
def is_eof(self) -> Token:
def is_eof(self) -> bool:
return self.index >= len(self.tokens) or self.peek_token().type == TokenType.EOF
@ -237,17 +238,17 @@ class ParseNode:
def _parse(self, ctx: ParseContext) -> bool:
raise NotImplementedError()
def err(self, message):
def err(self, message: str) -> "Err":
"""Causes this ParseNode to raise an exception if it fails to parse.
This prevents the parser from backtracking, so you should understand
what it does and how the parser works before using it."""
return Err(self, message)
def expected(self, expect):
def expected(self, expect) -> "Err":
"""Convenience method for err()."""
return self.err("Expected " + expect)
def warn(self, message):
def warn(self, message) -> "Warning":
"""Causes this ParseNode to emit a warning if it parses successfully."""
return Warning(self, message)
@ -255,11 +256,11 @@ class ParseNode:
class Err(ParseNode):
"""ParseNode that emits a compile error if it fails to parse."""
def __init__(self, child, message):
def __init__(self, child, message: str):
self.child = to_parse_node(child)
self.message = message
def _parse(self, ctx):
def _parse(self, ctx: ParseContext):
if self.child.parse(ctx).failed():
start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS:
@ -274,11 +275,11 @@ class Err(ParseNode):
class Warning(ParseNode):
"""ParseNode that emits a compile warning if it parses successfully."""
def __init__(self, child, message):
def __init__(self, child, message: str):
self.child = to_parse_node(child)
self.message = message
def _parse(self, ctx):
def _parse(self, ctx: ParseContext):
ctx.skip()
start_idx = ctx.index
if self.child.parse(ctx).succeeded():
@ -295,11 +296,11 @@ class Warning(ParseNode):
class Fail(ParseNode):
"""ParseNode that emits a compile error if it parses successfully."""
def __init__(self, child, message):
def __init__(self, child, message: str):
self.child = to_parse_node(child)
self.message = message
def _parse(self, ctx):
def _parse(self, ctx: ParseContext):
if self.child.parse(ctx).succeeded():
start_idx = ctx.start
while ctx.tokens[start_idx].type in SKIP_TOKENS:
@ -314,7 +315,7 @@ class Fail(ParseNode):
class Group(ParseNode):
"""ParseNode that creates a match group."""
def __init__(self, ast_type, child):
def __init__(self, ast_type: T.Type[AstNode], child):
self.ast_type = ast_type
self.child = to_parse_node(child)
@ -393,7 +394,7 @@ class Until(ParseNode):
self.child = to_parse_node(child)
self.delimiter = to_parse_node(delimiter)
def _parse(self, ctx):
def _parse(self, ctx: ParseContext):
while not self.delimiter.parse(ctx).succeeded():
if ctx.is_eof():
return False
@ -463,7 +464,7 @@ class Eof(ParseNode):
class Match(ParseNode):
"""ParseNode that matches the given literal token."""
def __init__(self, op):
def __init__(self, op: str):
self.op = op
def _parse(self, ctx: ParseContext) -> bool:
@ -482,7 +483,7 @@ class UseIdent(ParseNode):
"""ParseNode that matches any identifier and sets it in a key=value pair on
the containing match group."""
def __init__(self, key):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
@ -498,7 +499,7 @@ class UseNumber(ParseNode):
"""ParseNode that matches a number and sets it in a key=value pair on
the containing match group."""
def __init__(self, key):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
@ -517,7 +518,7 @@ class UseNumberText(ParseNode):
"""ParseNode that matches a number, but sets its *original text* it in a
key=value pair on the containing match group."""
def __init__(self, key):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
@ -533,7 +534,7 @@ class UseQuoted(ParseNode):
"""ParseNode that matches a quoted string and sets it in a key=value pair
on the containing match group."""
def __init__(self, key):
def __init__(self, key: str):
self.key = key
def _parse(self, ctx: ParseContext):
@ -557,7 +558,7 @@ class UseLiteral(ParseNode):
pair on the containing group. Useful for, e.g., property and signal flags:
`Sequence(Keyword("swapped"), UseLiteral("swapped", True))`"""
def __init__(self, key, literal):
def __init__(self, key: str, literal: T.Any):
self.key = key
self.literal = literal
@ -570,7 +571,7 @@ class Keyword(ParseNode):
"""Matches the given identifier and sets it as a named token, with the name
being the identifier itself."""
def __init__(self, kw):
def __init__(self, kw: str):
self.kw = kw
self.set_token = True

View file

@ -22,7 +22,7 @@ import typing as T
import re
from enum import Enum
from .errors import CompileError
from .errors import CompileError, CompilerBugError
class TokenType(Enum):
@ -53,18 +53,18 @@ _TOKENS = [(type, re.compile(regex)) for (type, regex) in _tokens]
class Token:
def __init__(self, type, start, end, string):
def __init__(self, type: TokenType, start: int, end: int, string: str):
self.type = type
self.start = start
self.end = end
self.string = string
def __str__(self):
def __str__(self) -> str:
return self.string[self.start : self.end]
def get_number(self):
def get_number(self) -> T.Union[int, float]:
if self.type != TokenType.NUMBER:
return None
raise CompilerBugError()
string = str(self).replace("_", "")
try:

View file

@ -58,14 +58,14 @@ TYPE_UNICHAR = 21
class Field:
def __init__(self, offset, type, shift=0, mask=None):
def __init__(self, offset: int, type: str, shift=0, mask=None):
self._offset = offset
self._type = type
self._shift = shift
self._mask = (1 << mask) - 1 if mask else None
self._name = f"{offset}__{type}__{shift}__{mask}"
def __get__(self, typelib, _objtype=None):
def __get__(self, typelib: "Typelib", _objtype=None):
if typelib is None:
return self
@ -181,47 +181,47 @@ class Typelib:
VALUE_NAME = Field(0x4, "string")
VALUE_VALUE = Field(0x8, "i32")
def __init__(self, typelib_file, offset):
def __init__(self, typelib_file, offset: int):
self._typelib_file = typelib_file
self._offset = offset
def __getitem__(self, index):
def __getitem__(self, index: int):
return Typelib(self._typelib_file, self._offset + index)
def attr(self, name):
return self.header.attr(self._offset, name)
@property
def header(self):
def header(self) -> "TypelibHeader":
return TypelibHeader(self._typelib_file)
@property
def u8(self):
def u8(self) -> int:
"""Gets the 8-bit unsigned int at this location."""
return self._int(1, False)
@property
def u16(self):
def u16(self) -> int:
"""Gets the 16-bit unsigned int at this location."""
return self._int(2, False)
@property
def u32(self):
def u32(self) -> int:
"""Gets the 32-bit unsigned int at this location."""
return self._int(4, False)
@property
def i8(self):
def i8(self) -> int:
"""Gets the 8-bit unsigned int at this location."""
return self._int(1, True)
@property
def i16(self):
def i16(self) -> int:
"""Gets the 16-bit unsigned int at this location."""
return self._int(2, True)
@property
def i32(self):
def i32(self) -> int:
"""Gets the 32-bit unsigned int at this location."""
return self._int(4, True)
@ -240,7 +240,7 @@ class Typelib:
end += 1
return self._typelib_file[loc:end].decode("utf-8")
def _int(self, size, signed):
def _int(self, size, signed) -> int:
return int.from_bytes(
self._typelib_file[self._offset : self._offset + size], sys.byteorder
)
@ -250,7 +250,7 @@ class TypelibHeader(Typelib):
def __init__(self, typelib_file):
super().__init__(typelib_file, 0)
def dir_entry(self, index):
def dir_entry(self, index) -> T.Optional[Typelib]:
if index == 0:
return None
else:

View file

@ -46,7 +46,7 @@ PARSE_GIR = set(
class Element:
def __init__(self, tag, attrs: T.Dict[str, str]):
def __init__(self, tag: str, attrs: T.Dict[str, str]):
self.tag = tag
self.attrs = attrs
self.children: T.List["Element"] = []
@ -56,10 +56,10 @@ class Element:
def cdata(self):
return "".join(self.cdata_chunks)
def get_elements(self, name) -> T.List["Element"]:
def get_elements(self, name: str) -> T.List["Element"]:
return [child for child in self.children if child.tag == name]
def __getitem__(self, key):
def __getitem__(self, key: str):
return self.attrs.get(key)