Clean up AST code

This commit is contained in:
James Westman 2021-10-31 16:44:34 -05:00
parent d7a8a21b8e
commit dc7c0cabd8
No known key found for this signature in database
GPG key ID: CE2DBA0ADB654EA6
5 changed files with 171 additions and 412 deletions

View file

@ -17,7 +17,99 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later
import typing as T
from collections import ChainMap, defaultdict
from .errors import *
from .utils import lazy_prop
from .xml_emitter import XmlEmitter
class Children:
""" Allows accessing children by type using array syntax. """
def __init__(self, children):
self._children = children
def __iter__(self):
return iter(self._children)
def __getitem__(self, key):
return [child for child in self._children if isinstance(child, key)]
class AstNode:
""" Base class for nodes in the abstract syntax tree. """
completers: T.List = []
def __init__(self, group, children, tokens, incomplete=False):
self.group = group
self.children = Children(children)
self.tokens = ChainMap(tokens, defaultdict(lambda: None))
self.incomplete = incomplete
self.parent = None
for child in self.children:
child.parent = self
def __init_subclass__(cls):
cls.completers = []
@property
def root(self):
if self.parent is None:
return self
else:
return self.parent.root
@lazy_prop
def errors(self):
return list(self._get_errors())
def _get_errors(self):
for name, attr in self._attrs_by_type(Validator):
try:
getattr(self, name)
except AlreadyCaughtError:
pass
except CompileError as e:
yield e
for child in self.children:
yield from child._get_errors()
def _attrs_by_type(self, attr_type):
for name in dir(type(self)):
item = getattr(type(self), name)
if isinstance(item, attr_type):
yield name, item
def generate(self) -> str:
""" Generates an XML string from the node. """
xml = XmlEmitter()
self.emit_xml(xml)
return xml.result
def emit_xml(self, xml: XmlEmitter):
""" Emits the XML representation of this AST node to the XmlEmitter. """
raise NotImplementedError()
def get_docs(self, idx: int) -> T.Optional[str]:
for name, attr in self._attrs_by_type(Docs):
if attr.token_name:
token = self.group.tokens.get(attr.token_name)
if token and token.start <= idx < token.end:
return getattr(self, name)
else:
return getattr(self, name)
for child in self.children:
if child.group.start <= idx < child.group.end:
docs = child.get_docs(idx)
if docs is not None:
return docs
return None
class Validator:
def __init__(self, func, token_name=None, end_token_name=None):