diff --git a/gigl/analytics/README.md b/gigl/analytics/README.md new file mode 100644 index 000000000..f4fb7062f --- /dev/null +++ b/gigl/analytics/README.md @@ -0,0 +1,225 @@ +# GiGL Analytics + +Pre-training graph data validation and analysis tooling. Use this module before committing to a GNN training run to +catch data quality and structural issues that silently degrade model quality. + +Two subpackages: + +- [`data_analyzer/`](data_analyzer/) — end-to-end `DataAnalyzer` that runs BigQuery checks and produces a single + self-contained HTML report. **Start here.** +- [`graph_validation/`](graph_validation/) — lightweight standalone validators (currently: `BQGraphValidator` for + dangling-edge checks). Use when you only need one check and not the full report. + +## Quickstart + +**Prerequisites.** Follow the [GiGL installation guide](../../docs/user_guide/getting_started/installation.md) so that +`uv` and GiGL's Python dependencies are available. Then authenticate to BigQuery: + +```bash +gcloud auth application-default login +``` + +**1. Write a YAML config.** Save as `my_analyzer_config.yaml`: + +```yaml +node_tables: + - bq_table: "your-project.your_dataset.user_nodes" + node_type: "user" + id_column: "user_id" + feature_columns: ["age", "country"] # optional; omit to auto-infer all non-ID, TFDV-compatible columns from the BQ schema + # label_column: "label" # optional; enables Tier 3 label checks + +edge_tables: + - bq_table: "your-project.your_dataset.user_edges" + edge_type: "follows" + src_id_column: "src_user_id" + dst_id_column: "dst_user_id" + +# Where to write the HTML report. Local path for quick iteration, or a gs:// URI. +output_gcs_path: "/tmp/my_analysis/" + +# Optional: sizing for the neighbor-explosion estimate (fan-out per GNN layer). +fan_out: [15, 10, 5] +``` + +**2. Run the analyzer.** + +```bash +uv run python -m gigl.analytics.data_analyzer \ + --analyzer_config_uri my_analyzer_config.yaml +``` + +**3. Open the report.** When the run completes: + +``` +[INFO] Report written to /tmp/my_analysis/report.html +``` + +Open the file in any browser. No server, no external dependencies, fully offline. + +## What it checks + +The analyzer organizes checks into four tiers. Tiers 1 and 2 always run; Tier 3 auto-enables when your config supports +it; Tier 4 is opt-in. + +| Tier | When | What it checks | +| ---------------------------- | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **1. Hard fails** | Always | Dangling edges (NULL src/dst), referential integrity (edges pointing to nodes not in the node table), duplicate nodes. Raises `DataQualityError` — the report still renders to show partial results. | +| **2. Core metrics** | Always | Node/edge counts, degree distribution (in/out) with percentiles, degree buckets, top-K hubs, super-hub int16 clamp count, cold-start node count, self-loops, duplicate edges, NULL rates per column, feature memory budget estimate, neighbor-explosion estimate (requires `fan_out`). | +| **3. Label + heterogeneous** | Auto when `label_column` is set on any node table, or when multiple edge types exist | Class imbalance, label coverage, edge type distribution, per-edge-type node coverage. | +| **4. Advanced** | Opt-in via config flags | Power-law exponent (implemented as a degree-stats approximation). Reciprocity, homophily, connected components, clustering coefficient are **not yet implemented** — the flags are accepted but currently no-op. | + +The thresholds below come from a review of production GNN papers (PinSage, BLADE, LiGNN, TwHIN, AliGraph, GraphSMOTE, +Beyond Homophily, Feature Propagation, and others). See the inline citations in the threshold table for what each paper +contributes. + +## Feature profiling + +In addition to the structural checks above, the analyzer runs +[TensorFlow Data Validation](https://www.tensorflow.org/tfx/guide/tfdv) on every node and edge table and embeds the +resulting Facets HTML report in the final output. + +- **Auto-inference.** By default, the profiler reads the BQ table schema and profiles every non-ID column whose type is + TFDV-compatible — scalars `STRING`, `INT64`, `FLOAT64`, `NUMERIC`, `BIGNUMERIC`, `BOOL`. Temporal types (`DATE`, + `DATETIME`, `TIMESTAMP`, `TIME`) and complex types (`RECORD`, `GEOGRAPHY`, `JSON`, `BYTES`) are not supported by TFDV + and are skipped with an info log. +- **Embedding columns.** `REPEATED` `FLOAT64` / `FLOAT` / `NUMERIC` / `BIGNUMERIC` columns are treated as embedding + vectors. Each expands in the Beam `SELECT` into four scalar hygiene companions — `_len`, `_has_nan`, + `_has_inf`, `_is_all_zero` — which are profiled by TFDV like any other scalar. Other REPEATED types + (`STRING` / `INT64` arrays, etc.) are skipped. +- **Embedding diagnostics.** After the TFDV pipelines finish, one BigQuery aggregate per embedding column computes + `total`, `unique_count`, `unique_ratio`, and top-K most-frequent hash clusters (via + `FARM_FINGERPRINT(TO_JSON_STRING())`). Results land in `FeatureProfileResult.embedding_diagnostics` and render as + a dedicated "Embedding Diagnostics" section in the report. +- **Explicit override.** Setting `feature_columns` in the YAML narrows the projection to those columns (still honoring + embedding expansion for REPEATED FLOAT families). Use this to scope down to a handful of columns, or to exclude PII / + expensive fields. +- **Join keys are excluded.** `id_column` on nodes and `src_id_column` / `dst_id_column` on edges are always dropped + from the auto-inferred list. `label_column` and `timestamp_column` are kept (profiling class balance / temporal + sparsity is useful). +- **Cost.** One Dataflow job is launched per table, so a config with many tables translates to many concurrent Dataflow + runs. During iteration, pass `--only structure` to skip the profiler entirely. Run `--only feature` (or the default + `both`) once the config is stable. + +## Machine-readable outputs + +Alongside `report.html`, each analyzer run writes versioned Pydantic JSON sidecars under `output_gcs_path/`: + +- `graph_structure.json` — the `GraphAnalysisResult` payload from `GraphStructureAnalyzer`. Written on success and also + on a Tier 1 `DataQualityError` (partial result) so failures are still recoverable. +- `feature_profile.json` — the `FeatureProfileResult` payload (facets URIs, TFDV stats URIs, embedding diagnostics). + +Each sidecar wraps its payload in an envelope: `{schema_version, component, generated_at, data}`. Load one with +`gigl.analytics.data_analyzer.types.load_artifact(path, expected_component=...)`. Schema changes are additive-only at +`schema_version="1"`; breaking changes bump the version. + +## Interpreting the report + +The report color-codes every numeric finding. Summary of the most important thresholds: + +| Metric | Green | Yellow | Red | What to do when yellow/red | +| -------------------------------------------------------- | ----- | ---------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Dangling edges / referential integrity / duplicate nodes | 0 | — | any > 0 | Fix the input tables. Training will fail or silently corrupt otherwise. | +| Feature missing rate | < 10% | 10–50% | > 90% | Plan an imputation strategy; above ~95% the Feature Propagation phase transition (Rossi et al., ICLR 2022) hits and GNNs stop recovering signal reliably. | +| Isolated node fraction | < 1% | 1–5% | > 5% | Filter isolated nodes or densify (LiGNN, KDD 2024) for cold-start cohorts. | +| Cold-start fraction (degree 0–1) | < 5% | 5–10% | > 10% | Candidates for graph densification; also flag for special handling at serving time. | +| Super-hub int16 clamp (degree > 32,767) | 0 | — | any > 0 | GiGL silently truncates super-hub degrees in `gigl/distributed/utils/degree.py`. Either cap the hub's edges upstream or plan to address the clamp. | +| Degree p99 / median | < 50 | 50–100 | > 100 | Use importance sampling (PinSage, KDD 2018) or degree-adaptive neighborhoods (BLADE, WSDM 2023) — degree skew is the single biggest lever in production GNNs. | +| Class imbalance ratio | < 1:5 | 1:5 – 1:10 | > 1:10 | Message passing amplifies label imbalance 2–3× in representation space (GraphSMOTE, WSDM 2021). Consider resampling or GraphSMOTE-style synthetic nodes. | +| Edge homophily (Tier 4, future) | > 0.7 | 0.3 – 0.7 | < 0.3 | Standard GCN/GAT fail at low h (Zhu et al., NeurIPS 2020). Consider H2GCN-style architectures; below h ≈ 0.2 a plain MLP often wins. | + +## Advanced config + +Optional YAML keys beyond the minimal quickstart: + +```yaml +# Enable Tier 3 class-imbalance + label-coverage checks for a node type: +node_tables: + - bq_table: ... + label_column: "label" + +# Neighbor explosion estimation — the fan-out per GNN layer you plan to train with: +fan_out: [15, 10, 5] + +# Tier 4 opt-in flags. Default false. +# NOTE: Only `compute_reciprocity` is wired into the analyzer today and it logs a +# warning rather than computing a result. The other three flags are placeholders +# for future work (see "Scope and limitations" below). +compute_reciprocity: true +compute_homophily: true +compute_connected_components: true +compute_clustering: true + +# Per-edge-type timestamp hint. NOTE: accepted by the config schema but not yet +# consumed by any Tier 4 query (temporal freshness check is planned). +edge_tables: + - bq_table: ... + timestamp_column: "created_at" +``` + +## Python API + +The CLI wraps a regular class. Call from your own code when you want programmatic access to the `GraphAnalysisResult`: + +```python +from gigl.analytics.data_analyzer import DataAnalyzer +from gigl.analytics.data_analyzer.config import load_analyzer_config + +config = load_analyzer_config("my_analyzer_config.yaml") +analyzer = DataAnalyzer() +report_path = analyzer.run(config=config) +# report_path points to the written report.html (local path or gs:// URI) +``` + +The underlying `GraphStructureAnalyzer` is also callable directly if you want the raw result dataclass and no HTML: + +```python +from gigl.analytics.data_analyzer.graph_structure_analyzer import GraphStructureAnalyzer + +result = GraphStructureAnalyzer().analyze(config) +print(result.degree_stats) +``` + +See a rendered report example at +[`tests/test_assets/analytics/golden_report.html`](../../tests/test_assets/analytics/golden_report.html) to preview the +output format before authenticating to BQ. + +## graph_validation + +One-off validators for the subset of cases where the full analyzer is overkill. Today the only check is dangling-edge +detection: + +```python +from gigl.analytics.graph_validation import BQGraphValidator + +has_dangling = BQGraphValidator.does_edge_table_have_dangling_edges( + edge_table="your-project.your_dataset.user_edges", + src_node_column_name="src_user_id", + dst_node_column_name="dst_user_id", +) +``` + +The `DataAnalyzer` runs this check (and many more) as part of Tier 1, so prefer the full analyzer unless you +specifically need a one-line gate (e.g., inside an Airflow task or a preprocessing job). This subpackage is the intended +home for additional standalone validators in the future. + +## Scope and limitations + +Current implementation status: + +- **Tier 4 checks are partial.** Power-law exponent is computed as a degree-stats approximation. Reciprocity, homophily, + connected components, and clustering coefficient config flags are accepted but currently no-op. The `timestamp_column` + edge field is accepted but no temporal-freshness query runs yet. +- **Heterogeneous graphs: referential integrity caveat.** For each edge table, the referential-integrity check joins + against `config.node_tables[0]`. On heterogeneous graphs where different edges reference different node types, the + current implementation will under-report integrity violations — fix is tracked for a follow-up. +- **GCS upload** works via `GcsUtils.upload_from_string` when `output_gcs_path` is a `gs://` URI, and falls back to + local filesystem write otherwise. + +## Related documents + +Within this module: + +- [`data_analyzer/report/PRD.md`](data_analyzer/report/PRD.md) — product intent for the HTML report (AI-owned) +- [`data_analyzer/report/SPEC.md`](data_analyzer/report/SPEC.md) — technical contract for the AI-owned HTML/JS/CSS + assets diff --git a/gigl/analytics/data_analyzer/__init__.py b/gigl/analytics/data_analyzer/__init__.py new file mode 100644 index 000000000..45304dacc --- /dev/null +++ b/gigl/analytics/data_analyzer/__init__.py @@ -0,0 +1,10 @@ +""" +BQ Data Analyzer for pre-training graph data analysis. + +Produces a single HTML report covering data quality, feature distributions, +and graph structure metrics from BigQuery node/edge tables. +""" + +from gigl.analytics.data_analyzer.data_analyzer import DataAnalyzer + +__all__ = ["DataAnalyzer"] diff --git a/gigl/analytics/data_analyzer/__main__.py b/gigl/analytics/data_analyzer/__main__.py new file mode 100644 index 000000000..693551d33 --- /dev/null +++ b/gigl/analytics/data_analyzer/__main__.py @@ -0,0 +1,6 @@ +"""Entry point for running the BQ Data Analyzer as a module: python -m gigl.analytics.data_analyzer.""" + +from gigl.analytics.data_analyzer.data_analyzer import main + +if __name__ == "__main__": + main() diff --git a/gigl/analytics/data_analyzer/config.py b/gigl/analytics/data_analyzer/config.py new file mode 100644 index 000000000..bb3fdcdc7 --- /dev/null +++ b/gigl/analytics/data_analyzer/config.py @@ -0,0 +1,283 @@ +import re +from dataclasses import dataclass, field +from typing import Optional + +from omegaconf import MISSING, OmegaConf + +from gigl.common.logger import Logger + +logger = Logger() + +# BigQuery identifier regexes used to reject configs that would be interpolated +# directly into SQL. See https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical +# for the allowed grammar. Tables are of the form project.dataset.table; +# columns are simple unquoted identifiers. +_BQ_TABLE_REGEX = re.compile(r"^[A-Za-z0-9_.\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_$\-]+$") +_BQ_COLUMN_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _validate_bq_table(name: str, field_label: str) -> None: + if not _BQ_TABLE_REGEX.fullmatch(name): + raise ValueError( + f"{field_label}={name!r} is not a valid BigQuery table reference. " + f"Expected project.dataset.table with no backticks, whitespace, or quotes." + ) + + +def _validate_bq_column(name: str, field_label: str) -> None: + if not _BQ_COLUMN_REGEX.fullmatch(name): + raise ValueError( + f"{field_label}={name!r} is not a valid BigQuery column identifier. " + f"Expected [A-Za-z_][A-Za-z0-9_]* with no backticks, whitespace, or quotes." + ) + + +@dataclass +class NodeTableSpec: + """Specification for a node table in BigQuery. + + Node-classification supervision is activated when ``label_column`` is + set. ``label_sentinel_values`` lets users distinguish "missing" labels + encoded as ``-1`` / ``"unknown"`` from SQL NULL — both are excluded + from the valid-label denominator used by class-imbalance and + homophily computations, but are reported separately so the upstream + bug can be traced. ``split_column`` enables split-validation checks + (cross-split node-id leakage as a Tier 1 hard fail, plus per-split + TFDV slicing for distribution drift). + """ + + bq_table: str = MISSING + node_type: str = MISSING + id_column: str = MISSING + feature_columns: list[str] = field(default_factory=list) + label_column: Optional[str] = None + label_sentinel_values: list[str] = field(default_factory=list) + split_column: Optional[str] = None + + +EDGE_ROLE_MESSAGE_PASSING = "message_passing" +EDGE_ROLE_SUPERVISION_POS = "supervision_pos" +EDGE_ROLE_SUPERVISION_NEG = "supervision_neg" +_VALID_EDGE_ROLES = frozenset( + {EDGE_ROLE_MESSAGE_PASSING, EDGE_ROLE_SUPERVISION_POS, EDGE_ROLE_SUPERVISION_NEG} +) + + +@dataclass +class EdgeTableSpec: + """Specification for an edge table in BigQuery. + + For heterogeneous graphs (more than one node table), src_node_type and + dst_node_type must be set to the node_type of the matching node table. + For homogeneous graphs (single node table) they default to that node_type. + + ``role`` marks the table's purpose for cross-table supervision analysis. + Defaults to ``"message_passing"`` when omitted. ``node_anchor`` selects + which side (src or dst) of the table is the anchor for the per-anchor + cross-table stats; required on ``supervision_pos`` tables, ignored when + no analysis applies. + """ + + bq_table: str = MISSING + edge_type: str = MISSING + src_id_column: str = MISSING + dst_id_column: str = MISSING + src_node_type: Optional[str] = None + dst_node_type: Optional[str] = None + feature_columns: list[str] = field(default_factory=list) + timestamp_column: Optional[str] = None + role: Optional[str] = None + node_anchor: Optional[str] = None + + +@dataclass +class DataAnalyzerConfig: + """Configuration for the BQ Data Analyzer. + + Parsed from YAML via OmegaConf. + + Example: + >>> config = load_analyzer_config("gs://bucket/config.yaml") + >>> config.node_tables[0].bq_table + 'project.dataset.user_nodes' + """ + + node_tables: list[NodeTableSpec] = MISSING + edge_tables: list[EdgeTableSpec] = MISSING + output_gcs_path: str = MISSING + fan_out: Optional[list[int]] = None + compute_reciprocity: bool = False + compute_homophily: bool = False + compute_connected_components: bool = False + compute_clustering: bool = False + + # Node-classification supervision tier flags. Activate any time a + # NodeTableSpec.label_column is set. + # + # ``compute_per_class_feature_stats`` controls TFDV slicing on the + # label column inside the feature profiler — default on because it's + # the highest-value NC-specific feature signal and costs one extra + # column on the existing BQ projection. + # + # ``compute_label_informativeness`` is the expensive (full-graph + # mutual-information) homophily measure from Platonov et al. 2023. + # Default off; the cheaper sampled adjusted-homophily always runs. + # + # ``label_homophily_edge_sample_cap`` caps the message-passing edge + # sample used to compute adjusted homophily. ``0`` means full-graph. + compute_per_class_feature_stats: bool = True + compute_label_informativeness: bool = False + label_homophily_edge_sample_cap: int = 50_000_000 + + # Per-chunk feature cap for TFDV profiling. Wide projections explode + # Beam 2.56's CombinePerKey state and trip + # "Instruction id ... was not registered" failures on Runner v2; + # chunking keeps every Dataflow job within the runner's + # state-iteration budget. 350 was validated end-to-end on a 706-col / + # ~950 M-row user table. + max_features_per_chunk: int = 350 + + # Per-config Dataflow job name prefix. Combined with a per-run + # timestamp at the entry point to keep concurrent / repeated runs + # from colliding on the fixed Dataflow job name. The CLI flag + # ``--job_name_prefix`` overrides this when set; if neither is set + # the entry point fails fast. + job_name_prefix: Optional[str] = None + + +def _validate_and_backfill(config: DataAnalyzerConfig) -> None: + """Run identifier validation and backfill default node-type references. + + - Every bq_table must match project.dataset.table. + - Every id_column / src_id_column / dst_id_column / feature_column / + label_column / timestamp_column must be a bare BQ identifier. + - For homogeneous configs, an edge table with no src_node_type / + dst_node_type inherits the single node table's node_type. + - For heterogeneous configs, every edge table must explicitly declare + src_node_type and dst_node_type, and both must resolve to a known + node_type. + """ + known_node_types = {nt.node_type for nt in config.node_tables} + single_node_type: Optional[str] = ( + next(iter(known_node_types)) if len(config.node_tables) == 1 else None + ) + + for node_table in config.node_tables: + _validate_bq_table(node_table.bq_table, "node_tables.bq_table") + _validate_bq_column(node_table.id_column, "node_tables.id_column") + for col in node_table.feature_columns: + _validate_bq_column(col, "node_tables.feature_columns") + if node_table.label_column is not None: + _validate_bq_column(node_table.label_column, "node_tables.label_column") + if node_table.split_column is not None: + _validate_bq_column(node_table.split_column, "node_tables.split_column") + # Sentinel values are not SQL identifiers (they're literal label + # values), but they're still embedded into SQL via parameterized + # IN clauses elsewhere. Reject empty strings to fail fast on + # likely-misconfigured YAML where a value got stripped. + for sentinel in node_table.label_sentinel_values: + if sentinel == "": + raise ValueError( + f"node_tables.label_sentinel_values contains an empty string " + f"for node_type={node_table.node_type!r}; declare each " + "sentinel value explicitly (e.g. '-1', 'unknown')." + ) + if node_table.label_sentinel_values and node_table.label_column is None: + raise ValueError( + f"node_type={node_table.node_type!r}: label_sentinel_values " + "are declared but label_column is not set; sentinels apply " + "to the label_column only." + ) + + for edge_table in config.edge_tables: + _validate_bq_table(edge_table.bq_table, "edge_tables.bq_table") + _validate_bq_column(edge_table.src_id_column, "edge_tables.src_id_column") + _validate_bq_column(edge_table.dst_id_column, "edge_tables.dst_id_column") + for col in edge_table.feature_columns: + _validate_bq_column(col, "edge_tables.feature_columns") + if edge_table.timestamp_column is not None: + _validate_bq_column( + edge_table.timestamp_column, "edge_tables.timestamp_column" + ) + + if edge_table.src_node_type is None: + if single_node_type is not None: + edge_table.src_node_type = single_node_type + else: + raise ValueError( + f"edge_type={edge_table.edge_type}: src_node_type is required " + f"when there are multiple node tables" + ) + if edge_table.dst_node_type is None: + if single_node_type is not None: + edge_table.dst_node_type = single_node_type + else: + raise ValueError( + f"edge_type={edge_table.edge_type}: dst_node_type is required " + f"when there are multiple node tables" + ) + if edge_table.src_node_type not in known_node_types: + raise ValueError( + f"edge_type={edge_table.edge_type}: src_node_type=" + f"{edge_table.src_node_type!r} is not a declared node_type. " + f"Known: {sorted(known_node_types)}" + ) + if edge_table.dst_node_type not in known_node_types: + raise ValueError( + f"edge_type={edge_table.edge_type}: dst_node_type=" + f"{edge_table.dst_node_type!r} is not a declared node_type. " + f"Known: {sorted(known_node_types)}" + ) + + if edge_table.role is None: + edge_table.role = EDGE_ROLE_MESSAGE_PASSING + elif edge_table.role not in _VALID_EDGE_ROLES: + raise ValueError( + f"edge_type={edge_table.edge_type}: role={edge_table.role!r} " + f"is not valid. Expected one of {sorted(_VALID_EDGE_ROLES)}." + ) + + if edge_table.node_anchor is not None: + if edge_table.node_anchor not in ( + edge_table.src_node_type, + edge_table.dst_node_type, + ): + raise ValueError( + f"edge_type={edge_table.edge_type}: node_anchor=" + f"{edge_table.node_anchor!r} must equal src_node_type=" + f"{edge_table.src_node_type!r} or dst_node_type=" + f"{edge_table.dst_node_type!r}." + ) + elif edge_table.role == EDGE_ROLE_SUPERVISION_POS: + raise ValueError( + f"edge_type={edge_table.edge_type}: node_anchor is required " + f"when role={EDGE_ROLE_SUPERVISION_POS!r}." + ) + + +def load_analyzer_config(config_path: str) -> DataAnalyzerConfig: + """Load and validate a DataAnalyzerConfig from a YAML file. + + Args: + config_path: Local file path or GCS URI to the YAML config. + + Returns: + Validated DataAnalyzerConfig instance with node-type references + backfilled on edge tables. + + Raises: + omegaconf.errors.MissingMandatoryValue: If required fields are missing. + ValueError: If any bq_table or column name is not a valid BigQuery + identifier, or if a heterogeneous config is missing a required + src_node_type / dst_node_type. + """ + logger.info(f"Loading analyzer config from {config_path}") + raw = OmegaConf.load(config_path) + merged = OmegaConf.merge(OmegaConf.structured(DataAnalyzerConfig), raw) + config: DataAnalyzerConfig = OmegaConf.to_object(merged) # type: ignore + _validate_and_backfill(config) + logger.info( + f"Loaded analyzer config with {len(config.node_tables)} node tables " + f"and {len(config.edge_tables)} edge tables" + ) + return config diff --git a/gigl/analytics/data_analyzer/data_analyzer.py b/gigl/analytics/data_analyzer/data_analyzer.py new file mode 100644 index 000000000..0e0187c9f --- /dev/null +++ b/gigl/analytics/data_analyzer/data_analyzer.py @@ -0,0 +1,260 @@ +"""Main orchestrator and CLI entry point for the BQ Data Analyzer.""" +import argparse +import re +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Literal, Optional + +from gigl.analytics.data_analyzer.config import DataAnalyzerConfig, load_analyzer_config +from gigl.analytics.data_analyzer.feature_profiler import FeatureProfiler +from gigl.analytics.data_analyzer.graph_structure_analyzer import ( + DataQualityError, + GraphStructureAnalyzer, +) +from gigl.analytics.data_analyzer.report.report_generator import generate_report +from gigl.analytics.data_analyzer.types import FeatureProfileResult, GraphAnalysisResult +from gigl.common import GcsUri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils +from gigl.env.pipelines_config import GiglResourceConfigWrapper, get_resource_config +from gigl.src.common.utils.time import current_formatted_datetime + +logger = Logger() + +# Lowercase, hyphen-safe, ≤20 chars. Composes cleanly with +# ``get_sanitized_dataflow_job_name`` and keeps the final Dataflow job +# name (``gigl-analyzer-{prefix}-{ts}-profile-{kind}-{type}``) inside +# Dataflow's ~63-char budget for typical type_name lengths. +_JOB_NAME_PREFIX_REGEX = re.compile(r"^[a-z][a-z0-9-]{0,19}$") +_RUN_TIMESTAMP_FORMAT = "%Y%m%d-%H%M" + + +def _write_report(html: str, output_gcs_path: str) -> str: + """Write the HTML report to a GCS URI or local path. + + Args: + html: Rendered HTML string. + output_gcs_path: Output directory. If it starts with ``gs://`` the + report is uploaded via ``GcsUtils``. Otherwise it is written to + the local filesystem (the directory is created if missing). + + Returns: + The full path to the written ``report.html`` file. + """ + trimmed = output_gcs_path.rstrip("/") + report_path = f"{trimmed}/report.html" + if trimmed.startswith("gs://"): + GcsUtils().upload_from_string(GcsUri(report_path), html) + else: + local_path = Path(report_path).expanduser().resolve() + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text(html) + report_path = str(local_path) + return report_path + + +class DataAnalyzer: + """Orchestrates graph structure analysis, feature profiling, and report generation. + + Example: + >>> from gigl.analytics.data_analyzer.config import load_analyzer_config + >>> config = load_analyzer_config("gs://bucket/config.yaml") + >>> analyzer = DataAnalyzer() + >>> report_path = analyzer.run(config=config) + """ + + def run( + self, + config: DataAnalyzerConfig, + resource_config: GiglResourceConfigWrapper, + job_name_prefix: str, + run_timestamp: str, + components: Literal["structure", "feature", "both"] = "both", + custom_worker_image_uri: Optional[str] = None, + ) -> str: + """Run the analysis pipeline and write an HTML report. + + The report is written to ``{config.output_gcs_path}/report.html`` via + ``GcsUtils`` when the output path is a ``gs://`` URI, or to the local + filesystem otherwise (the parent directory is created if missing). + + Args: + config: Analyzer configuration. + resource_config: Resource config for Dataflow sizing. + job_name_prefix: Prefix mixed into every per-table Dataflow job + name (resolved at the entry point from CLI flag or YAML). + run_timestamp: Per-run timestamp shared by every per-table job in + this invocation (computed once at the entry point). + components: Which components to run. ``"both"`` (default) runs the + structure analyzer and feature profiler concurrently. + ``"structure"`` runs only the graph structure analyzer. + ``"feature"`` runs only the feature profiler. The skipped + component is represented in the report by an empty result. + custom_worker_image_uri: Optional Docker image URI for the Dataflow + worker harness used by the feature profiler. When ``None``, the + profiler falls back to ``DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU``. + + Returns: + The path to the written ``report.html`` (GCS URI or local path). + """ + analysis_result: GraphAnalysisResult + profile_result: FeatureProfileResult + + if components == "both": + with ThreadPoolExecutor(max_workers=2) as executor: + structure_future = executor.submit( + GraphStructureAnalyzer().analyze, config + ) + profile_future = executor.submit( + FeatureProfiler().profile, + config, + resource_config, + job_name_prefix, + run_timestamp, + custom_worker_image_uri, + ) + + try: + analysis_result = structure_future.result() + except DataQualityError as e: + logger.error(f"Tier 1 data quality failure: {e}") + analysis_result = e.partial_result + + try: + profile_result = profile_future.result() + except Exception as e: + logger.exception(f"Feature profiler failed: {e}") + profile_result = FeatureProfileResult() + elif components == "structure": + try: + analysis_result = GraphStructureAnalyzer().analyze(config) + except DataQualityError as e: + logger.error(f"Tier 1 data quality failure: {e}") + analysis_result = e.partial_result + profile_result = FeatureProfileResult() + elif components == "feature": + analysis_result = GraphAnalysisResult() + profile_result = FeatureProfiler().profile( + config, + resource_config, + job_name_prefix=job_name_prefix, + run_timestamp=run_timestamp, + custom_worker_image_uri=custom_worker_image_uri, + ) + else: + raise ValueError( + f"components={components!r} must be one of 'structure', 'feature', 'both'" + ) + + html = generate_report( + analysis_result=analysis_result, + profile_result=profile_result, + ) + + report_path = _write_report(html, config.output_gcs_path) + logger.info(f"Report written to {report_path}") + return report_path + + +def _resolve_job_name_prefix( + cli_value: Optional[str], yaml_value: Optional[str] +) -> str: + """Pick the effective ``job_name_prefix`` from CLI flag or YAML field. + + CLI takes precedence; if both are set and differ the override is logged. + Raises ``ValueError`` if neither source supplies a value, or if the + chosen value doesn't match the lowercase / hyphen / ≤20-char shape we + require to keep the final Dataflow job name within Dataflow's ~63-char + cap. + """ + if cli_value and yaml_value and cli_value != yaml_value: + logger.info( + f"--job_name_prefix={cli_value!r} overrides YAML " + f"job_name_prefix={yaml_value!r}." + ) + effective = cli_value or yaml_value + if not effective: + raise ValueError( + "job_name_prefix is required: pass --job_name_prefix on the CLI " + "or set job_name_prefix in the analyzer YAML." + ) + if not _JOB_NAME_PREFIX_REGEX.fullmatch(effective): + raise ValueError( + f"job_name_prefix={effective!r} is invalid. Expected lowercase " + "letters, digits, and hyphens, starting with a letter, ≤20 chars." + ) + return effective + + +def main() -> None: + """CLI entry point for the BQ Data Analyzer.""" + parser = argparse.ArgumentParser( + description="BQ Data Analyzer: analyze graph data in BigQuery before GNN training" + ) + parser.add_argument( + "--analyzer_config_uri", + required=True, + help="Path or GCS URI to the analyzer YAML config", + ) + parser.add_argument( + "--resource_config_uri", + required=False, + help="Path or GCS URI to the resource config for Dataflow sizing", + ) + parser.add_argument( + "--only", + choices=["structure", "feature", "both"], + default="both", + help=( + "Run only the graph structure analyzer, only the feature profiler, " + "or both (default: both)." + ), + ) + parser.add_argument( + "--custom_worker_image_uri", + type=str, + required=False, + help=( + "Docker image URI to use for the Dataflow worker harness in the " + "feature profiler. When omitted, falls back to " + "DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU." + ), + ) + parser.add_argument( + "--job_name_prefix", + type=str, + required=False, + help=( + "Prefix mixed into every per-table Dataflow job name to " + "disambiguate concurrent / repeat runs. Required, but may be " + "set in the analyzer YAML instead. CLI overrides YAML. Lowercase " + "letters, digits, and hyphens, starting with a letter, ≤20 chars." + ), + ) + args = parser.parse_args() + resource_config = get_resource_config( + UriFactory.create_uri(args.resource_config_uri) + ) + config = load_analyzer_config(args.analyzer_config_uri) + job_name_prefix = _resolve_job_name_prefix( + cli_value=args.job_name_prefix, yaml_value=config.job_name_prefix + ) + run_timestamp = current_formatted_datetime(_RUN_TIMESTAMP_FORMAT) + logger.info( + f"Using job_name_prefix={job_name_prefix!r}, run_timestamp={run_timestamp!r}." + ) + + analyzer = DataAnalyzer() + report_path = analyzer.run( + config=config, + resource_config=resource_config, + job_name_prefix=job_name_prefix, + run_timestamp=run_timestamp, + components=args.only, + custom_worker_image_uri=args.custom_worker_image_uri, + ) + logger.info(f"Report generated at: {report_path}") + + +if __name__ == "__main__": + main() diff --git a/gigl/analytics/data_analyzer/embedding_diagnostics.py b/gigl/analytics/data_analyzer/embedding_diagnostics.py new file mode 100644 index 000000000..7fa5d36eb --- /dev/null +++ b/gigl/analytics/data_analyzer/embedding_diagnostics.py @@ -0,0 +1,174 @@ +"""Structural-sanity diagnostics for REPEATED FLOAT (embedding) columns. + +Runs one BigQuery aggregate per (table, embedding column) to compute +``total`` rows, ``unique_count`` of distinct vectors, ``unique_ratio``, +and the top-K most-frequent hash clusters. Uses +``FARM_FINGERPRINT(TO_JSON_STRING())`` as the deduplication key — +cheap, deterministic, and exact for equality (not similarity). + +A low ``unique_ratio`` or a heavily-weighted top entry indicates upstream +degeneracy (many rows emitting the same embedding — often a zero-padded +placeholder for missing data). + +The component is best-effort: a failure on one column logs a warning and +is skipped; callers receive an empty mapping for that column rather than +an exception. +""" + +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Optional + +from gigl.analytics.data_analyzer.types import EmbeddingDiagnosticsResult, TopKEntry +from gigl.common.logger import Logger +from gigl.src.common.utils.bq import BqUtils + +logger = Logger() + +_PARALLEL_DIAGNOSTICS_QUERIES = 8 +_DEFAULT_TOP_K = 20 + + +@dataclass(frozen=True) +class EmbeddingDiagnosticsRequest: + """One (table, embedding columns, result key) triple to analyze. + + ``result_key`` is the per-table analyzer key (``"node:{type}"`` or + ``"edge:{type}"``) used to organize outputs into the + ``FeatureProfileResult.embedding_diagnostics`` two-level dict. + """ + + result_key: str + bq_table: str + embedding_columns: list[str] + + +class EmbeddingDiagnostics: + """Compute structural diagnostics for embedding columns via BigQuery.""" + + def __init__( + self, + bq_utils: BqUtils, + top_k: int = _DEFAULT_TOP_K, + max_workers: int = _PARALLEL_DIAGNOSTICS_QUERIES, + ) -> None: + self._bq_utils = bq_utils + self._top_k = top_k + self._max_workers = max_workers + + def analyze( + self, requests: list[EmbeddingDiagnosticsRequest] + ) -> dict[str, dict[str, EmbeddingDiagnosticsResult]]: + """Run one aggregate query per (table, column) and collect results. + + Per-column failures are logged and skipped; one bad column does not + sink other columns in the same request or other requests. A request + whose every column failed produces an empty inner dict, which is + omitted from the output. + + Args: + requests: One entry per table with at least one embedding column. + + Returns: + ``{result_key: {column_name: EmbeddingDiagnosticsResult}}``. + Missing keys indicate the column's query failed. + """ + jobs: list[tuple[str, str, str]] = [] + for request in requests: + for column in request.embedding_columns: + jobs.append((request.result_key, request.bq_table, column)) + if not jobs: + return {} + + logger.info( + f"Running {len(jobs)} embedding diagnostic query(ies) across " + f"{len(requests)} table(s)." + ) + out: dict[str, dict[str, EmbeddingDiagnosticsResult]] = {} + with ThreadPoolExecutor(max_workers=self._max_workers) as executor: + future_to_key = { + executor.submit( + self._analyze_column, bq_table=bq_table, column=column + ): (result_key, column) + for result_key, bq_table, column in jobs + } + for future in as_completed(future_to_key): + result_key, column = future_to_key[future] + try: + diagnostics = future.result() + except Exception as exc: + logger.exception( + f"Embedding diagnostics failed for " + f"{result_key}:{column}: {exc}" + ) + continue + if diagnostics is None: + continue + out.setdefault(result_key, {})[column] = diagnostics + return out + + def _analyze_column( + self, bq_table: str, column: str + ) -> Optional[EmbeddingDiagnosticsResult]: + """Run the dedup aggregate for one column; return its result.""" + query = _build_dedup_query(bq_table=bq_table, column=column, top_k=self._top_k) + rows = list(self._bq_utils.run_query(query=query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Embedding diagnostics query expected exactly 1 row for " + f"{bq_table}.{column}; got {len(rows)}." + ) + row = rows[0] + total = int(row["total"] or 0) + unique_count = int(row["unique_count"] or 0) + unique_ratio = float(row["unique_ratio"] or 0.0) + top_k_rows = row["top_k"] or [] + top_k = [ + TopKEntry( + hash=int(entry["hash_value"]), + count=int(entry["count_value"]), + fraction=float(entry["fraction"] or 0.0), + ) + for entry in top_k_rows + ] + return EmbeddingDiagnosticsResult( + total=total, + unique_count=unique_count, + unique_ratio=unique_ratio, + top_k=top_k, + ) + + +def _build_dedup_query(bq_table: str, column: str, top_k: int) -> str: + """Render the per-column dedup aggregate. + + ``FARM_FINGERPRINT(TO_JSON_STRING())`` is deterministic and + collision-resistant enough for this purpose — we're looking for + unusually clumped clusters, not cryptographic uniqueness. + """ + return f""" +WITH hashes AS ( + SELECT FARM_FINGERPRINT(TO_JSON_STRING(`{column}`)) AS h + FROM `{bq_table}` +), +counts AS ( + SELECT h, COUNT(*) AS n FROM hashes GROUP BY h +), +agg AS ( + SELECT SUM(n) AS total, COUNT(*) AS unique_count FROM counts +) +SELECT + agg.total, + agg.unique_count, + SAFE_DIVIDE(agg.unique_count, agg.total) AS unique_ratio, + ARRAY( + SELECT AS STRUCT + h AS hash_value, + n AS count_value, + SAFE_DIVIDE(n, agg.total) AS fraction + FROM counts + ORDER BY n DESC + LIMIT {top_k} + ) AS top_k +FROM agg +""".strip() diff --git a/gigl/analytics/data_analyzer/embedding_projection.py b/gigl/analytics/data_analyzer/embedding_projection.py new file mode 100644 index 000000000..fec7d440d --- /dev/null +++ b/gigl/analytics/data_analyzer/embedding_projection.py @@ -0,0 +1,179 @@ +"""Schema-aware BQ projection builder for the feature profiler. + +Translates a BigQuery table schema into a ``SELECT`` projection that +TFDV can profile. Scalar profileable columns pass through unchanged; +REPEATED ``FLOAT`` / ``FLOAT64`` / ``NUMERIC`` / ``BIGNUMERIC`` columns +(embeddings) expand into four scalar hygiene companions: + +* ``_len`` — array length +* ``_has_nan`` — any NaN element +* ``_has_inf`` — any Inf element +* ``_is_all_zero`` — every element equals 0 + +Structural-sanity (dedup / unique-ratio / top-K) lives in +:mod:`gigl.analytics.data_analyzer.embedding_diagnostics`, which runs its +own aggregate query over ``FARM_FINGERPRINT(TO_JSON_STRING())``. The +hash is deliberately excluded from this projection so TFDV doesn't render +noisy stats on a 64-bit hash column. +""" + +from dataclasses import dataclass + +from google.cloud.bigquery import SchemaField + +from gigl.common.logger import Logger + +logger = Logger() + +# BigQuery scalar types TFDV can profile once wrapped as ``list`` by +# ``BqTableToRecordBatch``. Matches ``_PROFILEABLE_FIELD_TYPES`` in +# ``feature_profiler.py`` — kept in sync via a single import site. +_SCALAR_PROFILEABLE_TYPES: frozenset[str] = frozenset( + { + "STRING", + "INTEGER", + "INT64", + "FLOAT", + "FLOAT64", + "NUMERIC", + "BIGNUMERIC", + "BOOLEAN", + "BOOL", + } +) + +# REPEATED types that represent embedding vectors. STRING / INT arrays are +# intentionally excluded — they need different diagnostics (e.g. vocab stats) +# and are out of scope for this pass. +_EMBEDDING_FLOAT_TYPES: frozenset[str] = frozenset( + {"FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC"} +) + + +@dataclass(frozen=True) +class ProjectionResult: + """Output of :func:`build_projection`. + + ``projection`` is a list of ``(column_name, sql_expression)`` pairs + suitable for feeding directly into a + :class:`~gigl.common.beam.tfdv_transforms.BqTableToRecordBatch`. Each + entry renders as ``{sql_expression} AS \\`{column_name}\\``` in the + resulting ``SELECT``. + + ``embedding_columns`` lists the original REPEATED FLOAT column names + (pre-expansion) in schema order; the dedup pass uses them to locate the + corresponding ``_hash`` companion. + """ + + projection: list[tuple[str, str]] + embedding_columns: list[str] + + +def is_embedding_column(field: SchemaField) -> bool: + """Return ``True`` for REPEATED FLOAT-family columns (embedding vectors).""" + return ( + field.mode == "REPEATED" and field.field_type.upper() in _EMBEDDING_FLOAT_TYPES + ) + + +def detect_embedding_columns( + schema: dict[str, SchemaField], excluded: set[str] +) -> list[str]: + """List REPEATED FLOAT-family columns in the schema, in declaration order. + + Excluded columns (typically structural join keys) are dropped. + """ + return [ + name + for name, field in schema.items() + if name not in excluded and is_embedding_column(field) + ] + + +def build_projection( + schema: dict[str, SchemaField], excluded: set[str] +) -> ProjectionResult: + """Build a TFDV-compatible projection from a BigQuery schema. + + Scalar profileable columns (see :data:`_SCALAR_PROFILEABLE_TYPES`) are + passed through verbatim, *except* BOOL / BOOLEAN columns are cast to + INT64. ``BqTableToRecordBatch`` wraps each value in a single-element + list before emitting an Arrow ``RecordBatch``; TFDV's + ``get_feature_type_from_arrow_type`` does not accept ``list`` + (only int / float / string / bytes lists), so a raw BOOL column would + crash the Dataflow job in ``BasicStatsGenerator.add_input``. Casting + to INT64 in SQL keeps the BOOL semantics (0/1) profileable as an + int feature. + + REPEATED FLOAT-family columns are expanded into four scalar hygiene + companions (see module docstring). The three boolean companions + (``_has_nan``, ``_has_inf``, ``_is_all_zero``) are likewise cast to + INT64 for the same reason. REPEATED non-FLOAT columns and + non-profileable scalar types are skipped with an ``INFO`` log. + + Args: + schema: Column name → ``SchemaField`` map (as returned by + ``BqUtils.fetch_bq_table_schema``). + excluded: Column names to drop entirely (typically structural join + keys: node ``id_column``; edge ``src_id_column`` + + ``dst_id_column``). + + Returns: + :class:`ProjectionResult`. ``projection`` preserves schema order + with each embedding's hygiene companions appearing in a contiguous + block. + """ + projection: list[tuple[str, str]] = [] + embedding_columns: list[str] = [] + for name, field in schema.items(): + if name in excluded: + continue + if is_embedding_column(field): + projection.extend(_embedding_hygiene_projection(name)) + embedding_columns.append(name) + continue + if field.mode == "REPEATED": + logger.info( + f"skipping REPEATED column {name!r} of type {field.field_type} " + "(hygiene companions only cover REPEATED FLOAT families)." + ) + continue + type_upper = field.field_type.upper() + if type_upper not in _SCALAR_PROFILEABLE_TYPES: + logger.info( + f"skipping column {name!r} of type {field.field_type} " + "(not TFDV-profileable)." + ) + continue + if type_upper in ("BOOL", "BOOLEAN"): + projection.append((name, f"CAST(`{name}` AS INT64)")) + else: + projection.append((name, f"`{name}`")) + return ProjectionResult(projection=projection, embedding_columns=embedding_columns) + + +def _embedding_hygiene_projection(column: str) -> list[tuple[str, str]]: + """Return the four hygiene ``(name, expr)`` entries for one embedding column. + + The three boolean companions are wrapped in ``CAST(... AS INT64)`` so + the resulting Arrow column is ``list`` rather than ``list``; + see :func:`build_projection` for the TFDV compatibility rationale. + """ + return [ + (f"{column}_len", f"ARRAY_LENGTH(`{column}`)"), + ( + f"{column}_has_nan", + f"CAST(IFNULL((SELECT LOGICAL_OR(IS_NAN(v)) FROM UNNEST(`{column}`) v), " + "FALSE) AS INT64)", + ), + ( + f"{column}_has_inf", + f"CAST(IFNULL((SELECT LOGICAL_OR(IS_INF(v)) FROM UNNEST(`{column}`) v), " + "FALSE) AS INT64)", + ), + ( + f"{column}_is_all_zero", + f"CAST(IFNULL((SELECT LOGICAL_AND(v = 0) FROM UNNEST(`{column}`) v), " + "FALSE) AS INT64)", + ), + ] diff --git a/gigl/analytics/data_analyzer/feature_profiler.py b/gigl/analytics/data_analyzer/feature_profiler.py new file mode 100644 index 000000000..451806883 --- /dev/null +++ b/gigl/analytics/data_analyzer/feature_profiler.py @@ -0,0 +1,749 @@ +"""TFDV feature profiling via Beam/Dataflow. + +Launches one Dataflow pipeline per node and edge table in the analyzer +config. For each table, the BQ projection is built from the table schema +via :func:`~gigl.analytics.data_analyzer.embedding_projection.build_projection`: +scalar profileable columns pass through, REPEATED FLOAT-family columns +(embeddings) expand into four hygiene companions +(``_len``/``_has_nan``/``_has_inf``/``_is_all_zero``). Each pipeline +reads the resulting columns from BigQuery, emits ``pa.RecordBatch`` +batches, and runs ``tfdv.GenerateStatistics`` to write a Facets HTML +visualization plus a TFDV stats TFRecord to GCS. + +After all Dataflow pipelines finish, one aggregate BigQuery query per +embedding column runs via +:class:`~gigl.analytics.data_analyzer.embedding_diagnostics.EmbeddingDiagnostics` +to compute structural sanity (unique ratio + top-K most-frequent hashes). +The final :class:`FeatureProfileResult` is serialized to +``{output_gcs_path}/feature_profile.json`` via :func:`write_artifact` so +external consumers can parse it without scraping HTML. + +Tables whose final projection is empty (e.g. only ID columns, or a schema +fetch failed) are skipped with a warning. Per-table Beam failures, the +diagnostics pass, and the sidecar write are all best-effort: the TFDV +artifacts remain valuable even if one downstream step fails. +""" +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from typing import Any, Optional + +import apache_beam as beam +import tensorflow_data_validation as tfdv +from apache_beam.options.pipeline_options import GoogleCloudOptions +from tensorflow_data_validation.utils import slicing_util + +from gigl.analytics.data_analyzer.config import DataAnalyzerConfig +from gigl.analytics.data_analyzer.embedding_diagnostics import ( + EmbeddingDiagnostics, + EmbeddingDiagnosticsRequest, +) +from gigl.analytics.data_analyzer.embedding_projection import ( + ProjectionResult, + build_projection, +) +from gigl.analytics.data_analyzer.types import ( + FeatureProfileError, + FeatureProfileResult, + write_artifact, +) +from gigl.common import UriFactory +from gigl.common.beam.sharded_read import BigQueryShardedReadConfig +from gigl.common.beam.tfdv_transforms import ( + BqTableToRecordBatch, + GenerateAndVisualizeStats, +) +from gigl.common.logger import Logger +from gigl.env.pipelines_config import GiglResourceConfigWrapper +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.dataflow import init_beam_pipeline_options + +logger = Logger() + +_PARALLEL_DATAFLOW_WORKERS = 10 +# Kept short to leave room for the per-run prefix and timestamp inside +# the Dataflow job-name budget (~63 chars). +_APPLIED_TASK_IDENTIFIER = AppliedTaskIdentifier("analyzer") + + +def _safe_dataflow_job_id(result: Any) -> Optional[str]: + """Return ``result.job_id()`` if present, else ``None``. + + The DataflowRunner returns a ``DataflowPipelineResult`` whose + ``job_id()`` method exposes the submitted job's UUID. Other runners + (DirectRunner, etc.) don't have this attribute; we degrade silently + instead of raising so callers can keep an unrelated failure path + clean. + """ + job_id_attr = getattr(result, "job_id", None) + if job_id_attr is None: + return None + try: + if callable(job_id_attr): + value = job_id_attr() + else: + value = job_id_attr + except Exception: + return None + return str(value) if value else None + + +def _build_dataflow_console_url( + project: Optional[str], region: Optional[str], job_id: Optional[str] +) -> Optional[str]: + """Compose the Cloud Console URL for a Dataflow job. + + Returns ``None`` if any of project / region / job_id is missing, + rather than producing a malformed URL. + """ + if not project or not region or not job_id: + return None + return ( + f"https://console.cloud.google.com/dataflow/jobs/{region}/{job_id}" + f"?project={project}" + ) + + +def _resolve_projection( + bq_table: str, + explicit: list[str], + excluded: set[str], + bq_utils: BqUtils, + extra_columns: Optional[list[str]] = None, +) -> tuple[ProjectionResult, Optional[str]]: + """Build the projection for one table, honoring an explicit override. + + If ``explicit`` is non-empty, the schema is still fetched but only + those columns are considered (minus ``excluded``). Explicit names not + present in the schema are logged and dropped rather than raising. + Otherwise every non-excluded column is routed through + :func:`build_projection`. + + ``extra_columns`` are appended to the resulting projection unconditionally + if they exist in the schema (e.g. label / split columns the analyzer + needs available for TFDV slicing even when the user's explicit + ``feature_columns`` doesn't list them). Extras already present in the + base projection are skipped to avoid duplicate SELECT entries; extras + missing from the schema are warned about and dropped. + + Returns ``(projection_result, error_message_or_none)``. A non-None + second element means the schema fetch failed; the caller should + surface that as a structured error instead of just silently skipping + the table. + """ + try: + schema = bq_utils.fetch_bq_table_schema(bq_table) + except Exception as exc: + message = f"Schema fetch failed for {bq_table}: {exc}" + logger.warning(message) + return ProjectionResult(projection=[], embedding_columns=[]), message + + if explicit: + unknown = [c for c in explicit if c not in schema] + if unknown: + logger.warning( + f"{bq_table}: explicit feature_columns {unknown} not in " + f"schema; ignoring." + ) + filtered_schema = { + name: field + for name, field in schema.items() + if name in explicit and name not in excluded + } + base = build_projection(filtered_schema, excluded=set()) + else: + base = build_projection(schema, excluded=excluded) + + if extra_columns: + existing_names = {name for name, _ in base.projection} + extras_schema = {} + for column in extra_columns: + if column in existing_names: + continue + if column not in schema: + logger.warning( + f"{bq_table}: extra projection column {column!r} not in " + f"schema; ignoring." + ) + continue + extras_schema[column] = schema[column] + if extras_schema: + extras_projection = build_projection(extras_schema, excluded=set()) + base = ProjectionResult( + projection=list(base.projection) + list(extras_projection.projection), + embedding_columns=list(base.embedding_columns), + ) + + return base, None + + +@dataclass(frozen=True) +class _ProfileTask: + """One profiling unit: all columns of a single node or edge table. + + ``kind`` is ``"node"`` or ``"edge"`` (singular) and is used to build + the GCS output path and the result key (``"node:user"``, etc.). + + ``shard_key`` is the column the BQ read fans out on (hash-mod-N) to + avoid the single-giant-export pattern that hangs ``SplitWithSizing`` + on very large tables. Sourced from ``NodeTableSpec.id_column`` for + node tables and ``EdgeTableSpec.src_id_column`` for edge tables — + both are guaranteed present and uniformly distributed enough for a + FARM_FINGERPRINT-based mod split. + + ``slice_columns`` lists columns whose distinct values should each + produce a slice of the TFDV stats. The values come from + ``NodeTableSpec.label_column`` / ``NodeTableSpec.split_column`` — + when set, the profiler routes them through ``slicing_util`` so the + resulting TFDV stats include per-slice ``DatasetFeatureStatistics`` + entries (per-class label histograms, per-class feature null-rate, + per-split distributions). Empty for edge tables and for node tables + that don't activate NC supervision. + """ + + kind: str + type_name: str + bq_table: str + projection: list[tuple[str, str]] + embedding_columns: list[str] + shard_key: str + slice_columns: list[str] = field(default_factory=list) + chunk_index: int = 0 + total_chunks: int = 1 + + @property + def result_key(self) -> str: + return f"{self.kind}:{self.type_name}" + + @property + def artifact_subdir(self) -> str: + """Empty for single-chunk tables; ``chunk_NN/`` for multi-chunk tables. + + Multi-chunk tables write each chunk's Facets HTML + stats TFRecord + under their own ``chunk_NN/`` subdir to avoid collisions; single-chunk + tables keep the historical flat layout for backward-compatible URLs. + """ + if self.total_chunks <= 1: + return "" + return f"chunk_{self.chunk_index:02d}/" + + +class FeatureProfiler: + """Runs TFDV feature profiling + embedding diagnostics on BQ tables via Dataflow. + + Example: + >>> profiler = FeatureProfiler() + >>> result = profiler.profile(config, resource_config=config) + >>> result.facets_html_paths["node:user"] + 'gs://bucket/analyzer/feature_profiler/nodes/user/facets.html' + """ + + def profile( + self, + config: DataAnalyzerConfig, + resource_config: GiglResourceConfigWrapper, + job_name_prefix: str, + run_timestamp: str, + custom_worker_image_uri: Optional[str] = None, + ) -> FeatureProfileResult: + """Run TFDV profiling + embedding diagnostics for every table in the config. + + For each table, the BQ projection is built via + :func:`_resolve_projection` (explicit ``feature_columns`` narrow the + schema; otherwise every non-excluded column is considered). + Embedding columns (REPEATED FLOAT families) expand into hygiene + companions in the projection and trigger a post-Dataflow structural + diagnostics pass. + + Tables whose final projection is empty are skipped with a warning. + Per-table Dataflow failures are logged and omitted. The embedding + diagnostics pass and JSON sidecar write are best-effort. + + Args: + config: Analyzer configuration with node and edge table specs. + resource_config: Resource config; its ``.project`` is used for + BigQuery schema lookups and diagnostics queries. + job_name_prefix: User-supplied prefix mixed into every per-table + Dataflow job name to disambiguate concurrent / repeat runs. + run_timestamp: Per-run timestamp string mixed into every per-table + Dataflow job name. Computed once at the entry point so all + jobs from one analyzer invocation share the same value. + custom_worker_image_uri: Optional Docker image URI for the + Dataflow worker harness. When ``None``, falls back to + ``DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU``. + + Returns: + :class:`FeatureProfileResult` with GCS paths keyed by + ``"node:{type}"`` / ``"edge:{type}"`` plus any embedding + diagnostics that succeeded. Empty facets / stats paths indicate + a skipped or failed table. + """ + bq_utils = BqUtils(project=resource_config.project) + tasks, collection_errors = _collect_profile_tasks(config, bq_utils) + result = FeatureProfileResult() + result.errors.extend(collection_errors) + if not tasks: + logger.info("No tables have profileable columns; returning empty result.") + self._maybe_write_sidecar(result, config.output_gcs_path) + return result + + logger.info(f"Launching {len(tasks)} Dataflow feature-profile job(s).") + with ThreadPoolExecutor(max_workers=_PARALLEL_DATAFLOW_WORKERS) as executor: + future_to_task = { + executor.submit( + self._run_single_pipeline, + task, + config.output_gcs_path, + resource_config, + job_name_prefix, + run_timestamp, + custom_worker_image_uri, + ): task + for task in tasks + } + for future in as_completed(future_to_task): + task = future_to_task[future] + try: + facets_uri, stats_uri = future.result() + # ``setdefault`` keeps multi-chunk per-table aggregation safe + # under the unordered ``as_completed`` iteration: each chunk + # lands as a list entry under the table-level result_key. + result.facets_html_paths.setdefault(task.result_key, []).append( + facets_uri + ) + result.stats_paths.setdefault(task.result_key, []).append(stats_uri) + if task.slice_columns: + result.slice_columns_by_result_key[task.result_key] = list( + task.slice_columns + ) + except Exception as exc: + logger.exception( + f"Feature profiling failed for {task.result_key} " + f"(table={task.bq_table}): {exc}" + ) + result.errors.append( + FeatureProfileError( + result_key=task.result_key, + bq_table=task.bq_table, + stage="dataflow", + message=f"{type(exc).__name__}: {exc}", + job_id=getattr(exc, "_gigl_job_id", None), + job_name=getattr(exc, "_gigl_job_name", None), + console_url=getattr(exc, "_gigl_console_url", None), + ) + ) + + self._run_embedding_diagnostics(tasks, bq_utils, result) + self._maybe_write_sidecar(result, config.output_gcs_path) + return result + + def _run_single_pipeline( + self, + task: _ProfileTask, + output_gcs_path: str, + resource_config: GiglResourceConfigWrapper, + job_name_prefix: str, + run_timestamp: str, + custom_worker_image_uri: Optional[str] = None, + ) -> tuple[str, str]: + """Build, run, and block on a single table's Dataflow pipeline. + + Returns the ``(facets_uri, stats_uri)`` strings on success. + + Worker sizing (machine_type / num_workers / max_num_workers / + disk_size_gb / timeout) is read from + ``resource_config.preprocessor_config.node_preprocessor_config`` for + node tasks and ``.edge_preprocessor_config`` for edge tasks. The + analyzer reuses the preprocessor's Dataflow sizing on the same + kind of table rather than declaring its own block, mirroring the + pattern in + :func:`gigl.src.data_preprocessor.lib.transform.utils.transform_features`. + + Captures the Dataflow ``job_id`` / ``job_name`` / console URL on the + raised exception (as ``_gigl_*`` attributes) when the pipeline fails + on a Dataflow runner. The caller reads those off the exception and + promotes them into a :class:`FeatureProfileError` so the HTML report + can deep-link to the failed job's logs. Best-effort: a non-Dataflow + runner (e.g. DirectRunner in tests) yields ``None`` for job_id. + """ + # Single-chunk tables keep the historical flat layout + # (``.../{type}/facets.html``); multi-chunk tables write each chunk + # under its own ``chunk_NN/`` subdir so the stats / Facets per chunk + # don't collide. + base = ( + f"{output_gcs_path.rstrip('/')}/feature_profiler/" + f"{task.kind}s/{task.type_name}/{task.artifact_subdir}" + ).rstrip("/") + facets_uri = UriFactory.create_uri(f"{base}/facets.html") + stats_uri = UriFactory.create_uri(f"{base}/stats.tfrecord") + + if task.kind == "node": + dataflow_config = ( + resource_config.preprocessor_config.node_preprocessor_config + ) + elif task.kind == "edge": + dataflow_config = ( + resource_config.preprocessor_config.edge_preprocessor_config + ) + else: + raise ValueError( + f"Unexpected task.kind={task.kind!r}; expected 'node' or 'edge'." + ) + + # Append a chunk suffix to the Dataflow job-name only when the table + # is actually being chunked, to keep single-chunk job names stable + # and within Dataflow's 63-char job-name budget for the common case. + chunk_suffix = ( + f"-chunk-{task.chunk_index:02d}-of-{task.total_chunks:02d}" + if task.total_chunks > 1 + else "" + ) + options = init_beam_pipeline_options( + applied_task_identifier=_APPLIED_TASK_IDENTIFIER, + job_name_suffix=( + f"{job_name_prefix}-{run_timestamp}-profile-" + f"{task.kind}-{task.type_name}{chunk_suffix}" + ), + component=GiGLComponents.DataAnalyzer, + custom_worker_image_uri=custom_worker_image_uri, + timeout_seconds=dataflow_config.timeout + if dataflow_config.timeout + else None, + num_workers=dataflow_config.num_workers, + max_num_workers=dataflow_config.max_num_workers, + machine_type=dataflow_config.machine_type, + disk_size_gb=dataflow_config.disk_size_gb, + ) + gcp_opts = options.view_as(GoogleCloudOptions) + job_name = gcp_opts.job_name + project = gcp_opts.project + region = gcp_opts.region + + stats_options = _build_slice_stats_options(task.slice_columns) + + # Shard the BQ read on the natural per-table key (id_column for nodes, + # src_id_column for edges). Mirrors the data_preprocessor's + # ShardedExportRead pattern; without it, a single giant ReadFromBigQuery + # on a large user/edge table hangs Dataflow's SplitWithSizing on + # oversized GCS Avro reads. ``num_shards`` defaults to 20 inside the + # config dataclass (matches the preprocessor default). + sharded_read_config = BigQueryShardedReadConfig( + shard_key=task.shard_key, + project_id=resource_config.project, + temp_dataset_name=resource_config.temp_assets_bq_dataset_name, + ) + + pipeline = beam.Pipeline(options=options) + _ = ( + pipeline + | f"Read {task.result_key} from BQ" + >> BqTableToRecordBatch( + bq_table=task.bq_table, + projection=task.projection, + sharded_read_config=sharded_read_config, + ) + | f"Generate TFDV stats for {task.result_key}" + >> GenerateAndVisualizeStats( + facets_report_uri=facets_uri, + stats_output_uri=stats_uri, + stats_options=stats_options, + ) + ) + result = pipeline.run() + try: + result.wait_until_finish() + except Exception as exc: + job_id = _safe_dataflow_job_id(result) + console_url = _build_dataflow_console_url( + project=project, region=region, job_id=job_id + ) + exc._gigl_job_id = job_id # type: ignore[attr-defined] + exc._gigl_job_name = job_name # type: ignore[attr-defined] + exc._gigl_console_url = console_url # type: ignore[attr-defined] + raise + logger.info(f"Finished feature profiling for {task.result_key}.") + return facets_uri.uri, stats_uri.uri + + def _run_embedding_diagnostics( + self, + tasks: list[_ProfileTask], + bq_utils: BqUtils, + result: FeatureProfileResult, + ) -> None: + """Run structural diagnostics for every task with embedding columns. + + Best-effort: any exception is caught so the sidecar write and the + already-produced TFDV artifacts remain valuable. + + Multi-chunk tables emit multiple ``_ProfileTask``s with the same + ``result_key`` and ``embedding_columns`` (table-level). We dedupe + per ``result_key`` so the embedding-diagnostics BQ aggregate runs + once per table, not once per chunk. + """ + deduped: dict[str, EmbeddingDiagnosticsRequest] = {} + for task in tasks: + if not task.embedding_columns: + continue + existing = deduped.get(task.result_key) + if existing is None: + deduped[task.result_key] = EmbeddingDiagnosticsRequest( + result_key=task.result_key, + bq_table=task.bq_table, + embedding_columns=list(task.embedding_columns), + ) + continue + # Same result_key seen on a previous chunk — union the embedding + # columns to be safe against any chunk that happens to carry a + # narrower embedding subset (chunks share table-level + # embedding_columns today, but defensive). + seen = set(existing.embedding_columns) + extra = [c for c in task.embedding_columns if c not in seen] + if extra: + deduped[task.result_key] = EmbeddingDiagnosticsRequest( + result_key=existing.result_key, + bq_table=existing.bq_table, + embedding_columns=existing.embedding_columns + extra, + ) + requests = list(deduped.values()) + if not requests: + return + try: + diagnostics = EmbeddingDiagnostics(bq_utils=bq_utils).analyze(requests) + except Exception as exc: + logger.exception(f"Embedding diagnostics pass failed: {exc}") + message = f"{type(exc).__name__}: {exc}" + for request in requests: + result.errors.append( + FeatureProfileError( + result_key=request.result_key, + bq_table=request.bq_table, + stage="embedding_diagnostics", + message=message, + ) + ) + return + for result_key, per_column in diagnostics.items(): + result.embedding_diagnostics[result_key] = per_column + + def _maybe_write_sidecar( + self, result: FeatureProfileResult, output_gcs_path: str + ) -> None: + """Best-effort write of the Pydantic JSON sidecar.""" + try: + write_artifact( + result=result, + component="feature_profile", + output_gcs_path=output_gcs_path, + ) + except Exception as exc: + logger.exception(f"Failed to write feature_profile.json sidecar: {exc}") + + +def _build_slice_stats_options( + slice_columns: list[str], +) -> Optional[tfdv.StatsOptions]: + """Build a ``tfdv.StatsOptions`` configured to slice on the given columns. + + Returns ``None`` when no slice columns are requested so callers can + cheaply pass through to TFDV's defaults. Each entry produces a + standard "feature value slicer" that emits one slice per distinct + value of the column. The unsliced ("Overall") stats are always + emitted by TFDV in addition to the per-slice stats, so existing + consumers continue to see the same top-level stats they did before + slicing was enabled. + """ + if not slice_columns: + return None + slice_functions = [ + slicing_util.get_feature_value_slicer({column: None}) + for column in slice_columns + ] + return tfdv.StatsOptions(slice_functions=slice_functions) + + +def _chunk_projection( + projection: list[tuple[str, str]], + max_features: int, + forced_columns: set[str], +) -> list[list[tuple[str, str]]]: + """Slice a projection into ``ceil(len/max_features)`` ≤``max_features``-sized chunks. + + Beam 2.56's runner-v2 cannot reliably iterate the per-key state TFDV's + ``CombinePerKey(PreCombineFn)`` accumulates over very wide projections + (work items time out on ``Instruction id ... was not registered``). + Splitting the projection across multiple Dataflow pipelines keeps + every per-key partition small enough for the runner to iterate. + + ``forced_columns`` (typically slice columns: ``label_column`` / + ``split_column``) are present in **every** chunk so TFDV slicing + applies uniformly across chunks. Each chunk's effective non-forced + budget is ``max_features - len(forced_pairs)`` (clamped to ≥1). + + Args: + projection: ``(column_name, sql_expression)`` pairs from + :func:`_resolve_projection`. Slice columns are already in here + (via that function's ``extra_columns``). + max_features: Target per-chunk column cap. The actual chunk size + is ``max_features`` for non-forced columns plus the forced + columns appended. + forced_columns: Names that must appear in every chunk. + + Returns: + Non-empty list of chunks. Empty input returns ``[]``. + """ + forced_pairs = [(n, e) for n, e in projection if n in forced_columns] + rest = [(n, e) for n, e in projection if n not in forced_columns] + if not rest: + return [list(forced_pairs)] if forced_pairs else [] + budget_per_chunk = max(1, max_features - len(forced_pairs)) + chunks: list[list[tuple[str, str]]] = [] + for start in range(0, len(rest), budget_per_chunk): + chunks.append(list(forced_pairs) + rest[start : start + budget_per_chunk]) + return chunks + + +def _collect_profile_tasks( + config: DataAnalyzerConfig, bq_utils: BqUtils +) -> tuple[list[_ProfileTask], list[FeatureProfileError]]: + """Flatten the analyzer config into one ``_ProfileTask`` per table. + + Resolves the projection for each node/edge spec by either restricting + to explicit ``feature_columns`` or auto-inferring from the BQ table + schema (excluding structural join keys). Tables whose resolved + projection is empty (e.g. only ID columns, or the schema fetch failed) + are logged, recorded as a structured ``FeatureProfileError`` so the + HTML report can surface them, and skipped. + """ + tasks: list[_ProfileTask] = [] + errors: list[FeatureProfileError] = [] + for node_table in config.node_tables: + result_key = f"node:{node_table.node_type}" + # Slice columns must be in the projection so TFDV can read them. + # ``compute_per_class_feature_stats`` opts out of the label slice + # without forcing the user to drop ``label_column`` itself (the + # graph_structure_analyzer NC tier still needs the column there). + slice_columns: list[str] = [] + if ( + node_table.label_column is not None + and config.compute_per_class_feature_stats + ): + slice_columns.append(node_table.label_column) + if node_table.split_column is not None: + slice_columns.append(node_table.split_column) + + projection, schema_error = _resolve_projection( + bq_table=node_table.bq_table, + explicit=node_table.feature_columns, + excluded={node_table.id_column}, + bq_utils=bq_utils, + extra_columns=slice_columns, + ) + if schema_error is not None: + errors.append( + FeatureProfileError( + result_key=result_key, + bq_table=node_table.bq_table, + stage="schema_fetch", + message=schema_error, + ) + ) + continue + if not projection.projection: + message = ( + f"No profileable columns after projection " + f"(id_column={node_table.id_column!r}, " + f"explicit feature_columns={node_table.feature_columns})." + ) + logger.warning(f"Skipping {result_key}: {message}") + errors.append( + FeatureProfileError( + result_key=result_key, + bq_table=node_table.bq_table, + stage="empty_projection", + message=message, + ) + ) + continue + # Slice columns that didn't make it into the projection (missing + # from schema) are dropped; ``_resolve_projection`` already logged. + projected_names = {name for name, _ in projection.projection} + active_slice_columns = [ + column for column in slice_columns if column in projected_names + ] + chunks = _chunk_projection( + projection.projection, + max_features=config.max_features_per_chunk, + forced_columns=set(active_slice_columns), + ) + for chunk_index, chunk_projection in enumerate(chunks): + tasks.append( + _ProfileTask( + kind="node", + type_name=node_table.node_type, + bq_table=node_table.bq_table, + projection=chunk_projection, + embedding_columns=projection.embedding_columns, + shard_key=node_table.id_column, + slice_columns=active_slice_columns, + chunk_index=chunk_index, + total_chunks=len(chunks), + ) + ) + for edge_table in config.edge_tables: + result_key = f"edge:{edge_table.edge_type}" + projection, schema_error = _resolve_projection( + bq_table=edge_table.bq_table, + explicit=edge_table.feature_columns, + excluded={ + edge_table.src_id_column, + edge_table.dst_id_column, + }, + bq_utils=bq_utils, + ) + if schema_error is not None: + errors.append( + FeatureProfileError( + result_key=result_key, + bq_table=edge_table.bq_table, + stage="schema_fetch", + message=schema_error, + ) + ) + continue + if not projection.projection: + message = ( + f"No profileable columns after projection " + f"(src_id_column={edge_table.src_id_column!r}, " + f"dst_id_column={edge_table.dst_id_column!r}, " + f"explicit feature_columns={edge_table.feature_columns})." + ) + logger.warning(f"Skipping {result_key}: {message}") + errors.append( + FeatureProfileError( + result_key=result_key, + bq_table=edge_table.bq_table, + stage="empty_projection", + message=message, + ) + ) + continue + chunks = _chunk_projection( + projection.projection, + max_features=config.max_features_per_chunk, + forced_columns=set(), + ) + for chunk_index, chunk_projection in enumerate(chunks): + tasks.append( + _ProfileTask( + kind="edge", + type_name=edge_table.edge_type, + bq_table=edge_table.bq_table, + projection=chunk_projection, + embedding_columns=projection.embedding_columns, + shard_key=edge_table.src_id_column, + chunk_index=chunk_index, + total_chunks=len(chunks), + ) + ) + return tasks, errors diff --git a/gigl/analytics/data_analyzer/graph_structure_analyzer.py b/gigl/analytics/data_analyzer/graph_structure_analyzer.py new file mode 100644 index 000000000..48933f14f --- /dev/null +++ b/gigl/analytics/data_analyzer/graph_structure_analyzer.py @@ -0,0 +1,1169 @@ +"""GraphStructureAnalyzer: 4-tier BigQuery-based graph data quality checks. + +Tier 1 (hard fails) + dangling edges, referential integrity, duplicate nodes. Any violation + raises DataQualityError with a partially populated GraphAnalysisResult. + +Tier 2 (core metrics) + node/edge counts, degree distribution, top-K hubs, INT16 clamp hazards, + isolated/cold-start nodes, duplicate edges, self-loops, NULL rates, and + two Python-side computations (feature memory budget, neighbor explosion). + +Tier 3 (label and heterogeneous) + class imbalance and label coverage (auto-enabled when node_tables have a + label_column); edge-type distribution and per-edge-type node coverage + (auto-enabled when more than one edge table is declared). + +Tier 4 (opt-in) + reciprocity, power-law exponent estimate. Gated by config flags. +""" + +import math +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +from gigl.analytics.data_analyzer.config import ( + EDGE_ROLE_MESSAGE_PASSING, + EDGE_ROLE_SUPERVISION_NEG, + EDGE_ROLE_SUPERVISION_POS, + DataAnalyzerConfig, + EdgeTableSpec, + NodeTableSpec, +) +from gigl.analytics.data_analyzer.queries import ( + CLASS_IMBALANCE_QUERY, + COLD_START_NODE_COUNT_QUERY, + CROSS_SPLIT_OVERLAP_QUERY, + DANGLING_EDGES_QUERY, + DEGREE_BUCKET_QUERY, + DEGREE_DISTRIBUTION_QUERY, + DUPLICATE_EDGE_COUNT_QUERY, + DUPLICATE_NODE_COUNT_QUERY, + EDGE_COUNT_QUERY, + EDGE_REFERENTIAL_INTEGRITY_QUERY, + EDGE_TYPE_DISTRIBUTION_QUERY, + EDGE_TYPE_NODE_COVERAGE_QUERY, + ISOLATED_NODE_COUNT_QUERY, + LABEL_COVERAGE_QUERY, + NODE_COUNT_QUERY, + SELF_LOOP_COUNT_QUERY, + SPLIT_VALUE_COUNTS_QUERY, + SUPER_HUB_INT16_CLAMP_QUERY, + SUPERVISION_CROSS_TABLE_QUERY, + TOP_K_HUBS_QUERY, + build_adjusted_homophily_query, + build_label_sentinel_query, + build_null_rates_query, + build_per_class_degree_query, +) +from gigl.analytics.data_analyzer.types import ( + CrossSplitOverlap, + DegreeStats, + GraphAnalysisResult, + HomophilyStats, + LabelSentinelStats, + NodeClassificationSupervisionStats, + PerClassDegreeStats, + SupervisionCrossTableStats, + write_artifact, +) +from gigl.common.logger import Logger +from gigl.src.common.utils.bq import BqUtils + +logger = Logger() + +# Default assumption for feature memory budget: float64 per feature column. +_BYTES_PER_FEATURE = 8 +_TOP_K_HUBS = 20 +_PARALLEL_BQ_WORKERS = 10 + + +class DataQualityError(Exception): + """Raised when Tier 1 hard-fail checks detect data quality violations. + + Carries a partially populated GraphAnalysisResult so callers can inspect + which specific checks failed without re-running the analyzer. + """ + + def __init__(self, message: str, partial_result: GraphAnalysisResult) -> None: + super().__init__(message) + self.partial_result = partial_result + + +class GraphStructureAnalyzer: + """Runs BigQuery SQL checks across 4 tiers against the tables declared in a config. + + Example: + >>> config = load_analyzer_config("gs://bucket/config.yaml") + >>> analyzer = GraphStructureAnalyzer() + >>> result = analyzer.analyze(config) + >>> result.node_counts["user"] + 1000000 + + Tier 1 is blocking: a violation raises DataQualityError before Tiers 2-4 run. + Tiers 2-4 are aggregated best-effort into a single GraphAnalysisResult. + """ + + def __init__(self, bq_project: Optional[str] = None) -> None: + self._bq_utils = BqUtils(project=bq_project) + self._query_log: dict[str, list[str]] = {} + + def analyze(self, config: DataAnalyzerConfig) -> GraphAnalysisResult: + """Run all applicable tiers and return aggregated results. + + Always writes a versioned JSON sidecar to + ``{config.output_gcs_path}/graph_structure.json`` before returning + (or re-raising), so partial Tier 1 failures are recoverable by + downstream consumers without rerunning the analyzer. + + Args: + config: Data analyzer configuration declaring node and edge tables + plus any opt-in expensive checks (reciprocity, etc.). + + Returns: + GraphAnalysisResult with tier 1-4 fields populated per config. + + Raises: + DataQualityError: If tier 1 checks find any violations. The + exception carries a partial result with the specific counts; + that same partial result is persisted to the sidecar. + """ + self._query_log = {} + result = GraphAnalysisResult() + try: + logger.info("Starting graph structure analysis (Tier 1: hard fails)") + self._run_tier1(config, result) + + logger.info("Tier 1 passed. Running Tier 2 (core metrics)") + self._run_tier2(config, result) + + logger.info("Running Tier 3 (label / heterogeneous)") + self._run_tier3(config, result) + + logger.info("Running node-classification supervision tier") + self._run_node_classification_supervision(config, result) + + logger.info("Running supervision cross-table analysis") + self._run_supervision_cross_table(config, result) + + logger.info("Running Tier 4 (opt-in)") + self._run_tier4(config, result) + except DataQualityError as err: + err.partial_result.queries = dict(self._query_log) + self._maybe_write_sidecar(err.partial_result, config.output_gcs_path) + raise + result.queries = dict(self._query_log) + self._maybe_write_sidecar(result, config.output_gcs_path) + return result + + def _maybe_write_sidecar( + self, result: GraphAnalysisResult, output_gcs_path: str + ) -> None: + """Best-effort write of the Pydantic JSON sidecar. + + Never raises: the sidecar is a convenience artifact, not a + correctness contract. Failures are logged and swallowed so Tier 1 + errors (which also trigger a sidecar write) propagate intact. + """ + try: + write_artifact( + result=result, + component="graph_structure", + output_gcs_path=output_gcs_path, + ) + except Exception as exc: + logger.exception(f"Failed to write graph_structure.json sidecar: {exc}") + + # ------------------------------------------------------------------ # + # Tier 1: hard fails # + # ------------------------------------------------------------------ # + + def _run_tier1( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Run all tier 1 checks; raise DataQualityError on any violation.""" + violations: list[str] = [] + node_tables_by_type = {nt.node_type: nt for nt in config.node_tables} + + # Duplicate nodes (per node table). + for node_table in config.node_tables: + query = DUPLICATE_NODE_COUNT_QUERY.format( + table=node_table.bq_table, id_column=node_table.id_column + ) + count = self._query_scalar( + query, + "duplicate_count", + block_id=f"data_quality:duplicate_nodes:{node_table.node_type}", + ) + result.duplicate_node_counts[node_table.node_type] = count + if count > 0: + violations.append( + f"node_type={node_table.node_type} has {count} duplicate IDs" + ) + + # Dangling edges and referential integrity (per edge table). + for edge_table in config.edge_tables: + dangling_query = DANGLING_EDGES_QUERY.format( + table=edge_table.bq_table, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + dangling = self._query_scalar( + dangling_query, + "dangling_count", + block_id=f"data_quality:dangling_edges:{edge_table.edge_type}", + ) + result.dangling_edge_counts[edge_table.edge_type] = dangling + if dangling > 0: + violations.append( + f"edge_type={edge_table.edge_type} has {dangling} dangling edges" + ) + + # Referential integrity: src and dst can resolve to different node + # tables on heterogeneous graphs. `load_analyzer_config` guarantees + # src_node_type / dst_node_type are populated and known. + if not config.node_tables: + continue + assert edge_table.src_node_type is not None, ( + f"edge_type={edge_table.edge_type} has no src_node_type; " + "load the config via load_analyzer_config to backfill it." + ) + assert edge_table.dst_node_type is not None, ( + f"edge_type={edge_table.edge_type} has no dst_node_type; " + "load the config via load_analyzer_config to backfill it." + ) + src_node_table = node_tables_by_type[edge_table.src_node_type] + dst_node_table = node_tables_by_type[edge_table.dst_node_type] + ref_query = EDGE_REFERENTIAL_INTEGRITY_QUERY.format( + edge_table=edge_table.bq_table, + src_node_table=src_node_table.bq_table, + dst_node_table=dst_node_table.bq_table, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + src_node_id_column=src_node_table.id_column, + dst_node_id_column=dst_node_table.id_column, + ) + self._record_query( + f"data_quality:referential_integrity:{edge_table.edge_type}", + ref_query, + ) + rows = list(self._bq_utils.run_query(query=ref_query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Referential integrity query expected exactly 1 row; " + f"got {len(rows)}. Query: {ref_query.strip()[:200]}" + ) + missing_src = int(rows[0]["missing_src_count"] or 0) + missing_dst = int(rows[0]["missing_dst_count"] or 0) + total_missing = missing_src + missing_dst + result.referential_integrity_violations[ + edge_table.edge_type + ] = total_missing + if total_missing > 0: + violations.append( + f"edge_type={edge_table.edge_type} has {total_missing} " + "referential integrity violations" + ) + + if violations: + msg = "Tier 1 data quality violations detected:\n - " + "\n - ".join( + violations + ) + logger.error(msg) + raise DataQualityError(msg, partial_result=result) + + # ------------------------------------------------------------------ # + # Tier 2: core metrics # + # ------------------------------------------------------------------ # + + def _run_tier2( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Collect core structural metrics, fanning out BQ jobs in parallel. + + Edge-level metrics are computed from the src-side perspective: + isolated/cold-start joins pair each edge with its src_node_type's + table. Hetero dst-perspective coverage is exposed separately via + Tier 3 edge_type_node_coverage. + + BQ jobs are I/O-bound so ThreadPoolExecutor is used. Each worker + writes to distinct keys of the shared `result` dict (one key per + node_type / edge_type), so no lock is required under CPython's GIL. + """ + node_tables_by_type = {nt.node_type: nt for nt in config.node_tables} + + with ThreadPoolExecutor(max_workers=_PARALLEL_BQ_WORKERS) as executor: + futures = [] + for node_table in config.node_tables: + futures.append( + executor.submit(self._tier2_node_metrics, node_table, result) + ) + for edge_table in config.edge_tables: + src_node_table = node_tables_by_type.get(edge_table.src_node_type or "") + futures.append( + executor.submit( + self._tier2_edge_metrics, edge_table, src_node_table, result + ) + ) + for future in futures: + future.result() # re-raise any exception + + # Python-side computations run after all BQ data is collected. + self._compute_feature_memory_budget(config, result) + self._compute_neighbor_explosion_estimate(config, result) + + def _tier2_node_metrics( + self, node_table: NodeTableSpec, result: GraphAnalysisResult + ) -> None: + node_count_query = NODE_COUNT_QUERY.format(table=node_table.bq_table) + node_count = self._query_scalar( + node_count_query, + "node_count", + block_id=f"graph_structure:node_count:{node_table.node_type}", + ) + result.node_counts[node_table.node_type] = node_count + + columns_to_check: list[str] = [node_table.id_column] + columns_to_check.extend(node_table.feature_columns) + if node_table.label_column: + columns_to_check.append(node_table.label_column) + + null_query = build_null_rates_query( + table=node_table.bq_table, columns=columns_to_check + ) + self._record_query( + f"data_quality:null_rates:node:{node_table.node_type}", null_query + ) + rows = list(self._bq_utils.run_query(query=null_query, labels={})) + if rows: + row = rows[0] + rates: dict[str, float] = {} + for col in columns_to_check: + key = f"{col}_null_rate" + rate = row[key] + rates[col] = float(rate) if rate is not None else 0.0 + result.null_rates[node_table.node_type] = rates + + def _tier2_edge_metrics( + self, + edge_table: EdgeTableSpec, + node_table: Optional[NodeTableSpec], + result: GraphAnalysisResult, + ) -> None: + edge_type = edge_table.edge_type + + # Scalar counts. + edge_count_query = EDGE_COUNT_QUERY.format(table=edge_table.bq_table) + result.edge_counts[edge_type] = self._query_scalar( + edge_count_query, + "edge_count", + block_id=f"graph_structure:edge_count:{edge_type}", + ) + duplicate_edges_query = DUPLICATE_EDGE_COUNT_QUERY.format( + table=edge_table.bq_table, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + result.duplicate_edge_counts[edge_type] = self._query_scalar( + duplicate_edges_query, + "duplicate_count", + block_id=f"data_quality:duplicate_edges:{edge_type}", + ) + self_loop_query = SELF_LOOP_COUNT_QUERY.format( + table=edge_table.bq_table, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + result.self_loop_counts[edge_type] = self._query_scalar( + self_loop_query, + "self_loop_count", + block_id=f"graph_structure:self_loops:{edge_type}", + ) + + # Super-hub INT16 clamp check (indexed by src). + super_hub_query = SUPER_HUB_INT16_CLAMP_QUERY.format( + table=edge_table.bq_table, id_column=edge_table.src_id_column + ) + result.super_hub_int16_clamp_count[edge_type] = self._query_scalar( + super_hub_query, + "super_hub_count", + block_id=f"graph_structure:super_hub_clamp:{edge_type}", + ) + + # Isolated and cold-start require a node table join. + if node_table is not None: + isolated_query = ISOLATED_NODE_COUNT_QUERY.format( + node_table=node_table.bq_table, + edge_table=edge_table.bq_table, + node_id_column=node_table.id_column, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + result.isolated_node_counts[edge_type] = self._query_scalar( + isolated_query, + "isolated_count", + block_id=f"graph_structure:isolated_nodes:{edge_type}", + ) + cold_start_query = COLD_START_NODE_COUNT_QUERY.format( + node_table=node_table.bq_table, + edge_table=edge_table.bq_table, + node_id_column=node_table.id_column, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + result.cold_start_node_counts[edge_type] = self._query_scalar( + cold_start_query, + "cold_start_count", + block_id=f"graph_structure:cold_start_nodes:{edge_type}", + ) + + # Top-K hubs (by src). + top_hubs_query = TOP_K_HUBS_QUERY.format( + table=edge_table.bq_table, + id_column=edge_table.src_id_column, + k=_TOP_K_HUBS, + ) + self._record_query(f"graph_structure:top_hubs:{edge_type}", top_hubs_query) + top_hub_rows = list(self._bq_utils.run_query(query=top_hubs_query, labels={})) + result.top_hubs[edge_type] = [ + (str(row["node_id"]), int(row["degree"])) for row in top_hub_rows + ] + + # Degree statistics: distribution + buckets, in + out directions. + for direction, id_column in ( + ("out", edge_table.src_id_column), + ("in", edge_table.dst_id_column), + ): + degree_key = f"{edge_type}_{direction}" + result.degree_stats[degree_key] = self._build_degree_stats( + table=edge_table.bq_table, + id_column=id_column, + block_id=f"graph_structure:degree:{degree_key}", + ) + + def _build_degree_stats( + self, table: str, id_column: str, *, block_id: Optional[str] = None + ) -> DegreeStats: + """Run degree distribution + bucket queries and pack into DegreeStats. + + When ``block_id`` is provided both rendered SQL strings are recorded + under that key (in distribution-then-bucket order) so the report can + show the full pair behind the histogram + summary line. + """ + dist_query = DEGREE_DISTRIBUTION_QUERY.format(table=table, id_column=id_column) + bucket_query = DEGREE_BUCKET_QUERY.format(table=table, id_column=id_column) + if block_id is not None: + self._record_query(block_id, dist_query) + self._record_query(block_id, bucket_query) + dist_rows = list(self._bq_utils.run_query(query=dist_query, labels={})) + bucket_rows = list(self._bq_utils.run_query(query=bucket_query, labels={})) + dist_row = dist_rows[0] + bucket_row = bucket_rows[0] + + percentiles_raw = list(dist_row["percentiles"]) + percentiles = [int(p) if p is not None else 0 for p in percentiles_raw] + # APPROX_QUANTILES(degree, 100) returns 101 values: index 0..100. + median = percentiles[50] if len(percentiles) > 50 else 0 + p90 = percentiles[90] if len(percentiles) > 90 else percentiles[-1] + p99 = percentiles[99] if len(percentiles) > 99 else percentiles[-1] + # We only have 100-bucket quantiles, so p999 ~= p99 as best-effort. + p999 = p99 + + # Bucket keys must match BUCKET_ORDER in report/charts.ai.js for the + # histogram to render correctly; keep uppercase K. + buckets: dict[str, int] = { + "0-1": int(bucket_row["bucket_0_1"]), + "2-10": int(bucket_row["bucket_2_10"]), + "11-100": int(bucket_row["bucket_11_100"]), + "101-1K": int(bucket_row["bucket_101_1k"]), + "1K-10K": int(bucket_row["bucket_1k_10k"]), + "10K+": int(bucket_row["bucket_10k_plus"]), + } + + return DegreeStats( + min=int(dist_row["min_degree"] or 0), + max=int(dist_row["max_degree"] or 0), + mean=float(dist_row["avg_degree"] or 0.0), + median=median, + p90=p90, + p99=p99, + p999=p999, + percentiles=percentiles, + buckets=buckets, + ) + + # ------------------------------------------------------------------ # + # Tier 3: label and heterogeneous # + # ------------------------------------------------------------------ # + + def _run_tier3( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + # Label-related checks per node table with a label column. + for node_table in config.node_tables: + if not node_table.label_column: + continue + class_imbalance_query = CLASS_IMBALANCE_QUERY.format( + table=node_table.bq_table, + label_column=node_table.label_column, + ) + self._record_query( + f"advanced:class_imbalance:{node_table.node_type}", + class_imbalance_query, + ) + class_rows = list( + self._bq_utils.run_query(query=class_imbalance_query, labels={}) + ) + result.class_imbalance[node_table.node_type] = { + str(row["label"]): int(row["count"]) for row in class_rows + } + + label_coverage_query = LABEL_COVERAGE_QUERY.format( + table=node_table.bq_table, + label_column=node_table.label_column, + ) + self._record_query( + f"advanced:label_coverage:{node_table.node_type}", + label_coverage_query, + ) + coverage_rows = list( + self._bq_utils.run_query(query=label_coverage_query, labels={}) + ) + if coverage_rows: + coverage = coverage_rows[0]["coverage"] + result.label_coverage[node_table.node_type] = ( + float(coverage) if coverage is not None else 0.0 + ) + + # Heterogeneous distribution only if more than one edge type. + if len(config.edge_tables) > 1: + for edge_table in config.edge_tables: + edge_type = edge_table.edge_type + # Edge-type distribution is effectively the edge count; reuse. + if edge_type in result.edge_counts: + result.edge_type_distribution[edge_type] = result.edge_counts[ + edge_type + ] + else: + edge_type_dist_query = EDGE_TYPE_DISTRIBUTION_QUERY.format( + table=edge_table.bq_table + ) + result.edge_type_distribution[edge_type] = self._query_scalar( + edge_type_dist_query, + "edge_count", + block_id=f"advanced:edge_type_distribution:{edge_type}", + ) + coverage_query = EDGE_TYPE_NODE_COVERAGE_QUERY.format( + table=edge_table.bq_table, + src_id_column=edge_table.src_id_column, + dst_id_column=edge_table.dst_id_column, + ) + self._record_query( + f"advanced:edge_type_node_coverage:{edge_type}", coverage_query + ) + coverage_rows = list( + self._bq_utils.run_query(query=coverage_query, labels={}) + ) + if coverage_rows: + row = coverage_rows[0] + result.edge_type_node_coverage[edge_type] = { + "distinct_src_count": int(row["distinct_src_count"] or 0), + "distinct_dst_count": int(row["distinct_dst_count"] or 0), + } + + # ------------------------------------------------------------------ # + # Node-classification supervision tier # + # ------------------------------------------------------------------ # + + def _run_node_classification_supervision( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Run NC-supervision-tier checks for every labeled node table. + + Activates whenever a ``NodeTableSpec.label_column`` is set. + Computes the BQ-side metrics that aren't covered by the TFDV + slicing in the feature profiler: + + 1. Sentinel-vs-NULL accounting on the label column. + 2. Per-class degree distribution (joining labels to a + message-passing edge table). + 3. Adjusted homophily on a sampled message-passing edge set + (raw + class-prior-adjusted, per Platonov et al. 2023). + 4. Optional label informativeness when + ``config.compute_label_informativeness`` is True. + 5. Cross-split node-id leakage (hard fail) when + ``NodeTableSpec.split_column`` is set. + + Hard fails (cross-split id overlap) raise + :class:`DataQualityError` with a partially populated result, just + like Tier 1. + """ + message_passing_tables = [ + edge + for edge in config.edge_tables + if edge.role == EDGE_ROLE_MESSAGE_PASSING + ] + violations: list[str] = [] + + for node_table in config.node_tables: + if node_table.label_column is None: + continue + + sentinel_stats = self._compute_label_sentinel_stats(node_table) + per_class_degree, sentinel_degree_stats = self._compute_per_class_degree( + node_table, message_passing_tables + ) + homophily = self._compute_homophily_for_node_type( + node_table, message_passing_tables, config + ) + cross_split_overlap = self._compute_cross_split_overlap(node_table) + + stats = NodeClassificationSupervisionStats( + node_type=node_table.node_type, + label_column=node_table.label_column, + sentinel_stats=sentinel_stats, + per_class_degree=per_class_degree, + sentinel_degree_stats=sentinel_degree_stats, + homophily=homophily, + cross_split_overlap=cross_split_overlap, + ) + result.node_classification_supervision_stats.append(stats) + + if ( + cross_split_overlap is not None + and cross_split_overlap.overlap_node_count > 0 + ): + violations.append( + f"node_type={node_table.node_type}: " + f"{cross_split_overlap.overlap_node_count} node_ids appear " + f"in more than one split (column " + f"{node_table.split_column!r})" + ) + + if violations: + msg = ( + "Node-classification supervision violations detected:\n - " + + "\n - ".join(violations) + ) + logger.error(msg) + raise DataQualityError(msg, partial_result=result) + + def _compute_label_sentinel_stats( + self, node_table: NodeTableSpec + ) -> LabelSentinelStats: + """Single-pass query splitting label cells into NULL / sentinel / valid.""" + assert ( + node_table.label_column is not None + ), "_compute_label_sentinel_stats requires NodeTableSpec.label_column" + query = build_label_sentinel_query( + table=node_table.bq_table, + label_column=node_table.label_column, + sentinel_values=node_table.label_sentinel_values, + ) + self._record_query( + f"nc_supervision:label_sentinel:{node_table.node_type}", query + ) + rows = list(self._bq_utils.run_query(query=query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Label sentinel query expected exactly 1 row; got {len(rows)}. " + f"node_type={node_table.node_type}" + ) + row = rows[0] + total_rows = int(row["total_rows"] or 0) + null_count = int(row["null_count"] or 0) + valid_count = int(row["valid_count"] or 0) + sentinel_counts: dict[str, int] = {} + for index, sentinel in enumerate(node_table.label_sentinel_values): + sentinel_counts[sentinel] = int(row[f"sentinel_{index}"] or 0) + coverage = (valid_count / total_rows) if total_rows > 0 else 0.0 + return LabelSentinelStats( + total_rows=total_rows, + null_count=null_count, + sentinel_counts=sentinel_counts, + valid_label_count=valid_count, + valid_label_coverage=coverage, + ) + + def _compute_per_class_degree( + self, + node_table: NodeTableSpec, + message_passing_tables: list[EdgeTableSpec], + ) -> tuple[list[PerClassDegreeStats], list[PerClassDegreeStats]]: + """Per-label-value degree distribution against a message-passing edge table. + + Only edge tables whose src or dst node_type matches the labeled + node_type are included. The edge-type identity is not preserved + on the result here because per-class degree is defined over total + degree (in + out) regardless of which edge table contributed it. + When multiple message-passing edge tables match, only the first + is used to keep the output flat — multi-edge-type per-class + degree is left for a future iteration. + + Returns a 2-tuple ``(per_class, sentinel)``: rows whose + ``class_value`` matches a declared sentinel in + ``node_table.label_sentinel_values`` are routed to ``sentinel``; + all other non-NULL label rows go to ``per_class``. + """ + matching = [ + edge_table + for edge_table in message_passing_tables + if node_table.node_type + in (edge_table.src_node_type, edge_table.dst_node_type) + ] + if not matching: + return [], [] + edge_table = matching[0] + if len(matching) > 1: + logger.info( + f"Per-class degree for node_type={node_table.node_type!r}: " + f"using first matching message-passing edge table " + f"{edge_table.edge_type!r} of {[m.edge_type for m in matching]}." + ) + + assert ( + node_table.label_column is not None + ), "_compute_per_class_degree requires NodeTableSpec.label_column" + query = build_per_class_degree_query( + node_table=node_table.bq_table, + node_id_column=node_table.id_column, + label_column=node_table.label_column, + edge_table=edge_table.bq_table, + edge_src_column=edge_table.src_id_column, + edge_dst_column=edge_table.dst_id_column, + ) + self._record_query( + f"nc_supervision:per_class_degree:{node_table.node_type}", query + ) + rows = list(self._bq_utils.run_query(query=query, labels={})) + sentinel_value_set = set(node_table.label_sentinel_values) + per_class: list[PerClassDegreeStats] = [] + sentinel: list[PerClassDegreeStats] = [] + for row in rows: + percentiles_raw = list(row["percentiles"]) if row["percentiles"] else [] + percentiles = [int(p) if p is not None else 0 for p in percentiles_raw] + median = percentiles[50] if len(percentiles) > 50 else 0 + p90 = ( + percentiles[90] + if len(percentiles) > 90 + else (percentiles[-1] if percentiles else 0) + ) + p99 = ( + percentiles[99] + if len(percentiles) > 99 + else (percentiles[-1] if percentiles else 0) + ) + # Bucket keys must match BUCKET_ORDER in report/charts.ai.js so the + # sparkline histogram lines up with the overall degree chart. + buckets: dict[str, int] = { + "0-1": int(row["bucket_0_1"] or 0), + "2-10": int(row["bucket_2_10"] or 0), + "11-100": int(row["bucket_11_100"] or 0), + "101-1K": int(row["bucket_101_1k"] or 0), + "1K-10K": int(row["bucket_1k_10k"] or 0), + "10K+": int(row["bucket_10k_plus"] or 0), + } + class_value = str(row["class_value"]) + stats = PerClassDegreeStats( + class_value=class_value, + count=int(row["class_count"] or 0), + cold_start_count=int(row["cold_start_count"] or 0), + mean_degree=float(row["mean_degree"] or 0.0), + median_degree=median, + p90_degree=p90, + p99_degree=p99, + max_degree=int(row["max_degree"] or 0), + buckets=buckets, + ) + if class_value in sentinel_value_set: + sentinel.append(stats) + else: + per_class.append(stats) + return per_class, sentinel + + def _compute_homophily_for_node_type( + self, + node_table: NodeTableSpec, + message_passing_tables: list[EdgeTableSpec], + config: DataAnalyzerConfig, + ) -> list[HomophilyStats]: + """Sampled adjusted homophily per (labeled node type, edge type). + + Edges are sampled to ``config.label_homophily_edge_sample_cap`` + via deterministic ``MOD(FARM_FINGERPRINT(...))`` filtering. The + modulus is computed from the edge table's row count so the + sampled set is ~= the cap; small graphs (count <= cap) skip + sampling entirely. + """ + out: list[HomophilyStats] = [] + for edge_table in message_passing_tables: + if node_table.node_type not in ( + edge_table.src_node_type, + edge_table.dst_node_type, + ): + continue + # Edge-count subquery here is unrelated to the per-edge-type one + # in Tier 2 — it gates only the sampling decision below — so we + # don't tag it for the report and just run it. + edge_count = self._query_scalar( + EDGE_COUNT_QUERY.format(table=edge_table.bq_table), "edge_count" + ) + cap = config.label_homophily_edge_sample_cap + if cap > 0 and edge_count > cap: + modulus = max(1, edge_count // cap) + sample_cap = cap + else: + modulus = 1 + sample_cap = 0 # signal "no sampling" + assert ( + node_table.label_column is not None + ), "_compute_homophily_for_node_type requires NodeTableSpec.label_column" + template = build_adjusted_homophily_query( + node_table=node_table.bq_table, + node_id_column=node_table.id_column, + label_column=node_table.label_column, + sentinel_values=node_table.label_sentinel_values, + edge_table=edge_table.bq_table, + edge_src_column=edge_table.src_id_column, + edge_dst_column=edge_table.dst_id_column, + sample_cap=sample_cap, + ) + query = template.replace("{modulus_placeholder}", str(modulus)) + self._record_query( + f"nc_supervision:homophily:{node_table.node_type}:" + f"{edge_table.edge_type}", + query, + ) + rows = list(self._bq_utils.run_query(query=query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Adjusted-homophily query expected exactly 1 row; got " + f"{len(rows)}. node_type={node_table.node_type}, " + f"edge_type={edge_table.edge_type}" + ) + row = rows[0] + edge_homophily_value = row["edge_homophily"] + expected_value = row["expected_homophily"] + edge_homophily = ( + float(edge_homophily_value) if edge_homophily_value is not None else 0.0 + ) + expected = float(expected_value) if expected_value is not None else 0.0 + if expected < 1.0: + adjusted = (edge_homophily - expected) / (1.0 - expected) + else: + adjusted = 0.0 + out.append( + HomophilyStats( + edge_type=edge_table.edge_type, + edge_homophily=edge_homophily, + adjusted_homophily=adjusted, + edge_sample_count=int(row["edge_sample_count"] or 0), + label_informativeness=None, + ) + ) + return out + + def _compute_cross_split_overlap( + self, node_table: NodeTableSpec + ) -> Optional[CrossSplitOverlap]: + """Cross-split id leakage + per-split row counts. Returns None if no split_column.""" + if node_table.split_column is None: + return None + block_id = f"nc_supervision:cross_split:{node_table.node_type}" + cross_split_query = CROSS_SPLIT_OVERLAP_QUERY.format( + table=node_table.bq_table, + id_column=node_table.id_column, + split_column=node_table.split_column, + ) + overlap_count = self._query_scalar( + cross_split_query, "overlap_node_count", block_id=block_id + ) + split_value_query = SPLIT_VALUE_COUNTS_QUERY.format( + table=node_table.bq_table, + split_column=node_table.split_column, + ) + self._record_query(block_id, split_value_query) + split_rows = list(self._bq_utils.run_query(query=split_value_query, labels={})) + split_value_counts: dict[str, int] = { + str(row["split_value"]): int(row["row_count"] or 0) for row in split_rows + } + return CrossSplitOverlap( + overlap_node_count=overlap_count, + split_value_counts=split_value_counts, + ) + + # ------------------------------------------------------------------ # + # Supervision cross-table analysis # + # ------------------------------------------------------------------ # + + def _run_supervision_cross_table( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Run cross-table per-anchor stats for supervision edge tables. + + For every ``supervision_pos`` table we pair it with each + ``supervision_neg`` and ``message_passing`` table that shares its + ``(src_node_type, dst_node_type)``, then compute per-anchor edge + counts and label-leakage overlap. Each ``supervision_neg`` table + also drives a pass against matching ``message_passing`` tables so + the report can flag (negative-edge ∩ message-passing) leaks. Jobs + run in parallel via ``ThreadPoolExecutor`` (BQ is I/O-bound). + """ + pos_tables = [ + e for e in config.edge_tables if e.role == EDGE_ROLE_SUPERVISION_POS + ] + neg_tables = [ + e for e in config.edge_tables if e.role == EDGE_ROLE_SUPERVISION_NEG + ] + # Treat unset role as message_passing (default), matching backfill behavior. + mp_tables = [ + e + for e in config.edge_tables + if e.role is None or e.role == EDGE_ROLE_MESSAGE_PASSING + ] + + jobs: list[tuple[EdgeTableSpec, EdgeTableSpec, str]] = [] + + # Driver = positive: pair with every neg / mp sharing (src_type, dst_type). + for pos in pos_tables: + assert pos.node_anchor is not None, ( + f"edge_type={pos.edge_type}: supervision_pos must have node_anchor; " + "load the config via load_analyzer_config to enforce this." + ) + for other in neg_tables + mp_tables: + if (pos.src_node_type, pos.dst_node_type) == ( + other.src_node_type, + other.dst_node_type, + ): + jobs.append((pos, other, pos.node_anchor)) + + # Driver = negative: pair with mp sharing (src_type, dst_type). Anchor + # is the negative's own node_anchor when set, else inherited from a + # matching positive table to keep configs concise. + for neg in neg_tables: + anchor = neg.node_anchor or self._inherit_anchor_from_pos(neg, pos_tables) + if anchor is None: + continue + for mp in mp_tables: + if (neg.src_node_type, neg.dst_node_type) == ( + mp.src_node_type, + mp.dst_node_type, + ): + jobs.append((neg, mp, anchor)) + + if not jobs: + return + + with ThreadPoolExecutor(max_workers=_PARALLEL_BQ_WORKERS) as executor: + futures = [ + executor.submit(self._supervision_pair_stats, driver, other, anchor) + for driver, other, anchor in jobs + ] + for future in futures: + stats = future.result() + if stats is not None: + result.supervision_cross_table_stats.append(stats) + + @staticmethod + def _inherit_anchor_from_pos( + neg: EdgeTableSpec, pos_tables: list[EdgeTableSpec] + ) -> Optional[str]: + """Return the node_anchor of any positive table sharing neg's node types. + + Lets users declare ``node_anchor`` once on the positive table and + skip duplicating it on the matching negative. + """ + for pos in pos_tables: + if (pos.src_node_type, pos.dst_node_type) == ( + neg.src_node_type, + neg.dst_node_type, + ): + return pos.node_anchor + return None + + @staticmethod + def _resolve_anchor_columns( + edge_table: EdgeTableSpec, node_anchor: str + ) -> Optional[tuple[str, str]]: + """Return (anchor_column, other_column) for the given anchor node_type. + + If ``node_anchor`` matches both src and dst (homogeneous self-loop + edge), prefer the src side. Returns ``None`` if it matches neither. + """ + if node_anchor == edge_table.src_node_type: + return edge_table.src_id_column, edge_table.dst_id_column + if node_anchor == edge_table.dst_node_type: + return edge_table.dst_id_column, edge_table.src_id_column + return None + + def _supervision_pair_stats( + self, + driver: EdgeTableSpec, + other: EdgeTableSpec, + node_anchor: str, + ) -> Optional[SupervisionCrossTableStats]: + """Run the cross-table query for one (driver, other) pair. + + Returns ``None`` (and logs a warning) when the anchor cannot be + resolved on one of the two tables — happens only on misconfigured + heterogeneous pairs and should not abort the whole run. + """ + driver_cols = self._resolve_anchor_columns(driver, node_anchor) + other_cols = self._resolve_anchor_columns(other, node_anchor) + if driver_cols is None or other_cols is None: + logger.warning( + f"Skipping supervision pair driver={driver.edge_type!r} " + f"other={other.edge_type!r}: node_anchor={node_anchor!r} not " + "present on both tables." + ) + return None + + driver_anchor_column, driver_other_column = driver_cols + other_anchor_column, other_other_column = other_cols + + query = SUPERVISION_CROSS_TABLE_QUERY.format( + driver_table=driver.bq_table, + other_table=other.bq_table, + driver_anchor_column=driver_anchor_column, + driver_other_column=driver_other_column, + other_anchor_column=other_anchor_column, + other_other_column=other_other_column, + ) + self._record_query( + f"supervision_overlap:{driver.edge_type}:{other.edge_type}:" + f"{driver_anchor_column}:{other_anchor_column}", + query, + ) + rows = list(self._bq_utils.run_query(query=query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Supervision cross-table query expected exactly 1 row; " + f"got {len(rows)}. driver={driver.edge_type} other={other.edge_type}" + ) + row = rows[0] + avg_value = row["avg_other_per_driver_anchor"] + return SupervisionCrossTableStats( + driver_edge_type=driver.edge_type, + driver_role=driver.role or EDGE_ROLE_MESSAGE_PASSING, + other_edge_type=other.edge_type, + other_role=other.role or EDGE_ROLE_MESSAGE_PASSING, + node_anchor=node_anchor, + driver_anchor_count=int(row["driver_anchor_count"] or 0), + driver_pair_count=int(row["driver_pair_count"] or 0), + other_pair_count=int(row["other_pair_count"] or 0), + overlap_pair_count=int(row["overlap_pair_count"] or 0), + driver_anchors_with_zero_other=int( + row["driver_anchors_with_zero_other"] or 0 + ), + avg_other_per_driver_anchor=float(avg_value) + if avg_value is not None + else 0.0, + p50_other_per_driver_anchor=int(row["p50_other_per_driver_anchor"] or 0), + p90_other_per_driver_anchor=int(row["p90_other_per_driver_anchor"] or 0), + p99_other_per_driver_anchor=int(row["p99_other_per_driver_anchor"] or 0), + max_other_per_driver_anchor=int(row["max_other_per_driver_anchor"] or 0), + ) + + # ------------------------------------------------------------------ # + # Tier 4: opt-in # + # ------------------------------------------------------------------ # + + def _run_tier4( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Populate opt-in metrics gated by config flags. + + Power-law exponent is always cheap (derived from existing degree stats) + and is computed whenever degree stats are available. Reciprocity, + homophily, connected components and clustering require dedicated + queries not yet defined; they remain empty unless the corresponding + flag is enabled AND a query is implemented. + """ + # Power-law exponent: approximate from degree stats using a simple + # heuristic: alpha ~= 1 + log(max) / log(median) for median > 1. + for degree_key, stats in result.degree_stats.items(): + if stats.median > 1 and stats.max > stats.median: + exponent = 1.0 + math.log(stats.max) / math.log(stats.median) + result.power_law_exponent[degree_key] = exponent + + if config.compute_reciprocity: + # Query not yet defined; log and skip. + logger.warning( + "compute_reciprocity=True but reciprocity query is not implemented; " + "skipping Tier 4 reciprocity." + ) + + # ------------------------------------------------------------------ # + # Python-only computations # + # ------------------------------------------------------------------ # + + def _compute_feature_memory_budget( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Estimate per-node-type memory footprint of features (float64 assumed).""" + for node_table in config.node_tables: + node_count = result.node_counts.get(node_table.node_type, 0) + num_features = len(node_table.feature_columns) + result.feature_memory_bytes[node_table.node_type] = ( + node_count * num_features * _BYTES_PER_FEATURE + ) + + def _compute_neighbor_explosion_estimate( + self, config: DataAnalyzerConfig, result: GraphAnalysisResult + ) -> None: + """Multiply fan-out factors and scale by out-degree mean per edge type.""" + if not config.fan_out: + return + fan_out_product = 1 + for hop in config.fan_out: + fan_out_product *= int(hop) + for edge_table in config.edge_tables: + out_stats = result.degree_stats.get(f"{edge_table.edge_type}_out") + if out_stats is None: + continue + estimate = int(fan_out_product * max(out_stats.mean, 1.0)) + result.neighbor_explosion_estimate[edge_table.edge_type] = estimate + + # ------------------------------------------------------------------ # + # Helpers # + # ------------------------------------------------------------------ # + + def _query_scalar( + self, query: str, column: str, *, block_id: Optional[str] = None + ) -> int: + """Run a single-row, single-column query and return the scalar as int. + + Scalar queries (COUNT, COUNTIF) must return exactly one row with a + non-NULL value for the requested column. Any deviation indicates a + driver, auth, or schema mismatch rather than legitimate data — raise + loudly instead of silently coercing to 0, which would let a broken run + pass through as a green-light result. + + When ``block_id`` is provided the rendered SQL is recorded under + that key in ``self._query_log`` so the report can surface it. + """ + if block_id is not None: + self._record_query(block_id, query) + rows = list(self._bq_utils.run_query(query=query, labels={})) + if len(rows) != 1: + raise RuntimeError( + f"Scalar query expected exactly 1 row; got {len(rows)}. " + f"Query: {query.strip()[:200]}" + ) + value = rows[0][column] + if value is None: + raise RuntimeError( + f"Scalar query returned NULL for column '{column}'. " + f"Query: {query.strip()[:200]}" + ) + return int(value) + + def _record_query(self, block_id: str, query: str) -> None: + """Append ``query`` under ``block_id`` in the per-block SQL log. + + The report JS does dict lookups against ``GraphAnalysisResult.queries`` + keyed by the same ``block_id`` strings. CPython's GIL makes + ``dict.setdefault`` and ``list.append`` atomic, so concurrent writes + from the Tier-2 thread pool are safe without an explicit lock. + """ + self._query_log.setdefault(block_id, []).append(query) diff --git a/gigl/analytics/data_analyzer/queries.py b/gigl/analytics/data_analyzer/queries.py new file mode 100644 index 000000000..fedb57b3c --- /dev/null +++ b/gigl/analytics/data_analyzer/queries.py @@ -0,0 +1,485 @@ +"""SQL query templates for graph structure analysis. + +Each constant is a format-string template parameterized with table names +and column names. Pattern matches gigl/src/data_preprocessor/lib/enumerate/queries.py. +""" + +import torch + +INT16_MAX = int(torch.iinfo(torch.int16).max) # 32767 + +# --- Tier 1: Hard fails --- + +DANGLING_EDGES_QUERY = """ +SELECT COUNT(*) AS dangling_count +FROM `{table}` +WHERE {src_id_column} IS NULL OR {dst_id_column} IS NULL +""" + +EDGE_REFERENTIAL_INTEGRITY_QUERY = """ +SELECT + COUNTIF(src_node.{src_node_id_column} IS NULL) AS missing_src_count, + COUNTIF(dst_node.{dst_node_id_column} IS NULL) AS missing_dst_count +FROM `{edge_table}` AS e +LEFT JOIN `{src_node_table}` AS src_node + ON e.{src_id_column} = src_node.{src_node_id_column} +LEFT JOIN `{dst_node_table}` AS dst_node + ON e.{dst_id_column} = dst_node.{dst_node_id_column} +""" + +DUPLICATE_NODE_COUNT_QUERY = """ +SELECT COUNT(*) AS duplicate_count FROM ( + SELECT {id_column} + FROM `{table}` + GROUP BY {id_column} + HAVING COUNT(*) > 1 +) +""" + +# --- Tier 2: Core metrics --- + +NODE_COUNT_QUERY = """ +SELECT COUNT(*) AS node_count FROM `{table}` +""" + +EDGE_COUNT_QUERY = """ +SELECT COUNT(*) AS edge_count FROM `{table}` +""" + +DUPLICATE_EDGE_COUNT_QUERY = """ +SELECT COUNT(*) AS duplicate_count FROM ( + SELECT {src_id_column}, {dst_id_column} + FROM `{table}` + GROUP BY {src_id_column}, {dst_id_column} + HAVING COUNT(*) > 1 +) +""" + +SELF_LOOP_COUNT_QUERY = """ +SELECT COUNT(*) AS self_loop_count +FROM `{table}` +WHERE {src_id_column} = {dst_id_column} +""" + +ISOLATED_NODE_COUNT_QUERY = """ +SELECT COUNT(*) AS isolated_count FROM ( + SELECT n.{node_id_column} + FROM `{node_table}` AS n + LEFT JOIN `{edge_table}` AS e_src + ON n.{node_id_column} = e_src.{src_id_column} + LEFT JOIN `{edge_table}` AS e_dst + ON n.{node_id_column} = e_dst.{dst_id_column} + WHERE e_src.{src_id_column} IS NULL + AND e_dst.{dst_id_column} IS NULL +) +""" + +DEGREE_DISTRIBUTION_QUERY = """ +SELECT + MIN(degree) AS min_degree, + MAX(degree) AS max_degree, + AVG(degree) AS avg_degree, + APPROX_QUANTILES(degree, 100) AS percentiles +FROM ( + SELECT {id_column}, COUNT(*) AS degree + FROM `{table}` + GROUP BY {id_column} +) +""" + +DEGREE_BUCKET_QUERY = """ +SELECT + COUNTIF(degree BETWEEN 0 AND 1) AS bucket_0_1, + COUNTIF(degree BETWEEN 2 AND 10) AS bucket_2_10, + COUNTIF(degree BETWEEN 11 AND 100) AS bucket_11_100, + COUNTIF(degree BETWEEN 101 AND 1000) AS bucket_101_1k, + COUNTIF(degree BETWEEN 1001 AND 10000) AS bucket_1k_10k, + COUNTIF(degree > 10000) AS bucket_10k_plus +FROM ( + SELECT {id_column}, COUNT(*) AS degree + FROM `{table}` + GROUP BY {id_column} +) +""" + +TOP_K_HUBS_QUERY = """ +SELECT {id_column} AS node_id, COUNT(*) AS degree +FROM `{table}` +GROUP BY {id_column} +ORDER BY degree DESC +LIMIT {k} +""" + +SUPER_HUB_INT16_CLAMP_QUERY = f""" +SELECT COUNT(*) AS super_hub_count FROM ( + SELECT {{id_column}}, COUNT(*) AS degree + FROM `{{table}}` + GROUP BY {{id_column}} + HAVING COUNT(*) > {INT16_MAX} +) +""" + +COLD_START_NODE_COUNT_QUERY = """ +SELECT COUNT(*) AS cold_start_count FROM ( + SELECT n.{node_id_column}, COALESCE(e.degree, 0) AS degree + FROM `{node_table}` AS n + LEFT JOIN ( + SELECT nid, COUNT(*) AS degree FROM ( + SELECT {src_id_column} AS nid FROM `{edge_table}` + UNION ALL + SELECT {dst_id_column} AS nid FROM `{edge_table}` + ) + GROUP BY nid + ) AS e ON n.{node_id_column} = e.nid + WHERE COALESCE(e.degree, 0) <= 1 +) +""" + +# --- Tier 3: Label and heterogeneous --- + +CLASS_IMBALANCE_QUERY = """ +SELECT {label_column} AS label, COUNT(*) AS count +FROM `{table}` +WHERE {label_column} IS NOT NULL +GROUP BY {label_column} +ORDER BY count DESC +""" + +LABEL_COVERAGE_QUERY = """ +SELECT + COUNT(*) AS total, + COUNTIF({label_column} IS NOT NULL) AS labeled, + SAFE_DIVIDE(COUNTIF({label_column} IS NOT NULL), COUNT(*)) AS coverage +FROM `{table}` +""" + +EDGE_TYPE_DISTRIBUTION_QUERY = """ +SELECT COUNT(*) AS edge_count FROM `{table}` +""" + +EDGE_TYPE_NODE_COVERAGE_QUERY = """ +SELECT + APPROX_COUNT_DISTINCT({src_id_column}) AS distinct_src_count, + APPROX_COUNT_DISTINCT({dst_id_column}) AS distinct_dst_count +FROM `{table}` +""" + + +# --- Supervision cross-table analysis --- + +SUPERVISION_CROSS_TABLE_QUERY = """ +WITH driver_pairs AS ( + SELECT DISTINCT + {driver_anchor_column} AS anchor, + {driver_other_column} AS neighbor + FROM `{driver_table}` + WHERE {driver_anchor_column} IS NOT NULL + AND {driver_other_column} IS NOT NULL +), +other_pairs AS ( + SELECT DISTINCT + {other_anchor_column} AS anchor, + {other_other_column} AS neighbor + FROM `{other_table}` + WHERE {other_anchor_column} IS NOT NULL + AND {other_other_column} IS NOT NULL +), +driver_anchors AS ( + SELECT DISTINCT anchor FROM driver_pairs +), +other_per_driver_anchor AS ( + SELECT driver_anchors.anchor, + COALESCE(other_counts.cnt, 0) AS cnt + FROM driver_anchors + LEFT JOIN ( + SELECT anchor, COUNT(*) AS cnt FROM other_pairs GROUP BY anchor + ) AS other_counts USING (anchor) +) +SELECT + (SELECT COUNT(*) FROM driver_anchors) AS driver_anchor_count, + (SELECT COUNT(*) FROM driver_pairs) AS driver_pair_count, + (SELECT COUNT(*) FROM other_pairs) AS other_pair_count, + ( + SELECT COUNT(*) + FROM driver_pairs + INNER JOIN other_pairs USING (anchor, neighbor) + ) AS overlap_pair_count, + (SELECT COUNTIF(cnt = 0) FROM other_per_driver_anchor) + AS driver_anchors_with_zero_other, + (SELECT AVG(cnt) FROM other_per_driver_anchor) + AS avg_other_per_driver_anchor, + (SELECT APPROX_QUANTILES(cnt, 100)[OFFSET(50)] FROM other_per_driver_anchor) + AS p50_other_per_driver_anchor, + (SELECT APPROX_QUANTILES(cnt, 100)[OFFSET(90)] FROM other_per_driver_anchor) + AS p90_other_per_driver_anchor, + (SELECT APPROX_QUANTILES(cnt, 100)[OFFSET(99)] FROM other_per_driver_anchor) + AS p99_other_per_driver_anchor, + (SELECT MAX(cnt) FROM other_per_driver_anchor) + AS max_other_per_driver_anchor +""" + + +# --- Node-classification supervision tier --- + + +def build_label_sentinel_query( + table: str, label_column: str, sentinel_values: list[str] +) -> str: + """Build a single-pass query that splits label cells into NULL / sentinel / valid. + + Sentinel values are interpolated as quoted string literals; callers + must ensure values come from a trusted config (the analyzer config + is loaded by ``load_analyzer_config`` which already validates the + structure of the YAML it reads). The label column is cast to STRING + in the comparison so integer and string sentinels both work. + + Args: + table: Fully qualified BQ table name. + label_column: Column whose cells we're bucketing. + sentinel_values: Strings that should be classified as sentinels + distinct from SQL NULL. + + Returns: + SQL query string returning one row with columns ``total_rows``, + ``null_count``, ``valid_count``, and one ``sentinel_`` count + per sentinel value (in declaration order). + """ + sentinel_clauses = ",\n ".join( + f"COUNTIF(CAST({label_column} AS STRING) = " + f"{_sql_string_literal(sentinel)}) AS sentinel_{idx}" + for idx, sentinel in enumerate(sentinel_values) + ) + sentinel_in_list = ( + ", ".join(_sql_string_literal(s) for s in sentinel_values) + if sentinel_values + else None + ) + valid_clause = ( + f"COUNTIF({label_column} IS NOT NULL " + f"AND CAST({label_column} AS STRING) NOT IN ({sentinel_in_list})) AS valid_count" + if sentinel_in_list is not None + else f"COUNTIF({label_column} IS NOT NULL) AS valid_count" + ) + extra = f",\n {sentinel_clauses}" if sentinel_clauses else "" + return f""" +SELECT + COUNT(*) AS total_rows, + COUNTIF({label_column} IS NULL) AS null_count, + {valid_clause}{extra} +FROM `{table}` +""" + + +def _sql_string_literal(value: str) -> str: + """Quote a string for safe inline use in BQ SQL. + + Escapes single quotes and backslashes; no other characters are + transformed. Sentinel values flow into ``IN`` lists so we control + the surrounding context. Anything more invasive (parameterized + queries) would require restructuring how every other query in this + module is built. + """ + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + + +def build_per_class_degree_query( + node_table: str, + node_id_column: str, + label_column: str, + edge_table: str, + edge_src_column: str, + edge_dst_column: str, +) -> str: + """Per-label-value degree distribution joining labeled nodes to a message-passing edge table. + + Computes for each distinct non-NULL label value: count of class + members, count with total degree <= 1 (cold-start), and degree + distribution (mean / median / p90 / p99 / max). NULL labels are + excluded — they are accounted for separately in + :class:`LabelSentinelStats`. Sentinel-declared values (e.g. ``-1``) + are *not* filtered out and surface as their own rows; the caller is + responsible for partitioning the result into "valid class" vs + "sentinel" using its own ``label_sentinel_values``. + + Returns one row per distinct non-NULL label value. + """ + return f""" +WITH node_degrees AS ( + SELECT nid, COUNT(*) AS degree FROM ( + SELECT {edge_src_column} AS nid FROM `{edge_table}` + UNION ALL + SELECT {edge_dst_column} AS nid FROM `{edge_table}` + ) + GROUP BY nid +), +labeled AS ( + SELECT + CAST(n.{label_column} AS STRING) AS class_value, + COALESCE(d.degree, 0) AS degree + FROM `{node_table}` AS n + LEFT JOIN node_degrees AS d + ON n.{node_id_column} = d.nid + WHERE n.{label_column} IS NOT NULL +) +SELECT + class_value, + COUNT(*) AS class_count, + COUNTIF(degree <= 1) AS cold_start_count, + AVG(degree) AS mean_degree, + APPROX_QUANTILES(degree, 100) AS percentiles, + MAX(degree) AS max_degree, + COUNTIF(degree BETWEEN 0 AND 1) AS bucket_0_1, + COUNTIF(degree BETWEEN 2 AND 10) AS bucket_2_10, + COUNTIF(degree BETWEEN 11 AND 100) AS bucket_11_100, + COUNTIF(degree BETWEEN 101 AND 1000) AS bucket_101_1k, + COUNTIF(degree BETWEEN 1001 AND 10000) AS bucket_1k_10k, + COUNTIF(degree > 10000) AS bucket_10k_plus +FROM labeled +GROUP BY class_value +ORDER BY class_count DESC +""" + + +def build_adjusted_homophily_query( + node_table: str, + node_id_column: str, + label_column: str, + sentinel_values: list[str], + edge_table: str, + edge_src_column: str, + edge_dst_column: str, + sample_cap: int, +) -> str: + """Edge homophily and class-prior-adjusted homophily on a sampled edge set. + + Adjusted homophily is computed per Platonov et al., NeurIPS 2023: + + adjusted = (h_edge - sum_c (D_c / 2|E|)^2) + / (1 - sum_c (D_c / 2|E|)^2) + + where ``D_c`` is the sum of degrees of nodes in class ``c`` over the + sampled edge set. Values near 0 mean "no signal beyond class + priors"; positive is homophilic, negative heterophilic. + + Edges are sampled by ``MOD(FARM_FINGERPRINT(...), modulus) = 0`` so + sampling is deterministic and consistent across reruns. ``sample_cap + = 0`` means full-graph (no sampling). + + Returns one row with: ``edge_homophily``, ``expected_homophily`` + (the class-prior baseline), ``adjusted_homophily`` (computed in + Python from the two columns above), and ``edge_sample_count``. + """ + sentinel_filter_src = "" + sentinel_filter_dst = "" + if sentinel_values: + sentinel_in_list = ", ".join(_sql_string_literal(s) for s in sentinel_values) + sentinel_filter_src = ( + f"AND CAST(s.{label_column} AS STRING) NOT IN ({sentinel_in_list})" + ) + sentinel_filter_dst = ( + f"AND CAST(d.{label_column} AS STRING) NOT IN ({sentinel_in_list})" + ) + + sample_filter = ( + "" + if sample_cap <= 0 + else ( + f"WHERE MOD(ABS(FARM_FINGERPRINT(CONCAT(" + f"CAST({edge_src_column} AS STRING), '|', " + f"CAST({edge_dst_column} AS STRING)))), {{modulus_placeholder}}) = 0" + ) + ) + # We pass {modulus_placeholder} verbatim and let the caller fill it + # in based on the cardinality of the edge table, so the same SQL + # template is used for any sample size. + return f""" +WITH sampled_edges AS ( + SELECT {edge_src_column} AS src_id, {edge_dst_column} AS dst_id + FROM `{edge_table}` + {sample_filter} +), +labeled_pairs AS ( + SELECT + CAST(s.{label_column} AS STRING) AS src_label, + CAST(d.{label_column} AS STRING) AS dst_label + FROM sampled_edges AS e + JOIN `{node_table}` AS s + ON e.src_id = s.{node_id_column} + JOIN `{node_table}` AS d + ON e.dst_id = d.{node_id_column} + WHERE s.{label_column} IS NOT NULL + AND d.{label_column} IS NOT NULL + {sentinel_filter_src} + {sentinel_filter_dst} +), +endpoint_classes AS ( + SELECT label, COUNT(*) AS endpoint_count FROM ( + SELECT src_label AS label FROM labeled_pairs + UNION ALL + SELECT dst_label AS label FROM labeled_pairs + ) + GROUP BY label +), +totals AS ( + SELECT SUM(endpoint_count) AS total_endpoints FROM endpoint_classes +) +SELECT + SAFE_DIVIDE(COUNTIF(src_label = dst_label), COUNT(*)) AS edge_homophily, + ( + SELECT SUM(POW(SAFE_DIVIDE(endpoint_count, total_endpoints), 2)) + FROM endpoint_classes, totals + ) AS expected_homophily, + COUNT(*) AS edge_sample_count +FROM labeled_pairs +""" + + +CROSS_SPLIT_OVERLAP_QUERY = """ +SELECT + ( + SELECT COUNT(*) FROM ( + SELECT {id_column} + FROM `{table}` + WHERE {id_column} IS NOT NULL + AND {split_column} IS NOT NULL + GROUP BY {id_column} + HAVING COUNT(DISTINCT {split_column}) > 1 + ) + ) AS overlap_node_count +""" + + +SPLIT_VALUE_COUNTS_QUERY = """ +SELECT + CAST({split_column} AS STRING) AS split_value, + COUNT(*) AS row_count +FROM `{table}` +WHERE {split_column} IS NOT NULL +GROUP BY split_value +ORDER BY row_count DESC +""" + + +def build_null_rates_query(table: str, columns: list[str]) -> str: + """Build a batched NULL rates query for multiple columns. + + One query, one table scan, one COUNTIF per column. + + Args: + table: Fully qualified BQ table name. + columns: List of column names to check. + + Returns: + SQL query string. + """ + countif_clauses = ",\n ".join( + f"SAFE_DIVIDE(COUNTIF({col} IS NULL), COUNT(*)) AS {col}_null_rate" + for col in columns + ) + return f""" +SELECT + COUNT(*) AS total_rows, + {countif_clauses} +FROM `{table}` +""" diff --git a/gigl/analytics/data_analyzer/report/PRD.md b/gigl/analytics/data_analyzer/report/PRD.md new file mode 100644 index 000000000..9888e676c --- /dev/null +++ b/gigl/analytics/data_analyzer/report/PRD.md @@ -0,0 +1,166 @@ +# PRD: BQ Data Analyzer HTML Report + +## Status + +**AI-owned.** An AI agent reads this PRD together with the sibling `SPEC.md` and regenerates `report.ai.html`, +`charts.ai.js`, and `styles.ai.css` when the product intent or technical contract changes. This PRD describes *why* and +*what*; `SPEC.md` describes *how*. + +## Problem + +Before training a GNN on graph data in BigQuery, engineers need a fast way to see whether the data is healthy enough to +train on. Today they find out only after a Dataflow job crashes or a trainer produces a poor model, which costs days and +thousands of dollars per iteration. + +A review of 18 production GNN papers ([reference doc](../../../docs/plans/20260415-bq-data-analyzer-references.md)) +found that graph-specific data properties drive 30-230% model quality differences. None of these are caught by standard +tabular data quality tools. We need a report that surfaces these graph-specific issues in a form engineers can act on in +minutes, not days. + +## Users + +| Persona | Primary need | Frequency | +| ---------------------------------------- | ------------------------------------------------------------------------- | -------------------------- | +| **GNN engineer running an applied task** | Decide whether a new BQ dataset is trainable, and if not, what to fix | Per new dataset or refresh | +| **Applied task reviewer / tech lead** | Sanity-check a teammate's dataset choices before approving a training run | Per PR | +| **On-call engineer** | Triage why a training run degraded vs last week | Per incident | + +Out of scope: data scientists doing generic exploratory data analysis, product managers, non-technical stakeholders. + +## User Stories + +1. **As a GNN engineer**, I point the analyzer at a new BQ node/edge table pair and open the resulting HTML report. + Within 30 seconds of scrolling I know whether the dataset has any training-blocking issues (dangling edges, + referential integrity, duplicates). +2. **As a GNN engineer**, I inspect the degree distribution histogram for each edge type and decide whether my planned + fan-out is realistic or will cause neighbor explosion. +3. **As a reviewer**, I share the GCS link to the report in a PR comment. My teammate opens it in a browser without + installing anything. +4. **As an on-call engineer**, I run the analyzer on today's data and last week's data and diff the two reports to see + what changed. +5. **As any of the above**, I expand the collapsed sections I do not care about so the overview stays scannable. + +## Goals + +1. **Zero-setup viewing.** The report opens in any modern browser with no server, no CDN, no authentication beyond the + GCS link. Works offline once downloaded. +2. **Action-oriented.** Every numeric finding is color-coded against a literature-derived threshold (green/yellow/red) + so the reader knows what to do about it. +3. **Traceable.** Every color-coded threshold and every check cites the paper or codebase location that justifies it, so + readers can verify claims. +4. **Portable.** A single `.html` file that can be shared in chat, stored indefinitely in GCS, and archived alongside + the training run it describes. +5. **Graph-native.** Surfaces metrics that matter for GNNs specifically (degree distribution, super-hub int16 clamp, + cold-start fraction, homophily, neighbor explosion), not just generic tabular stats. +6. **AI-regenerable.** The three `.ai.*` assets can be regenerated deterministically from this PRD plus `SPEC.md` + without human intervention on the HTML/JS/CSS. + +## Non-Goals + +- **Not a real-time monitoring dashboard.** Aegis covers that + ([Phase 2](../../../docs/plans/20260415-bq-data-analyzer.md#aegis-integration-phase-2)). This report is a + point-in-time snapshot. +- **Not a BI tool.** No filtering, drill-down, or ad-hoc querying. The report is a rendered artifact, not an interactive + app. +- **Not cross-dataset comparison.** Diffing reports is a user workflow (open two tabs), not a report feature. +- **Not a model evaluation report.** This is about training data, not trained model performance. +- **Not accessible (WCAG AA) in v1.** We document this gap and will address it if the report is used by users who need + it. + +## Functional Requirements + +Each requirement maps to a section of `SPEC.md` where the implementation contract lives. + +**FR-1: Overview at a glance.** The first screen (above the fold) shows total nodes, total edges, node/edge type counts, +and a single green/yellow/red status light summarizing the worst issue found. Rationale: engineers decide "do I need to +look deeper" in the first 5 seconds. + +**FR-2: Hard-fail visibility.** Dangling edges, referential integrity violations, and duplicate nodes render red +regardless of magnitude. These block training entirely. The report shows them prominently even if count is exactly one. +Rationale: [GiGL](../../../docs/plans/20260415-bq-data-analyzer-references.md#6-gigl), +[AliGraph (7.1)](../../../docs/plans/20260415-bq-data-analyzer-references.md#7-aligraph) — silent NaN propagation from +referential integrity violations is a production-documented failure mode. + +**FR-3: Degree distribution per edge type.** Inline SVG histogram using the six literature-aligned buckets: `0-1`, +`2-10`, `11-100`, `101-1K`, `1K-10K`, `10K+`. Separate in-degree and out-degree. Rationale: +[BLADE](../../../docs/plans/20260415-bq-data-analyzer-references.md#3-blade) showed 230% embedding improvement from +degree-adaptive neighborhoods; the reader needs to see which buckets dominate. + +**FR-4: Super-hub warning.** A red call-out appears when any node exceeds the GiGL int16 degree clamp (32,767). Include +the count and the affected edge type. Rationale: +[GiGL (6.2)](../../../docs/plans/20260415-bq-data-analyzer-references.md#6-gigl) — the clamp is silent in production and +corrupts PPR sampling probabilities. Users have no other way to discover this. + +**FR-5: Cold-start visibility.** Show the count and fraction of degree-0-1 nodes per type. Color-code the fraction +against the 5% / 10% threshold. Rationale: +[LiGNN (4.1)](../../../docs/plans/20260415-bq-data-analyzer-references.md#4-lignn) — +0.28% AUC from cold-start +densification; the reader decides whether densification is worth investigating. + +**FR-6: Optional Tier 3 visibility.** Class imbalance, label coverage, edge type distribution, and per-edge-type node +coverage are shown only when the input data supports them. Rationale: a report full of "not applicable" sections is +noise. + +**FR-7: Embedded FACETS.** When feature profiling is available, the FACETS HTML output is embedded inline via +`