Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Be paranoid about concurrency
  • Loading branch information
kyleconroy committed Aug 25, 2025
commit 9bdaf614464726f45253111ed2ef8aefb684f32b
4 changes: 4 additions & 0 deletions internal/sqltest/docker/enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package docker
import (
"fmt"
"os/exec"

"golang.org/x/sync/singleflight"
)

var flight singleflight.Group

func Installed() error {
if _, err := exec.LookPath("docker"); err != nil {
return fmt.Errorf("docker not found: %w", err)
Expand Down
104 changes: 60 additions & 44 deletions internal/sqltest/docker/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,100 @@ import (
"database/sql"
"fmt"
"os/exec"
"sync"
"strings"
"time"

_ "github.com/go-sql-driver/mysql"
)

var mysqlSync sync.Once
var mysqlHost string

func StartMySQLServer(c context.Context) (string, error) {
if err := Installed(); err != nil {
return "", err
}
if mysqlHost != "" {
return mysqlHost, nil
}
value, err, _ := flight.Do("mysql", func() (interface{}, error) {
host, err := startMySQLServer(c)
if err != nil {
return "", err
}
mysqlHost = host
return host, nil
})
if err != nil {
return "", err
}
data, ok := value.(string)
if !ok {
return "", fmt.Errorf("returned value was not a string")
}
return data, nil
}

func startMySQLServer(c context.Context) (string, error) {
{
_, err := exec.Command("docker", "pull", "mysql:8").CombinedOutput()
_, err := exec.Command("docker", "pull", "mysql:9").CombinedOutput()
if err != nil {
return "", fmt.Errorf("docker pull: mysql:8 %w", err)
return "", fmt.Errorf("docker pull: mysql:9 %w", err)
}
}

var syncErr error
mysqlSync.Do(func() {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
var exists bool
{
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_mysql")
// This means we've already started the container
exists = cmd.Run() == nil
}

if !exists {
cmd := exec.Command("docker", "run",
"--name", "sqlc_sqltest_docker_mysql",
"-e", "MYSQL_ROOT_PASSWORD=mysecretpassword",
"-e", "MYSQL_DATABASE=dinotest",
"-p", "3306:3306",
"-d",
"mysql:8",
"mysql:9",
)

output, err := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
syncErr = err
return
}

// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
msg := `Conflict. The container name "/sqlc_sqltest_docker_mysql" is already in use by container`
if !strings.Contains(string(output), msg) && err != nil {
return "", err
}
}

uri := "root:mysecretpassword@/dinotest"
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()

db, err := sql.Open("mysql", uri)
if err != nil {
syncErr = fmt.Errorf("sql.Open: %w", err)
return
}
// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
return

case <-ticker.C:
// Run your function here
if err := db.PingContext(ctx); err != nil {
continue
}
mysqlHost = uri
return
}
}
})
uri := "root:mysecretpassword@/dinotest?multiStatements=true&parseTime=true"

if syncErr != nil {
return "", syncErr
db, err := sql.Open("mysql", uri)
if err != nil {
return "", fmt.Errorf("sql.Open: %w", err)
}

if mysqlHost == "" {
return "", fmt.Errorf("mysql server setup failed")
}
defer db.Close()

for {
select {
case <-ctx.Done():
return "", fmt.Errorf("timeout reached: %w", ctx.Err())

return mysqlHost, nil
case <-ticker.C:
// Run your function here
if err := db.PingContext(ctx); err != nil {
continue
}
return uri, nil
}
}
}
98 changes: 57 additions & 41 deletions internal/sqltest/docker/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,57 @@ import (
"fmt"
"log/slog"
"os/exec"
"sync"
"strings"
"time"

"github.com/jackc/pgx/v5"
)

var postgresSync sync.Once
var postgresHost string

func StartPostgreSQLServer(c context.Context) (string, error) {
if err := Installed(); err != nil {
return "", err
}
if postgresHost != "" {
return postgresHost, nil
}
value, err, _ := flight.Do("postgresql", func() (interface{}, error) {
host, err := startPostgreSQLServer(c)
if err != nil {
return "", err
}
postgresHost = host
return host, err
})
if err != nil {
return "", err
}
data, ok := value.(string)
if !ok {
return "", fmt.Errorf("returned value was not a string")
}
return data, nil
}

func startPostgreSQLServer(c context.Context) (string, error) {
{
_, err := exec.Command("docker", "pull", "postgres:16").CombinedOutput()
if err != nil {
return "", fmt.Errorf("docker pull: postgres:16 %w", err)
}
}

var syncErr error
postgresSync.Do(func() {
ctx, cancel := context.WithTimeout(c, 5*time.Second)
defer cancel()
uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"

var exists bool
{
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_postgres")
// This means we've already started the container
exists = cmd.Run() == nil
}

if !exists {
cmd := exec.Command("docker", "run",
"--name", "sqlc_sqltest_docker_postgres",
"-e", "POSTGRES_PASSWORD=mysecretpassword",
Expand All @@ -43,47 +68,38 @@ func StartPostgreSQLServer(c context.Context) (string, error) {

output, err := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
syncErr = err
return

msg := `Conflict. The container name "/sqlc_sqltest_docker_postgres" is already in use by container`
if !strings.Contains(string(output), msg) && err != nil {
return "", err
}
}

// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
ctx, cancel := context.WithTimeout(c, 5*time.Second)
defer cancel()

uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"
// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
return
for {
select {
case <-ctx.Done():
return "", fmt.Errorf("timeout reached: %w", ctx.Err())

case <-ticker.C:
// Run your function here
conn, err := pgx.Connect(ctx, uri)
if err != nil {
slog.Debug("sqltest", "connect", err)
continue
}
if err := conn.Ping(ctx); err != nil {
slog.Error("sqltest", "ping", err)
continue
}
postgresHost = uri
return
case <-ticker.C:
// Run your function here
conn, err := pgx.Connect(ctx, uri)
if err != nil {
slog.Debug("sqltest", "connect", err)
continue
}
defer conn.Close(ctx)
if err := conn.Ping(ctx); err != nil {
slog.Error("sqltest", "ping", err)
continue
}
return uri, nil
}
})

if syncErr != nil {
return "", syncErr
}

if postgresHost == "" {
return "", fmt.Errorf("postgres server setup failed")
}

return postgresHost, nil
}
Loading