From bf6972d7c26e952e4cfe9a757a2b65f2265bbfe0 Mon Sep 17 00:00:00 2001 From: Ming Jer Lee Date: Tue, 14 Apr 2026 17:12:56 -0400 Subject: [PATCH 1/3] feat: emit where_filter edges for WHERE clause columns (Gap 8) Add WHERE filter lineage tracking: columns referenced in WHERE clauses now produce where_filter edges to all non-star output columns. Subquery columns within WHERE are excluded from the outer query's predicates. - Add WherePredicateInfo dataclass and where_predicates to QueryUnit - Add is_where_filter and where_condition fields to ColumnEdge - Extract WHERE column refs in query parser (skipping subquery subtrees) - Create where_filter edges in lineage builder - Preserve WHERE metadata in pipeline cross-query edge copies - Fix trace_forward BFS to treat nodes as terminals when all outgoing targets are already visited (prevents cycles from breaking traversal) --- src/clgraph/lineage_builder.py | 40 ++++ src/clgraph/lineage_tracer.py | 8 +- src/clgraph/models.py | 16 ++ src/clgraph/pipeline_lineage_builder.py | 3 + src/clgraph/query_parser.py | 26 +++ tests/test_lineage.py | 10 +- tests/test_where_filter_lineage.py | 295 ++++++++++++++++++++++++ 7 files changed, 393 insertions(+), 5 deletions(-) create mode 100644 tests/test_where_filter_lineage.py diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index 1e190e2..9408f52 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -177,6 +177,10 @@ def _process_unit(self, unit: QueryUnit): if unit.join_predicates: self._create_join_predicate_edges(unit, output_cols) + # 10. Create where filter edges + if unit.where_predicates: + self._create_where_filter_edges(unit, output_cols) + def _create_window_function_edges(self, unit: QueryUnit, output_cols: List[Dict]): """ Create edges for columns used in window functions. @@ -488,6 +492,42 @@ def _resolve_join_predicate_column( return None + def _create_where_filter_edges(self, unit: QueryUnit, output_cols: List[Dict]): + """ + Create edges for columns used in WHERE clauses. + + WHERE clauses filter rows, so all referenced columns affect every + non-star output column. Creates where_filter edges from each WHERE + column to each non-star output column. + + Args: + unit: The query unit with where_predicates + output_cols: The output columns of this unit + """ + for pred in unit.where_predicates: + for table_ref, col_name in pred.columns: + source_node = self._resolve_join_predicate_column(unit, table_ref, col_name) + if not source_node: + continue + for col_info in output_cols: + if col_info.get("is_star"): + continue + node_key = get_node_key(unit, col_info) + output_node = self.lineage_graph.nodes.get(node_key) + if not output_node: + continue + edge = ColumnEdge( + from_node=source_node, + to_node=output_node, + edge_type="where_filter", + transformation="where_filter", + context="WHERE", + expression=pred.condition_sql, + is_where_filter=True, + where_condition=pred.condition_sql, + ) + self.lineage_graph.add_edge(edge) + def _create_qualify_edges(self, unit: QueryUnit, output_cols: List[Dict]): """ Create edges for columns used in QUALIFY clause. diff --git a/src/clgraph/lineage_tracer.py b/src/clgraph/lineage_tracer.py index ad7109e..15a5c3f 100644 --- a/src/clgraph/lineage_tracer.py +++ b/src/clgraph/lineage_tracer.py @@ -223,8 +223,14 @@ def trace_forward( if not outgoing: descendants.append(current) else: + has_unvisited = False for edge in outgoing: - queue.append(edge.to_node) + if edge.to_node.full_name not in visited: + has_unvisited = True + queue.append(edge.to_node) + # If all outgoing targets are already visited, treat as terminal + if not has_unvisited: + descendants.append(current) return descendants diff --git a/src/clgraph/models.py b/src/clgraph/models.py index f096de8..ae4eada 100644 --- a/src/clgraph/models.py +++ b/src/clgraph/models.py @@ -187,6 +187,14 @@ class JoinPredicateInfo: right_table: Optional[str] # Name/alias of the joined (right-side) table +@dataclass +class WherePredicateInfo: + """WHERE clause predicate for column lineage tracking.""" + + condition_sql: str + columns: List[Tuple[Optional[str], str]] + + @dataclass class QueryUnit: """ @@ -285,6 +293,10 @@ class QueryUnit: # Stores info about JOIN ON clause columns for predicate lineage edges join_predicates: List["JoinPredicateInfo"] = field(default_factory=list) + # WHERE predicate metadata + # Stores info about WHERE clause columns for filter lineage edges + where_predicates: List["WherePredicateInfo"] = field(default_factory=list) + # Metadata depth: int = 0 # Nesting depth (0 = main query) order: int = 0 # Topological order for CTEs @@ -645,6 +657,10 @@ class ColumnEdge: join_condition: Optional[str] = None # Raw SQL of the ON clause join_side: Optional[str] = None # "left" or "right" (which side of the join this column is on) + # ─── WHERE Filter Metadata ─── + is_where_filter: bool = False # True if this edge is from a WHERE clause + where_condition: Optional[str] = None # Raw SQL of the WHERE condition + # ─── Self-Reference / Pipeline Ordering Metadata ─── statement_order: Optional[int] = None # Topological sort index of the query edge_role: Optional[str] = None # "prior_state_read", "cross_query_self_ref", or None diff --git a/src/clgraph/pipeline_lineage_builder.py b/src/clgraph/pipeline_lineage_builder.py index 55191af..39d2cb0 100644 --- a/src/clgraph/pipeline_lineage_builder.py +++ b/src/clgraph/pipeline_lineage_builder.py @@ -470,6 +470,9 @@ def _add_query_edges( is_join_predicate=getattr(edge, "is_join_predicate", False), join_condition=getattr(edge, "join_condition", None), join_side=getattr(edge, "join_side", None), + # Preserve WHERE filter metadata + is_where_filter=getattr(edge, "is_where_filter", False), + where_condition=getattr(edge, "where_condition", None), # Preserve complex aggregate metadata aggregate_spec=getattr(edge, "aggregate_spec", None), ) diff --git a/src/clgraph/query_parser.py b/src/clgraph/query_parser.py index 9c23860..7933847 100644 --- a/src/clgraph/query_parser.py +++ b/src/clgraph/query_parser.py @@ -19,6 +19,7 @@ TVFInfo, TVFType, ValuesInfo, + WherePredicateInfo, ) # ============================================================================ @@ -199,6 +200,17 @@ def _parse_select_unit( if where_clause: self._parse_where_subqueries(where_clause, unit, depth) + # 4b. Extract WHERE clause column refs for filter lineage + if where_clause: + where_cols = self._extract_where_columns(where_clause.this) + if where_cols: + unit.where_predicates.append( + WherePredicateInfo( + condition_sql=where_clause.this.sql(), + columns=where_cols, + ) + ) + # 5. Parse HAVING clause (may contain subqueries) having_clause = select_node.args.get("having") if having_clause: @@ -1551,6 +1563,20 @@ def _get_join_right_table(self, join: exp.Join, unit: QueryUnit) -> Optional[str return None + def _extract_where_columns(self, condition: exp.Expression): + """Extract column refs from WHERE condition, skipping exp.Subquery subtrees.""" + subquery_columns: set = set() + for subq in condition.find_all(exp.Subquery): + for col in subq.find_all(exp.Column): + subquery_columns.add(id(col)) + + columns = [] + for col in condition.find_all(exp.Column): + if id(col) not in subquery_columns: + table_ref = col.table if col.table else None + columns.append((table_ref, col.name)) + return columns + def _parse_where_subqueries( self, where_node: exp.Expression, parent_unit: QueryUnit, depth: int ): diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 771ef8e..9ca2192 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -1238,11 +1238,13 @@ def test_simplified_multiple_ctes(self): graph = builder.build() simplified = graph.to_simplified() - # Original should have 8 nodes (2 input + 2 step1 + 2 step2 + 2 output) - assert len(graph.nodes) == 8 + # Original should have 9 nodes (3 input + 2 step1 + 2 step2 + 2 output) + # orders.status is an input node from the WHERE clause filter lineage + assert len(graph.nodes) == 9 - # Simplified should only have 4 nodes (2 input + 2 output) - assert len(simplified.nodes) == 4 + # Simplified should have 5 nodes (3 input + 2 output) + # orders.status is included as a WHERE filter input + assert len(simplified.nodes) == 5 # Check edges trace through both CTEs edge_pairs = {(e.from_node.full_name, e.to_node.full_name) for e in simplified.edges} diff --git a/tests/test_where_filter_lineage.py b/tests/test_where_filter_lineage.py new file mode 100644 index 0000000..0757f5e --- /dev/null +++ b/tests/test_where_filter_lineage.py @@ -0,0 +1,295 @@ +""" +Test suite for Gap 8: WHERE Filter Lineage. + +Tests cover: +1. Simple WHERE col = 'value' produces is_where_filter=True edges to all non-star output columns +2. Compound WHERE with multiple column refs produces filter edges for all column refs +3. WHERE x IN (SELECT y FROM other) — subquery columns are NOT in outer where_predicates +4. SELECT * FROM t WHERE t.id = 1 produces NO where_filter edges (no non-star output columns) +5. where_condition attribute contains the clause SQL without the WHERE keyword +6. ColumnEdge has is_where_filter and where_condition as actual dataclass fields + +Total: 6 test cases +""" + +import dataclasses + +import pytest + +from clgraph import RecursiveLineageBuilder +from clgraph.models import ColumnEdge + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _edges_dict(graph): + """Build a dict keyed by (from_full_name, to_full_name) -> edge.""" + return {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges} + + +def _where_filter_edges(graph): + """Return only edges with is_where_filter=True.""" + return [e for e in graph.edges if e.is_where_filter] + + +def _where_filter_edges_to(graph, target_full_name): + """Return where_filter edges targeting a specific output column.""" + return [e for e in graph.edges if e.is_where_filter and e.to_node.full_name == target_full_name] + + +def _where_filter_sources_to(graph, target_full_name): + """Return set of from_node.full_name for where_filter edges to a target.""" + return {e.from_node.full_name for e in _where_filter_edges_to(graph, target_full_name)} + + +# ============================================================================ +# Test 1: Simple WHERE col = 'value' +# ============================================================================ + + +class TestSimpleWhereFilter: + """Simple WHERE col = 'value' produces is_where_filter=True edges to all non-star outputs.""" + + SQL = """ + SELECT t.id, t.name, t.city + FROM my_table t + WHERE t.status = 'active' + """ + + def test_where_filter_edges_exist(self): + """WHERE t.status = 'active' produces filter edges to all output columns.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + assert len(filter_edges) > 0, "Should have where_filter edges" + + def test_where_filter_edge_to_each_output(self): + """Each non-star output column gets a where_filter edge from t.status.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + for output_col in ["output.id", "output.name", "output.city"]: + sources = _where_filter_sources_to(graph, output_col) + assert "my_table.status" in sources, ( + f"my_table.status should have where_filter edge to {output_col}" + ) + + def test_where_filter_edge_is_not_join_predicate(self): + """WHERE filter edges should NOT be marked as join predicates.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + for edge in _where_filter_edges(graph): + assert not edge.is_join_predicate, ( + f"WHERE filter edge should not be is_join_predicate: {edge}" + ) + + def test_where_filter_edge_type(self): + """WHERE filter edges have edge_type='where_filter'.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + for edge in _where_filter_edges(graph): + assert edge.edge_type == "where_filter", ( + f"Expected edge_type='where_filter', got '{edge.edge_type}'" + ) + + +# ============================================================================ +# Test 2: Compound WHERE with multiple column refs +# ============================================================================ + + +class TestCompoundWhereFilter: + """Compound WHERE with OR/AND produces filter edges for all column refs.""" + + SQL = """ + SELECT s.id, s.name, s.city + FROM staging s + LEFT JOIN dim_customer t ON s.id = t.id + WHERE t.id IS NULL OR (t.name <> s.name OR t.city <> s.city) + """ + + def test_all_where_columns_produce_filter_edges(self): + """All columns referenced in the WHERE clause produce filter edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + # Collect all from_node column names from where_filter edges + filter_edge_sources = {e.from_node.column_name for e in _where_filter_edges(graph)} + + # t.id, t.name, t.city, s.name, s.city should all appear as filter sources + expected_cols = {"id", "name", "city"} + assert expected_cols.issubset(filter_edge_sources), ( + f"Expected WHERE columns {expected_cols} in filter sources, got {filter_edge_sources}" + ) + + def test_filter_edges_target_non_star_outputs(self): + """Filter edges target all non-star output columns.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + target_names = {e.to_node.full_name for e in filter_edges} + + for expected in ["output.id", "output.name", "output.city"]: + assert expected in target_names, ( + f"Expected filter edge targeting {expected}, got targets: {target_names}" + ) + + +# ============================================================================ +# Test 3: WHERE with subquery — subquery columns excluded +# ============================================================================ + + +class TestWhereSubqueryExclusion: + """WHERE x IN (SELECT y FROM other) — subquery columns are NOT in outer where_predicates.""" + + SQL = """ + SELECT t.id, t.name + FROM my_table t + WHERE t.status = 'active' AND t.id IN (SELECT o.id FROM other_table o) + """ + + def test_outer_where_column_produces_filter_edges(self): + """t.status from the outer WHERE produces filter edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _where_filter_sources_to(graph, "output.id") + assert "my_table.status" in sources, ( + "my_table.status should have where_filter edge to output.id" + ) + + def test_subquery_columns_not_in_outer_filter_edges(self): + """o.id from the subquery should NOT appear as a where_filter source in the outer query.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + # Filter edges from the outer query should not reference other_table.id + outer_filter_sources = {e.from_node.full_name for e in filter_edges} + assert "other_table.id" not in outer_filter_sources, ( + "Subquery column other_table.id should NOT appear as outer where_filter source" + ) + + def test_outer_where_also_includes_t_id(self): + """t.id from the outer WHERE (in the IN clause) also produces filter edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + sources = _where_filter_sources_to(graph, "output.name") + assert "my_table.id" in sources, ( + "my_table.id (from IN clause) should have where_filter edge to output.name" + ) + + +# ============================================================================ +# Test 4: SELECT * — no where_filter edges +# ============================================================================ + + +class TestSelectStarNoFilterEdges: + """SELECT * FROM t WHERE t.id = 1 produces NO where_filter edges.""" + + SQL = """ + SELECT * FROM my_table t WHERE t.id = 1 + """ + + def test_no_where_filter_edges_for_star(self): + """Star outputs do not receive where_filter edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + assert len(filter_edges) == 0, ( + f"SELECT * should produce no where_filter edges, got {len(filter_edges)}" + ) + + +# ============================================================================ +# Test 5: where_condition contains clause SQL without WHERE keyword +# ============================================================================ + + +class TestWhereConditionContent: + """where_condition attribute contains the clause SQL without the WHERE keyword.""" + + SQL = """ + SELECT t.id, t.name + FROM my_table t + WHERE t.status = 'active' + """ + + def test_where_condition_does_not_contain_where_keyword(self): + """where_condition should not start with 'WHERE'.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + assert len(filter_edges) > 0, "Should have filter edges" + for edge in filter_edges: + assert edge.where_condition is not None, "where_condition should not be None" + assert not edge.where_condition.strip().upper().startswith("WHERE"), ( + f"where_condition should not start with WHERE keyword: {edge.where_condition}" + ) + + def test_where_condition_contains_status_reference(self): + """where_condition should contain a reference to the status column.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + filter_edges = _where_filter_edges(graph) + assert len(filter_edges) > 0, "Should have filter edges" + # At least one edge should reference 'status' in its condition + conditions = {edge.where_condition for edge in filter_edges} + assert any("status" in c.lower() for c in conditions if c), ( + f"where_condition should reference 'status': {conditions}" + ) + + +# ============================================================================ +# Test 6: ColumnEdge has is_where_filter and where_condition as dataclass fields +# ============================================================================ + + +class TestColumnEdgeDataclassFields: + """ColumnEdge has is_where_filter and where_condition as actual dataclass fields.""" + + def test_is_where_filter_is_dataclass_field(self): + """is_where_filter should be a proper dataclass field on ColumnEdge.""" + field_names = {f.name for f in dataclasses.fields(ColumnEdge)} + assert "is_where_filter" in field_names, ( + f"is_where_filter should be a dataclass field, found: {field_names}" + ) + + def test_where_condition_is_dataclass_field(self): + """where_condition should be a proper dataclass field on ColumnEdge.""" + field_names = {f.name for f in dataclasses.fields(ColumnEdge)} + assert "where_condition" in field_names, ( + f"where_condition should be a dataclass field, found: {field_names}" + ) + + def test_default_values(self): + """Default values: is_where_filter=False, where_condition=None.""" + field_map = {f.name: f for f in dataclasses.fields(ColumnEdge)} + + is_where_filter_field = field_map.get("is_where_filter") + assert is_where_filter_field is not None + assert is_where_filter_field.default is False, ( + f"is_where_filter default should be False, got {is_where_filter_field.default}" + ) + + where_condition_field = field_map.get("where_condition") + assert where_condition_field is not None + assert where_condition_field.default is None, ( + f"where_condition default should be None, got {where_condition_field.default}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) From 67859c7d31c9fb83090a8d5b104f9f9d4de39a6c Mon Sep 17 00:00:00 2001 From: Ming Jer Lee Date: Tue, 14 Apr 2026 17:16:55 -0400 Subject: [PATCH 2/3] feat: struct dot-access fallback for unresolvable table refs (Gap 1) When sqlglot parses `after.id` as Column(table="after", name="id") and "after" cannot be resolved as a table, alias, or unit in scope, the new struct fallback emits a lineage edge with nested_path=".id" and access_type="struct", using the first base table from the dependency chain as the source table. Includes recursive base table resolution for CDC-like subquery patterns. --- src/clgraph/lineage_builder.py | 34 ++++ src/clgraph/trace_strategies.py | 26 +++ tests/test_struct_dot_access.py | 284 ++++++++++++++++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 tests/test_struct_dot_access.py diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index 9408f52..d8006a3 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -632,6 +632,13 @@ def _resolve_qualify_column(self, unit: QueryUnit, col_ref: str) -> Optional[Col if node: return node + # Struct dot-access fallback (Gap 1): if table_ref is unresolvable, + # treat as struct field access on a column from the first base table + if table_ref and self._is_unresolvable_struct_ref(unit, table_ref): + fallback_tables = self._collect_base_tables_recursive(unit) + fallback_table = fallback_tables[0] if fallback_tables else table_ref + return find_or_create_table_column_node(self.lineage_graph, fallback_table, table_ref) + # Fallback: use table_ref directly if provided if table_ref: return find_or_create_table_column_node(self.lineage_graph, table_ref, col_name) @@ -912,8 +919,35 @@ def _trace_column_dependencies(self, unit: QueryUnit, output_node: ColumnNode, c self._resolve_base_table_name, self._find_column_in_unit, self._get_default_from_table, + self._is_unresolvable_struct_ref, + self._collect_base_tables_recursive, ) + # ─── Struct Dot-Access Fallback (Gap 1) ─── + + _MAX_STRUCT_RESOLVE_DEPTH = 4 + + def _is_unresolvable_struct_ref(self, unit: QueryUnit, table_ref: str) -> bool: + """Return True when table_ref cannot be resolved as a table/alias/unit in scope.""" + if table_ref in unit.alias_mapping: + return False + if table_ref in unit.depends_on_tables: + return False + if self._resolve_source_unit(unit, table_ref): + return False + return True + + def _collect_base_tables_recursive(self, unit: QueryUnit, depth: int = 0) -> List[str]: + """Walk unit dependency chain to find ultimate base tables.""" + if depth > self._MAX_STRUCT_RESOLVE_DEPTH: + return [] + result = list(unit.depends_on_tables) + for dep_unit_id in unit.depends_on_units: + dep_unit = self.unit_graph.units.get(dep_unit_id) + if dep_unit: + result = [*result, *self._collect_base_tables_recursive(dep_unit, depth + 1)] + return result + def _resolve_source_unit( self, current_unit: QueryUnit, table_ref: Optional[str] ) -> Optional[QueryUnit]: diff --git a/src/clgraph/trace_strategies.py b/src/clgraph/trace_strategies.py index d9fc65a..6f6dc5c 100644 --- a/src/clgraph/trace_strategies.py +++ b/src/clgraph/trace_strategies.py @@ -239,6 +239,8 @@ def trace_regular_columns( resolve_base_table_name: Callable, find_column_in_unit: Callable, get_default_from_table: Callable, + is_unresolvable_struct_ref: Callable = None, # (unit, table_ref) -> bool + collect_base_tables: Callable = None, # (unit) -> List[str] ) -> None: """Handle regular column references with UNNEST, TVF, VALUES sub-cases.""" for source_ref in source_columns: @@ -386,3 +388,27 @@ def trace_regular_columns( aggregate_spec=aggregate_spec, ) graph.add_edge(edge) + elif ( + effective_table_ref + and is_unresolvable_struct_ref + and is_unresolvable_struct_ref(unit, effective_table_ref) + ): + # Struct dot-access fallback: "after.id" where "after" is not a + # table/alias/unit — treat as struct field access on a column + # named "after" from the first resolvable base table. + fallback_tables = collect_base_tables(unit) if collect_base_tables else [] + fallback_table = fallback_tables[0] if fallback_tables else effective_table_ref + source_node = find_or_create_table_column_node( + graph, fallback_table, effective_table_ref + ) + edge = ColumnEdge( + from_node=source_node, + to_node=output_node, + edge_type=col_info["type"], + transformation=col_info["type"], + context=unit.unit_type.value, + expression=col_info["expression"], + nested_path=f".{col_name}", + access_type="struct", + ) + graph.add_edge(edge) diff --git a/tests/test_struct_dot_access.py b/tests/test_struct_dot_access.py new file mode 100644 index 0000000..2e73c39 --- /dev/null +++ b/tests/test_struct_dot_access.py @@ -0,0 +1,284 @@ +""" +Test suite for Gap 1: Struct Dot-Access Fallback. + +When sqlglot parses `after.id`, it becomes Column(table="after", name="id"). +If "after" cannot be resolved as a table/alias/unit in scope, the struct fallback +should emit a lineage edge with nested_path=".id" and access_type="struct", +using the first base table from the dependency chain as the source table. + +Tests cover: +1. Simple struct dot-access: SELECT after.id AS id FROM raw_table +2. Multiple struct fields: after.name, after.city, after.email +3. CDC-like subquery pattern: recursive base table resolution +4. Bracket notation regression: items[0].product_id still works +5. Multi-table JOIN with struct ref: fallback uses first base table +6. Empty fallback_tables case: effective_table_ref used as table name + +Total: 6 test cases +""" + +import pytest + +from clgraph import RecursiveLineageBuilder + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _edges_dict(graph): + """Build a dict keyed by (from_full_name, to_full_name) -> edge.""" + return {(e.from_node.full_name, e.to_node.full_name): e for e in graph.edges} + + +def _struct_edges(graph): + """Return only edges with access_type='struct'.""" + return [e for e in graph.edges if e.access_type == "struct"] + + +def _struct_edges_to(graph, target_full_name): + """Return struct edges targeting a specific output column.""" + return [ + e + for e in graph.edges + if e.access_type == "struct" and e.to_node.full_name == target_full_name + ] + + +# ============================================================================ +# Test 1: Simple struct dot-access +# ============================================================================ + + +class TestSimpleStructDotAccess: + """SELECT after.id AS id FROM raw_table — struct fallback emits edge.""" + + SQL = "SELECT after.id AS id FROM raw_table" + + def test_struct_edge_exists(self): + """after.id should produce a struct edge.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + struct_edges = _struct_edges(graph) + assert len(struct_edges) > 0, ( + f"Expected struct edges for after.id, got none. " + f"All edges: {[(e.from_node.full_name, e.to_node.full_name, e.access_type) for e in graph.edges]}" + ) + + def test_struct_edge_nested_path(self): + """Struct edge should have nested_path='.id'.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _struct_edges_to(graph, "output.id") + assert len(edges) == 1, f"Expected 1 struct edge to output.id, got {len(edges)}" + assert edges[0].nested_path == ".id" + + def test_struct_edge_from_node_column_name(self): + """Struct edge from_node.column_name should be 'after' (the struct column).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _struct_edges_to(graph, "output.id") + assert len(edges) == 1 + assert edges[0].from_node.column_name == "after" + + def test_struct_edge_from_node_table_name(self): + """Struct edge from_node.table_name should be 'raw_table' (the base table).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _struct_edges_to(graph, "output.id") + assert len(edges) == 1 + assert edges[0].from_node.table_name == "raw_table" + + +# ============================================================================ +# Test 2: Multiple struct fields +# ============================================================================ + + +class TestMultipleStructFields: + """Multiple struct field accesses all emit struct edges.""" + + SQL = """ + SELECT after.name AS name, after.city AS city, after.email AS email + FROM raw_table + """ + + def test_all_fields_produce_struct_edges(self): + """after.name, after.city, after.email all produce struct edges.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + struct_edges = _struct_edges(graph) + assert len(struct_edges) == 3, ( + f"Expected 3 struct edges, got {len(struct_edges)}. " + f"All edges: {[(e.from_node.full_name, e.to_node.full_name, e.access_type) for e in graph.edges]}" + ) + + def test_each_field_has_correct_nested_path(self): + """Each struct edge has the correct nested_path.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + for output_col, expected_path in [ + ("output.name", ".name"), + ("output.city", ".city"), + ("output.email", ".email"), + ]: + edges = _struct_edges_to(graph, output_col) + assert len(edges) == 1, f"Expected 1 struct edge to {output_col}, got {len(edges)}" + assert edges[0].nested_path == expected_path, ( + f"Expected nested_path='{expected_path}' for {output_col}, " + f"got '{edges[0].nested_path}'" + ) + + def test_all_from_nodes_reference_raw_table(self): + """All struct edge from_nodes should reference raw_table.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + for edge in _struct_edges(graph): + assert edge.from_node.table_name == "raw_table", ( + f"Expected table_name='raw_table', got '{edge.from_node.table_name}'" + ) + + +# ============================================================================ +# Test 3: CDC-like subquery pattern (recursive base table resolution) +# ============================================================================ + + +class TestCDCSubqueryPattern: + """CDC pattern: SELECT after.id, after.name FROM (SELECT * FROM raw_customer_cdc).""" + + SQL = """ + SELECT after.id, after.name + FROM (SELECT * FROM raw_customer_cdc) sub + """ + + def test_struct_edges_exist(self): + """Struct edges should exist for after.id and after.name.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + struct_edges = _struct_edges(graph) + assert len(struct_edges) >= 2, ( + f"Expected at least 2 struct edges, got {len(struct_edges)}. " + f"All edges: {[(e.from_node.full_name, e.to_node.full_name, e.access_type) for e in graph.edges]}" + ) + + def test_from_node_resolves_to_base_table(self): + """from_node.table_name should resolve to raw_customer_cdc (ultimate base table).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + struct_edges = _struct_edges(graph) + table_names = {e.from_node.table_name for e in struct_edges} + assert "raw_customer_cdc" in table_names, ( + f"Expected raw_customer_cdc in table names, got {table_names}" + ) + + +# ============================================================================ +# Test 4: Bracket notation regression +# ============================================================================ + + +class TestBracketNotationRegression: + """Bracket notation items[0].product_id should still work (existing behavior).""" + + SQL = "SELECT items[0].product_id AS first_product FROM orders" + + def test_bracket_notation_still_works(self): + """items[0].product_id should produce a nested edge with mixed access_type.""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + nested_edges = [e for e in graph.edges if e.nested_path] + assert len(nested_edges) > 0, "Bracket notation should still produce nested edges" + + edge = nested_edges[0] + assert edge.nested_path == "[0].product_id" + assert edge.access_type == "mixed" + assert edge.from_node.column_name == "items" + + +# ============================================================================ +# Test 5: Multi-table JOIN with struct ref +# ============================================================================ + + +class TestMultiTableJoinStructRef: + """Struct ref in a JOIN context uses first base table as fallback.""" + + SQL = """ + SELECT after.id AS id, b.name + FROM raw_table a + INNER JOIN lookup_table b ON a.key = b.key + """ + + def test_struct_edge_uses_first_base_table(self): + """after.id struct fallback should use first base table (raw_table).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _struct_edges_to(graph, "output.id") + assert len(edges) == 1, ( + f"Expected 1 struct edge to output.id, got {len(edges)}. " + f"All edges: {[(e.from_node.full_name, e.to_node.full_name, e.access_type) for e in graph.edges]}" + ) + assert edges[0].from_node.table_name == "raw_table", ( + f"Expected fallback to raw_table, got '{edges[0].from_node.table_name}'" + ) + + def test_normal_column_still_resolves(self): + """b.name should resolve normally (not as struct).""" + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + edges = _edges_dict(graph) + assert ("lookup_table.name", "output.name") in edges, ( + f"Expected lookup_table.name -> output.name edge. Available edges: {list(edges.keys())}" + ) + + +# ============================================================================ +# Test 6: Empty fallback_tables case +# ============================================================================ + + +class TestEmptyFallbackTables: + """When no base tables exist, effective_table_ref is used as table name.""" + + # This is an edge case: a query where the struct ref has no resolvable base tables. + # In practice this is rare, but the fallback should still produce an edge. + SQL = "SELECT after.id AS id FROM after" + + def test_fallback_uses_effective_table_ref(self): + """When 'after' is both the table name and the struct ref, it resolves normally. + + This test verifies that when a table literally named 'after' exists, + no struct fallback is triggered (it resolves as a normal table ref). + """ + builder = RecursiveLineageBuilder(self.SQL, dialect="bigquery") + graph = builder.build() + + # 'after' is a real table name here, so it resolves normally + edges = _edges_dict(graph) + assert ("after.id", "output.id") in edges, ( + f"Expected after.id -> output.id edge. Available: {list(edges.keys())}" + ) + + # Should NOT produce struct edges since 'after' is a real table + struct_edges = _struct_edges(graph) + assert len(struct_edges) == 0, ( + f"Should not produce struct edges when 'after' is a real table, " + f"got {len(struct_edges)} struct edges" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) From 622e651b0149ddcc498bb11032bdb0e2e7e61e2f Mon Sep 17 00:00:00 2001 From: Ming Jer Lee Date: Tue, 14 Apr 2026 17:20:05 -0400 Subject: [PATCH 3/3] feat: promote qualify metadata from subquery-based dedup WHERE pattern (Gap 2) Detect the common dedup pattern (ROW_NUMBER() OVER (...) AS rn in subquery + WHERE rn = 1 in outer query) and promote it to qualify_info on the outer unit. Supports EQ, LTE, LT comparisons against ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE). Adds ranking_window_columns to QueryUnit model for cross-unit metadata propagation. --- src/clgraph/models.py | 5 + src/clgraph/query_parser.py | 62 +++++++++ tests/test_subquery_dedup_qualify.py | 189 +++++++++++++++++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 tests/test_subquery_dedup_qualify.py diff --git a/src/clgraph/models.py b/src/clgraph/models.py index ae4eada..2cd9e14 100644 --- a/src/clgraph/models.py +++ b/src/clgraph/models.py @@ -297,6 +297,11 @@ class QueryUnit: # Stores info about WHERE clause columns for filter lineage edges where_predicates: List["WherePredicateInfo"] = field(default_factory=list) + # Ranking window columns for dedup qualify promotion (Gap 2) + # Maps alias -> {function, partition_by, order_by} for ranking functions + # Example: {'rn': {'function': 'ROW_NUMBER', 'partition_by': ['id'], 'order_by': [...]}} + ranking_window_columns: Dict[str, Dict[str, Any]] = field(default_factory=dict) + # Metadata depth: int = 0 # Nesting depth (0 = main query) order: int = 0 # Topological order for CTEs diff --git a/src/clgraph/query_parser.py b/src/clgraph/query_parser.py index 7933847..3c1aae2 100644 --- a/src/clgraph/query_parser.py +++ b/src/clgraph/query_parser.py @@ -211,6 +211,9 @@ def _parse_select_unit( ) ) + # 4c. Promote dedup qualify info from WHERE (Gap 2) + self._promote_dedup_qualify_if_applicable(select_node, unit) + # 5. Parse HAVING clause (may contain subqueries) having_clause = select_node.args.get("having") if having_clause: @@ -1687,6 +1690,53 @@ def _parse_qualify_clause(self, qualify_node: exp.Qualify, unit: QueryUnit): "window_functions": window_functions, } + def _promote_dedup_qualify_if_applicable(self, select_node: exp.Select, unit: QueryUnit): + """ + Promote dedup qualify info from a subquery-based WHERE pattern (Gap 2). + + Detects the common dedup pattern: + SELECT ... FROM (SELECT *, ROW_NUMBER() OVER (...) AS rn FROM t) WHERE rn = 1 + and promotes it to qualify_info on the outer unit. + + Only ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE) are eligible. + Comparison operators =, <=, < against a literal are recognized. + + Args: + select_node: The SELECT expression + unit: The query unit to potentially add qualify_info to + """ + where_clause = select_node.args.get("where") + if not where_clause or unit.qualify_info: + return + + for dep_unit_id in unit.depends_on_units: + dep_unit = self.unit_graph.units.get(dep_unit_id) + if not dep_unit or not dep_unit.ranking_window_columns: + continue + + for node in where_clause.walk(): + if isinstance(node, (exp.EQ, exp.LTE, exp.LT)): + left, right = node.left, node.right + col_name = None + if isinstance(left, exp.Column) and isinstance(right, exp.Literal): + col_name = left.name + elif isinstance(right, exp.Column) and isinstance(left, exp.Literal): + col_name = right.name + + if col_name and col_name in dep_unit.ranking_window_columns: + window_meta = dep_unit.ranking_window_columns[col_name] + unit.qualify_info = { + "condition": where_clause.this.sql(), + "partition_columns": list(window_meta["partition_by"]), + "order_columns": [ + c["column"] if isinstance(c, dict) else c + for c in window_meta["order_by"] + ], + "window_functions": [window_meta["function"]], + "promoted_from_subquery": True, + } + return + def _parse_grouping_sets(self, group_clause: exp.Group, unit: QueryUnit): """ Parse GROUP BY clause for GROUPING SETS, CUBE, and ROLLUP constructs. @@ -1842,6 +1892,18 @@ def _parse_window_functions(self, select_node: exp.Select, unit: QueryUnit): if windows: unit.window_info = {"windows": windows} + # Populate ranking_window_columns for dedup qualify promotion (Gap 2) + RANKING_FUNCTIONS = {"ROW_NUMBER", "RANK", "DENSE_RANK", "NTILE"} + for window_def in windows: + func_name = window_def.get("function", "").upper() + output_col = window_def.get("output_column") + if func_name in RANKING_FUNCTIONS and output_col: + unit.ranking_window_columns[output_col] = { + "function": func_name, + "partition_by": window_def.get("partition_by", []), + "order_by": window_def.get("order_by", []), + } + def _parse_single_window( self, window: exp.Window, diff --git a/tests/test_subquery_dedup_qualify.py b/tests/test_subquery_dedup_qualify.py new file mode 100644 index 0000000..95a1ff2 --- /dev/null +++ b/tests/test_subquery_dedup_qualify.py @@ -0,0 +1,189 @@ +""" +Test suite for Gap 2: Subquery-Based Dedup Qualify Promotion. + +Tests that the common dedup pattern: + SELECT ... FROM (SELECT *, ROW_NUMBER() OVER (...) AS rn FROM t) WHERE rn = 1 +is promoted to qualify metadata on the outer query unit. +""" + +import pytest + +from clgraph.query_parser import RecursiveQueryParser + +# ============================================================================ +# Test Group 1: Qualify Promotion from Subquery Dedup Pattern +# ============================================================================ + + +class TestDedupQualifyPromotion: + """Test promotion of subquery-based dedup WHERE to qualify_info.""" + + def test_qualify_promotion_eq(self): + """WHERE rn = 1 with ROW_NUMBER promotes qualify_info on outer unit.""" + sql = """ + SELECT id, name + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY ts DESC) AS rn + FROM t + ) + WHERE rn = 1 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + assert main_unit.qualify_info is not None + assert main_unit.qualify_info["promoted_from_subquery"] is True + assert "ROW_NUMBER" in main_unit.qualify_info["window_functions"] + assert "id" in main_unit.qualify_info["partition_columns"] + + def test_qualify_promotion_lte(self): + """WHERE rn <= 3 with ROW_NUMBER promotes qualify_info.""" + sql = """ + SELECT id, name + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY ts DESC) AS rn + FROM t + ) + WHERE rn <= 3 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + assert main_unit.qualify_info is not None + assert main_unit.qualify_info["promoted_from_subquery"] is True + assert "ROW_NUMBER" in main_unit.qualify_info["window_functions"] + + def test_qualify_promotion_lt(self): + """WHERE rn < 3 with ROW_NUMBER promotes qualify_info.""" + sql = """ + SELECT id, name + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY ts DESC) AS rn + FROM t + ) + WHERE rn < 3 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + assert main_unit.qualify_info is not None + assert main_unit.qualify_info["promoted_from_subquery"] is True + assert "ROW_NUMBER" in main_unit.qualify_info["window_functions"] + + +# ============================================================================ +# Test Group 2: Non-Ranking Functions Should NOT Promote +# ============================================================================ + + +class TestNonRankingNotPromoted: + """Test that non-ranking window functions are not promoted.""" + + def test_sum_window_not_promoted(self): + """SUM() OVER (...) + WHERE total > 100 should NOT produce qualify_info.""" + sql = """ + SELECT id, total + FROM ( + SELECT id, SUM(amount) OVER (PARTITION BY id) AS total + FROM t + ) + WHERE total > 100 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + assert main_unit.qualify_info is None + + +# ============================================================================ +# Test Group 3: Explicit QUALIFY Not Overwritten +# ============================================================================ + + +class TestExplicitQualifyNotOverwritten: + """Test that explicit QUALIFY clause is not overwritten by promotion.""" + + def test_explicit_qualify_preserved(self): + """Explicit QUALIFY should remain; promotion should not overwrite.""" + sql = """ + SELECT customer_id, order_date + FROM orders + QUALIFY ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY order_date DESC) = 1 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + assert main_unit.qualify_info is not None + # Explicit QUALIFY should NOT have promoted_from_subquery + assert main_unit.qualify_info.get("promoted_from_subquery") is not True + + +# ============================================================================ +# Test Group 4: rn Not in Output Columns +# ============================================================================ + + +class TestRnNotInOutput: + """Test that the ranking alias (rn) is not in the outer unit output columns.""" + + def test_rn_not_in_output(self): + """Outer SELECT id, name should not include rn in output columns.""" + sql = """ + SELECT id, name + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY ts DESC) AS rn + FROM t + ) + WHERE rn = 1 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + main_unit = graph.units["main"] + output_col_names = [c.get("name", "") for c in main_unit.output_columns] + assert "rn" not in output_col_names + + +# ============================================================================ +# Test Group 5: ranking_window_columns Populated on Inner Unit +# ============================================================================ + + +class TestRankingWindowColumns: + """Test that inner subquery unit has ranking_window_columns metadata.""" + + def test_ranking_window_columns_populated(self): + """Inner unit should have ranking_window_columns with correct metadata.""" + sql = """ + SELECT id, name + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY ts DESC) AS rn + FROM t + ) + WHERE rn = 1 + """ + parser = RecursiveQueryParser(sql, dialect="bigquery") + graph = parser.parse() + + # Find the inner subquery unit (not 'main') + inner_units = [u for uid, u in graph.units.items() if uid != "main"] + assert len(inner_units) >= 1 + + # At least one inner unit should have ranking_window_columns + inner_with_ranking = [u for u in inner_units if u.ranking_window_columns] + assert len(inner_with_ranking) >= 1 + + inner_unit = inner_with_ranking[0] + assert "rn" in inner_unit.ranking_window_columns + rn_meta = inner_unit.ranking_window_columns["rn"] + assert rn_meta["function"] == "ROW_NUMBER" + assert "id" in rn_meta["partition_by"] + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])