Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 64 additions & 16 deletions src/cfengine_cli/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import itertools
import tree_sitter_cfengine as tscfengine
from dataclasses import dataclass
from tree_sitter import Language, Parser
from cfbs.validate import validate_config
from cfbs.cfbs_config import CFBSConfig
Expand All @@ -26,6 +27,39 @@
)


@dataclass
class _State:
block_type: str | None = None # "bundle" | "body" | "promise" | None
promise_type: str | None = None # "vars" | "files" | "classes" | ... | None
attribute_name: str | None = None # "if" | "string" | "slist" | ... | None

def update(self, node) -> "_State":
"""Updates and returns the state that should apply to the children of `node`."""
if node.type == "bundle_block":
return _State(block_type="bundle")
if node.type == "body_block":
return _State(block_type="body")
if node.type == "promise_block":
return _State(block_type="promise")
if node.type == "bundle_section":
for child in node.children:
if child.type == "promise_guard":
return _State(
block_type=self.block_type,
promise_type=_text(child)[:-1], # strip trailing ':'
)
return _State(block_type=self.block_type)
if node.type == "attribute":
for child in node.children:
if child.type == "attribute_name":
return _State(
block_type=self.block_type,
promise_type=self.promise_type,
attribute_name=_text(child),
)
return self


def lint_cfbs_json(filename) -> int:
assert os.path.isfile(filename)
assert filename.endswith("cfbs.json")
Expand Down Expand Up @@ -93,16 +127,9 @@ def _find_node_type(filename, lines, node, node_type):
return matches


def _find_nodes(filename, lines, node):
matches = []
visitor = lambda x: matches.append(x)
_walk_generic(filename, lines, node, visitor)
return matches


def _single_node_checks(filename, lines, node, user_definition, strict):
"""Things which can be checked by only looking at one node,
not needing to recurse into children."""
def _node_checks(filename, lines, node, user_definition, strict, state: _State):
"""Checks we run on each node in the syntax tree,
utilizes state for checks which require context."""
line = node.range.start_point[0] + 1
column = node.range.start_point[1] + 1
if node.type == "attribute_name" and _text(node) == "ifvarclass":
Expand Down Expand Up @@ -133,7 +160,6 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
f"Error: Undefined promise type '{promise_type}' at {filename}:{line}:{column}"
)
return 1

if node.type == "bundle_block_name":
if _text(node) != _text(node).lower():
_highlight_range(node, lines)
Expand All @@ -156,6 +182,16 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
)
return 1
if node.type == "calling_identifier":
if (
strict
and _text(node) in user_definition.get("all_bundle_names", set())
and state.promise_type in user_definition.get("custom_promise_types", set())
):
_highlight_range(node, lines)
print(
f"Error: Call to bundle '{_text(node)}' inside custom promise: '{state.promise_type}' at {filename}:{line}:{column}"
)
return 1
if strict and (
_text(node)
not in BUILTIN_FUNCTIONS.union(
Expand All @@ -171,6 +207,22 @@ def _single_node_checks(filename, lines, node, user_definition, strict):
return 0


def _stateful_walk(
filename, lines, node, user_definition, strict, state: _State | None = None
) -> int:
if state is None:
state = _State()

errors = _node_checks(filename, lines, node, user_definition, strict, state)

child_state = state.update(node)
for child in node.children:
errors += _stateful_walk(
filename, lines, child, user_definition, strict, child_state
)
return errors


def _walk(filename, lines, node, user_definition=None, strict=True) -> int:
if user_definition is None:
user_definition = {}
Expand All @@ -187,11 +239,7 @@ def _walk(filename, lines, node, user_definition=None, strict=True) -> int:
line = node.range.start_point[0] + 1
column = node.range.start_point[1] + 1

errors = 0
for node in _find_nodes(filename, lines, node):
errors += _single_node_checks(filename, lines, node, user_definition, strict)

return errors
return _stateful_walk(filename, lines, node, user_definition, strict)


def _parse_user_definition(filename, lines, root_node):
Expand Down
Loading