From 36c33d93788c077254560b075e57d385795e29d3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 28 Mar 2026 07:07:28 +0000 Subject: [PATCH] Add scope graph + bidirectional type checking for SQL analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a new SQL analysis system based on two concepts from programming language theory: 1. Scope graphs (internal/analysis/scope/): Model SQL name resolution as path-finding in a labeled graph. Each scope contains declarations (columns, tables, aliases) connected by edges (PARENT, ALIAS, LATERAL, OUTER). This handles joins, subqueries, CTEs, and aliases compositionally — the resolution algorithm doesn't change, only the graph structure does. 2. Bidirectional type checking (internal/analysis/typecheck/): Type information flows in two directions — synthesis (bottom-up: "what type does this expression have?") and checking (top-down: "does this match the expected type?"). This naturally handles parameter type inference: when $1 appears in `WHERE age > $1`, checking mode infers $1's type from the column's type. 3. SQL analyzer (internal/analysis/sqlanalyze/): Combines both systems by walking the sqlc AST, building scope graphs from FROM/JOIN/CTE clauses, and running bidirectional type checking on expressions. Supports SELECT, INSERT, UPDATE, DELETE with parameter inference, output column resolution, and engine-specific operator rules for PostgreSQL and MySQL. https://claude.ai/code/session_01VFJemaXKRZ2NfxYkpwXSbD --- internal/analysis/scope/resolve.go | 274 ++++ internal/analysis/scope/scope.go | 214 ++++ internal/analysis/scope/scope_test.go | 306 +++++ internal/analysis/sqlanalyze/analyzer.go | 1122 +++++++++++++++++ internal/analysis/sqlanalyze/analyzer_test.go | 579 +++++++++ internal/analysis/typecheck/expr.go | 139 ++ internal/analysis/typecheck/rules.go | 166 +++ internal/analysis/typecheck/typecheck_test.go | 370 ++++++ internal/analysis/typecheck/types.go | 286 +++++ 9 files changed, 3456 insertions(+) create mode 100644 internal/analysis/scope/resolve.go create mode 100644 internal/analysis/scope/scope.go create mode 100644 internal/analysis/scope/scope_test.go create mode 100644 internal/analysis/sqlanalyze/analyzer.go create mode 100644 internal/analysis/sqlanalyze/analyzer_test.go create mode 100644 internal/analysis/typecheck/expr.go create mode 100644 internal/analysis/typecheck/rules.go create mode 100644 internal/analysis/typecheck/typecheck_test.go create mode 100644 internal/analysis/typecheck/types.go diff --git a/internal/analysis/scope/resolve.go b/internal/analysis/scope/resolve.go new file mode 100644 index 0000000000..af11b47a44 --- /dev/null +++ b/internal/analysis/scope/resolve.go @@ -0,0 +1,274 @@ +package scope + +import "fmt" + +// ResolutionError describes why name resolution failed with full provenance. +type ResolutionError struct { + Name string + Qualifier string // Table/alias qualifier, if any + Kind ResolutionErrorKind + Scope *Scope // The scope where resolution was attempted + Candidates []string // For ambiguity errors, the competing names + Location int // Source position of the reference +} + +type ResolutionErrorKind int + +const ( + ErrNotFound ResolutionErrorKind = iota + ErrAmbiguous + ErrQualifierNotFound // e.g., "u.name" but "u" doesn't exist +) + +func (e *ResolutionError) Error() string { + switch e.Kind { + case ErrNotFound: + if e.Qualifier != "" { + return fmt.Sprintf("column %q not found in %q", e.Name, e.Qualifier) + } + return fmt.Sprintf("column %q does not exist", e.Name) + case ErrAmbiguous: + return fmt.Sprintf("column reference %q is ambiguous", e.Name) + case ErrQualifierNotFound: + return fmt.Sprintf("table or alias %q does not exist", e.Qualifier) + default: + return fmt.Sprintf("resolution error for %q", e.Name) + } +} + +// ResolutionPath records the edges traversed during successful resolution. +// This is the provenance — it tells you exactly how a name was resolved. +type ResolutionPath struct { + Steps []ResolutionStep +} + +type ResolutionStep struct { + Edge *Edge // nil for the final lookup step + Scope *Scope // The scope where this step occurred +} + +// ResolvedName is the result of successful name resolution. +type ResolvedName struct { + Declaration *Declaration + Path ResolutionPath +} + +// Resolve looks up an unqualified column name in this scope. +// It searches local declarations first, then follows parent edges. +// Returns an error if the name is not found or is ambiguous. +func (s *Scope) Resolve(name string) (*ResolvedName, error) { + return s.resolve(name, nil, 0) +} + +// ResolveQualified looks up a qualified name like "u.name". +// First resolves the qualifier (table/alias), then looks up the column +// in that table's scope. +func (s *Scope) ResolveQualified(qualifier, name string) (*ResolvedName, error) { + // First, find the qualifier (table name or alias) + qualScope, err := s.resolveQualifier(qualifier, 0) + if err != nil { + return nil, &ResolutionError{ + Name: name, + Qualifier: qualifier, + Kind: ErrQualifierNotFound, + Scope: s, + } + } + + // Then resolve the column within that scope + var matches []*Declaration + for _, d := range qualScope.Declarations { + if d.Name == name { + matches = append(matches, d) + } + } + + if len(matches) == 0 { + return nil, &ResolutionError{ + Name: name, + Qualifier: qualifier, + Kind: ErrNotFound, + Scope: qualScope, + } + } + if len(matches) > 1 { + return nil, &ResolutionError{ + Name: name, + Qualifier: qualifier, + Kind: ErrAmbiguous, + Scope: qualScope, + } + } + + return &ResolvedName{ + Declaration: matches[0], + Path: ResolutionPath{ + Steps: []ResolutionStep{ + {Scope: s}, + {Scope: qualScope}, + }, + }, + }, nil +} + +const maxResolutionDepth = 20 + +// resolve performs recursive name resolution with cycle detection via depth limit. +func (s *Scope) resolve(name string, visited map[*Scope]bool, depth int) (*ResolvedName, error) { + if depth > maxResolutionDepth { + return nil, &ResolutionError{Name: name, Kind: ErrNotFound, Scope: s} + } + if visited == nil { + visited = make(map[*Scope]bool) + } + if visited[s] { + return nil, &ResolutionError{Name: name, Kind: ErrNotFound, Scope: s} + } + visited[s] = true + + // Search local declarations first + var matches []*Declaration + for _, d := range s.Declarations { + if d.Name == name && d.Kind == DeclColumn { + matches = append(matches, d) + } + } + + // Also search table/alias declarations to find columns inside their scopes + for _, d := range s.Declarations { + if (d.Kind == DeclTable || d.Kind == DeclAlias || d.Kind == DeclCTE) && d.Scope != nil { + for _, cd := range d.Scope.Declarations { + if cd.Name == name && cd.Kind == DeclColumn { + matches = append(matches, cd) + } + } + } + } + + if len(matches) == 1 { + return &ResolvedName{ + Declaration: matches[0], + Path: ResolutionPath{ + Steps: []ResolutionStep{{Scope: s}}, + }, + }, nil + } + if len(matches) > 1 { + return nil, &ResolutionError{Name: name, Kind: ErrAmbiguous, Scope: s} + } + + // Follow parent, lateral, and outer edges + for _, edge := range s.Edges { + switch edge.Kind { + case EdgeParent, EdgeLateral, EdgeOuter: + result, err := edge.Target.resolve(name, visited, depth+1) + if err == nil { + result.Path.Steps = append([]ResolutionStep{{Edge: edge, Scope: s}}, result.Path.Steps...) + return result, nil + } + // Propagate ambiguity errors — don't swallow them + if resErr, ok := err.(*ResolutionError); ok && resErr.Kind == ErrAmbiguous { + return nil, resErr + } + } + } + + return nil, &ResolutionError{Name: name, Kind: ErrNotFound, Scope: s} +} + +// resolveQualifier finds the scope associated with a table name or alias. +func (s *Scope) resolveQualifier(qualifier string, depth int) (*Scope, error) { + if depth > maxResolutionDepth { + return nil, fmt.Errorf("qualifier %q not found", qualifier) + } + + // Check alias edges first (higher priority) + for _, edge := range s.Edges { + if edge.Kind == EdgeAlias && edge.Label == qualifier { + return edge.Target, nil + } + } + + // Check local table/alias declarations + for _, d := range s.Declarations { + if d.Name == qualifier && (d.Kind == DeclTable || d.Kind == DeclAlias || d.Kind == DeclCTE) && d.Scope != nil { + return d.Scope, nil + } + } + + // Follow parent edges + for _, edge := range s.Edges { + if edge.Kind == EdgeParent || edge.Kind == EdgeLateral || edge.Kind == EdgeOuter { + result, err := edge.Target.resolveQualifier(qualifier, depth+1) + if err == nil { + return result, nil + } + } + } + + return nil, fmt.Errorf("qualifier %q not found", qualifier) +} + +// ResolveColumnRef resolves a column reference that may have 1, 2, or 3 parts: +// - ["name"] -> unqualified column +// - ["alias", "name"] -> table-qualified column +// - ["schema", "table", "name"] -> schema-qualified column (treated as qualifier="table") +func (s *Scope) ResolveColumnRef(parts []string) (*ResolvedName, error) { + switch len(parts) { + case 1: + return s.Resolve(parts[0]) + case 2: + return s.ResolveQualified(parts[0], parts[1]) + case 3: + // For now, ignore schema and use table.column + return s.ResolveQualified(parts[1], parts[2]) + default: + return nil, fmt.Errorf("invalid column reference with %d parts", len(parts)) + } +} + +// AllColumns returns all column declarations visible from this scope, +// optionally filtered by a qualifier. This is used for SELECT * expansion. +func (s *Scope) AllColumns(qualifier string) []*Declaration { + if qualifier != "" { + qualScope, err := s.resolveQualifier(qualifier, 0) + if err != nil { + return nil + } + var cols []*Declaration + for _, d := range qualScope.Declarations { + if d.Kind == DeclColumn { + cols = append(cols, d) + } + } + return cols + } + + // Collect from all table/alias declarations in this scope + var cols []*Declaration + seen := make(map[string]bool) + + var collect func(sc *Scope, depth int) + collect = func(sc *Scope, depth int) { + if depth > maxResolutionDepth { + return + } + for _, d := range sc.Declarations { + if (d.Kind == DeclTable || d.Kind == DeclAlias || d.Kind == DeclCTE) && d.Scope != nil { + for _, cd := range d.Scope.Declarations { + if cd.Kind == DeclColumn && !seen[d.Name+"."+cd.Name] { + seen[d.Name+"."+cd.Name] = true + cols = append(cols, cd) + } + } + } + } + for _, edge := range sc.Edges { + if edge.Kind == EdgeParent { + collect(edge.Target, depth+1) + } + } + } + collect(s, 0) + return cols +} diff --git a/internal/analysis/scope/scope.go b/internal/analysis/scope/scope.go new file mode 100644 index 0000000000..57f41505aa --- /dev/null +++ b/internal/analysis/scope/scope.go @@ -0,0 +1,214 @@ +// Package scope implements scope graphs for SQL name resolution. +// +// A scope graph models the visibility and accessibility of names (columns, +// tables, aliases, functions) in a SQL query. Each scope is a node in the +// graph, containing declarations and connected to other scopes via labeled +// edges. Name resolution is path-finding in this graph. +// +// This approach is inspired by the Statix/scope graph framework from +// programming language theory, adapted for SQL's particular scoping rules. +package scope + +import "fmt" + +// EdgeKind labels the relationship between two scopes. +type EdgeKind int + +const ( + // EdgeParent links a child scope to its parent (e.g., WHERE -> FROM). + EdgeParent EdgeKind = iota + // EdgeAlias links an alias name to the scope it refers to (e.g., "u" -> users table scope). + EdgeAlias + // EdgeLateral links a LATERAL subquery to preceding FROM items. + EdgeLateral + // EdgeOuter links a correlated subquery to its outer query scope. + EdgeOuter +) + +func (k EdgeKind) String() string { + switch k { + case EdgeParent: + return "PARENT" + case EdgeAlias: + return "ALIAS" + case EdgeLateral: + return "LATERAL" + case EdgeOuter: + return "OUTER" + default: + return fmt.Sprintf("EdgeKind(%d)", int(k)) + } +} + +// DeclKind describes what kind of entity a declaration represents. +type DeclKind int + +const ( + DeclColumn DeclKind = iota + DeclTable + DeclCTE + DeclFunction + DeclAlias // A table alias (e.g., "u" in "FROM users AS u") +) + +func (k DeclKind) String() string { + switch k { + case DeclColumn: + return "column" + case DeclTable: + return "table" + case DeclCTE: + return "CTE" + case DeclFunction: + return "function" + case DeclAlias: + return "alias" + default: + return fmt.Sprintf("DeclKind(%d)", int(k)) + } +} + +// Type represents a SQL type within the scope system. It's kept simple +// and engine-agnostic — detailed type information lives in the catalog. +type Type struct { + Name string // e.g., "integer", "text", "boolean" + Schema string // e.g., "pg_catalog", "" for default + NotNull bool + IsArray bool + ArrayDims int + Unsigned bool + Length *int +} + +// Equals checks structural type equality (ignoring nullability). +func (t Type) Equals(other Type) bool { + return t.Name == other.Name && t.Schema == other.Schema && t.IsArray == other.IsArray +} + +// IsUnknown returns true if this type hasn't been determined yet. +func (t Type) IsUnknown() bool { + return t.Name == "" || t.Name == "any" +} + +var ( + TypeUnknown = Type{Name: "any"} + TypeInt = Type{Name: "integer", NotNull: true} + TypeText = Type{Name: "text", NotNull: true} + TypeBool = Type{Name: "boolean", NotNull: true} + TypeFloat = Type{Name: "float", NotNull: true} + TypeNumeric = Type{Name: "numeric", NotNull: true} +) + +// Declaration is a named entity visible within a scope. +type Declaration struct { + Name string + Kind DeclKind + Type Type + Scope *Scope // For table/CTE declarations, the scope containing their columns + Location int // Source position for error reporting +} + +// Edge connects one scope to another with a labeled relationship. +type Edge struct { + Kind EdgeKind + Label string // For EdgeAlias, the alias name + Target *Scope +} + +// ScopeKind describes the syntactic context that created this scope. +type ScopeKind int + +const ( + ScopeRoot ScopeKind = iota + ScopeFrom + ScopeJoin + ScopeWhere + ScopeSelect + ScopeHaving + ScopeOrderBy + ScopeSubquery + ScopeCTE + ScopeInsert + ScopeUpdate + ScopeDelete + ScopeValues + ScopeReturning + ScopeFunction +) + +func (k ScopeKind) String() string { + names := [...]string{ + "ROOT", "FROM", "JOIN", "WHERE", "SELECT", "HAVING", + "ORDER_BY", "SUBQUERY", "CTE", "INSERT", "UPDATE", + "DELETE", "VALUES", "RETURNING", "FUNCTION", + } + if int(k) < len(names) { + return names[k] + } + return fmt.Sprintf("ScopeKind(%d)", int(k)) +} + +// Scope is a node in the scope graph. It contains declarations and +// edges to other scopes. +type Scope struct { + Kind ScopeKind + Declarations []*Declaration + Edges []*Edge + Location int // Source position of the construct that created this scope +} + +// NewScope creates a new empty scope of the given kind. +func NewScope(kind ScopeKind) *Scope { + return &Scope{ + Kind: kind, + } +} + +// Declare adds a declaration to this scope. +func (s *Scope) Declare(d *Declaration) { + s.Declarations = append(s.Declarations, d) +} + +// DeclareColumn is a convenience method for declaring a column. +func (s *Scope) DeclareColumn(name string, typ Type, location int) *Declaration { + d := &Declaration{ + Name: name, + Kind: DeclColumn, + Type: typ, + Location: location, + } + s.Declare(d) + return d +} + +// DeclareTable adds a table declaration with its own column scope. +func (s *Scope) DeclareTable(name string, columnScope *Scope, location int) *Declaration { + d := &Declaration{ + Name: name, + Kind: DeclTable, + Type: TypeUnknown, + Scope: columnScope, + Location: location, + } + s.Declare(d) + return d +} + +// AddEdge connects this scope to another scope via a labeled edge. +func (s *Scope) AddEdge(kind EdgeKind, label string, target *Scope) { + s.Edges = append(s.Edges, &Edge{ + Kind: kind, + Label: label, + Target: target, + }) +} + +// AddParent connects this scope to a parent scope. +func (s *Scope) AddParent(parent *Scope) { + s.AddEdge(EdgeParent, "", parent) +} + +// AddAlias connects an alias name to a target scope. +func (s *Scope) AddAlias(alias string, target *Scope) { + s.AddEdge(EdgeAlias, alias, target) +} diff --git a/internal/analysis/scope/scope_test.go b/internal/analysis/scope/scope_test.go new file mode 100644 index 0000000000..fa14486a4b --- /dev/null +++ b/internal/analysis/scope/scope_test.go @@ -0,0 +1,306 @@ +package scope + +import ( + "testing" +) + +func TestResolveUnqualified(t *testing.T) { + // Build scope graph: + // [FROM scope] has table "users" with columns {id, name, email} + // [SELECT scope] → PARENT → [FROM scope] + + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", Type{Name: "integer", NotNull: true}, 0) + usersScope.DeclareColumn("name", Type{Name: "text", NotNull: true}, 0) + usersScope.DeclareColumn("email", Type{Name: "text", NotNull: false}, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.DeclareTable("users", usersScope, 0) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + // Resolve "name" from SELECT scope + resolved, err := selectScope.Resolve("name") + if err != nil { + t.Fatalf("expected to resolve 'name', got error: %v", err) + } + if resolved.Declaration.Name != "name" { + t.Errorf("expected declaration name 'name', got %q", resolved.Declaration.Name) + } + if resolved.Declaration.Type.Name != "text" { + t.Errorf("expected type 'text', got %q", resolved.Declaration.Type.Name) + } + if !resolved.Declaration.Type.NotNull { + t.Error("expected 'name' to be NOT NULL") + } +} + +func TestResolveQualified(t *testing.T) { + // SELECT u.name FROM users AS u + + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", Type{Name: "integer", NotNull: true}, 0) + usersScope.DeclareColumn("name", Type{Name: "text", NotNull: true}, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.AddAlias("u", usersScope) + fromScope.Declare(&Declaration{ + Name: "u", + Kind: DeclAlias, + Type: TypeUnknown, + Scope: usersScope, + }) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + // Resolve "u.name" + resolved, err := selectScope.ResolveQualified("u", "name") + if err != nil { + t.Fatalf("expected to resolve 'u.name', got error: %v", err) + } + if resolved.Declaration.Name != "name" { + t.Errorf("expected 'name', got %q", resolved.Declaration.Name) + } +} + +func TestResolveAmbiguous(t *testing.T) { + // SELECT id FROM users JOIN orders ON ... + // Both tables have an 'id' column → should be ambiguous + + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", TypeInt, 0) + usersScope.DeclareColumn("name", TypeText, 0) + + ordersScope := NewScope(ScopeFrom) + ordersScope.DeclareColumn("id", TypeInt, 0) + ordersScope.DeclareColumn("total", TypeNumeric, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.DeclareTable("users", usersScope, 0) + fromScope.DeclareTable("orders", ordersScope, 0) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + _, err := selectScope.Resolve("id") + if err == nil { + t.Fatal("expected ambiguity error for 'id', got nil") + } + resErr, ok := err.(*ResolutionError) + if !ok { + t.Fatalf("expected *ResolutionError, got %T", err) + } + if resErr.Kind != ErrAmbiguous { + t.Errorf("expected ErrAmbiguous, got %v", resErr.Kind) + } +} + +func TestResolveNotFound(t *testing.T) { + fromScope := NewScope(ScopeFrom) + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + _, err := selectScope.Resolve("nonexistent") + if err == nil { + t.Fatal("expected error for nonexistent column") + } + resErr, ok := err.(*ResolutionError) + if !ok { + t.Fatalf("expected *ResolutionError, got %T", err) + } + if resErr.Kind != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", resErr.Kind) + } +} + +func TestResolveQualifierNotFound(t *testing.T) { + fromScope := NewScope(ScopeFrom) + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + _, err := selectScope.ResolveQualified("nonexistent", "col") + if err == nil { + t.Fatal("expected error for nonexistent qualifier") + } + resErr, ok := err.(*ResolutionError) + if !ok { + t.Fatalf("expected *ResolutionError, got %T", err) + } + if resErr.Kind != ErrQualifierNotFound { + t.Errorf("expected ErrQualifierNotFound, got %v", resErr.Kind) + } +} + +func TestResolveColumnRef(t *testing.T) { + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", TypeInt, 0) + usersScope.DeclareColumn("name", TypeText, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.AddAlias("u", usersScope) + fromScope.Declare(&Declaration{Name: "u", Kind: DeclAlias, Scope: usersScope}) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + tests := []struct { + parts []string + wantName string + wantErr bool + }{ + {[]string{"name"}, "name", false}, + {[]string{"u", "name"}, "name", false}, + {[]string{"public", "u", "name"}, "name", false}, + {[]string{"nonexistent"}, "", true}, + {[]string{"u", "nonexistent"}, "", true}, + } + + for _, tt := range tests { + resolved, err := selectScope.ResolveColumnRef(tt.parts) + if tt.wantErr { + if err == nil { + t.Errorf("ResolveColumnRef(%v): expected error, got nil", tt.parts) + } + continue + } + if err != nil { + t.Errorf("ResolveColumnRef(%v): unexpected error: %v", tt.parts, err) + continue + } + if resolved.Declaration.Name != tt.wantName { + t.Errorf("ResolveColumnRef(%v): got name %q, want %q", tt.parts, resolved.Declaration.Name, tt.wantName) + } + } +} + +func TestAllColumns(t *testing.T) { + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", TypeInt, 0) + usersScope.DeclareColumn("name", TypeText, 0) + + ordersScope := NewScope(ScopeFrom) + ordersScope.DeclareColumn("id", TypeInt, 0) + ordersScope.DeclareColumn("total", TypeNumeric, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.AddAlias("u", usersScope) + fromScope.Declare(&Declaration{Name: "u", Kind: DeclAlias, Scope: usersScope}) + fromScope.AddAlias("o", ordersScope) + fromScope.Declare(&Declaration{Name: "o", Kind: DeclAlias, Scope: ordersScope}) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + // All columns (SELECT *) + all := selectScope.AllColumns("") + if len(all) != 4 { + t.Errorf("AllColumns(''): got %d columns, want 4", len(all)) + } + + // Qualified (SELECT u.*) + uCols := selectScope.AllColumns("u") + if len(uCols) != 2 { + t.Errorf("AllColumns('u'): got %d columns, want 2", len(uCols)) + } + for _, c := range uCols { + if c.Name != "id" && c.Name != "name" { + t.Errorf("AllColumns('u'): unexpected column %q", c.Name) + } + } +} + +func TestCTEScope(t *testing.T) { + // WITH active_users AS (SELECT ...) SELECT * FROM active_users + + cteScope := NewScope(ScopeCTE) + cteScope.DeclareColumn("id", TypeInt, 0) + cteScope.DeclareColumn("name", TypeText, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.Declare(&Declaration{ + Name: "active_users", + Kind: DeclCTE, + Type: TypeUnknown, + Scope: cteScope, + }) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + // Resolve "name" from the CTE + resolved, err := selectScope.Resolve("name") + if err != nil { + t.Fatalf("expected to resolve 'name' from CTE, got error: %v", err) + } + if resolved.Declaration.Type.Name != "text" { + t.Errorf("expected type 'text', got %q", resolved.Declaration.Type.Name) + } + + // Resolve qualified "active_users.id" + resolved, err = selectScope.ResolveQualified("active_users", "id") + if err != nil { + t.Fatalf("expected to resolve 'active_users.id', got error: %v", err) + } + if resolved.Declaration.Type.Name != "integer" { + t.Errorf("expected type 'integer', got %q", resolved.Declaration.Type.Name) + } +} + +func TestJoinScope(t *testing.T) { + // SELECT u.name, o.total + // FROM users AS u + // JOIN orders AS o ON u.id = o.user_id + + usersScope := NewScope(ScopeFrom) + usersScope.DeclareColumn("id", TypeInt, 0) + usersScope.DeclareColumn("name", TypeText, 0) + + ordersScope := NewScope(ScopeFrom) + ordersScope.DeclareColumn("id", TypeInt, 0) + ordersScope.DeclareColumn("user_id", TypeInt, 0) + ordersScope.DeclareColumn("total", TypeNumeric, 0) + + fromScope := NewScope(ScopeFrom) + fromScope.AddAlias("u", usersScope) + fromScope.Declare(&Declaration{Name: "u", Kind: DeclAlias, Scope: usersScope}) + fromScope.AddAlias("o", ordersScope) + fromScope.Declare(&Declaration{Name: "o", Kind: DeclAlias, Scope: ordersScope}) + + selectScope := NewScope(ScopeSelect) + selectScope.AddParent(fromScope) + + // u.name should resolve + resolved, err := selectScope.ResolveQualified("u", "name") + if err != nil { + t.Fatalf("u.name: %v", err) + } + if resolved.Declaration.Type.Name != "text" { + t.Errorf("u.name type: got %q, want 'text'", resolved.Declaration.Type.Name) + } + + // o.total should resolve + resolved, err = selectScope.ResolveQualified("o", "total") + if err != nil { + t.Fatalf("o.total: %v", err) + } + if resolved.Declaration.Type.Name != "numeric" { + t.Errorf("o.total type: got %q, want 'numeric'", resolved.Declaration.Type.Name) + } + + // Unqualified "total" should resolve (only in orders) + resolved, err = selectScope.Resolve("total") + if err != nil { + t.Fatalf("total: %v", err) + } + if resolved.Declaration.Type.Name != "numeric" { + t.Errorf("total type: got %q, want 'numeric'", resolved.Declaration.Type.Name) + } + + // Unqualified "id" should be ambiguous + _, err = selectScope.Resolve("id") + if err == nil { + t.Fatal("expected ambiguity for 'id'") + } +} diff --git a/internal/analysis/sqlanalyze/analyzer.go b/internal/analysis/sqlanalyze/analyzer.go new file mode 100644 index 0000000000..67616dbfdf --- /dev/null +++ b/internal/analysis/sqlanalyze/analyzer.go @@ -0,0 +1,1122 @@ +// Package sqlanalyze implements SQL analysis using scope graphs for name +// resolution and bidirectional type checking for type inference. +// +// It walks the sqlc AST, building a scope graph that models the visibility +// of tables, columns, and aliases. It then uses bidirectional type checking +// to infer parameter types and validate expressions. +// +// This package works with both PostgreSQL and MySQL ASTs (via the sqlc +// unified AST representation). +package sqlanalyze + +import ( + "fmt" + + "github.com/sqlc-dev/sqlc/internal/analysis/scope" + "github.com/sqlc-dev/sqlc/internal/analysis/typecheck" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +// Result holds the analysis results for a single query. +type Result struct { + // The root scope of the query's scope graph. + RootScope *scope.Scope + // Inferred types for query parameters. + ParamTypes map[int]*typecheck.ParamTypeInference + // Output columns with their resolved types. + OutputColumns []OutputColumn + // Type errors found during analysis. + Errors []typecheck.TypeError +} + +// OutputColumn describes a column in the query's result set. +type OutputColumn struct { + Name string + Type scope.Type + TableRef string // The table this column came from, if any +} + +// Analyzer performs scope-graph-based analysis on SQL queries. +type Analyzer struct { + catalog *catalog.Catalog + engine config.Engine + checker *typecheck.Checker +} + +// New creates a new analyzer for the given catalog and engine. +func New(cat *catalog.Catalog, engine config.Engine) *Analyzer { + var rules typecheck.OperatorRules + switch engine { + case config.EngineMySQL: + rules = &typecheck.MySQLOperatorRules{} + case config.EnginePostgreSQL: + rules = &typecheck.PostgreSQLOperatorRules{} + default: + rules = &typecheck.DefaultOperatorRules{} + } + + return &Analyzer{ + catalog: cat, + engine: engine, + checker: typecheck.NewChecker(rules), + } +} + +// AnalyzeQuery performs full analysis on a SQL statement: builds the scope +// graph, resolves names, infers types bidirectionally, and returns results. +func (a *Analyzer) AnalyzeQuery(raw *ast.RawStmt) (*Result, error) { + if raw == nil || raw.Stmt == nil { + return nil, fmt.Errorf("nil statement") + } + + result := &Result{ + ParamTypes: make(map[int]*typecheck.ParamTypeInference), + } + + // Build the scope graph from the AST + rootScope, err := a.buildScopeGraph(raw.Stmt) + if err != nil { + return nil, fmt.Errorf("building scope graph: %w", err) + } + result.RootScope = rootScope + + // Walk expressions to perform bidirectional type checking + a.typeCheckStatement(raw.Stmt, rootScope) + + // Collect results + result.ParamTypes = a.checker.ParamTypes() + result.Errors = a.checker.Errors() + + // Compute output columns + outputCols, err := a.computeOutputColumns(raw.Stmt, rootScope) + if err != nil { + return nil, fmt.Errorf("computing output columns: %w", err) + } + result.OutputColumns = outputCols + + return result, nil +} + +// buildScopeGraph constructs the scope graph for a SQL statement. +func (a *Analyzer) buildScopeGraph(stmt ast.Node) (*scope.Scope, error) { + switch n := stmt.(type) { + case *ast.SelectStmt: + return a.buildSelectScope(n) + case *ast.InsertStmt: + return a.buildInsertScope(n) + case *ast.UpdateStmt: + return a.buildUpdateScope(n) + case *ast.DeleteStmt: + return a.buildDeleteScope(n) + default: + return scope.NewScope(scope.ScopeRoot), nil + } +} + +// buildSelectScope builds the scope graph for a SELECT statement. +// +// The scope structure for SELECT is: +// +// [CTE scope] (if WITH clause exists) +// | +// [FROM scope] ← contains table declarations + aliases +// | +// [WHERE scope] → PARENT → [FROM scope] +// | +// [SELECT scope] → PARENT → [FROM scope] +func (a *Analyzer) buildSelectScope(sel *ast.SelectStmt) (*scope.Scope, error) { + if sel == nil { + return scope.NewScope(scope.ScopeRoot), nil + } + + // Handle UNION queries + if sel.Larg != nil { + return a.buildSelectScope(sel.Larg) + } + + // Build FROM scope + fromScope := scope.NewScope(scope.ScopeFrom) + + // Process CTEs first (WITH clause) + if sel.WithClause != nil { + if err := a.processCTEs(sel.WithClause, fromScope); err != nil { + return nil, err + } + } + + // Process FROM clause + if sel.FromClause != nil { + for _, item := range sel.FromClause.Items { + if err := a.processFromItem(item, fromScope); err != nil { + return nil, err + } + } + } + + // The SELECT scope has the FROM scope as parent + selectScope := scope.NewScope(scope.ScopeSelect) + selectScope.AddParent(fromScope) + + return selectScope, nil +} + +// buildInsertScope builds the scope graph for an INSERT statement. +func (a *Analyzer) buildInsertScope(ins *ast.InsertStmt) (*scope.Scope, error) { + insertScope := scope.NewScope(scope.ScopeInsert) + + if ins.WithClause != nil { + if err := a.processCTEs(ins.WithClause, insertScope); err != nil { + return nil, err + } + } + + // Add the target table + if ins.Relation != nil { + if err := a.processFromItem(ins.Relation, insertScope); err != nil { + return nil, err + } + } + + return insertScope, nil +} + +// buildUpdateScope builds the scope graph for an UPDATE statement. +func (a *Analyzer) buildUpdateScope(upd *ast.UpdateStmt) (*scope.Scope, error) { + updateScope := scope.NewScope(scope.ScopeUpdate) + + if upd.WithClause != nil { + if err := a.processCTEs(upd.WithClause, updateScope); err != nil { + return nil, err + } + } + + // Add tables from Relations + if upd.Relations != nil { + for _, item := range upd.Relations.Items { + if err := a.processFromItem(item, updateScope); err != nil { + return nil, err + } + } + } + + // Add tables from FROM clause + if upd.FromClause != nil { + for _, item := range upd.FromClause.Items { + if err := a.processFromItem(item, updateScope); err != nil { + return nil, err + } + } + } + + return updateScope, nil +} + +// buildDeleteScope builds the scope graph for a DELETE statement. +func (a *Analyzer) buildDeleteScope(del *ast.DeleteStmt) (*scope.Scope, error) { + deleteScope := scope.NewScope(scope.ScopeDelete) + + if del.WithClause != nil { + if err := a.processCTEs(del.WithClause, deleteScope); err != nil { + return nil, err + } + } + + if del.Relations != nil { + for _, item := range del.Relations.Items { + if err := a.processFromItem(item, deleteScope); err != nil { + return nil, err + } + } + } + + return deleteScope, nil +} + +// processCTEs adds CTE declarations to the given scope. +func (a *Analyzer) processCTEs(with *ast.WithClause, parentScope *scope.Scope) error { + if with == nil || with.Ctes == nil { + return nil + } + + for _, item := range with.Ctes.Items { + cte, ok := item.(*ast.CommonTableExpr) + if !ok || cte.Ctename == nil { + continue + } + + // Build the CTE's scope by analyzing its query + cteQueryScope, err := a.buildScopeGraph(cte.Ctequery) + if err != nil { + continue // Don't fail on CTE analysis errors + } + + // If the CTE has explicit column names, use those + cteScope := scope.NewScope(scope.ScopeCTE) + if cte.Aliascolnames != nil { + for _, nameNode := range cte.Aliascolnames.Items { + if s, ok := nameNode.(*ast.String); ok { + cteScope.DeclareColumn(s.Str, scope.TypeUnknown, 0) + } + } + } else { + // Copy columns from the CTE's query scope + cols := cteQueryScope.AllColumns("") + for _, col := range cols { + cteScope.DeclareColumn(col.Name, col.Type, col.Location) + } + } + + parentScope.Declare(&scope.Declaration{ + Name: *cte.Ctename, + Kind: scope.DeclCTE, + Type: scope.TypeUnknown, + Scope: cteScope, + }) + } + + return nil +} + +// processFromItem adds a FROM clause item (table, join, subquery) to the scope. +func (a *Analyzer) processFromItem(item ast.Node, parentScope *scope.Scope) error { + switch n := item.(type) { + case *ast.RangeVar: + return a.processRangeVar(n, parentScope) + + case *ast.JoinExpr: + return a.processJoinExpr(n, parentScope) + + case *ast.RangeSubselect: + return a.processRangeSubselect(n, parentScope) + + case *ast.RangeFunction: + // Function in FROM clause — add placeholder columns + return a.processRangeFunction(n, parentScope) + + default: + return nil // Ignore unknown FROM item types + } +} + +// processRangeVar looks up a table in the catalog and adds it to the scope. +func (a *Analyzer) processRangeVar(rv *ast.RangeVar, parentScope *scope.Scope) error { + if rv == nil || rv.Relname == nil { + return nil + } + + tableName := &ast.TableName{Name: *rv.Relname} + if rv.Schemaname != nil { + tableName.Schema = *rv.Schemaname + } + + // Create a scope for this table's columns + tableScope := scope.NewScope(scope.ScopeFrom) + + // Look up the table in the catalog + table, err := a.catalog.GetTable(tableName) + if err != nil { + // Table might be a CTE — check if it's already declared in parent + _, resolveErr := parentScope.ResolveQualified(*rv.Relname, "") + if resolveErr != nil { + // Not found anywhere — declare empty scope so analysis can continue + tableScope.DeclareColumn("*", scope.TypeUnknown, 0) + } + } else { + // Add all columns from the catalog + for _, col := range table.Columns { + typ := scope.Type{ + Name: col.Type.Name, + Schema: col.Type.Schema, + NotNull: col.IsNotNull, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Unsigned: col.IsUnsigned, + Length: col.Length, + } + tableScope.DeclareColumn(col.Name, typ, 0) + } + } + + // Determine the name to use (alias or table name) + name := *rv.Relname + if rv.Alias != nil && rv.Alias.Aliasname != nil { + alias := *rv.Alias.Aliasname + // Register both the alias edge and the table declaration + parentScope.AddAlias(alias, tableScope) + parentScope.Declare(&scope.Declaration{ + Name: alias, + Kind: scope.DeclAlias, + Type: scope.TypeUnknown, + Scope: tableScope, + }) + } else { + parentScope.Declare(&scope.Declaration{ + Name: name, + Kind: scope.DeclTable, + Type: scope.TypeUnknown, + Scope: tableScope, + }) + } + + return nil +} + +// processJoinExpr processes a JOIN and adds both sides to the scope. +func (a *Analyzer) processJoinExpr(join *ast.JoinExpr, parentScope *scope.Scope) error { + if join == nil { + return nil + } + + // Process left side + if join.Larg != nil { + if err := a.processFromItem(join.Larg, parentScope); err != nil { + return err + } + } + + // Process right side + if join.Rarg != nil { + if err := a.processFromItem(join.Rarg, parentScope); err != nil { + return err + } + } + + return nil +} + +// processRangeSubselect processes a subquery in the FROM clause. +func (a *Analyzer) processRangeSubselect(rs *ast.RangeSubselect, parentScope *scope.Scope) error { + if rs == nil { + return nil + } + + subScope := scope.NewScope(scope.ScopeSubquery) + + // Analyze the subquery + if rs.Subquery != nil { + subQueryScope, err := a.buildScopeGraph(rs.Subquery) + if err == nil { + cols := subQueryScope.AllColumns("") + for _, col := range cols { + subScope.DeclareColumn(col.Name, col.Type, col.Location) + } + } + } + + if rs.Alias != nil && rs.Alias.Aliasname != nil { + alias := *rs.Alias.Aliasname + parentScope.AddAlias(alias, subScope) + parentScope.Declare(&scope.Declaration{ + Name: alias, + Kind: scope.DeclAlias, + Type: scope.TypeUnknown, + Scope: subScope, + }) + } + + return nil +} + +// processRangeFunction processes a function call in the FROM clause. +func (a *Analyzer) processRangeFunction(rf *ast.RangeFunction, parentScope *scope.Scope) error { + if rf == nil { + return nil + } + + funcScope := scope.NewScope(scope.ScopeFunction) + + // If there's an alias with column definitions, use those + if rf.Alias != nil { + if rf.Alias.Colnames != nil { + for _, nameNode := range rf.Alias.Colnames.Items { + if s, ok := nameNode.(*ast.String); ok { + funcScope.DeclareColumn(s.Str, scope.TypeUnknown, 0) + } + } + } + if rf.Alias.Aliasname != nil { + alias := *rf.Alias.Aliasname + parentScope.AddAlias(alias, funcScope) + parentScope.Declare(&scope.Declaration{ + Name: alias, + Kind: scope.DeclAlias, + Type: scope.TypeUnknown, + Scope: funcScope, + }) + } + } + + return nil +} + +// typeCheckStatement walks the AST and performs bidirectional type checking. +func (a *Analyzer) typeCheckStatement(stmt ast.Node, rootScope *scope.Scope) { + switch n := stmt.(type) { + case *ast.SelectStmt: + a.typeCheckSelect(n, rootScope) + case *ast.InsertStmt: + a.typeCheckInsert(n, rootScope) + case *ast.UpdateStmt: + a.typeCheckUpdate(n, rootScope) + case *ast.DeleteStmt: + a.typeCheckDelete(n, rootScope) + } +} + +// typeCheckSelect type-checks a SELECT statement's expressions. +func (a *Analyzer) typeCheckSelect(sel *ast.SelectStmt, selectScope *scope.Scope) { + if sel == nil { + return + } + + // Handle UNION + if sel.Larg != nil { + lScope, _ := a.buildSelectScope(sel.Larg) + if lScope != nil { + a.typeCheckSelect(sel.Larg, lScope) + } + return + } + + // Type-check WHERE clause + if sel.WhereClause != nil { + a.typeCheckExpr(sel.WhereClause, selectScope) + } + + // Type-check HAVING clause + if sel.HavingClause != nil { + a.typeCheckExpr(sel.HavingClause, selectScope) + } + + // Type-check LIMIT/OFFSET (they should be integer) + if sel.LimitCount != nil { + expr := a.astToExpr(sel.LimitCount, selectScope) + if expr != nil { + a.checker.Check(expr, scope.TypeInt, 0) + } + } + if sel.LimitOffset != nil { + expr := a.astToExpr(sel.LimitOffset, selectScope) + if expr != nil { + a.checker.Check(expr, scope.TypeInt, 0) + } + } +} + +// typeCheckInsert type-checks an INSERT statement. +func (a *Analyzer) typeCheckInsert(ins *ast.InsertStmt, insertScope *scope.Scope) { + if ins == nil { + return + } + + // For INSERT ... VALUES, check parameter types against column types + if ins.SelectStmt != nil { + if valSel, ok := ins.SelectStmt.(*ast.SelectStmt); ok && valSel.ValuesLists != nil { + a.typeCheckInsertValues(ins, valSel, insertScope) + } + } +} + +// typeCheckInsertValues infers parameter types in INSERT VALUES from column types. +func (a *Analyzer) typeCheckInsertValues(ins *ast.InsertStmt, valSel *ast.SelectStmt, insertScope *scope.Scope) { + if ins.Relation == nil || ins.Relation.Relname == nil { + return + } + + tableName := &ast.TableName{Name: *ins.Relation.Relname} + if ins.Relation.Schemaname != nil { + tableName.Schema = *ins.Relation.Schemaname + } + + table, err := a.catalog.GetTable(tableName) + if err != nil { + return + } + + // Get the column names from the INSERT column list + var targetCols []string + if ins.Cols != nil { + for _, item := range ins.Cols.Items { + if rt, ok := item.(*ast.ResTarget); ok && rt.Name != nil { + targetCols = append(targetCols, *rt.Name) + } + } + } else { + // No explicit columns — use all table columns in order + for _, col := range table.Columns { + targetCols = append(targetCols, col.Name) + } + } + + // Build a map of column name -> type + colTypes := make(map[string]scope.Type) + for _, col := range table.Columns { + colTypes[col.Name] = scope.Type{ + Name: col.Type.Name, + Schema: col.Type.Schema, + NotNull: col.IsNotNull, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Unsigned: col.IsUnsigned, + Length: col.Length, + } + } + + // Type-check each value against its target column + for _, row := range valSel.ValuesLists.Items { + rowList, ok := row.(*ast.List) + if !ok { + continue + } + for i, val := range rowList.Items { + if i >= len(targetCols) { + break + } + colType, exists := colTypes[targetCols[i]] + if !exists { + continue + } + expr := a.astToExpr(val, insertScope) + if expr != nil { + // Use checking mode: the parameter should have the column's type + a.checker.Check(expr, colType, 0) + } + } + } +} + +// typeCheckUpdate type-checks an UPDATE statement. +func (a *Analyzer) typeCheckUpdate(upd *ast.UpdateStmt, updateScope *scope.Scope) { + if upd == nil { + return + } + + // Type-check SET clause values against their target columns + if upd.TargetList != nil { + for _, item := range upd.TargetList.Items { + rt, ok := item.(*ast.ResTarget) + if !ok || rt.Name == nil || rt.Val == nil { + continue + } + + // Look up the column type from the scope + resolved, err := updateScope.Resolve(*rt.Name) + if err != nil { + continue + } + + // Check the value against the column's type + expr := a.astToExpr(rt.Val, updateScope) + if expr != nil { + a.checker.Check(expr, resolved.Declaration.Type, rt.Location) + } + } + } + + // Type-check WHERE clause + if upd.WhereClause != nil { + a.typeCheckExpr(upd.WhereClause, updateScope) + } +} + +// typeCheckDelete type-checks a DELETE statement. +func (a *Analyzer) typeCheckDelete(del *ast.DeleteStmt, deleteScope *scope.Scope) { + if del == nil { + return + } + if del.WhereClause != nil { + a.typeCheckExpr(del.WhereClause, deleteScope) + } +} + +// typeCheckExpr walks an AST expression, synthesizing types and +// using checking mode where appropriate. +func (a *Analyzer) typeCheckExpr(node ast.Node, sc *scope.Scope) { + expr := a.astToExpr(node, sc) + if expr != nil { + a.checker.Synth(expr) + } +} + +// astToExpr converts a sqlc AST node to a type-checkable expression. +// This is the bridge between the sqlc AST and the type checker's expression language. +func (a *Analyzer) astToExpr(node ast.Node, sc *scope.Scope) typecheck.Expr { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.A_Const: + return a.constToExpr(n) + + case *ast.ColumnRef: + return a.columnRefToExpr(n, sc) + + case *ast.ParamRef: + return &typecheck.ParamExpr{ + Number: n.Number, + Location: n.Location, + } + + case *ast.A_Expr: + return a.aExprToExpr(n, sc) + + case *ast.BoolExpr: + return a.boolExprToExpr(n, sc) + + case *ast.FuncCall: + return a.funcCallToExpr(n, sc) + + case *ast.TypeCast: + return a.typeCastToExpr(n, sc) + + case *ast.SubLink: + return a.subLinkToExpr(n) + + case *ast.NullTest: + arg := a.astToExpr(n.Arg, sc) + if arg == nil { + arg = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + return &typecheck.NullTestExpr{ + Arg: arg, + IsNot: n.Nulltesttype == ast.NullTestTypeIsNotNull, + Location: n.Location, + } + + case *ast.CaseExpr: + resultType := scope.TypeUnknown + if n.Defresult != nil { + if tc, ok := n.Defresult.(*ast.TypeCast); ok && tc.TypeName != nil { + resultType = typeNameToScopeType(tc.TypeName) + } + } + return &typecheck.CaseExpr{ResultType: resultType, Location: 0} + + case *ast.CoalesceExpr: + var args []typecheck.Expr + if n.Args != nil { + for _, item := range n.Args.Items { + if e := a.astToExpr(item, sc); e != nil { + args = append(args, e) + } + } + } + return &typecheck.CoalesceExpr{Args: args} + + case *ast.In: + return a.inToExpr(n, sc) + + case *ast.BetweenExpr: + return a.betweenToExpr(n, sc) + + case *ast.List: + // For lists (e.g., value lists), just check each item + for _, item := range n.Items { + a.astToExpr(item, sc) // Side-effecting: records param types + } + return nil + + default: + return nil + } +} + +func (a *Analyzer) constToExpr(n *ast.A_Const) typecheck.Expr { + if n == nil { + return nil + } + switch n.Val.(type) { + case *ast.String: + return &typecheck.LiteralExpr{Type: scope.TypeText} + case *ast.Integer: + return &typecheck.LiteralExpr{Type: scope.TypeInt} + case *ast.Float: + return &typecheck.LiteralExpr{Type: scope.TypeFloat} + case *ast.Boolean: + return &typecheck.LiteralExpr{Type: scope.TypeBool} + case *ast.Null: + return &typecheck.LiteralExpr{Type: scope.Type{Name: "any"}} + default: + return &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } +} + +func (a *Analyzer) columnRefToExpr(n *ast.ColumnRef, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + parts := stringSlice(n.Fields) + resolved, err := sc.ResolveColumnRef(parts) + if err != nil { + return &typecheck.ColumnRefExpr{ + Parts: parts, + ResolvedType: scope.TypeUnknown, + Location: n.Location, + } + } + + return &typecheck.ColumnRefExpr{ + Parts: parts, + ResolvedType: resolved.Declaration.Type, + Location: n.Location, + } +} + +func (a *Analyzer) aExprToExpr(n *ast.A_Expr, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + op := astutils.Join(n.Name, ".") + + left := a.astToExpr(n.Lexpr, sc) + right := a.astToExpr(n.Rexpr, sc) + + if left == nil { + left = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + if right == nil { + right = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + + return &typecheck.BinaryOpExpr{ + Op: op, + Left: left, + Right: right, + Location: n.Location, + } +} + +func (a *Analyzer) boolExprToExpr(n *ast.BoolExpr, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + var args []typecheck.Expr + if n.Args != nil { + for _, item := range n.Args.Items { + if e := a.astToExpr(item, sc); e != nil { + args = append(args, e) + } + } + } + + var op string + switch n.Boolop { + case ast.BoolExprTypeAnd: + op = "AND" + case ast.BoolExprTypeOr: + op = "OR" + case ast.BoolExprTypeNot: + op = "NOT" + default: + // IS NULL, IS NOT NULL checks + op = "IS" + } + + return &typecheck.BoolExpr{ + Op: op, + Args: args, + Location: n.Location, + } +} + +func (a *Analyzer) funcCallToExpr(n *ast.FuncCall, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + // Type-check arguments (side effect: infers param types within args) + var args []typecheck.Expr + if n.Args != nil { + for _, item := range n.Args.Items { + if e := a.astToExpr(item, sc); e != nil { + args = append(args, e) + } + } + } + + // Try to resolve the function's return type from the catalog + returnType := scope.TypeUnknown + fun, err := a.catalog.ResolveFuncCall(n) + if err == nil && fun.ReturnType != nil { + returnType = typeNameToScopeType(fun.ReturnType) + + // Use checking mode on arguments against function parameter types + for i, arg := range args { + if i < len(fun.Args) && fun.Args[i].Type != nil { + expectedType := typeNameToScopeType(fun.Args[i].Type) + a.checker.Check(arg, expectedType, n.Location) + } + } + } + + return &typecheck.FuncCallExpr{ + Name: n.Func.Name, + Args: args, + ReturnType: returnType, + Location: n.Location, + } +} + +func (a *Analyzer) typeCastToExpr(n *ast.TypeCast, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + arg := a.astToExpr(n.Arg, sc) + if arg == nil { + arg = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + + castType := scope.TypeUnknown + if n.TypeName != nil { + castType = typeNameToScopeType(n.TypeName) + } + + // If the argument is a parameter, infer its type from the cast + if param, ok := arg.(*typecheck.ParamExpr); ok { + a.checker.InferParamFromContext(param.Number, castType, param.Location) + } + + return &typecheck.TypeCastExpr{ + Arg: arg, + CastType: castType, + Location: 0, + } +} + +func (a *Analyzer) subLinkToExpr(n *ast.SubLink) typecheck.Expr { + if n == nil { + return nil + } + + switch n.SubLinkType { + case ast.EXISTS_SUBLINK: + return &typecheck.SubqueryExpr{IsExists: true} + default: + return &typecheck.SubqueryExpr{Columns: []scope.Type{scope.TypeUnknown}} + } +} + +func (a *Analyzer) inToExpr(n *ast.In, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + exprNode := a.astToExpr(n.Expr, sc) + if exprNode == nil { + exprNode = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + + var values []typecheck.Expr + for _, item := range n.List { + if e := a.astToExpr(item, sc); e != nil { + values = append(values, e) + } + } + + // For IN expressions, use checking mode on list items: + // each item should match the type of the expression + exprType := a.checker.Synth(exprNode) + if !exprType.Type.IsUnknown() { + for _, v := range values { + a.checker.Check(v, exprType.Type, 0) + } + } + + return &typecheck.InExpr{ + Expr: exprNode, + Values: values, + } +} + +func (a *Analyzer) betweenToExpr(n *ast.BetweenExpr, sc *scope.Scope) typecheck.Expr { + if n == nil { + return nil + } + + exprNode := a.astToExpr(n.Expr, sc) + low := a.astToExpr(n.Left, sc) + high := a.astToExpr(n.Right, sc) + + if exprNode == nil { + exprNode = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + if low == nil { + low = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + if high == nil { + high = &typecheck.LiteralExpr{Type: scope.TypeUnknown} + } + + // Between: low and high should match the type of the expression + exprType := a.checker.Synth(exprNode) + if !exprType.Type.IsUnknown() { + a.checker.Check(low, exprType.Type, 0) + a.checker.Check(high, exprType.Type, 0) + } + + return &typecheck.BetweenExpr{ + Expr: exprNode, + Low: low, + High: high, + } +} + +// computeOutputColumns determines the columns in the query's result set. +func (a *Analyzer) computeOutputColumns(stmt ast.Node, sc *scope.Scope) ([]OutputColumn, error) { + var targets *ast.List + + switch n := stmt.(type) { + case *ast.SelectStmt: + if n.Larg != nil { + return a.computeOutputColumns(n.Larg, sc) + } + targets = n.TargetList + case *ast.InsertStmt: + targets = n.ReturningList + case *ast.UpdateStmt: + targets = n.ReturningList + case *ast.DeleteStmt: + targets = n.ReturningList + } + + if targets == nil { + return nil, nil + } + + var cols []OutputColumn + for _, item := range targets.Items { + res, ok := item.(*ast.ResTarget) + if !ok { + continue + } + + switch val := res.Val.(type) { + case *ast.ColumnRef: + parts := stringSlice(val.Fields) + + // Handle SELECT * + for _, field := range val.Fields.Items { + if _, isStar := field.(*ast.A_Star); isStar { + qualifier := "" + if len(parts) > 0 && parts[0] != "*" { + qualifier = parts[0] + } + allCols := sc.AllColumns(qualifier) + for _, d := range allCols { + cols = append(cols, OutputColumn{ + Name: d.Name, + Type: d.Type, + }) + } + goto nextTarget + } + } + + // Regular column reference + resolved, err := sc.ResolveColumnRef(parts) + if err != nil { + name := parts[len(parts)-1] + if res.Name != nil { + name = *res.Name + } + cols = append(cols, OutputColumn{ + Name: name, + Type: scope.TypeUnknown, + }) + } else { + name := resolved.Declaration.Name + if res.Name != nil { + name = *res.Name + } + cols = append(cols, OutputColumn{ + Name: name, + Type: resolved.Declaration.Type, + }) + } + + case *ast.FuncCall: + name := val.Func.Name + if res.Name != nil { + name = *res.Name + } + funcExpr := a.funcCallToExpr(val, sc) + result := a.checker.Synth(funcExpr) + cols = append(cols, OutputColumn{ + Name: name, + Type: result.Type, + }) + + case *ast.A_Const: + name := "" + if res.Name != nil { + name = *res.Name + } + constExpr := a.constToExpr(val) + result := a.checker.Synth(constExpr) + cols = append(cols, OutputColumn{ + Name: name, + Type: result.Type, + }) + + default: + name := "" + if res.Name != nil { + name = *res.Name + } + if val != nil { + expr := a.astToExpr(val, sc) + if expr != nil { + result := a.checker.Synth(expr) + cols = append(cols, OutputColumn{ + Name: name, + Type: result.Type, + }) + continue + } + } + cols = append(cols, OutputColumn{ + Name: name, + Type: scope.TypeUnknown, + }) + } + nextTarget: + } + + return cols, nil +} + +// Helper functions + +func stringSlice(list *ast.List) []string { + if list == nil { + return nil + } + var result []string + for _, item := range list.Items { + switch n := item.(type) { + case *ast.String: + result = append(result, n.Str) + case *ast.A_Star: + // Don't include star in string slice + } + } + return result +} + +func typeNameToScopeType(tn *ast.TypeName) scope.Type { + if tn == nil { + return scope.TypeUnknown + } + return scope.Type{ + Name: tn.Name, + Schema: tn.Schema, + } +} diff --git a/internal/analysis/sqlanalyze/analyzer_test.go b/internal/analysis/sqlanalyze/analyzer_test.go new file mode 100644 index 0000000000..070b922c81 --- /dev/null +++ b/internal/analysis/sqlanalyze/analyzer_test.go @@ -0,0 +1,579 @@ +package sqlanalyze + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/analysis/scope" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +// buildTestCatalog creates a catalog with users and orders tables. +func buildTestCatalog(defaultSchema string) *catalog.Catalog { + c := catalog.New(defaultSchema) + + intType := ast.TypeName{Name: "integer"} + textType := ast.TypeName{Name: "text"} + numericType := ast.TypeName{Name: "numeric"} + boolType := ast.TypeName{Name: "boolean"} + + c.Update(ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateTableStmt{ + Name: &ast.TableName{Name: "users"}, + Cols: []*ast.ColumnDef{ + {Colname: "id", TypeName: &intType, IsNotNull: true}, + {Colname: "name", TypeName: &textType, IsNotNull: true}, + {Colname: "email", TypeName: &textType, IsNotNull: false}, + {Colname: "age", TypeName: &intType, IsNotNull: false}, + {Colname: "active", TypeName: &boolType, IsNotNull: true}, + }, + }, + }, + }, nil) + + c.Update(ast.Statement{ + Raw: &ast.RawStmt{ + Stmt: &ast.CreateTableStmt{ + Name: &ast.TableName{Name: "orders"}, + Cols: []*ast.ColumnDef{ + {Colname: "id", TypeName: &intType, IsNotNull: true}, + {Colname: "user_id", TypeName: &intType, IsNotNull: true}, + {Colname: "total", TypeName: &numericType, IsNotNull: true}, + {Colname: "status", TypeName: &textType, IsNotNull: false}, + }, + }, + }, + }, nil) + + return c +} + +func TestAnalyzeSelectSimple(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT name, email FROM users WHERE id = $1 + nameStr := "name" + emailStr := "email" + relname := "users" + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}, + }, + Name: &nameStr, + }, + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "email"}}}, + }, + Name: &emailStr, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.ParamRef{Number: 1, Location: 40}, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // Check output columns + if len(result.OutputColumns) != 2 { + t.Fatalf("expected 2 output columns, got %d", len(result.OutputColumns)) + } + if result.OutputColumns[0].Name != "name" { + t.Errorf("col 0 name: got %q, want 'name'", result.OutputColumns[0].Name) + } + if result.OutputColumns[0].Type.Name != "text" { + t.Errorf("col 0 type: got %q, want 'text'", result.OutputColumns[0].Type.Name) + } + if result.OutputColumns[1].Name != "email" { + t.Errorf("col 1 name: got %q, want 'email'", result.OutputColumns[1].Name) + } + + // Check parameter type inference + if len(result.ParamTypes) != 1 { + t.Fatalf("expected 1 param, got %d", len(result.ParamTypes)) + } + p := result.ParamTypes[1] + if p == nil { + t.Fatal("param $1 not found") + } + if p.Type.Name != "integer" { + t.Errorf("param $1 type: got %q, want 'integer'", p.Type.Name) + } +} + +func TestAnalyzeSelectWithAlias(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT u.name FROM users AS u WHERE u.id = $1 + nameStr := "name" + relname := "users" + aliasname := "u" + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{ + &ast.String{Str: "u"}, + &ast.String{Str: "name"}, + }}, + }, + Name: &nameStr, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: &relname, + Alias: &ast.Alias{Aliasname: &aliasname}, + }, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{ + &ast.String{Str: "u"}, + &ast.String{Str: "id"}, + }}, + }, + Rexpr: &ast.ParamRef{Number: 1, Location: 45}, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + if len(result.OutputColumns) != 1 { + t.Fatalf("expected 1 output column, got %d", len(result.OutputColumns)) + } + if result.OutputColumns[0].Type.Name != "text" { + t.Errorf("col type: got %q, want 'text'", result.OutputColumns[0].Type.Name) + } + + // $1 should be inferred as integer from u.id + if p := result.ParamTypes[1]; p == nil { + t.Error("param $1 not found") + } else if p.Type.Name != "integer" { + t.Errorf("param $1: got %q, want 'integer'", p.Type.Name) + } +} + +func TestAnalyzeSelectStar(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT * FROM users + relname := "users" + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.A_Star{}}}, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // users has 5 columns: id, name, email, age, active + if len(result.OutputColumns) != 5 { + t.Fatalf("expected 5 output columns, got %d", len(result.OutputColumns)) + } + + // Check that types are resolved + expectedCols := []struct { + name string + typeName string + }{ + {"id", "integer"}, + {"name", "text"}, + {"email", "text"}, + {"age", "integer"}, + {"active", "boolean"}, + } + for i, expected := range expectedCols { + if i >= len(result.OutputColumns) { + break + } + col := result.OutputColumns[i] + if col.Name != expected.name { + t.Errorf("col %d name: got %q, want %q", i, col.Name, expected.name) + } + if col.Type.Name != expected.typeName { + t.Errorf("col %d (%s) type: got %q, want %q", i, expected.name, col.Type.Name, expected.typeName) + } + } +} + +func TestAnalyzeJoin(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT u.name, o.total FROM users AS u JOIN orders AS o ON u.id = o.user_id + nameStr := "name" + totalStr := "total" + usersStr := "users" + ordersStr := "orders" + uAlias := "u" + oAlias := "o" + + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{ + &ast.String{Str: "u"}, &ast.String{Str: "name"}, + }}, + }, + Name: &nameStr, + }, + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{ + &ast.String{Str: "o"}, &ast.String{Str: "total"}, + }}, + }, + Name: &totalStr, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.JoinExpr{ + Jointype: ast.JoinTypeInner, + Larg: &ast.RangeVar{ + Relname: &usersStr, + Alias: &ast.Alias{Aliasname: &uAlias}, + }, + Rarg: &ast.RangeVar{ + Relname: &ordersStr, + Alias: &ast.Alias{Aliasname: &oAlias}, + }, + }, + }, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + if len(result.OutputColumns) != 2 { + t.Fatalf("expected 2 output columns, got %d", len(result.OutputColumns)) + } + if result.OutputColumns[0].Type.Name != "text" { + t.Errorf("u.name type: got %q, want 'text'", result.OutputColumns[0].Type.Name) + } + if result.OutputColumns[1].Type.Name != "numeric" { + t.Errorf("o.total type: got %q, want 'numeric'", result.OutputColumns[1].Type.Name) + } +} + +func TestAnalyzeInsertParamTypes(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // INSERT INTO users (name, email) VALUES ($1, $2) + relname := "users" + nameCol := "name" + emailCol := "email" + + stmt := &ast.RawStmt{ + Stmt: &ast.InsertStmt{ + Relation: &ast.RangeVar{Relname: &relname}, + Cols: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{Name: &nameCol}, + &ast.ResTarget{Name: &emailCol}, + }, + }, + SelectStmt: &ast.SelectStmt{ + ValuesLists: &ast.List{ + Items: []ast.Node{ + &ast.List{ + Items: []ast.Node{ + &ast.ParamRef{Number: 1, Location: 40}, + &ast.ParamRef{Number: 2, Location: 44}, + }, + }, + }, + }, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // $1 should be text (name column), $2 should be text (email column) + if len(result.ParamTypes) != 2 { + t.Fatalf("expected 2 params, got %d", len(result.ParamTypes)) + } + if p := result.ParamTypes[1]; p == nil { + t.Error("$1 not found") + } else if p.Type.Name != "text" { + t.Errorf("$1: got %q, want 'text'", p.Type.Name) + } + if p := result.ParamTypes[2]; p == nil { + t.Error("$2 not found") + } else if p.Type.Name != "text" { + t.Errorf("$2: got %q, want 'text'", p.Type.Name) + } +} + +func TestAnalyzeUpdate(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // UPDATE users SET name = $1 WHERE id = $2 + relname := "users" + nameCol := "name" + + stmt := &ast.RawStmt{ + Stmt: &ast.UpdateStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Name: &nameCol, + Val: &ast.ParamRef{Number: 1, Location: 25}, + }, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.ParamRef{Number: 2, Location: 45}, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // $1 should be text (name column type), $2 should be integer (id column type) + if p := result.ParamTypes[1]; p == nil { + t.Error("$1 not found") + } else if p.Type.Name != "text" { + t.Errorf("$1: got %q, want 'text'", p.Type.Name) + } + if p := result.ParamTypes[2]; p == nil { + t.Error("$2 not found") + } else if p.Type.Name != "integer" { + t.Errorf("$2: got %q, want 'integer'", p.Type.Name) + } +} + +func TestAnalyzeDelete(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // DELETE FROM users WHERE id = $1 + relname := "users" + stmt := &ast.RawStmt{ + Stmt: &ast.DeleteStmt{ + Relations: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + WhereClause: &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "="}}}, + Lexpr: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "id"}}}, + }, + Rexpr: &ast.ParamRef{Number: 1, Location: 30}, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + if p := result.ParamTypes[1]; p == nil { + t.Error("$1 not found") + } else if p.Type.Name != "integer" { + t.Errorf("$1: got %q, want 'integer'", p.Type.Name) + } +} + +func TestAnalyzeLimitOffset(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT name FROM users LIMIT $1 OFFSET $2 + nameStr := "name" + relname := "users" + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}, + }, + Name: &nameStr, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + LimitCount: &ast.ParamRef{Number: 1, Location: 30}, + LimitOffset: &ast.ParamRef{Number: 2, Location: 40}, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // Both LIMIT and OFFSET params should be integer + if p := result.ParamTypes[1]; p == nil { + t.Error("$1 (LIMIT) not found") + } else if p.Type.Name != "integer" { + t.Errorf("$1 (LIMIT): got %q, want 'integer'", p.Type.Name) + } + if p := result.ParamTypes[2]; p == nil { + t.Error("$2 (OFFSET) not found") + } else if p.Type.Name != "integer" { + t.Errorf("$2 (OFFSET): got %q, want 'integer'", p.Type.Name) + } +} + +func TestScopeGraphStructure(t *testing.T) { + cat := buildTestCatalog("public") + a := New(cat, config.EnginePostgreSQL) + + // SELECT name FROM users + nameStr := "name" + relname := "users" + stmt := &ast.RawStmt{ + Stmt: &ast.SelectStmt{ + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "name"}}}, + }, + Name: &nameStr, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{Relname: &relname}, + }, + }, + }, + } + + result, err := a.AnalyzeQuery(stmt) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + // Verify scope graph structure + rootScope := result.RootScope + if rootScope.Kind != scope.ScopeSelect { + t.Errorf("root scope kind: got %v, want scope.ScopeSelect", rootScope.Kind) + } + + // Should have a parent edge to the FROM scope + if len(rootScope.Edges) == 0 { + t.Fatal("root scope has no edges") + } + + parentEdge := rootScope.Edges[0] + if parentEdge.Kind != scope.EdgeParent { + t.Errorf("edge kind: got %v, want scope.EdgeParent", parentEdge.Kind) + } + + fromScope := parentEdge.Target + if fromScope.Kind != scope.ScopeFrom { + t.Errorf("from scope kind: got %v, want scope.ScopeFrom", fromScope.Kind) + } + + // FROM scope should have a table declaration for "users" + if len(fromScope.Declarations) == 0 { + t.Fatal("FROM scope has no declarations") + } + + tableDecl := fromScope.Declarations[0] + if tableDecl.Name != "users" { + t.Errorf("table declaration: got %q, want 'users'", tableDecl.Name) + } + if tableDecl.Kind != scope.DeclTable { + t.Errorf("declaration kind: got %v, want scope.DeclTable", tableDecl.Kind) + } + + // The table's scope should have column declarations + if tableDecl.Scope == nil { + t.Fatal("table declaration has no scope") + } + if len(tableDecl.Scope.Declarations) != 5 { + t.Errorf("table scope has %d declarations, want 5", len(tableDecl.Scope.Declarations)) + } +} + diff --git a/internal/analysis/typecheck/expr.go b/internal/analysis/typecheck/expr.go new file mode 100644 index 0000000000..51fb823dff --- /dev/null +++ b/internal/analysis/typecheck/expr.go @@ -0,0 +1,139 @@ +package typecheck + +import "github.com/sqlc-dev/sqlc/internal/analysis/scope" + +// Expr is the interface for type-checkable SQL expressions. +// This is a simplified representation of the SQL AST focused on type information. +type Expr interface { + exprNode() + ExprLocation() int +} + +// LiteralExpr represents a constant value (string, int, float, bool, null). +type LiteralExpr struct { + Type scope.Type + Location int +} + +func (*LiteralExpr) exprNode() {} +func (e *LiteralExpr) ExprLocation() int { return e.Location } + +// ColumnRefExpr represents a resolved column reference. +type ColumnRefExpr struct { + Parts []string // The original parts (e.g., ["u", "name"]) + ResolvedType scope.Type // The type from the catalog after name resolution + Location int +} + +func (*ColumnRefExpr) exprNode() {} +func (e *ColumnRefExpr) ExprLocation() int { return e.Location } + +// ParamExpr represents a query parameter ($1, $2, ?, etc.). +type ParamExpr struct { + Number int // 1-indexed parameter number + Location int +} + +func (*ParamExpr) exprNode() {} +func (e *ParamExpr) ExprLocation() int { return e.Location } + +// BinaryOpExpr represents a binary operation (a + b, a = b, a AND b, etc.). +type BinaryOpExpr struct { + Op string + Left Expr + Right Expr + Location int +} + +func (*BinaryOpExpr) exprNode() {} +func (e *BinaryOpExpr) ExprLocation() int { return e.Location } + +// FuncCallExpr represents a function call. +type FuncCallExpr struct { + Name string + Args []Expr + ReturnType scope.Type + Location int +} + +func (*FuncCallExpr) exprNode() {} +func (e *FuncCallExpr) ExprLocation() int { return e.Location } + +// TypeCastExpr represents an explicit type cast (e.g., $1::integer). +type TypeCastExpr struct { + Arg Expr + CastType scope.Type + Location int +} + +func (*TypeCastExpr) exprNode() {} +func (e *TypeCastExpr) ExprLocation() int { return e.Location } + +// SubqueryExpr represents a scalar subquery or EXISTS subquery. +type SubqueryExpr struct { + Columns []scope.Type // Column types of the subquery result + IsExists bool + Location int +} + +func (*SubqueryExpr) exprNode() {} +func (e *SubqueryExpr) ExprLocation() int { return e.Location } + +// BoolExpr represents AND, OR, NOT operations. +type BoolExpr struct { + Op string // "AND", "OR", "NOT" + Args []Expr + Location int +} + +func (*BoolExpr) exprNode() {} +func (e *BoolExpr) ExprLocation() int { return e.Location } + +// NullTestExpr represents IS NULL / IS NOT NULL. +type NullTestExpr struct { + Arg Expr + IsNot bool + Location int +} + +func (*NullTestExpr) exprNode() {} +func (e *NullTestExpr) ExprLocation() int { return e.Location } + +// CaseExpr represents a CASE WHEN ... THEN ... ELSE ... END expression. +type CaseExpr struct { + ResultType scope.Type + Location int +} + +func (*CaseExpr) exprNode() {} +func (e *CaseExpr) ExprLocation() int { return e.Location } + +// CoalesceExpr represents COALESCE(a, b, c, ...). +type CoalesceExpr struct { + Args []Expr + Location int +} + +func (*CoalesceExpr) exprNode() {} +func (e *CoalesceExpr) ExprLocation() int { return e.Location } + +// InExpr represents `expr IN (values...)` or `expr IN (subquery)`. +type InExpr struct { + Expr Expr + Values []Expr + Location int +} + +func (*InExpr) exprNode() {} +func (e *InExpr) ExprLocation() int { return e.Location } + +// BetweenExpr represents `expr BETWEEN low AND high`. +type BetweenExpr struct { + Expr Expr + Low Expr + High Expr + Location int +} + +func (*BetweenExpr) exprNode() {} +func (e *BetweenExpr) ExprLocation() int { return e.Location } diff --git a/internal/analysis/typecheck/rules.go b/internal/analysis/typecheck/rules.go new file mode 100644 index 0000000000..23cf5d9f05 --- /dev/null +++ b/internal/analysis/typecheck/rules.go @@ -0,0 +1,166 @@ +package typecheck + +import "github.com/sqlc-dev/sqlc/internal/analysis/scope" + +// DefaultOperatorRules provides baseline type inference rules shared +// across PostgreSQL and MySQL. Engine-specific rules can override this. +type DefaultOperatorRules struct{} + +func (r *DefaultOperatorRules) BinaryOp(op string, left, right scope.Type) scope.Type { + switch op { + // Comparison operators always produce boolean + case "=", "<>", "!=", "<", ">", "<=", ">=", + "IS", "IS NOT", "LIKE", "ILIKE", "NOT LIKE", + "SIMILAR TO", "~", "~*", "!~", "!~*", + "REGEXP", "NOT REGEXP": + return scope.Type{Name: "boolean", NotNull: true} + + // Logical operators produce boolean + case "AND", "OR": + return scope.Type{Name: "boolean", NotNull: left.NotNull && right.NotNull} + + // Concatenation produces text + case "||": + return scope.Type{Name: "text", NotNull: left.NotNull && right.NotNull} + + // Arithmetic operators + case "+", "-", "*": + return r.arithmeticResult(left, right) + case "/": + result := r.arithmeticResult(left, right) + result.NotNull = left.NotNull && right.NotNull + return result + case "%": + return scope.Type{Name: "integer", NotNull: left.NotNull && right.NotNull} + + // Array operators + case "@>", "<@", "&&": + return scope.Type{Name: "boolean", NotNull: true} + + // JSON operators + case "->": + return scope.Type{Name: "jsonb", NotNull: false} + case "->>": + return scope.Type{Name: "text", NotNull: false} + + default: + // For unknown operators, try to propagate the left type + if !left.IsUnknown() { + return left + } + return right + } +} + +func (r *DefaultOperatorRules) UnaryOp(op string, operand scope.Type) scope.Type { + switch op { + case "NOT": + return scope.Type{Name: "boolean", NotNull: operand.NotNull} + case "-", "+": + return operand + case "~": // bitwise not + return operand + default: + return operand + } +} + +func (r *DefaultOperatorRules) CanCast(from, to scope.Type) bool { + if from.IsUnknown() || to.IsUnknown() { + return true + } + if from.Equals(to) { + return true + } + + // Numeric types are generally castable to each other + numericTypes := map[string]bool{ + "integer": true, "int": true, "int4": true, + "bigint": true, "int8": true, + "smallint": true, "int2": true, + "numeric": true, "decimal": true, + "real": true, "float4": true, + "float": true, "double precision": true, "float8": true, + } + if numericTypes[from.Name] && numericTypes[to.Name] { + return true + } + + // Text types are generally castable to each other + textTypes := map[string]bool{ + "text": true, "varchar": true, "char": true, + "character varying": true, "character": true, + "name": true, "citext": true, + } + if textTypes[from.Name] && textTypes[to.Name] { + return true + } + + // Most things can be cast to/from text + if from.Name == "text" || to.Name == "text" { + return true + } + + return true // Be permissive by default +} + +func (r *DefaultOperatorRules) arithmeticResult(left, right scope.Type) scope.Type { + // If either is unknown, use the other + if left.IsUnknown() { + return right + } + if right.IsUnknown() { + return left + } + + // Type promotion hierarchy + hierarchy := map[string]int{ + "smallint": 1, "int2": 1, + "integer": 2, "int": 2, "int4": 2, + "bigint": 3, "int8": 3, + "numeric": 4, "decimal": 4, + "real": 5, "float4": 5, + "float": 6, "double precision": 6, "float8": 6, + } + + lp := hierarchy[left.Name] + rp := hierarchy[right.Name] + + if lp == 0 && rp == 0 { + return scope.Type{Name: "numeric", NotNull: left.NotNull && right.NotNull} + } + if lp >= rp { + return scope.Type{Name: left.Name, NotNull: left.NotNull && right.NotNull} + } + return scope.Type{Name: right.Name, NotNull: left.NotNull && right.NotNull} +} + +// PostgreSQLOperatorRules extends default rules with PostgreSQL-specific behavior. +type PostgreSQLOperatorRules struct { + DefaultOperatorRules +} + +func (r *PostgreSQLOperatorRules) BinaryOp(op string, left, right scope.Type) scope.Type { + switch op { + case "||": + // In PostgreSQL, || is string concatenation + return scope.Type{Name: "text", NotNull: left.NotNull && right.NotNull} + default: + return r.DefaultOperatorRules.BinaryOp(op, left, right) + } +} + +// MySQLOperatorRules extends default rules with MySQL-specific behavior. +type MySQLOperatorRules struct { + DefaultOperatorRules +} + +func (r *MySQLOperatorRules) BinaryOp(op string, left, right scope.Type) scope.Type { + switch op { + case "||": + // In MySQL (with default settings), || is logical OR, not concat + return scope.Type{Name: "boolean", NotNull: left.NotNull && right.NotNull} + default: + return r.DefaultOperatorRules.BinaryOp(op, left, right) + } +} diff --git a/internal/analysis/typecheck/typecheck_test.go b/internal/analysis/typecheck/typecheck_test.go new file mode 100644 index 0000000000..71e9cc918f --- /dev/null +++ b/internal/analysis/typecheck/typecheck_test.go @@ -0,0 +1,370 @@ +package typecheck + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/analysis/scope" +) + +func TestSynthLiteral(t *testing.T) { + c := NewChecker(nil) + + tests := []struct { + name string + expr Expr + want string + }{ + {"int", &LiteralExpr{Type: scope.TypeInt}, "integer"}, + {"text", &LiteralExpr{Type: scope.TypeText}, "text"}, + {"bool", &LiteralExpr{Type: scope.TypeBool}, "boolean"}, + {"float", &LiteralExpr{Type: scope.TypeFloat}, "float"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.Synth(tt.expr) + if result.Type.Name != tt.want { + t.Errorf("Synth: got type %q, want %q", result.Type.Name, tt.want) + } + if result.Source != SourceLiteral { + t.Errorf("Synth: got source %v, want SourceLiteral", result.Source) + } + }) + } +} + +func TestSynthColumnRef(t *testing.T) { + c := NewChecker(nil) + + expr := &ColumnRefExpr{ + Parts: []string{"users", "name"}, + ResolvedType: scope.TypeText, + } + + result := c.Synth(expr) + if result.Type.Name != "text" { + t.Errorf("got type %q, want 'text'", result.Type.Name) + } + if result.Source != SourceCatalog { + t.Errorf("got source %v, want SourceCatalog", result.Source) + } +} + +func TestSynthBinaryOp_Comparison(t *testing.T) { + c := NewChecker(nil) + + // age = 25 → boolean + expr := &BinaryOpExpr{ + Op: "=", + Left: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Right: &LiteralExpr{Type: scope.TypeInt}, + } + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("got type %q, want 'boolean'", result.Type.Name) + } +} + +func TestSynthBinaryOp_Arithmetic(t *testing.T) { + c := NewChecker(nil) + + // age + 1 → integer + expr := &BinaryOpExpr{ + Op: "+", + Left: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Right: &LiteralExpr{Type: scope.TypeInt}, + } + + result := c.Synth(expr) + if result.Type.Name != "integer" { + t.Errorf("got type %q, want 'integer'", result.Type.Name) + } +} + +func TestSynthBinaryOp_Concat(t *testing.T) { + c := NewChecker(nil) + + // name || ' suffix' → text + expr := &BinaryOpExpr{ + Op: "||", + Left: &ColumnRefExpr{ResolvedType: scope.TypeText}, + Right: &LiteralExpr{Type: scope.TypeText}, + } + + result := c.Synth(expr) + if result.Type.Name != "text" { + t.Errorf("got type %q, want 'text'", result.Type.Name) + } +} + +func TestCheckParamInfersType(t *testing.T) { + c := NewChecker(nil) + + // WHERE age > $1 → $1 should be inferred as integer + param := &ParamExpr{Number: 1, Location: 42} + + c.Check(param, scope.TypeInt, 42) + + params := c.ParamTypes() + if len(params) != 1 { + t.Fatalf("expected 1 param, got %d", len(params)) + } + p, ok := params[1] + if !ok { + t.Fatal("param $1 not found") + } + if p.Type.Name != "integer" { + t.Errorf("param $1 type: got %q, want 'integer'", p.Type.Name) + } + if p.Source != SourceContext { + t.Errorf("param $1 source: got %v, want SourceContext", p.Source) + } +} + +func TestSynthBinaryOpWithParam(t *testing.T) { + c := NewChecker(nil) + + // age = $1 → $1 should be inferred as integer (from column ref) + expr := &BinaryOpExpr{ + Op: "=", + Left: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Right: &ParamExpr{Number: 1, Location: 20}, + } + + c.Synth(expr) + + params := c.ParamTypes() + p, ok := params[1] + if !ok { + t.Fatal("param $1 not found after synth") + } + if p.Type.Name != "integer" { + t.Errorf("param $1 type: got %q, want 'integer'", p.Type.Name) + } +} + +func TestCheckMultipleParams(t *testing.T) { + c := NewChecker(nil) + + // WHERE name = $1 AND age > $2 + nameExpr := &BinaryOpExpr{ + Op: "=", + Left: &ColumnRefExpr{ResolvedType: scope.TypeText}, + Right: &ParamExpr{Number: 1, Location: 10}, + } + ageExpr := &BinaryOpExpr{ + Op: ">", + Left: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Right: &ParamExpr{Number: 2, Location: 30}, + } + + c.Synth(nameExpr) + c.Synth(ageExpr) + + params := c.ParamTypes() + if len(params) != 2 { + t.Fatalf("expected 2 params, got %d", len(params)) + } + if params[1].Type.Name != "text" { + t.Errorf("$1: got %q, want 'text'", params[1].Type.Name) + } + if params[2].Type.Name != "integer" { + t.Errorf("$2: got %q, want 'integer'", params[2].Type.Name) + } +} + +func TestSynthFuncCall(t *testing.T) { + c := NewChecker(nil) + + expr := &FuncCallExpr{ + Name: "count", + ReturnType: scope.Type{Name: "bigint", NotNull: true}, + } + + result := c.Synth(expr) + if result.Type.Name != "bigint" { + t.Errorf("got %q, want 'bigint'", result.Type.Name) + } + if result.Source != SourceFunction { + t.Errorf("got source %v, want SourceFunction", result.Source) + } +} + +func TestSynthTypeCast(t *testing.T) { + c := NewChecker(nil) + + // $1::integer → param should get integer type, expr returns integer + expr := &TypeCastExpr{ + Arg: &ParamExpr{Number: 1, Location: 5}, + CastType: scope.TypeInt, + } + + result := c.Synth(expr) + if result.Type.Name != "integer" { + t.Errorf("cast result: got %q, want 'integer'", result.Type.Name) + } +} + +func TestSynthSubquery(t *testing.T) { + c := NewChecker(nil) + + // Scalar subquery with one column + expr := &SubqueryExpr{ + Columns: []scope.Type{scope.TypeInt}, + } + + result := c.Synth(expr) + if result.Type.Name != "integer" { + t.Errorf("subquery: got %q, want 'integer'", result.Type.Name) + } + if result.Kind != KindScalar { + t.Errorf("subquery: got kind %v, want KindScalar", result.Kind) + } +} + +func TestSynthBoolExpr(t *testing.T) { + c := NewChecker(nil) + + expr := &BoolExpr{Op: "AND", Args: []Expr{ + &LiteralExpr{Type: scope.TypeBool}, + &LiteralExpr{Type: scope.TypeBool}, + }} + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("AND: got %q, want 'boolean'", result.Type.Name) + } +} + +func TestSynthNullTest(t *testing.T) { + c := NewChecker(nil) + + expr := &NullTestExpr{ + Arg: &ColumnRefExpr{ResolvedType: scope.TypeText}, + IsNot: false, + } + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("IS NULL: got %q, want 'boolean'", result.Type.Name) + } +} + +func TestSynthCoalesce(t *testing.T) { + c := NewChecker(nil) + + expr := &CoalesceExpr{ + Args: []Expr{ + &ColumnRefExpr{ResolvedType: scope.TypeText}, + &LiteralExpr{Type: scope.TypeText}, + }, + } + + result := c.Synth(expr) + if result.Type.Name != "text" { + t.Errorf("COALESCE: got %q, want 'text'", result.Type.Name) + } +} + +func TestSynthIn(t *testing.T) { + c := NewChecker(nil) + + expr := &InExpr{ + Expr: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Values: []Expr{&LiteralExpr{Type: scope.TypeInt}}, + } + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("IN: got %q, want 'boolean'", result.Type.Name) + } +} + +func TestSynthBetween(t *testing.T) { + c := NewChecker(nil) + + expr := &BetweenExpr{ + Expr: &ColumnRefExpr{ResolvedType: scope.TypeInt}, + Low: &LiteralExpr{Type: scope.TypeInt}, + High: &LiteralExpr{Type: scope.TypeInt}, + } + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("BETWEEN: got %q, want 'boolean'", result.Type.Name) + } +} + +func TestParamKeepsFirstKnownType(t *testing.T) { + c := NewChecker(nil) + + // First use: $1 = integer column + c.Check(&ParamExpr{Number: 1}, scope.TypeInt, 0) + + // Second use: $1 = text column (conflicting) + c.Check(&ParamExpr{Number: 1}, scope.TypeText, 0) + + // Should keep the first known type + params := c.ParamTypes() + if params[1].Type.Name != "integer" { + t.Errorf("param kept wrong type: got %q, want 'integer'", params[1].Type.Name) + } +} + +func TestMySQLConcatIsOR(t *testing.T) { + c := NewChecker(&MySQLOperatorRules{}) + + // In MySQL default mode, || is OR, not concat + expr := &BinaryOpExpr{ + Op: "||", + Left: &LiteralExpr{Type: scope.TypeBool}, + Right: &LiteralExpr{Type: scope.TypeBool}, + } + + result := c.Synth(expr) + if result.Type.Name != "boolean" { + t.Errorf("MySQL ||: got %q, want 'boolean'", result.Type.Name) + } +} + +func TestPostgreSQLConcatIsText(t *testing.T) { + c := NewChecker(&PostgreSQLOperatorRules{}) + + expr := &BinaryOpExpr{ + Op: "||", + Left: &LiteralExpr{Type: scope.TypeText}, + Right: &LiteralExpr{Type: scope.TypeText}, + } + + result := c.Synth(expr) + if result.Type.Name != "text" { + t.Errorf("PostgreSQL ||: got %q, want 'text'", result.Type.Name) + } +} + +func TestArithmeticTypePromotion(t *testing.T) { + rules := &DefaultOperatorRules{} + + tests := []struct { + name string + left scope.Type + right scope.Type + wantName string + }{ + {"int+int", scope.Type{Name: "integer"}, scope.Type{Name: "integer"}, "integer"}, + {"int+bigint", scope.Type{Name: "integer"}, scope.Type{Name: "bigint"}, "bigint"}, + {"smallint+int", scope.Type{Name: "smallint"}, scope.Type{Name: "integer"}, "integer"}, + {"int+numeric", scope.Type{Name: "integer"}, scope.Type{Name: "numeric"}, "numeric"}, + {"real+float", scope.Type{Name: "real"}, scope.Type{Name: "double precision"}, "double precision"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := rules.BinaryOp("+", tt.left, tt.right) + if result.Name != tt.wantName { + t.Errorf("got %q, want %q", result.Name, tt.wantName) + } + }) + } +} diff --git a/internal/analysis/typecheck/types.go b/internal/analysis/typecheck/types.go new file mode 100644 index 0000000000..bf102d908c --- /dev/null +++ b/internal/analysis/typecheck/types.go @@ -0,0 +1,286 @@ +// Package typecheck implements bidirectional type checking for SQL expressions. +// +// Type information flows in two directions: +// - Synthesis (bottom-up): "what type does this expression have?" +// - Checking (top-down): "does this expression have the type I expect?" +// +// This is particularly useful for SQL parameter type inference. When we see +// `WHERE age > $1`, the parameter $1 is in checking mode — its expected type +// is inferred from the context (the type of `age`). When we see `SELECT age + 1`, +// the expression is in synthesis mode — we compute the result type from the operands. +// +// Reference: Dunfield & Krishnaswami, "Bidirectional Typing", ACM Computing Surveys, 2021. +package typecheck + +import ( + "fmt" + + "github.com/sqlc-dev/sqlc/internal/analysis/scope" +) + +// TypeKind classifies types for the checker. +type TypeKind int + +const ( + KindScalar TypeKind = iota // A single value (int, text, bool, ...) + KindRow // A row type (the result of a subquery) + KindSet // A set of rows (table reference) + KindUnknown // Not yet determined +) + +// InferredType is the result of type inference on an expression. +type InferredType struct { + Type scope.Type + Kind TypeKind + Source TypeSource // Where this type came from +} + +// TypeSource records provenance for type inference results. +type TypeSource int + +const ( + SourceCatalog TypeSource = iota // Came from the database catalog + SourceLiteral // Inferred from a literal value + SourceOperator // Inferred from an operator's return type + SourceContext // Inferred from surrounding context (checking mode) + SourceFunction // Inferred from a function's return type + SourceParameter // A parameter whose type was inferred + SourceUnknown // Could not be determined +) + +func (s TypeSource) String() string { + names := [...]string{"catalog", "literal", "operator", "context", "function", "parameter", "unknown"} + if int(s) < len(names) { + return names[s] + } + return fmt.Sprintf("TypeSource(%d)", int(s)) +} + +// ParamTypeInference records a type inferred for a query parameter. +type ParamTypeInference struct { + Number int // Parameter number ($1, $2, ...) + Type scope.Type // The inferred type + Source TypeSource // How the type was inferred + Location int // Source position of the parameter +} + +// Checker performs bidirectional type checking on SQL expressions. +type Checker struct { + params map[int]*ParamTypeInference // Accumulated parameter types + errors []TypeError // Type errors found + opRules OperatorRules // Engine-specific operator type rules +} + +// TypeError describes a type mismatch found during checking. +type TypeError struct { + Message string + Location int + Expected scope.Type // What was expected (in check mode) + Got scope.Type // What was found (synthesized) +} + +func (e TypeError) Error() string { + return e.Message +} + +// OperatorRules provides engine-specific type inference rules for operators. +type OperatorRules interface { + // BinaryOp returns the result type of a binary operator given operand types. + BinaryOp(op string, left, right scope.Type) scope.Type + // UnaryOp returns the result type of a unary operator. + UnaryOp(op string, operand scope.Type) scope.Type + // CanCast returns whether a value of `from` type can be cast to `to` type. + CanCast(from, to scope.Type) bool +} + +// NewChecker creates a new type checker with the given operator rules. +// If rules is nil, default rules are used. +func NewChecker(rules OperatorRules) *Checker { + if rules == nil { + rules = &DefaultOperatorRules{} + } + return &Checker{ + params: make(map[int]*ParamTypeInference), + opRules: rules, + } +} + +// Synth synthesizes (infers bottom-up) the type of an expression. +// This is the "what type does this have?" direction. +func (c *Checker) Synth(expr Expr) InferredType { + switch e := expr.(type) { + case *LiteralExpr: + return InferredType{Type: e.Type, Kind: KindScalar, Source: SourceLiteral} + + case *ColumnRefExpr: + return InferredType{Type: e.ResolvedType, Kind: KindScalar, Source: SourceCatalog} + + case *ParamExpr: + // In synth mode, a parameter has unknown type unless previously inferred + if prev, ok := c.params[e.Number]; ok { + return InferredType{Type: prev.Type, Kind: KindScalar, Source: SourceParameter} + } + return InferredType{Type: scope.TypeUnknown, Kind: KindUnknown, Source: SourceParameter} + + case *BinaryOpExpr: + left := c.Synth(e.Left) + right := c.Synth(e.Right) + + // If one side is a parameter, use checking mode to infer its type + if lp, ok := e.Left.(*ParamExpr); ok && !left.Type.IsUnknown() { + c.Check(e.Left, right.Type, e.Location) + _ = lp // parameter type recorded by Check + } + if rp, ok := e.Right.(*ParamExpr); ok && !right.Type.IsUnknown() { + c.Check(e.Right, left.Type, e.Location) + _ = rp + } + + // If one side is a param and the other is known, infer param from known + if left.Type.IsUnknown() && !right.Type.IsUnknown() { + if lp, ok := e.Left.(*ParamExpr); ok { + c.recordParam(lp.Number, right.Type, SourceContext, lp.Location) + } + left.Type = right.Type + } + if right.Type.IsUnknown() && !left.Type.IsUnknown() { + if rp, ok := e.Right.(*ParamExpr); ok { + c.recordParam(rp.Number, left.Type, SourceContext, rp.Location) + } + right.Type = left.Type + } + + resultType := c.opRules.BinaryOp(e.Op, left.Type, right.Type) + return InferredType{Type: resultType, Kind: KindScalar, Source: SourceOperator} + + case *FuncCallExpr: + return InferredType{Type: e.ReturnType, Kind: KindScalar, Source: SourceFunction} + + case *TypeCastExpr: + return InferredType{Type: e.CastType, Kind: KindScalar, Source: SourceOperator} + + case *SubqueryExpr: + if len(e.Columns) == 1 { + return InferredType{Type: e.Columns[0], Kind: KindScalar, Source: SourceCatalog} + } + return InferredType{Type: scope.TypeUnknown, Kind: KindRow, Source: SourceUnknown} + + case *BoolExpr: + return InferredType{Type: scope.TypeBool, Kind: KindScalar, Source: SourceOperator} + + case *NullTestExpr: + return InferredType{Type: scope.TypeBool, Kind: KindScalar, Source: SourceOperator} + + case *CaseExpr: + if e.ResultType.IsUnknown() { + return InferredType{Type: scope.TypeUnknown, Kind: KindScalar, Source: SourceUnknown} + } + return InferredType{Type: e.ResultType, Kind: KindScalar, Source: SourceOperator} + + case *CoalesceExpr: + // Type is the type of the first non-null argument + for _, arg := range e.Args { + t := c.Synth(arg) + if !t.Type.IsUnknown() { + return t + } + } + return InferredType{Type: scope.TypeUnknown, Kind: KindScalar, Source: SourceUnknown} + + case *InExpr: + return InferredType{Type: scope.TypeBool, Kind: KindScalar, Source: SourceOperator} + + case *BetweenExpr: + return InferredType{Type: scope.TypeBool, Kind: KindScalar, Source: SourceOperator} + + default: + return InferredType{Type: scope.TypeUnknown, Kind: KindUnknown, Source: SourceUnknown} + } +} + +// Check verifies that an expression has the expected type (top-down). +// For parameters, this records the expected type as the parameter's inferred type. +// This is the "does this have the type I expect?" direction. +func (c *Checker) Check(expr Expr, expected scope.Type, location int) bool { + switch e := expr.(type) { + case *ParamExpr: + // This is the key insight of bidirectional checking for SQL: + // when a parameter appears in a context with a known type, + // we learn the parameter's type from the context. + c.recordParam(e.Number, expected, SourceContext, e.Location) + return true + + case *LiteralExpr: + if expected.IsUnknown() { + return true + } + if !c.opRules.CanCast(e.Type, expected) { + c.addError(TypeError{ + Message: fmt.Sprintf("literal of type %s is not compatible with expected type %s", e.Type.Name, expected.Name), + Location: location, + Expected: expected, + Got: e.Type, + }) + return false + } + return true + + case *TypeCastExpr: + // An explicit cast always succeeds at the type level + return true + + default: + // For other expressions, synthesize and compare + synth := c.Synth(expr) + if synth.Type.IsUnknown() || expected.IsUnknown() { + return true // Can't check if either side is unknown + } + if !synth.Type.Equals(expected) && !c.opRules.CanCast(synth.Type, expected) { + c.addError(TypeError{ + Message: fmt.Sprintf("type mismatch: expected %s but got %s", expected.Name, synth.Type.Name), + Location: location, + Expected: expected, + Got: synth.Type, + }) + return false + } + return true + } +} + +// InferParamFromContext infers a parameter's type from its usage context. +// This is called when we know the expected type from the surrounding expression. +func (c *Checker) InferParamFromContext(paramNum int, expectedType scope.Type, location int) { + c.recordParam(paramNum, expectedType, SourceContext, location) +} + +func (c *Checker) recordParam(number int, typ scope.Type, source TypeSource, location int) { + if existing, ok := c.params[number]; ok { + // If we already have a type for this parameter, keep the more specific one + if existing.Type.IsUnknown() && !typ.IsUnknown() { + existing.Type = typ + existing.Source = source + } + return + } + c.params[number] = &ParamTypeInference{ + Number: number, + Type: typ, + Source: source, + Location: location, + } +} + +func (c *Checker) addError(err TypeError) { + c.errors = append(c.errors, err) +} + +// ParamTypes returns all inferred parameter types. +func (c *Checker) ParamTypes() map[int]*ParamTypeInference { + return c.params +} + +// Errors returns all type errors found during checking. +func (c *Checker) Errors() []TypeError { + return c.errors +}