From 71f058214927a5784ffdca0d3c1320e904055fdd Mon Sep 17 00:00:00 2001 From: ebembi-crdb Date: Thu, 14 May 2026 19:15:03 +0530 Subject: [PATCH] Add automated SQL command testing for docs (EDUENG-131) Add infrastructure to extract SQL code blocks from markdown documentation and execute them against a CockroachDB cluster to verify correctness. Blocks are classified as executable, expected-error, fragment, or skipped, with skip annotations supported per-block and per-page. Co-Authored-By: Claude Opus 4.6 --- .github/scripts/sql_test/__init__.py | 0 .github/scripts/sql_test/executor.py | 198 ++++++++++++++ .github/scripts/sql_test/extractor.py | 256 ++++++++++++++++++ .github/scripts/sql_test/models.py | 44 ++++ .github/scripts/sql_test/reporter.py | 151 +++++++++++ .github/scripts/sql_test_runner.py | 129 +++++++++ .github/scripts/test_sql_extractor.py | 361 ++++++++++++++++++++++++++ .github/workflows/sql-test.yml | 136 ++++++++++ src/current/Makefile | 14 + 9 files changed, 1289 insertions(+) create mode 100644 .github/scripts/sql_test/__init__.py create mode 100644 .github/scripts/sql_test/executor.py create mode 100644 .github/scripts/sql_test/extractor.py create mode 100644 .github/scripts/sql_test/models.py create mode 100644 .github/scripts/sql_test/reporter.py create mode 100644 .github/scripts/sql_test_runner.py create mode 100644 .github/scripts/test_sql_extractor.py create mode 100644 .github/workflows/sql-test.yml diff --git a/.github/scripts/sql_test/__init__.py b/.github/scripts/sql_test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/.github/scripts/sql_test/executor.py b/.github/scripts/sql_test/executor.py new file mode 100644 index 00000000000..5c1af9c685b --- /dev/null +++ b/.github/scripts/sql_test/executor.py @@ -0,0 +1,198 @@ +"""Executes SQL blocks against a CockroachDB cluster.""" + +import re +import subprocess +import time +from pathlib import Path +from typing import List + +from .models import BlockType, SqlBlock, TestResult, PageResult +from .extractor import MOVR_TABLES + +DEFAULT_CONNECTION_URL = "postgresql://root@localhost:26257?sslmode=disable" +STATEMENT_TIMEOUT_S = 30 + + +def _sanitize_db_name(file_path: str) -> str: + """Generate a safe database name from a file path.""" + name = Path(file_path).stem + # Replace non-alphanumeric characters with underscores + name = re.sub(r'[^a-zA-Z0-9]', '_', name) + return f"sqltest_{name}"[:63] # CockroachDB identifier limit + + +def _uses_movr(blocks: List[SqlBlock]) -> bool: + """Check if any block references MovR tables.""" + for block in blocks: + content_lower = block.raw_content.lower() + for table in MOVR_TABLES: + if re.search(r'\b' + table + r'\b', content_lower): + return True + return False + + +def _run_sql(connection_url: str, sql: str, timeout: int = STATEMENT_TIMEOUT_S) -> subprocess.CompletedProcess: + """Execute SQL via cockroach sql subprocess.""" + return subprocess.run( + ["cockroach", "sql", "--url", connection_url, "--format=table", "-e", sql], + capture_output=True, + text=True, + timeout=timeout, + ) + + +def execute_page(page_result: PageResult, connection_url: str = DEFAULT_CONNECTION_URL) -> PageResult: + """Execute all SQL blocks for a single page against CockroachDB. + + Creates an isolated database per page, runs all executable blocks in + document order within that database, then cleans up. + + Args: + page_result: PageResult with extracted blocks (no results yet). + connection_url: CockroachDB connection URL. + + Returns: + The same PageResult with results populated. + """ + db_name = _sanitize_db_name(page_result.file_path) + + # Create isolated database + try: + _run_sql(connection_url, f'CREATE DATABASE IF NOT EXISTS "{db_name}";') + except Exception as e: + # If we can't create the DB, fail all blocks + for block in page_result.blocks: + if block.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR): + page_result.results.append(TestResult( + block=block, + success=False, + error_message=f"Failed to create test database: {e}", + )) + return page_result + + # Build connection URL with the test database + if '?' in connection_url: + db_url = connection_url.replace('?', f'/"{db_name}"?', 1) + else: + db_url = f'{connection_url}/"{db_name}"' + + # Initialize MovR data if needed + executable_blocks = [b for b in page_result.blocks if b.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR)] + if _uses_movr(page_result.blocks): + try: + subprocess.run( + ["cockroach", "workload", "init", "movr", db_url], + capture_output=True, + text=True, + timeout=60, + ) + except Exception as e: + for block in executable_blocks: + page_result.results.append(TestResult( + block=block, + success=False, + error_message=f"Failed to initialize MovR: {e}", + )) + _cleanup_db(connection_url, db_name) + return page_result + + # Execute blocks in order + for block in page_result.blocks: + if block.block_type == BlockType.FRAGMENT or block.block_type == BlockType.SKIPPED: + continue + + result = _execute_block(block, db_url) + page_result.results.append(result) + + # Cleanup + _cleanup_db(connection_url, db_name) + + return page_result + + +def _execute_block(block: SqlBlock, db_url: str) -> TestResult: + """Execute a single SQL block and return the result.""" + start_time = time.time() + combined_output = [] + combined_errors = [] + + for stmt in block.cleaned_statements: + try: + proc = _run_sql(db_url, stmt) + if proc.stdout: + combined_output.append(proc.stdout) + if proc.stderr: + combined_errors.append(proc.stderr) + + if proc.returncode != 0: + elapsed = (time.time() - start_time) * 1000 + error_text = proc.stderr.strip() if proc.stderr else "Non-zero exit code" + + if block.block_type == BlockType.EXPECTED_ERROR: + # Expected error: passing because it did error + return TestResult( + block=block, + success=True, + actual_output=error_text, + execution_time_ms=elapsed, + ) + else: + return TestResult( + block=block, + success=False, + actual_output='\n'.join(combined_output), + error_message=error_text, + execution_time_ms=elapsed, + ) + + except subprocess.TimeoutExpired: + elapsed = (time.time() - start_time) * 1000 + return TestResult( + block=block, + success=False, + error_message=f"Statement timed out after {STATEMENT_TIMEOUT_S}s: {stmt[:100]}", + execution_time_ms=elapsed, + ) + except Exception as e: + elapsed = (time.time() - start_time) * 1000 + if block.block_type == BlockType.EXPECTED_ERROR: + return TestResult( + block=block, + success=True, + actual_output=str(e), + execution_time_ms=elapsed, + ) + return TestResult( + block=block, + success=False, + error_message=str(e), + execution_time_ms=elapsed, + ) + + elapsed = (time.time() - start_time) * 1000 + + # All statements succeeded + if block.block_type == BlockType.EXPECTED_ERROR: + # Expected an error but all statements succeeded — this is a failure + return TestResult( + block=block, + success=False, + actual_output='\n'.join(combined_output), + error_message="Expected an error but all statements succeeded", + execution_time_ms=elapsed, + ) + + return TestResult( + block=block, + success=True, + actual_output='\n'.join(combined_output), + execution_time_ms=elapsed, + ) + + +def _cleanup_db(connection_url: str, db_name: str) -> None: + """Drop the test database.""" + try: + _run_sql(connection_url, f'DROP DATABASE IF EXISTS "{db_name}" CASCADE;') + except Exception: + pass # Best-effort cleanup diff --git a/.github/scripts/sql_test/extractor.py b/.github/scripts/sql_test/extractor.py new file mode 100644 index 00000000000..2c9bbda4f01 --- /dev/null +++ b/.github/scripts/sql_test/extractor.py @@ -0,0 +1,256 @@ +"""Extracts and classifies SQL code blocks from CockroachDB documentation markdown files.""" + +import re +from pathlib import Path +from typing import List, Optional + +from .models import BlockType, SqlBlock, PageResult + + +# Tables that indicate MovR dataset usage +MOVR_TABLES = frozenset({ + "users", "vehicles", "rides", "promo_codes", + "vehicle_location_histories", "user_promo_codes", +}) + +# Patterns that indicate a block is a fragment (not executable as-is) +FRAGMENT_INDICATORS = [ + re.compile(r'\.\.\.'), # Ellipsis (truncated content) + re.compile(r'<[a-zA-Z_][a-zA-Z0-9_ -]*>'), # style + re.compile(r'\{[a-zA-Z_][a-zA-Z0-9_]*\}'), # {placeholder} style + re.compile(r'{% remote_include'), # Liquid remote include +] + +# Skip annotation pattern: +SKIP_COMMENT_RE = re.compile( + r'' +) + + +def _has_page_level_skip(content: str) -> bool: + """Check if frontmatter contains sql_test: skip.""" + frontmatter_match = re.match(r'^---\s*\n(.*?)\n---', content, re.DOTALL) + if not frontmatter_match: + return False + frontmatter = frontmatter_match.group(1) + return bool(re.search(r'^\s*sql_test:\s*skip\s*$', frontmatter, re.MULTILINE)) + + +def _clean_sql_lines(raw: str) -> List[str]: + """Clean raw SQL block content into executable statements. + + Strips the leading '> ' prompt prefix from each line, then splits + on semicolons to produce individual statements. + """ + lines = [] + for line in raw.split('\n'): + # Strip the leading '> ' prompt that CockroachDB docs use + stripped = line + if stripped.startswith('> '): + stripped = stripped[2:] + elif stripped == '>': + stripped = '' + lines.append(stripped) + + joined = '\n'.join(lines).strip() + if not joined: + return [] + + # Split on semicolons, keeping each as a complete statement + statements = [] + current = [] + for line in joined.split('\n'): + current.append(line) + if line.rstrip().endswith(';'): + stmt = '\n'.join(current).strip() + if stmt: + statements.append(stmt) + current = [] + + # If there's remaining content without a trailing semicolon, + # include it as a statement (some SQL commands like \dt don't use semicolons) + if current: + stmt = '\n'.join(current).strip() + if stmt: + statements.append(stmt) + + return statements + + +def _classify_block( + raw: str, + statements: List[str], + expected_output: Optional[str], + skip_reason: Optional[str], +) -> BlockType: + """Classify a SQL block based on its content and context.""" + if skip_reason is not None: + return BlockType.SKIPPED + + # Check for fragment indicators in the raw SQL content + for pattern in FRAGMENT_INDICATORS: + if pattern.search(raw): + return BlockType.FRAGMENT + + # Check if any statement starts with $ (shell command, not SQL) + for stmt in statements: + if stmt.lstrip().startswith('$'): + return BlockType.FRAGMENT + + # Check if expected output indicates an error + if expected_output: + output_stripped = expected_output.strip() + if output_stripped.startswith('ERROR:') or output_stripped.startswith('pq:'): + return BlockType.EXPECTED_ERROR + + return BlockType.EXECUTABLE + + +def _uses_movr(blocks: List[SqlBlock]) -> bool: + """Check if any block references MovR tables.""" + for block in blocks: + content_lower = block.raw_content.lower() + for table in MOVR_TABLES: + # Match table name as a word boundary to avoid false positives + if re.search(r'\b' + table + r'\b', content_lower): + return True + return False + + +def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: + """Extract all SQL code blocks from a markdown file. + + Args: + file_path: Path to the markdown file. + content: Optional file content. If None, reads from file_path. + + Returns: + PageResult containing all extracted and classified SQL blocks. + """ + path = Path(file_path) + + if content is None: + if not path.exists(): + return PageResult(file_path=file_path) + content = path.read_text(encoding='utf-8') + + page_result = PageResult(file_path=file_path) + + # Check for page-level skip + page_skip = _has_page_level_skip(content) + + lines = content.split('\n') + i = 0 + block_index = 0 + + while i < len(lines): + line = lines[i] + + # Check for skip annotation comment + skip_match = SKIP_COMMENT_RE.search(line) + if skip_match: + skip_reason = skip_match.group(1) or "Marked with sql-test:skip" + # Look for the next SQL block immediately following + j = i + 1 + while j < len(lines) and lines[j].strip() == '': + j += 1 + + if j < len(lines) and lines[j].strip() == '~~~ sql': + # Found the SQL block after the skip comment + sql_start = j + 1 + sql_end = sql_start + while sql_end < len(lines) and lines[sql_end].strip() != '~~~': + sql_end += 1 + + raw = '\n'.join(lines[sql_start:sql_end]) + statements = _clean_sql_lines(raw) + + block = SqlBlock( + file_path=file_path, + line_number=j + 1, # 1-indexed + raw_content=raw, + cleaned_statements=statements, + block_type=BlockType.SKIPPED, + skip_reason=skip_reason, + block_index=block_index, + ) + page_result.blocks.append(block) + block_index += 1 + i = sql_end + 1 + continue + + i += 1 + continue + + # Detect ~~~ sql block + if line.strip() == '~~~ sql': + sql_line_number = i + 1 # 1-indexed + + # Collect SQL content + sql_start = i + 1 + sql_end = sql_start + while sql_end < len(lines) and lines[sql_end].strip() != '~~~': + sql_end += 1 + + raw = '\n'.join(lines[sql_start:sql_end]) + statements = _clean_sql_lines(raw) + + # Look ahead for expected output block (~~~ without a language tag) + expected_output = None + j = sql_end + 1 + # Skip blank lines and non-code-block lines between SQL and output + while j < len(lines) and lines[j].strip() == '': + j += 1 + + if j < len(lines) and lines[j].strip() == '~~~': + # This is a plain ~~~ block (output block) + out_start = j + 1 + out_end = out_start + while out_end < len(lines) and lines[out_end].strip() != '~~~': + out_end += 1 + expected_output = '\n'.join(lines[out_start:out_end]) + + # Determine skip reason + skip_reason = None + if page_skip: + skip_reason = "Page-level sql_test: skip in frontmatter" + + block_type = _classify_block(raw, statements, expected_output, skip_reason) + + block = SqlBlock( + file_path=file_path, + line_number=sql_line_number, + raw_content=raw, + cleaned_statements=statements, + block_type=block_type, + expected_output=expected_output, + skip_reason=skip_reason, + block_index=block_index, + ) + page_result.blocks.append(block) + block_index += 1 + + # Advance past the closing ~~~ + i = sql_end + 1 + continue + + i += 1 + + return page_result + + +def extract_from_files(file_paths: List[str]) -> List[PageResult]: + """Extract SQL blocks from multiple files. + + Args: + file_paths: List of markdown file paths to process. + + Returns: + List of PageResult, one per file (only files with blocks). + """ + results = [] + for fp in file_paths: + page = extract_blocks(fp) + if page.blocks: + results.append(page) + return results diff --git a/.github/scripts/sql_test/models.py b/.github/scripts/sql_test/models.py new file mode 100644 index 00000000000..c4d7a17f50b --- /dev/null +++ b/.github/scripts/sql_test/models.py @@ -0,0 +1,44 @@ +"""Data models for SQL testing infrastructure.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + + +class BlockType(Enum): + """Classification of a SQL code block.""" + EXECUTABLE = "executable" + EXPECTED_ERROR = "expected_error" + FRAGMENT = "fragment" + SKIPPED = "skipped" + + +@dataclass +class SqlBlock: + """A single SQL code block extracted from a markdown file.""" + file_path: str + line_number: int + raw_content: str + cleaned_statements: List[str] + block_type: BlockType + expected_output: Optional[str] = None + skip_reason: Optional[str] = None + block_index: int = 0 + + +@dataclass +class TestResult: + """Result of executing a single SQL block.""" + block: SqlBlock + success: bool + actual_output: str = "" + error_message: str = "" + execution_time_ms: float = 0.0 + + +@dataclass +class PageResult: + """Aggregated results for all SQL blocks in a single file.""" + file_path: str + blocks: List[SqlBlock] = field(default_factory=list) + results: List[TestResult] = field(default_factory=list) diff --git a/.github/scripts/sql_test/reporter.py b/.github/scripts/sql_test/reporter.py new file mode 100644 index 00000000000..edaf3cf9f7b --- /dev/null +++ b/.github/scripts/sql_test/reporter.py @@ -0,0 +1,151 @@ +"""Output formatting for SQL test results.""" + +import sys +from typing import List + +from .models import BlockType, PageResult, TestResult + + +def _count_blocks(pages: List[PageResult]): + """Count blocks by type across all pages.""" + total = 0 + executable = 0 + skipped = 0 + fragments = 0 + for page in pages: + for block in page.blocks: + total += 1 + if block.block_type == BlockType.SKIPPED: + skipped += 1 + elif block.block_type == BlockType.FRAGMENT: + fragments += 1 + elif block.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR): + executable += 1 + return total, executable, skipped, fragments + + +def print_dry_run(pages: List[PageResult], verbose: bool = False) -> None: + """Print extraction/classification summary without execution results.""" + total, executable, skipped, fragments = _count_blocks(pages) + + print(f"\n{'='*60}") + print(f"SQL Test Dry Run Summary") + print(f"{'='*60}") + print(f"Pages scanned: {len(pages)}") + print(f"Total SQL blocks: {total}") + print(f" Executable: {executable}") + print(f" Expected errors: {sum(1 for p in pages for b in p.blocks if b.block_type == BlockType.EXPECTED_ERROR)}") + print(f" Fragments: {fragments}") + print(f" Skipped: {skipped}") + print(f"{'='*60}\n") + + if verbose: + for page in pages: + print(f"\n--- {page.file_path} ({len(page.blocks)} blocks) ---") + for block in page.blocks: + status = block.block_type.value.upper() + preview = block.raw_content.split('\n')[0][:60] + print(f" [{status:15s}] line {block.line_number}: {preview}") + if block.skip_reason: + print(f" skip reason: {block.skip_reason}") + + +def print_results(pages: List[PageResult]) -> None: + """Print execution results to console.""" + total_tested = 0 + total_passed = 0 + total_failed = 0 + failures = [] + + for page in pages: + for result in page.results: + total_tested += 1 + if result.success: + total_passed += 1 + else: + total_failed += 1 + failures.append(result) + + total, executable, skipped, fragments = _count_blocks(pages) + + print(f"\n{'='*60}") + print(f"SQL Test Results") + print(f"{'='*60}") + print(f"Pages tested: {len(pages)}") + print(f"Total SQL blocks: {total}") + print(f" Tested: {total_tested}") + print(f" Passed: {total_passed}") + print(f" Failed: {total_failed}") + print(f" Fragments: {fragments}") + print(f" Skipped: {skipped}") + print(f"{'='*60}") + + if failures: + print(f"\nFailures:\n") + for result in failures: + block = result.block + print(f" FAIL: {block.file_path}:{block.line_number}") + # Show first statement as context + if block.cleaned_statements: + stmt_preview = block.cleaned_statements[0][:100] + print(f" Statement: {stmt_preview}") + print(f" Error: {result.error_message}") + print() + else: + print(f"\nAll tests passed.\n") + + +def write_github_comment(pages: List[PageResult], output_path: str = "sql-test-comment.md") -> None: + """Write a GitHub PR comment markdown file.""" + failures = [] + total_tested = 0 + total_passed = 0 + + for page in pages: + for result in page.results: + total_tested += 1 + if result.success: + total_passed += 1 + else: + failures.append(result) + + total, executable, skipped, fragments = _count_blocks(pages) + + lines = [] + if not failures: + lines.append("**SQL Test Check Passed**") + lines.append("") + lines.append(f"Tested {total_tested} SQL blocks across {len(pages)} pages. All passed.") + else: + lines.append("**SQL Test Check Failed**") + lines.append("") + lines.append(f"Found {len(failures)} failure(s) out of {total_tested} tested SQL blocks.") + lines.append("") + lines.append("| File | Line | Error |") + lines.append("|------|------|-------|") + for result in failures: + block = result.block + error_brief = result.error_message.split('\n')[0][:120] + lines.append(f"| `{block.file_path}` | {block.line_number} | {error_brief} |") + lines.append("") + lines.append("
") + lines.append("Failure details") + lines.append("") + for result in failures: + block = result.block + lines.append(f"### `{block.file_path}:{block.line_number}`") + lines.append("") + if block.cleaned_statements: + lines.append("```sql") + lines.append(block.cleaned_statements[0][:200]) + lines.append("```") + lines.append("") + lines.append(f"**Error:** {result.error_message}") + lines.append("") + lines.append("
") + + lines.append("") + lines.append(f"**Summary:** {total_tested} tested, {total_passed} passed, {len(failures)} failed, {fragments} fragments, {skipped} skipped") + + with open(output_path, 'w') as f: + f.write('\n'.join(lines)) diff --git a/.github/scripts/sql_test_runner.py b/.github/scripts/sql_test_runner.py new file mode 100644 index 00000000000..36929cd1a52 --- /dev/null +++ b/.github/scripts/sql_test_runner.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +sql_test_runner.py + +Extracts SQL code blocks from CockroachDB documentation markdown files +and optionally executes them against a CockroachDB cluster. + +Usage: + python sql_test_runner.py [file2] ... + python sql_test_runner.py --version v25.4 + python sql_test_runner.py --dry-run --version v25.4 +""" + +import argparse +import glob +import os +import sys + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_blocks, extract_from_files +from sql_test.executor import execute_page, DEFAULT_CONNECTION_URL +from sql_test.reporter import print_dry_run, print_results, write_github_comment + + +def collect_files(file_args: list, version: str = None) -> list: + """Collect markdown files to test. + + Args: + file_args: Explicitly provided file paths. + version: If set, find all markdown files under src/current//. + + Returns: + List of file paths. + """ + files = [] + + if version: + # Find repo root by looking for src/current/ relative to this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.dirname(os.path.dirname(script_dir)) + version_dir = os.path.join(repo_root, "src", "current", version) + if not os.path.isdir(version_dir): + print(f"Error: version directory not found: {version_dir}", file=sys.stderr) + sys.exit(1) + pattern = os.path.join(version_dir, "**", "*.md") + files = sorted(glob.glob(pattern, recursive=True)) + + if file_args: + files.extend(file_args) + + return files + + +def main(): + parser = argparse.ArgumentParser( + description="Test SQL code blocks in CockroachDB documentation." + ) + parser.add_argument( + "files", nargs="*", help="Markdown files to test." + ) + parser.add_argument( + "--version", type=str, default=None, + help="Test all files in a version directory (e.g., v25.4)." + ) + parser.add_argument( + "--connection-url", type=str, default=DEFAULT_CONNECTION_URL, + help=f"CockroachDB connection URL (default: {DEFAULT_CONNECTION_URL})." + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Extract and classify blocks only, no execution." + ) + parser.add_argument( + "--verbose", action="store_true", + help="Show all blocks including skipped and fragments." + ) + + args = parser.parse_args() + + # Collect files + files = collect_files(args.files, args.version) + if not files: + print("No files to test. Provide file paths or --version.", file=sys.stderr) + sys.exit(1) + + if args.verbose: + print(f"Scanning {len(files)} file(s)...") + + # Extract blocks from all files + pages = extract_from_files(files) + + if not pages: + print("No SQL blocks found in the provided files.") + sys.exit(0) + + if args.dry_run: + print_dry_run(pages, verbose=args.verbose) + sys.exit(0) + + # Execute blocks + has_failures = False + for page in pages: + if args.verbose: + executable_count = sum( + 1 for b in page.blocks + if b.block_type.value in ("executable", "expected_error") + ) + print(f"Testing {page.file_path} ({executable_count} executable blocks)...") + + execute_page(page, connection_url=args.connection_url) + + for result in page.results: + if not result.success: + has_failures = True + + # Report results + print_results(pages) + + # Write GitHub comment if in CI + if os.environ.get("GITHUB_ACTIONS"): + write_github_comment(pages) + + sys.exit(1 if has_failures else 0) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/test_sql_extractor.py b/.github/scripts/test_sql_extractor.py new file mode 100644 index 00000000000..97be00579f6 --- /dev/null +++ b/.github/scripts/test_sql_extractor.py @@ -0,0 +1,361 @@ +"""Unit tests for the SQL block extractor.""" + +import sys +import os +import unittest + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_blocks, _clean_sql_lines, _has_page_level_skip +from sql_test.models import BlockType + + +class TestCleanSqlLines(unittest.TestCase): + """Tests for SQL line cleaning.""" + + def test_strips_prompt_prefix(self): + raw = "> SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_strips_bare_prompt(self): + raw = ">\n> SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_no_prefix(self): + raw = "SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_multiline_statement(self): + raw = "> SELECT\n> id, name\n> FROM users;" + stmts = _clean_sql_lines(raw) + self.assertEqual(len(stmts), 1) + self.assertIn("SELECT", stmts[0]) + self.assertIn("FROM users;", stmts[0]) + + def test_multiple_statements(self): + raw = "> CREATE TABLE t (id INT);\n> INSERT INTO t VALUES (1);" + stmts = _clean_sql_lines(raw) + self.assertEqual(len(stmts), 2) + self.assertIn("CREATE TABLE", stmts[0]) + self.assertIn("INSERT INTO", stmts[1]) + + def test_empty_content(self): + self.assertEqual(_clean_sql_lines(""), []) + self.assertEqual(_clean_sql_lines(" \n "), []) + + +class TestHasPageLevelSkip(unittest.TestCase): + """Tests for page-level skip detection.""" + + def test_detects_skip(self): + content = "---\ntitle: Test\nsql_test: skip\n---\nBody" + self.assertTrue(_has_page_level_skip(content)) + + def test_no_skip(self): + content = "---\ntitle: Test\n---\nBody" + self.assertFalse(_has_page_level_skip(content)) + + def test_no_frontmatter(self): + content = "No frontmatter here\n~~~ sql\nSELECT 1;\n~~~" + self.assertFalse(_has_page_level_skip(content)) + + +class TestExtractBlocks(unittest.TestCase): + """Tests for block extraction and classification.""" + + def test_basic_executable_block(self): + md = """--- +title: Test +--- + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertEqual(block.cleaned_statements, ["SELECT 1;"]) + self.assertEqual(block.line_number, 5) + + def test_block_with_output(self): + md = """~~~ sql +> SELECT 1; +~~~ + +~~~ + ?column? ++----------+ + 1 +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertIsNotNone(block.expected_output) + self.assertIn("?column?", block.expected_output) + + def test_expected_error_pq(self): + md = """~~~ sql +> INSERT INTO t VALUES (1); +~~~ + +~~~ +pq: duplicate key value violates unique constraint +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_expected_error_ERROR(self): + md = """~~~ sql +> DROP TABLE nonexistent; +~~~ + +~~~ +ERROR: relation "nonexistent" does not exist +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_fragment_with_ellipsis(self): + md = """~~~ sql +> SELECT ...; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_placeholder(self): + md = """~~~ sql +ALTER ROLE SET copy_from_retries_enabled = true; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_curly_placeholder(self): + md = """~~~ sql +ALTER ROLE {username} SET copy_from_retries_enabled = true; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_remote_include(self): + md = """~~~ sql +{% remote_include https://example.com/snippet.sql %} +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_skip_annotation(self): + md = """ +~~~ sql +> SLEECT * FORM users; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.SKIPPED) + self.assertEqual(block.skip_reason, "Demonstrates invalid syntax") + + def test_skip_annotation_no_reason(self): + md = """ +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.SKIPPED) + + def test_page_level_skip(self): + md = """--- +title: Test +sql_test: skip +--- + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.SKIPPED) + + def test_multiple_blocks_preserve_order(self): + md = """~~~ sql +> CREATE TABLE t (id INT PRIMARY KEY); +~~~ + +~~~ sql +> INSERT INTO t VALUES (1); +~~~ + +~~~ sql +> SELECT * FROM t; +~~~ + +~~~ + id ++----+ + 1 +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 3) + self.assertEqual(result.blocks[0].block_index, 0) + self.assertEqual(result.blocks[1].block_index, 1) + self.assertEqual(result.blocks[2].block_index, 2) + # Only the last block has expected output + self.assertIsNone(result.blocks[0].expected_output) + self.assertIsNone(result.blocks[1].expected_output) + self.assertIsNotNone(result.blocks[2].expected_output) + + def test_ignores_non_sql_blocks(self): + md = """~~~ shell +$ cockroach start --insecure +~~~ + +~~~ sql +> SELECT 1; +~~~ + +~~~ json +{"key": "value"} +~~~ +""" + result = extract_blocks("test.md", content=md) + # Should only extract the sql block, not shell or json + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].cleaned_statements, ["SELECT 1;"]) + + def test_no_sql_blocks(self): + md = """--- +title: No SQL +--- + +This page has no SQL blocks. + +~~~ shell +$ echo hello +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 0) + + def test_sql_without_prompt_prefix(self): + md = """~~~ sql +SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].cleaned_statements, ["SELECT 1;"]) + + def test_mixed_executable_and_fragment(self): + md = """~~~ sql +> SELECT * FROM users; +~~~ + +~~~ sql +> SELECT ... FROM ; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 2) + self.assertEqual(result.blocks[0].block_type, BlockType.EXECUTABLE) + self.assertEqual(result.blocks[1].block_type, BlockType.FRAGMENT) + + def test_block_line_numbers(self): + md = """Line 1 +Line 2 +Line 3 +~~~ sql +> SELECT 1; +~~~ +Line 7 +Line 8 +~~~ sql +> SELECT 2; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 2) + # ~~~ sql is on line 4 (1-indexed) + self.assertEqual(result.blocks[0].line_number, 4) + # ~~~ sql is on line 9 (1-indexed) + self.assertEqual(result.blocks[1].line_number, 9) + + +class TestExtractBlocksFromRealPatterns(unittest.TestCase): + """Tests using patterns found in actual CockroachDB docs.""" + + def test_movr_select_with_output(self): + """Pattern from select-clause.md.""" + md = """{% include_cached copy-clipboard.html %} +~~~ sql +> SELECT id, city, name FROM users LIMIT 10; +~~~ + +~~~ + id | city | name ++--------------------------------------+---------------+------------------+ + 7ae147ae-147a-4000-8000-000000000018 | los angeles | Alfred Garcia +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertEqual(block.cleaned_statements, ["SELECT id, city, name FROM users LIMIT 10;"]) + self.assertIn("Alfred Garcia", block.expected_output) + + def test_upsert_error_pattern(self): + """Pattern from upsert.md with pq: error output.""" + md = """~~~ sql +> UPSERT INTO unique_test VALUES (4, 1); +~~~ + +~~~ +pq: duplicate key value (b)=(1) violates unique constraint "unique_test_b_key" +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_multiline_insert(self): + """Multi-line SQL statement.""" + md = """~~~ sql +> INSERT INTO user_promo_codes (city, user_id, code, "timestamp", usage_count) + VALUES ('new york', '147ae147-ae14-4b00-8000-000000000004', 'promo_code', now(), 1); +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(len(result.blocks[0].cleaned_statements), 1) + self.assertIn("INSERT INTO", result.blocks[0].cleaned_statements[0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/.github/workflows/sql-test.yml b/.github/workflows/sql-test.yml new file mode 100644 index 00000000000..9b9c4b5263e --- /dev/null +++ b/.github/workflows/sql-test.yml @@ -0,0 +1,136 @@ +name: SQL Test Check + +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'src/current/v25.4/**/*.md' + schedule: + # Run nightly at 6am UTC + - cron: '0 6 * * *' + workflow_dispatch: + +jobs: + sql-test: + name: Test SQL code blocks + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install CockroachDB + run: | + curl https://binaries.cockroachdb.com/cockroach-v25.1.2.linux-amd64.tgz | tar -xz + sudo cp cockroach-v25.1.2.linux-amd64/cockroach /usr/local/bin/ + + - name: Start CockroachDB + run: | + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + cockroach sql --insecure -e "SELECT 1;" + + - name: Get changed files + if: github.event_name == 'pull_request' + id: changed-files + uses: tj-actions/changed-files@cc08e170f4447237bcaf8acaacfa615b9cb86612 # v35 + with: + files: | + src/current/v25.4/**/*.md + separator: ' ' + + - name: Run SQL tests (PR - changed files only) + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' + id: sql-test-pr + run: | + echo "Testing changed files..." + python .github/scripts/sql_test_runner.py ${{ steps.changed-files.outputs.all_changed_files }} + continue-on-error: true + + - name: Run SQL tests (scheduled/manual - all v25.4 files) + if: github.event_name != 'pull_request' + id: sql-test-full + run: | + echo "Testing all v25.4 files..." + python .github/scripts/sql_test_runner.py --version v25.4 + continue-on-error: true + + - name: Post PR comment with failures + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' && steps.sql-test-pr.outcome == 'failure' + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + + let comment = ''; + try { + comment = fs.readFileSync('sql-test-comment.md', 'utf8'); + } catch (error) { + comment = '**SQL Test Check Failed**\n\nSQL test failures were detected, but the detailed report could not be generated.'; + } + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('SQL Test Check') + ); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: comment + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: comment + }); + } + + - name: Post success comment if previously failed + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' && steps.sql-test-pr.outcome == 'success' + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('SQL Test Check Failed') + ); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: '**SQL Test Check Passed**\n\nAll SQL test issues have been resolved.' + }); + } + + - name: Stop CockroachDB + if: always() + run: | + cockroach quit --insecure --host=localhost:26257 || true diff --git a/src/current/Makefile b/src/current/Makefile index f9ee8cddc49..6c3375f60f2 100644 --- a/src/current/Makefile +++ b/src/current/Makefile @@ -85,6 +85,20 @@ linkcheck: cockroachdb-build vale: vale $(subst $(\n), $( ), $(shell git status --porcelain | cut -c 4- | egrep "\.md")) +.PHONY: sql-test +sql-test: + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + cockroach workload init movr 'postgresql://root@localhost:26257?sslmode=disable' || true + python3 ../../.github/scripts/sql_test_runner.py --version v25.4; \ + EXIT_CODE=$$?; \ + cockroach quit --insecure --host=localhost:26257 || true; \ + exit $$EXIT_CODE + +.PHONY: sql-test-dry-run +sql-test-dry-run: + python3 ../../.github/scripts/sql_test_runner.py --dry-run --verbose --version v25.4 + .PHONY: vendor vendor: gem install bundler