Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
POC
  • Loading branch information
sapk committed Aug 31, 2023
commit f0b46ea61dc2fba888616f160832afa35ce0b067
3 changes: 3 additions & 0 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -43,7 +44,9 @@ func (c *Compiler) parseCatalog(schemas []string) error {
merr.Add(filename, contents, 0, err)
continue
}

for i := range stmts {
log.Printf("stmts[%d]: %#v", i, stmts[i].Raw.Stmt)
if err := c.catalog.Update(stmts[i], c); err != nil {
merr.Add(filename, contents, stmts[i].Pos(), err)
continue
Expand Down
52 changes: 50 additions & 2 deletions internal/endtoend/testdata/enum_alter/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 45 additions & 60 deletions internal/engine/dolphin/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,7 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node {
case pcast.AlterTableAddColumns:
for _, def := range spec.NewColumns {
name := def.Name.String()
columnDef := ast.ColumnDef{
Colname: def.Name.String(),
TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())},
IsNotNull: isNotNull(def),
IsUnsigned: isUnsigned(def),
}
if def.Tp.GetFlen() >= 0 {
length := def.Tp.GetFlen()
columnDef.Length = &length
}
columnDef := convertColumnDef(def)
alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{
Name: &name,
Subtype: ast.AT_AddColumn,
Expand All @@ -77,36 +68,20 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node {

for _, def := range spec.NewColumns {
name := def.Name.String()
columnDef := ast.ColumnDef{
Colname: def.Name.String(),
TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())},
IsNotNull: isNotNull(def),
IsUnsigned: isUnsigned(def),
}
if def.Tp.GetFlen() >= 0 {
length := def.Tp.GetFlen()
columnDef.Length = &length
}
columnDef := convertColumnDef(def)
alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{
Name: &name,
Subtype: ast.AT_AddColumn,
Def: &columnDef,
})

log.Printf("CHANGE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals)
}

case pcast.AlterTableModifyColumn:
for _, def := range spec.NewColumns {
name := def.Name.String()
columnDef := ast.ColumnDef{
Colname: def.Name.String(),
TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())},
IsNotNull: isNotNull(def),
IsUnsigned: isUnsigned(def),
}
if def.Tp.GetFlen() >= 0 {
length := def.Tp.GetFlen()
columnDef.Length = &length
}
columnDef := convertColumnDef(def)
alt.Cmds.Items = append(alt.Cmds.Items, &ast.AlterTableCmd{
Name: &name,
Subtype: ast.AT_DropColumn,
Expand All @@ -116,6 +91,8 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node {
Subtype: ast.AT_AddColumn,
Def: &columnDef,
})

log.Printf("MODIFY COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals)
}

case pcast.AlterTableAlterColumn:
Expand Down Expand Up @@ -249,36 +226,9 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node {
create.ReferTable = parseTableName(n.ReferTable)
}
for _, def := range n.Cols {
var vals *ast.List
if len(def.Tp.GetElems()) > 0 {
vals = &ast.List{}
for i := range def.Tp.GetElems() {
vals.Items = append(vals.Items, &ast.String{
Str: def.Tp.GetElems()[i],
})
}
}
comment := ""
for _, opt := range def.Options {
switch opt.Tp {
case pcast.ColumnOptionComment:
if value, ok := opt.Expr.(*driver.ValueExpr); ok {
comment = value.GetString()
}
}
}
columnDef := ast.ColumnDef{
Colname: def.Name.String(),
TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())},
IsNotNull: isNotNull(def),
IsUnsigned: isUnsigned(def),
Comment: comment,
Vals: vals,
}
if def.Tp.GetFlen() >= 0 {
length := def.Tp.GetFlen()
columnDef.Length = &length
}
columnDef := convertColumnDef(def)

log.Printf("CREATE COLUMN: %#v\n%#v\n%#v", columnDef, columnDef.TypeName, columnDef.Vals)
create.Cols = append(create.Cols, &columnDef)
}
for _, opt := range n.Options {
Expand All @@ -290,6 +240,41 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node {
return create
}

func convertColumnDef(def *pcast.ColumnDef) ast.ColumnDef {
var vals *ast.List
if len(def.Tp.GetElems()) > 0 {
vals = &ast.List{}
for i := range def.Tp.GetElems() {
vals.Items = append(vals.Items, &ast.String{
Str: def.Tp.GetElems()[i],
})
}
}
comment := ""
for _, opt := range def.Options {
switch opt.Tp {
case pcast.ColumnOptionComment:
if value, ok := opt.Expr.(*driver.ValueExpr); ok {
comment = value.GetString()
}
}
}
columnDef := ast.ColumnDef{
Colname: def.Name.String(),
TypeName: &ast.TypeName{Name: types.TypeToStr(def.Tp.GetType(), def.Tp.GetCharset())},
IsNotNull: isNotNull(def),
IsUnsigned: isUnsigned(def),
Comment: comment,
Vals: vals,
}
if def.Tp.GetFlen() >= 0 {
length := def.Tp.GetFlen()
columnDef.Length = &length
}

return columnDef
}

func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef {
var items []ast.Node
if schema := n.Name.Schema.String(); schema != "" {
Expand Down
69 changes: 39 additions & 30 deletions internal/sql/catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package catalog
import (
"errors"
"fmt"
"log"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
Expand Down Expand Up @@ -41,7 +42,7 @@ func (table *Table) isExistColumn(cmd *ast.AlterTableCmd) (int, error) {
return -1, nil
}

func (table *Table) addColumn(cmd *ast.AlterTableCmd) error {
func (table *Table) addColumn(c *Catalog, cmd *ast.AlterTableCmd) error {
for _, c := range table.Columns {
if c.Name == cmd.Def.Colname {
if !cmd.MissingOk {
Expand All @@ -51,15 +52,13 @@ func (table *Table) addColumn(cmd *ast.AlterTableCmd) error {
}
}

table.Columns = append(table.Columns, &Column{
Name: cmd.Def.Colname,
Type: *cmd.Def.TypeName,
IsNotNull: cmd.Def.IsNotNull,
IsUnsigned: cmd.Def.IsUnsigned,
IsArray: cmd.Def.IsArray,
ArrayDims: cmd.Def.ArrayDims,
Length: cmd.Def.Length,
})
tc, err := c.defToColumn(table.Rel, cmd.Def)
if err != nil {
return err
}
log.Printf("addColumn COLUMN: %#v\n%#v\n%#v", tc, tc.Type, cmd.Def.Vals)

table.Columns = append(table.Columns, tc)
return nil
}

Expand Down Expand Up @@ -187,7 +186,7 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {
case *ast.AlterTableCmd:
switch cmd.Subtype {
case ast.AT_AddColumn:
if err := table.addColumn(cmd); err != nil {
if err := table.addColumn(c, cmd); err != nil {
return err
}
case ast.AT_AlterColumnType:
Expand Down Expand Up @@ -305,26 +304,11 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
continue
}

tc := &Column{
Name: col.Colname,
Type: *col.TypeName,
IsNotNull: col.IsNotNull,
IsUnsigned: col.IsUnsigned,
IsArray: col.IsArray,
ArrayDims: col.ArrayDims,
Comment: col.Comment,
Length: col.Length,
}
if col.Vals != nil {
typeName := ast.TypeName{
Name: fmt.Sprintf("%s_%s", stmt.Name.Name, col.Colname),
}
s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals}
if err := c.createEnum(s); err != nil {
return err
}
tc.Type = typeName
tc, err := c.defToColumn(stmt.Name, col)
if err != nil {
return err
}
log.Printf("createTable COLUMN: %#v\n%#v\n%#v", tc, tc.Type, col.Vals)
tbl.Columns = append(tbl.Columns, tc)
}
}
Expand All @@ -340,6 +324,31 @@ func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
return nil
}

func (c *Catalog) defToColumn(table *ast.TableName, col *ast.ColumnDef) (*Column, error) {
tc := &Column{
Name: col.Colname,
Type: *col.TypeName,
IsNotNull: col.IsNotNull,
IsUnsigned: col.IsUnsigned,
IsArray: col.IsArray,
ArrayDims: col.ArrayDims,
Comment: col.Comment,
Length: col.Length,
}
if col.Vals != nil {
typeName := ast.TypeName{
Name: fmt.Sprintf("%s_%s", table.Name, col.Colname),
}
s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals}
if err := c.createOrSetEnum(s); err != nil {
return nil, err
}
tc.Type = typeName
}

return tc, nil
}

func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
for _, name := range stmt.Tables {
ns := name.Schema
Expand Down
36 changes: 36 additions & 0 deletions internal/sql/catalog/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package catalog
import (
"errors"
"fmt"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)
Expand Down Expand Up @@ -90,6 +91,41 @@ func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error {
return nil
}

func (c *Catalog) createOrSetEnum(stmt *ast.CreateEnumStmt) error {
ns := stmt.TypeName.Schema
if ns == "" {
ns = c.DefaultSchema
}
schema, err := c.getSchema(ns)
if err != nil {
return err
}
// Because tables have associated data types, the type name must also
// be distinct from the name of any existing table in the same
// schema.
// https://www.postgresql.org/docs/current/sql-createtype.html
tbl := &ast.TableName{
Name: stmt.TypeName.Name,
}
if _, _, err := schema.getTable(tbl); err == nil {
return sqlerr.RelationExists(tbl.Name)
}
if typ, _, err := schema.getType(stmt.TypeName); err == nil {
enum, ok := typ.(*Enum)
if !ok {
return fmt.Errorf("type is not an enum: %s", stmt.TypeName.Name)
}
enum.Vals = stringSlice(stmt.Vals)

return nil
}
schema.Types = append(schema.Types, &Enum{
Name: stmt.TypeName.Name,
Vals: stringSlice(stmt.Vals),
})
return nil
}

func stringSlice(list *ast.List) []string {
items := []string{}
for _, item := range list.Items {
Expand Down