Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
  • Loading branch information
kyleconroy committed May 10, 2024
commit 2ee6b81be714f5e3fbfb1b60d3a1e5ccd680d209
25 changes: 25 additions & 0 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"runtime/trace"
Expand All @@ -27,6 +28,8 @@ import (
"github.com/sqlc-dev/sqlc/internal/info"
"github.com/sqlc-dev/sqlc/internal/multierr"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/pgx/createdb"
"github.com/sqlc-dev/sqlc/internal/pgx/poolcache"
"github.com/sqlc-dev/sqlc/internal/plugin"
"github.com/sqlc-dev/sqlc/internal/remote"
"github.com/sqlc-dev/sqlc/internal/sql/sqlpath"
Expand Down Expand Up @@ -316,9 +319,31 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
}
return nil, true
}

{
uri := combo.Global.Servers[0].URI
cache := poolcache.New()
pool, err := cache.Open(ctx, uri)
if err != nil {
log.Println("cache.Open", err)
return nil, false
}
creator := createdb.New(uri, pool)
dburi, db, err := creator.Create(ctx, c.SchemaHash, c.Schema)
if err != nil {
log.Println("creator.Create", err)
}
fmt.Println(db)

combo.Package.Database.URI = dburi
combo.Package.Database.Managed = false
c.UpdateAnalyzer(combo.Package.Database)
}

if parserOpts.Debug.DumpCatalog {
debug.Dump(c.Catalog())
}

if err := c.ParseQueries(sql.Queries, parserOpts); err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
if parserErr, ok := err.(*multierr.Error); ok {
Expand Down
4 changes: 0 additions & 4 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package compiler
import (
"errors"
"fmt"
"hash/fnv"
"io"
"os"
"path/filepath"
Expand Down Expand Up @@ -32,15 +31,13 @@ func (c *Compiler) parseCatalog(schemas []string) error {
return err
}
merr := multierr.New()
h := fnv.New64()
for _, filename := range files {
blob, err := os.ReadFile(filename)
if err != nil {
merr.Add(filename, "", 0, err)
continue
}
contents := migrations.RemoveRollbackStatements(string(blob))
io.WriteString(h, contents)
c.schema = append(c.schema, contents)
stmts, err := c.parser.Parse(strings.NewReader(contents))
if err != nil {
Expand All @@ -54,7 +51,6 @@ func (c *Compiler) parseCatalog(schemas []string) error {
}
}
}
c.schemaHash = fmt.Sprintf("%x", h.Sum(nil))
if len(merr.Errs()) > 0 {
return merr
}
Expand Down
13 changes: 10 additions & 3 deletions internal/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ type Compiler struct {
analyzer analyzer.Analyzer
client pb.QuickClient

schema []string
schemaHash string
schema []string
}

func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) {
Expand All @@ -53,7 +52,7 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
if conf.Database != nil {
if conf.Analyzer.Database == nil || *conf.Analyzer.Database {
c.analyzer = analyzer.Cached(
pganalyze.New(c.client, *conf.Database),
pganalyze.New(c.client, combo.Global.Servers, *conf.Database),
combo.Global,
*conf.Database,
)
Expand All @@ -65,6 +64,14 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
return c, nil
}

func (c *Compiler) UpdateAnalyzer(db *config.Database) {
c.analyzer = analyzer.Cached(
pganalyze.New(c.client, *db),
c.combo.Global,
*db,
)
}

func (c *Compiler) Catalog() *catalog.Catalog {
return c.catalog
}
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Config struct {
type Database struct {
URI string `json:"uri" yaml:"uri"`
Managed bool `json:"managed" yaml:"managed"`
Auto bool `json:"auto" yaml:"auto"`
}

type Cloud struct {
Expand Down
50 changes: 37 additions & 13 deletions internal/engine/postgresql/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ import (
"context"
"errors"
"fmt"
"hash/fnv"
"io"
"strings"
"sync"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/sync/singleflight"

core "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/pgx/poolcache"
pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
"github.com/sqlc-dev/sqlc/internal/shfmt"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
Expand All @@ -22,22 +26,28 @@ import (
)

type Analyzer struct {
db config.Database
client pb.QuickClient
pool *pgxpool.Pool
dbg opts.Debug
replacer *shfmt.Replacer
formats sync.Map
columns sync.Map
tables sync.Map
db config.Database
client pb.QuickClient
pool *pgxpool.Pool
dbg opts.Debug
replacer *shfmt.Replacer
formats sync.Map
columns sync.Map
tables sync.Map
servers []config.Server
serverCache *poolcache.Cache
flight singleflight.Group
}

func New(client pb.QuickClient, db config.Database) *Analyzer {
func New(client pb.QuickClient, servers []config.Server, db config.Database) *Analyzer {
return &Analyzer{
db: db,
dbg: opts.DebugFromEnv(),
client: client,
replacer: shfmt.NewReplacer(nil),
// TODO: Pick first
servers: servers,
db: db,
dbg: opts.DebugFromEnv(),
client: client,
replacer: shfmt.NewReplacer(nil),
serverCache: poolcache.New(),
}
}

Expand Down Expand Up @@ -99,6 +109,14 @@ type columnKey struct {
Attr uint16
}

func (a *Analyzer) fnv(migrations []string) string {
h := fnv.New64()
for _, query := range migrations {
io.WriteString(h, query)
}
return fmt.Sprintf("%x", h.Sum(nil))
}

// Cache these types in memory
func (a *Analyzer) columnInfo(ctx context.Context, field pgconn.FieldDescription) (*pgColumn, error) {
key := columnKey{field.TableOID, field.TableAttributeNumber}
Expand Down Expand Up @@ -211,6 +229,12 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
uri = edb.Uri
} else if a.dbg.OnlyManagedDatabases {
return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
} else if a.db.Auto {
var err error
uri, err = a.createDb(ctx, migrations)
if err != nil {
return nil, err
}
} else {
uri = a.replacer.Replace(a.db.URI)
}
Expand Down
68 changes: 68 additions & 0 deletions internal/engine/postgresql/analyzer/createdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package analyzer

import (
"context"
"fmt"
"log/slog"
"net/url"
"strings"

"github.com/jackc/pgx/v5"
)

func (a *Analyzer) createDb(ctx context.Context, migrations []string) (string, error) {
hash := a.fnv(migrations)
name := fmt.Sprintf("sqlc_%s", hash)

serverUri := a.replacer.Replace(a.servers[0].URI)
pool, err := a.serverCache.Open(ctx, serverUri)
if err != nil {
return "", err
}

uri, err := url.Parse(serverUri)
if err != nil {
return "", err
}
uri.Path = name

key := uri.String()
_, err, _ = a.flight.Do(key, func() (interface{}, error) {
// TODO: Use a parameterized query
row := pool.QueryRow(ctx,
fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name))

var datname string
if err := row.Scan(&datname); err == nil {
slog.Info("database exists", "name", name)
return nil, nil
}

slog.Info("creating database", "name", name)
if _, err := pool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil {
return nil, err
}

conn, err := pgx.Connect(ctx, uri.String())
if err != nil {
return nil, fmt.Errorf("connect %s: %s", name, err)
}
defer conn.Close(ctx)

for _, q := range migrations {
if len(strings.TrimSpace(q)) == 0 {
continue
}
if _, err := conn.Exec(ctx, q); err != nil {
return nil, fmt.Errorf("%s: %s", q, err)
}
}
return nil, nil
})

if err != nil {
return "", err
}

return key, err
}
Loading