Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/clgraph/lineage_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def _process_unit(self, unit: QueryUnit):
if unit.join_predicates:
self._create_join_predicate_edges(unit, output_cols)

# 10. Create where filter edges
if unit.where_predicates:
self._create_where_filter_edges(unit, output_cols)

def _create_window_function_edges(self, unit: QueryUnit, output_cols: List[Dict]):
"""
Create edges for columns used in window functions.
Expand Down Expand Up @@ -488,6 +492,42 @@ def _resolve_join_predicate_column(

return None

def _create_where_filter_edges(self, unit: QueryUnit, output_cols: List[Dict]):
"""
Create edges for columns used in WHERE clauses.

WHERE clauses filter rows, so all referenced columns affect every
non-star output column. Creates where_filter edges from each WHERE
column to each non-star output column.

Args:
unit: The query unit with where_predicates
output_cols: The output columns of this unit
"""
for pred in unit.where_predicates:
for table_ref, col_name in pred.columns:
source_node = self._resolve_join_predicate_column(unit, table_ref, col_name)
if not source_node:
continue
for col_info in output_cols:
if col_info.get("is_star"):
continue
node_key = get_node_key(unit, col_info)
output_node = self.lineage_graph.nodes.get(node_key)
if not output_node:
continue
edge = ColumnEdge(
from_node=source_node,
to_node=output_node,
edge_type="where_filter",
transformation="where_filter",
context="WHERE",
expression=pred.condition_sql,
is_where_filter=True,
where_condition=pred.condition_sql,
)
self.lineage_graph.add_edge(edge)

def _create_qualify_edges(self, unit: QueryUnit, output_cols: List[Dict]):
"""
Create edges for columns used in QUALIFY clause.
Expand Down Expand Up @@ -592,6 +632,13 @@ def _resolve_qualify_column(self, unit: QueryUnit, col_ref: str) -> Optional[Col
if node:
return node

# Struct dot-access fallback (Gap 1): if table_ref is unresolvable,
# treat as struct field access on a column from the first base table
if table_ref and self._is_unresolvable_struct_ref(unit, table_ref):
fallback_tables = self._collect_base_tables_recursive(unit)
fallback_table = fallback_tables[0] if fallback_tables else table_ref
return find_or_create_table_column_node(self.lineage_graph, fallback_table, table_ref)

# Fallback: use table_ref directly if provided
if table_ref:
return find_or_create_table_column_node(self.lineage_graph, table_ref, col_name)
Expand Down Expand Up @@ -872,8 +919,35 @@ def _trace_column_dependencies(self, unit: QueryUnit, output_node: ColumnNode, c
self._resolve_base_table_name,
self._find_column_in_unit,
self._get_default_from_table,
self._is_unresolvable_struct_ref,
self._collect_base_tables_recursive,
)

# ─── Struct Dot-Access Fallback (Gap 1) ───

_MAX_STRUCT_RESOLVE_DEPTH = 4

def _is_unresolvable_struct_ref(self, unit: QueryUnit, table_ref: str) -> bool:
"""Return True when table_ref cannot be resolved as a table/alias/unit in scope."""
if table_ref in unit.alias_mapping:
return False
if table_ref in unit.depends_on_tables:
return False
if self._resolve_source_unit(unit, table_ref):
return False
return True

def _collect_base_tables_recursive(self, unit: QueryUnit, depth: int = 0) -> List[str]:
"""Walk unit dependency chain to find ultimate base tables."""
if depth > self._MAX_STRUCT_RESOLVE_DEPTH:
return []
result = list(unit.depends_on_tables)
for dep_unit_id in unit.depends_on_units:
dep_unit = self.unit_graph.units.get(dep_unit_id)
if dep_unit:
result = [*result, *self._collect_base_tables_recursive(dep_unit, depth + 1)]
return result

def _resolve_source_unit(
self, current_unit: QueryUnit, table_ref: Optional[str]
) -> Optional[QueryUnit]:
Expand Down
8 changes: 7 additions & 1 deletion src/clgraph/lineage_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,14 @@ def trace_forward(
if not outgoing:
descendants.append(current)
else:
has_unvisited = False
for edge in outgoing:
queue.append(edge.to_node)
if edge.to_node.full_name not in visited:
has_unvisited = True
queue.append(edge.to_node)
# If all outgoing targets are already visited, treat as terminal
if not has_unvisited:
descendants.append(current)

return descendants

Expand Down
21 changes: 21 additions & 0 deletions src/clgraph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ class JoinPredicateInfo:
right_table: Optional[str] # Name/alias of the joined (right-side) table


@dataclass
class WherePredicateInfo:
"""WHERE clause predicate for column lineage tracking."""

condition_sql: str
columns: List[Tuple[Optional[str], str]]


@dataclass
class QueryUnit:
"""
Expand Down Expand Up @@ -285,6 +293,15 @@ class QueryUnit:
# Stores info about JOIN ON clause columns for predicate lineage edges
join_predicates: List["JoinPredicateInfo"] = field(default_factory=list)

# WHERE predicate metadata
# Stores info about WHERE clause columns for filter lineage edges
where_predicates: List["WherePredicateInfo"] = field(default_factory=list)

# Ranking window columns for dedup qualify promotion (Gap 2)
# Maps alias -> {function, partition_by, order_by} for ranking functions
# Example: {'rn': {'function': 'ROW_NUMBER', 'partition_by': ['id'], 'order_by': [...]}}
ranking_window_columns: Dict[str, Dict[str, Any]] = field(default_factory=dict)

# Metadata
depth: int = 0 # Nesting depth (0 = main query)
order: int = 0 # Topological order for CTEs
Expand Down Expand Up @@ -645,6 +662,10 @@ class ColumnEdge:
join_condition: Optional[str] = None # Raw SQL of the ON clause
join_side: Optional[str] = None # "left" or "right" (which side of the join this column is on)

# ─── WHERE Filter Metadata ───
is_where_filter: bool = False # True if this edge is from a WHERE clause
where_condition: Optional[str] = None # Raw SQL of the WHERE condition

# ─── Self-Reference / Pipeline Ordering Metadata ───
statement_order: Optional[int] = None # Topological sort index of the query
edge_role: Optional[str] = None # "prior_state_read", "cross_query_self_ref", or None
Expand Down
3 changes: 3 additions & 0 deletions src/clgraph/pipeline_lineage_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ def _add_query_edges(
is_join_predicate=getattr(edge, "is_join_predicate", False),
join_condition=getattr(edge, "join_condition", None),
join_side=getattr(edge, "join_side", None),
# Preserve WHERE filter metadata
is_where_filter=getattr(edge, "is_where_filter", False),
where_condition=getattr(edge, "where_condition", None),
# Preserve complex aggregate metadata
aggregate_spec=getattr(edge, "aggregate_spec", None),
)
Expand Down
88 changes: 88 additions & 0 deletions src/clgraph/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TVFInfo,
TVFType,
ValuesInfo,
WherePredicateInfo,
)

# ============================================================================
Expand Down Expand Up @@ -199,6 +200,20 @@ def _parse_select_unit(
if where_clause:
self._parse_where_subqueries(where_clause, unit, depth)

# 4b. Extract WHERE clause column refs for filter lineage
if where_clause:
where_cols = self._extract_where_columns(where_clause.this)
if where_cols:
unit.where_predicates.append(
WherePredicateInfo(
condition_sql=where_clause.this.sql(),
columns=where_cols,
)
)

# 4c. Promote dedup qualify info from WHERE (Gap 2)
self._promote_dedup_qualify_if_applicable(select_node, unit)

# 5. Parse HAVING clause (may contain subqueries)
having_clause = select_node.args.get("having")
if having_clause:
Expand Down Expand Up @@ -1551,6 +1566,20 @@ def _get_join_right_table(self, join: exp.Join, unit: QueryUnit) -> Optional[str

return None

def _extract_where_columns(self, condition: exp.Expression):
"""Extract column refs from WHERE condition, skipping exp.Subquery subtrees."""
subquery_columns: set = set()
for subq in condition.find_all(exp.Subquery):
for col in subq.find_all(exp.Column):
subquery_columns.add(id(col))

columns = []
for col in condition.find_all(exp.Column):
if id(col) not in subquery_columns:
table_ref = col.table if col.table else None
columns.append((table_ref, col.name))
return columns

def _parse_where_subqueries(
self, where_node: exp.Expression, parent_unit: QueryUnit, depth: int
):
Expand Down Expand Up @@ -1661,6 +1690,53 @@ def _parse_qualify_clause(self, qualify_node: exp.Qualify, unit: QueryUnit):
"window_functions": window_functions,
}

def _promote_dedup_qualify_if_applicable(self, select_node: exp.Select, unit: QueryUnit):
"""
Promote dedup qualify info from a subquery-based WHERE pattern (Gap 2).

Detects the common dedup pattern:
SELECT ... FROM (SELECT *, ROW_NUMBER() OVER (...) AS rn FROM t) WHERE rn = 1
and promotes it to qualify_info on the outer unit.

Only ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE) are eligible.
Comparison operators =, <=, < against a literal are recognized.

Args:
select_node: The SELECT expression
unit: The query unit to potentially add qualify_info to
"""
where_clause = select_node.args.get("where")
if not where_clause or unit.qualify_info:
return

for dep_unit_id in unit.depends_on_units:
dep_unit = self.unit_graph.units.get(dep_unit_id)
if not dep_unit or not dep_unit.ranking_window_columns:
continue

for node in where_clause.walk():
if isinstance(node, (exp.EQ, exp.LTE, exp.LT)):
left, right = node.left, node.right
col_name = None
if isinstance(left, exp.Column) and isinstance(right, exp.Literal):
col_name = left.name
elif isinstance(right, exp.Column) and isinstance(left, exp.Literal):
col_name = right.name

if col_name and col_name in dep_unit.ranking_window_columns:
window_meta = dep_unit.ranking_window_columns[col_name]
unit.qualify_info = {
"condition": where_clause.this.sql(),
"partition_columns": list(window_meta["partition_by"]),
"order_columns": [
c["column"] if isinstance(c, dict) else c
for c in window_meta["order_by"]
],
"window_functions": [window_meta["function"]],
"promoted_from_subquery": True,
}
return

def _parse_grouping_sets(self, group_clause: exp.Group, unit: QueryUnit):
"""
Parse GROUP BY clause for GROUPING SETS, CUBE, and ROLLUP constructs.
Expand Down Expand Up @@ -1816,6 +1892,18 @@ def _parse_window_functions(self, select_node: exp.Select, unit: QueryUnit):
if windows:
unit.window_info = {"windows": windows}

# Populate ranking_window_columns for dedup qualify promotion (Gap 2)
RANKING_FUNCTIONS = {"ROW_NUMBER", "RANK", "DENSE_RANK", "NTILE"}
for window_def in windows:
func_name = window_def.get("function", "").upper()
output_col = window_def.get("output_column")
if func_name in RANKING_FUNCTIONS and output_col:
unit.ranking_window_columns[output_col] = {
"function": func_name,
"partition_by": window_def.get("partition_by", []),
"order_by": window_def.get("order_by", []),
}

def _parse_single_window(
self,
window: exp.Window,
Expand Down
26 changes: 26 additions & 0 deletions src/clgraph/trace_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def trace_regular_columns(
resolve_base_table_name: Callable,
find_column_in_unit: Callable,
get_default_from_table: Callable,
is_unresolvable_struct_ref: Callable = None, # (unit, table_ref) -> bool
collect_base_tables: Callable = None, # (unit) -> List[str]
) -> None:
"""Handle regular column references with UNNEST, TVF, VALUES sub-cases."""
for source_ref in source_columns:
Expand Down Expand Up @@ -386,3 +388,27 @@ def trace_regular_columns(
aggregate_spec=aggregate_spec,
)
graph.add_edge(edge)
elif (
effective_table_ref
and is_unresolvable_struct_ref
and is_unresolvable_struct_ref(unit, effective_table_ref)
):
# Struct dot-access fallback: "after.id" where "after" is not a
# table/alias/unit — treat as struct field access on a column
# named "after" from the first resolvable base table.
fallback_tables = collect_base_tables(unit) if collect_base_tables else []
fallback_table = fallback_tables[0] if fallback_tables else effective_table_ref
source_node = find_or_create_table_column_node(
graph, fallback_table, effective_table_ref
)
edge = ColumnEdge(
from_node=source_node,
to_node=output_node,
edge_type=col_info["type"],
transformation=col_info["type"],
context=unit.unit_type.value,
expression=col_info["expression"],
nested_path=f".{col_name}",
access_type="struct",
)
graph.add_edge(edge)
10 changes: 6 additions & 4 deletions tests/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,11 +1238,13 @@ def test_simplified_multiple_ctes(self):
graph = builder.build()
simplified = graph.to_simplified()

# Original should have 8 nodes (2 input + 2 step1 + 2 step2 + 2 output)
assert len(graph.nodes) == 8
# Original should have 9 nodes (3 input + 2 step1 + 2 step2 + 2 output)
# orders.status is an input node from the WHERE clause filter lineage
assert len(graph.nodes) == 9

# Simplified should only have 4 nodes (2 input + 2 output)
assert len(simplified.nodes) == 4
# Simplified should have 5 nodes (3 input + 2 output)
# orders.status is included as a WHERE filter input
assert len(simplified.nodes) == 5

# Check edges trace through both CTEs
edge_pairs = {(e.from_node.full_name, e.to_node.full_name) for e in simplified.edges}
Expand Down
Loading
Loading