ast: Separate validation from properties

This commit is contained in:
James Westman 2021-10-31 22:51:48 -05:00
parent bfd9daf6a9
commit 53ad4ec69d
No known key found for this signature in database
GPG key ID: CE2DBA0ADB654EA6
3 changed files with 125 additions and 113 deletions

View file

@ -31,22 +31,40 @@ from .xml_emitter import XmlEmitter
class UI(AstNode): class UI(AstNode):
""" The AST node for the entire file """ """ The AST node for the entire file """
@validate() @property
def gir(self): def gir(self):
gir = GirContext() gir = GirContext()
self._gir_errors = []
try:
gir.add_namespace(self.children[GtkDirective][0].gir_namespace)
except CompileError as e:
self._gir_errors.append(e)
gir.add_namespace(self.children[GtkDirective][0].gir_namespace)
for i in self.children[Import]: for i in self.children[Import]:
gir.add_namespace(i.gir_namespace) try:
gir.add_namespace(i.gir_namespace)
except CompileError as e:
self._gir_errors.append(e)
return gir return gir
@validate()
def gir_errors(self):
# make sure gir is loaded
self.gir
if len(self._gir_errors):
raise MultipleErrors(self._gir_errors)
@validate() @validate()
def at_most_one_template(self): def at_most_one_template(self):
if len(self.children[Template]) > 1: if len(self.children[Template]) > 1:
raise CompileError(f"Only one template may be defined per file, but this file contains {len(self.templates)}", raise CompileError(f"Only one template may be defined per file, but this file contains {len(self.templates)}",
self.children[Template][1].group.start) self.children[Template][1].group.start)
def emit_xml(self, xml: XmlEmitter): def emit_xml(self, xml: XmlEmitter):
xml.start_tag("interface") xml.start_tag("interface")
for x in self.children: for x in self.children:
@ -56,10 +74,8 @@ class UI(AstNode):
class GtkDirective(AstNode): class GtkDirective(AstNode):
@validate("version") @validate("version")
def gir_namespace(self): def gtk_version(self):
if self.tokens["version"] in ["4.0"]: if self.tokens["version"] not in ["4.0"]:
return get_namespace("Gtk", self.tokens["version"])
else:
err = CompileError("Only GTK 4 is supported") err = CompileError("Only GTK 4 is supported")
if self.version.startswith("4"): if self.version.startswith("4"):
err.hint("Expected the GIR version, not an exact version number. Use `using Gtk 4.0;`.") err.hint("Expected the GIR version, not an exact version number. Use `using Gtk 4.0;`.")
@ -67,6 +83,12 @@ class GtkDirective(AstNode):
err.hint("Expected `using Gtk 4.0;`") err.hint("Expected `using Gtk 4.0;`")
raise err raise err
@property
def gir_namespace(self):
return get_namespace("Gtk", self.tokens["version"])
def emit_xml(self, xml: XmlEmitter): def emit_xml(self, xml: XmlEmitter):
xml.put_self_closing("requires", lib="gtk", version=self.tokens["version"]) xml.put_self_closing("requires", lib="gtk", version=self.tokens["version"])
@ -79,9 +101,13 @@ class Import(AstNode):
class Template(AstNode): class Template(AstNode):
@validate("namespace", "class_name") @validate("namespace", "class_name")
def gir_parent(self): def gir_parent_exists(self):
if not self.tokens["ignore_gir"]: if not self.tokens["ignore_gir"]:
return self.root.gir.get_class(self.tokens["class_name"], self.tokens["namespace"]) self.root.gir.validate_class(self.tokens["class_name"], self.tokens["namespace"])
@property
def gir_parent(self):
return self.root.gir.get_class(self.tokens["class_name"], self.tokens["namespace"])
@docs("namespace") @docs("namespace")
@ -106,6 +132,11 @@ class Template(AstNode):
class Object(AstNode): class Object(AstNode):
@validate("namespace", "class_name") @validate("namespace", "class_name")
def gir_class_exists(self):
if not self.tokens["ignore_gir"]:
self.root.gir.validate_class(self.tokens["class_name"], self.tokens["namespace"])
@property
def gir_class(self): def gir_class(self):
if not self.tokens["ignore_gir"]: if not self.tokens["ignore_gir"]:
return self.root.gir.get_class(self.tokens["class_name"], self.tokens["namespace"]) return self.root.gir.get_class(self.tokens["class_name"], self.tokens["namespace"])
@ -141,7 +172,7 @@ class Child(AstNode):
class ObjectContent(AstNode): class ObjectContent(AstNode):
@validate() @property
def gir_class(self): def gir_class(self):
if isinstance(self.parent, Template): if isinstance(self.parent, Template):
return self.parent.gir_parent return self.parent.gir_parent
@ -164,12 +195,13 @@ class ObjectContent(AstNode):
class Property(AstNode): class Property(AstNode):
@validate() @property
def gir_property(self): def gir_property(self):
if self.gir_class is not None: if self.gir_class is not None:
return self.gir_class.properties.get(self.tokens["name"]) return self.gir_class.properties.get(self.tokens["name"])
@validate()
@property
def gir_class(self): def gir_class(self):
parent = self.parent.parent parent = self.parent.parent
if isinstance(parent, Template): if isinstance(parent, Template):
@ -179,6 +211,7 @@ class Property(AstNode):
else: else:
raise CompilerBugError() raise CompilerBugError()
@validate("name") @validate("name")
def property_exists(self): def property_exists(self):
if self.gir_class is None: if self.gir_class is None:
@ -233,12 +266,13 @@ class Property(AstNode):
class Signal(AstNode): class Signal(AstNode):
@validate() @property
def gir_signal(self): def gir_signal(self):
if self.gir_class is not None: if self.gir_class is not None:
return self.gir_class.signals.get(self.name) return self.gir_class.signals.get(self.tokens["name"])
@validate()
@property
def gir_class(self): def gir_class(self):
parent = self.parent.parent parent = self.parent.parent
if isinstance(parent, Template): if isinstance(parent, Template):
@ -248,6 +282,7 @@ class Signal(AstNode):
else: else:
raise CompilerBugError() raise CompilerBugError()
@validate("name") @validate("name")
def signal_exists(self): def signal_exists(self):
if self.gir_class is None: if self.gir_class is None:
@ -262,7 +297,7 @@ class Signal(AstNode):
if self.gir_signal is None: if self.gir_signal is None:
raise CompileError( raise CompileError(
f"Class {self.gir_class.full_name} does not contain a signal called {self.name}", f"Class {self.gir_class.full_name} does not contain a signal called {self.tokens['name']}",
did_you_mean=(self.tokens["name"], self.gir_class.signals.keys()) did_you_mean=(self.tokens["name"], self.gir_class.signals.keys())
) )

View file

@ -52,6 +52,7 @@ class AstNode:
def __init_subclass__(cls): def __init_subclass__(cls):
cls.completers = [] cls.completers = []
cls.validators = [getattr(cls, f) for f in dir(cls) if hasattr(getattr(cls, f), "_validator")]
@property @property
@ -66,9 +67,9 @@ class AstNode:
return list(self._get_errors()) return list(self._get_errors())
def _get_errors(self): def _get_errors(self):
for name, attr in self._attrs_by_type(Validator): for validator in self.validators:
try: try:
getattr(self, name) validator(self)
except AlreadyCaughtError: except AlreadyCaughtError:
pass pass
except CompileError as e: except CompileError as e:
@ -111,70 +112,40 @@ class AstNode:
return None return None
class Validator: def validate(token_name=None, end_token_name=None, skip_incomplete=False):
def __init__(self, func, token_name=None, end_token_name=None):
self.func = func
self.token_name = token_name
self.end_token_name = end_token_name
def __get__(self, instance, owner):
if instance is None:
return self
key = "_validation_result_" + self.func.__name__
if key + "_err" in instance.__dict__:
# If the validator has failed before, raise a generic Exception.
# We want anything that depends on this validation result to
# fail, but not report the exception twice.
raise AlreadyCaughtError()
if key not in instance.__dict__:
try:
instance.__dict__[key] = self.func(instance)
except CompileError as e:
# Mark the validator as already failed so we don't print the
# same message again
instance.__dict__[key + "_err"] = True
# If the node is only partially complete, then an error must
# have already been reported at the parsing stage
if instance.incomplete:
return None
# This mess of code sets the error's start and end positions
# from the tokens passed to the decorator, if they have not
# already been set
if self.token_name is not None and e.start is None:
group = instance.group.tokens.get(self.token_name)
if self.end_token_name is not None and group is None:
group = instance.group.tokens[self.end_token_name]
e.start = group.start
if (self.token_name is not None or self.end_token_name is not None) and e.end is None:
e.end = instance.group.tokens[self.end_token_name or self.token_name].end
# Re-raise the exception
raise e
except Exception as e:
# If the node is only partially complete, then an error must
# have already been reported at the parsing stage
if instance.incomplete:
return None
else:
raise e
# Return the validation result (which other validators, or the code
# generation phase, might depend on)
return instance.__dict__[key]
def validate(*args, **kwargs):
""" Decorator for functions that validate an AST node. Exceptions raised """ Decorator for functions that validate an AST node. Exceptions raised
during validation are marked with range information from the tokens. Also during validation are marked with range information from the tokens. Also
creates a cached property out of the function. """ creates a cached property out of the function. """
def decorator(func): def decorator(func):
return Validator(func, *args, **kwargs) def inner(self):
if skip_incomplete and self.incomplete:
return
try:
func(self)
except CompileError as e:
# If the node is only partially complete, then an error must
# have already been reported at the parsing stage
if self.incomplete:
return
# This mess of code sets the error's start and end positions
# from the tokens passed to the decorator, if they have not
# already been set
if token_name is not None and e.start is None:
group = self.group.tokens.get(token_name)
if end_token_name is not None and group is None:
group = self.group.tokens[end_token_name]
e.start = group.start
if (token_name is not None or end_token_name is not None) and e.end is None:
e.end = self.group.tokens[end_token_name or token_name].end
# Re-raise the exception
raise e
inner._validator = True
return inner
return decorator return decorator

View file

@ -203,6 +203,12 @@ class Namespace(GirNode):
def signature(self): def signature(self):
return f"namespace {self.name} {self.version}" return f"namespace {self.name} {self.version}"
def get_type(self, name):
""" Gets a type (class, interface, enum, etc.) from this namespace. """
return self.classes.get(name) or self.interfaces.get(name)
def lookup_class(self, name: str): def lookup_class(self, name: str):
if "." in name: if "." in name:
ns, cls = name.split(".") ns, cls = name.split(".")
@ -254,43 +260,43 @@ class GirContext:
self.namespaces[namespace.name] = namespace self.namespaces[namespace.name] = namespace
def get_class(self, name: str, ns:str=None) -> Class: def get_type(self, name: str, ns: str) -> GirNode:
if ns is None: ns = ns or "Gtk"
options = [namespace.classes[name]
for namespace in self.namespaces.values()
if name in namespace.classes]
if len(options) == 1: if ns not in self.namespaces:
return options[0] return None
elif len(options) == 0:
raise CompileError(
f"No imported namespace contains a class called {name}",
hints=[
"Did you forget to import a namespace?",
"Did you check your spelling?",
"Are your dependencies up to date?",
],
)
else:
raise CompileError(
f"Class name {name} is ambiguous",
hints=[
f"Specify the namespace, e.g. `{options[0].ns.name}.{name}`",
f"Namespaces with a class named {name}: {', '.join([cls.ns.name for cls in options])}",
],
)
else: return self.namespaces[ns].get_type(name)
if ns not in self.namespaces:
raise CompileError(
f"Namespace `{ns}` was not imported.",
did_you_mean=(ns, self.namespaces.keys()),
)
if name not in self.namespaces[ns].classes:
raise CompileError(
f"Namespace {ns} does not contain a class called {name}.",
did_you_mean=(name, self.namespaces[ns].classes.keys()),
)
return self.namespaces[ns].classes[name] def get_class(self, name: str, ns: str) -> T.Optional[Class]:
type = self.get_type(name, ns)
if isinstance(type, Class):
return type
def validate_class(self, name: str, ns: str) -> Class:
""" Raises an exception if there is a problem looking up the given
class (it doesn't exist, it isn't a class, etc.) """
ns = ns or "Gtk"
if ns not in self.namespaces:
raise CompileError(
f"Namespace `{ns}` was not imported.",
did_you_mean=(ns, self.namespaces.keys()),
)
type = self.get_type(name, ns)
if type is None:
raise CompileError(
f"Namespace {ns} does not contain a class called {name}.",
did_you_mean=(name, self.namespaces[ns].classes.keys()),
)
elif not isinstance(type, Class):
raise CompileError(
f"{ns}.{name} is not a class.",
did_you_mean=(name, self.namespaces[ns].classes.keys()),
)