diff --git a/src/cfengine_cli/lint.py b/src/cfengine_cli/lint.py index c397170..a686f69 100644 --- a/src/cfengine_cli/lint.py +++ b/src/cfengine_cli/lint.py @@ -14,6 +14,7 @@ import json import itertools import tree_sitter_cfengine as tscfengine +from dataclasses import dataclass from tree_sitter import Language, Parser from cfbs.validate import validate_config from cfbs.cfbs_config import CFBSConfig @@ -26,6 +27,39 @@ ) +@dataclass +class _State: + block_type: str | None = None # "bundle" | "body" | "promise" | None + promise_type: str | None = None # "vars" | "files" | "classes" | ... | None + attribute_name: str | None = None # "if" | "string" | "slist" | ... | None + + def update(self, node) -> "_State": + """Updates and returns the state that should apply to the children of `node`.""" + if node.type == "bundle_block": + return _State(block_type="bundle") + if node.type == "body_block": + return _State(block_type="body") + if node.type == "promise_block": + return _State(block_type="promise") + if node.type == "bundle_section": + for child in node.children: + if child.type == "promise_guard": + return _State( + block_type=self.block_type, + promise_type=_text(child)[:-1], # strip trailing ':' + ) + return _State(block_type=self.block_type) + if node.type == "attribute": + for child in node.children: + if child.type == "attribute_name": + return _State( + block_type=self.block_type, + promise_type=self.promise_type, + attribute_name=_text(child), + ) + return self + + def lint_cfbs_json(filename) -> int: assert os.path.isfile(filename) assert filename.endswith("cfbs.json") @@ -93,16 +127,9 @@ def _find_node_type(filename, lines, node, node_type): return matches -def _find_nodes(filename, lines, node): - matches = [] - visitor = lambda x: matches.append(x) - _walk_generic(filename, lines, node, visitor) - return matches - - -def _single_node_checks(filename, lines, node, user_definition, strict): - """Things which can be checked by only looking at one node, - not needing to recurse into children.""" +def _node_checks(filename, lines, node, user_definition, strict, state: _State): + """Checks we run on each node in the syntax tree, + utilizes state for checks which require context.""" line = node.range.start_point[0] + 1 column = node.range.start_point[1] + 1 if node.type == "attribute_name" and _text(node) == "ifvarclass": @@ -133,7 +160,6 @@ def _single_node_checks(filename, lines, node, user_definition, strict): f"Error: Undefined promise type '{promise_type}' at {filename}:{line}:{column}" ) return 1 - if node.type == "bundle_block_name": if _text(node) != _text(node).lower(): _highlight_range(node, lines) @@ -156,6 +182,16 @@ def _single_node_checks(filename, lines, node, user_definition, strict): ) return 1 if node.type == "calling_identifier": + if ( + strict + and _text(node) in user_definition.get("all_bundle_names", set()) + and state.promise_type in user_definition.get("custom_promise_types", set()) + ): + _highlight_range(node, lines) + print( + f"Error: Call to bundle '{_text(node)}' inside custom promise: '{state.promise_type}' at {filename}:{line}:{column}" + ) + return 1 if strict and ( _text(node) not in BUILTIN_FUNCTIONS.union( @@ -171,6 +207,22 @@ def _single_node_checks(filename, lines, node, user_definition, strict): return 0 +def _stateful_walk( + filename, lines, node, user_definition, strict, state: _State | None = None +) -> int: + if state is None: + state = _State() + + errors = _node_checks(filename, lines, node, user_definition, strict, state) + + child_state = state.update(node) + for child in node.children: + errors += _stateful_walk( + filename, lines, child, user_definition, strict, child_state + ) + return errors + + def _walk(filename, lines, node, user_definition=None, strict=True) -> int: if user_definition is None: user_definition = {} @@ -187,11 +239,7 @@ def _walk(filename, lines, node, user_definition=None, strict=True) -> int: line = node.range.start_point[0] + 1 column = node.range.start_point[1] + 1 - errors = 0 - for node in _find_nodes(filename, lines, node): - errors += _single_node_checks(filename, lines, node, user_definition, strict) - - return errors + return _stateful_walk(filename, lines, node, user_definition, strict) def _parse_user_definition(filename, lines, root_node):