Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
feat(mysql): Use forked driver to get prepared statement metadata
Updates the MySQL analyzer to use the sqlc-dev/mysql forked driver which
exposes column and parameter metadata from COM_STMT_PREPARE responses.
This provides more accurate type information directly from MySQL.

The forked driver adds a StmtMetadata interface with ColumnMetadata() and
ParamMetadata() methods that return type info including DatabaseTypeName,
Nullable, Unsigned, and Length fields.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
  • Loading branch information
kyleconroy and claude committed Nov 30, 2025
commit f1238e438f6a6fe14ce5eb568f85d5dd53b6b3fc
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)

replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
Expand Down Expand Up @@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU=
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE=
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
192 changes: 82 additions & 110 deletions internal/engine/dolphin/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package analyzer
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"hash/fnv"
"io"
"strings"
"sync"

_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"

core "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
Expand Down Expand Up @@ -139,90 +140,102 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
}
}

// Count parameters in the query
paramCount := countParameters(query)

// Try to prepare the statement first to validate syntax
stmt, err := a.conn.PrepareContext(ctx, query)
// Get metadata directly from prepared statement via driver connection
result, err := a.getStatementMetadata(ctx, n, query, ps)
if err != nil {
return nil, a.extractSqlErr(n, err)
return nil, err
}
stmt.Close()

return result, nil
}

// getStatementMetadata uses the MySQL driver's prepared statement metadata API
// to get column and parameter type information without executing the query
func (a *Analyzer) getStatementMetadata(ctx context.Context, n ast.Node, query string, ps *named.ParamSet) (*core.Analysis, error) {
var result core.Analysis

// For SELECT queries, execute with default parameter values to get column metadata
if isSelectQuery(query) {
cols, err := a.getColumnMetadata(ctx, query, paramCount)
if err == nil {
result.Columns = cols
}
// If we fail to get column metadata, fall through to return empty columns
// and let the catalog-based inference handle it
// Get a raw connection to access driver-level prepared statement
conn, err := a.conn.Conn(ctx)
if err != nil {
return nil, a.extractSqlErr(n, fmt.Errorf("failed to get connection: %w", err))
}
defer conn.Close()

// Build parameter info
for i := 1; i <= paramCount; i++ {
name := ""
if ps != nil {
name, _ = ps.NameFor(i)
err = conn.Raw(func(driverConn any) error {
// Get the driver connection that supports PrepareContext
preparer, ok := driverConn.(driver.ConnPrepareContext)
if !ok {
return fmt.Errorf("driver connection does not support PrepareContext")
}
result.Params = append(result.Params, &core.Parameter{
Number: int32(i),
Column: &core.Column{
Name: name,
DataType: "any",
NotNull: false,
},
})
}

return &result, nil
}

// isSelectQuery checks if a query is a SELECT statement
func isSelectQuery(query string) bool {
trimmed := strings.TrimSpace(strings.ToUpper(query))
return strings.HasPrefix(trimmed, "SELECT") ||
strings.HasPrefix(trimmed, "WITH") // CTEs
}
// Prepare the statement - this sends COM_STMT_PREPARE to MySQL
// and receives column and parameter metadata
stmt, err := preparer.PrepareContext(ctx, query)
if err != nil {
return err
}
defer stmt.Close()

// Access the metadata via the StmtMetadata interface from our forked driver
meta, ok := stmt.(mysql.StmtMetadata)
if !ok {
// Fallback: just use param count from NumInput
paramCount := stmt.NumInput()
for i := 1; i <= paramCount; i++ {
name := ""
if ps != nil {
name, _ = ps.NameFor(i)
}
result.Params = append(result.Params, &core.Parameter{
Number: int32(i),
Column: &core.Column{
Name: name,
DataType: "any",
NotNull: false,
},
})
}
return nil
}

// getColumnMetadata executes the query with default values to retrieve column information
func (a *Analyzer) getColumnMetadata(ctx context.Context, query string, paramCount int) ([]*core.Column, error) {
// Generate default parameter values (use 1 for all - works for most types)
args := make([]any, paramCount)
for i := range args {
args[i] = 1
}
// Get column metadata
for _, col := range meta.ColumnMetadata() {
result.Columns = append(result.Columns, &core.Column{
Name: col.Name,
DataType: strings.ToLower(col.DatabaseTypeName),
NotNull: !col.Nullable,
Unsigned: col.Unsigned,
Length: int32(col.Length),
})
}

// Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0
// This ensures we get column metadata without executing the actual query
wrappedQuery := fmt.Sprintf("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0", query)
// Get parameter metadata
paramMeta := meta.ParamMetadata()
for i, param := range paramMeta {
name := ""
if ps != nil {
name, _ = ps.NameFor(i + 1)
}
result.Params = append(result.Params, &core.Parameter{
Number: int32(i + 1),
Column: &core.Column{
Name: name,
DataType: strings.ToLower(param.DatabaseTypeName),
NotNull: !param.Nullable,
Unsigned: param.Unsigned,
Length: int32(param.Length),
},
})
}

rows, err := a.conn.QueryContext(ctx, wrappedQuery, args...)
if err != nil {
// If wrapped query fails, try direct query with LIMIT 0
// Some queries may not support being wrapped (e.g., queries with UNION at the end)
return nil, err
}
defer rows.Close()
return nil
})

colTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}

var columns []*core.Column
for _, col := range colTypes {
nullable, _ := col.Nullable()
columns = append(columns, &core.Column{
Name: col.Name(),
DataType: strings.ToLower(col.DatabaseTypeName()),
NotNull: !nullable,
})
return nil, a.extractSqlErr(n, err)
}

return columns, nil
return &result, nil
}

// replaceDatabase replaces the database name in a MySQL DSN
Expand Down Expand Up @@ -253,47 +266,6 @@ func replaceDatabase(dsn string, newDB string) string {
return dsn[:slashIdx+1] + newDB + dsn[slashIdx+paramIdx:]
}

// countParameters counts the number of ? placeholders in a query
func countParameters(query string) int {
count := 0
inString := false
stringChar := byte(0)
escaped := false

for i := 0; i < len(query); i++ {
c := query[i]

if escaped {
escaped = false
continue
}

if c == '\\' {
escaped = true
continue
}

if inString {
if c == stringChar {
inString = false
}
continue
}

if c == '\'' || c == '"' || c == '`' {
inString = true
stringChar = c
continue
}

if c == '?' {
count++
}
}

return count
}

func (a *Analyzer) extractSqlErr(n ast.Node, err error) error {
if err == nil {
return nil
Expand Down
Loading