Skip to content

Commit 42d649c

Browse files
committed
feat: wip generate orm.py
now - generate Base class
1 parent 4b62b13 commit 42d649c

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

_examples/gen/sqlc/orm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# Code generated by sqlc. DO NOT EDIT.
22
# versions:
33
# sqlc v1.20.0
4-
import pydantic
4+
import sqlalchemy
55
from typing import Optional
66

77

8-
class Author(pydantic.BaseModel):
8+
class Base(sqlalchemy.orm.DeclarativeBase):
9+
pass
10+
11+
12+
class Author(hello.Base):
913
id: int
1014
name: str
1115
age: int

_examples/gen/sqlc/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sqlalchemy
88
import sqlalchemy.ext.asyncio
99

10-
from authors import models
10+
from . import models
1111

1212

1313
CREATE_AUTHOR = """-- name: create_author \\:one

internal/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ type Config struct {
99
EmitPydanticModels bool `json:"emit_pydantic_models"`
1010
QueryParameterLimit *int32 `json:"query_parameter_limit"`
1111
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
12+
EmitSQLAlchemyModels bool `json:"emit_sqlalchemy_models"`
1213
}

internal/gen.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,46 @@ func pydanticNode(name string) *pyast.ClassDef {
618618
}
619619
}
620620

621+
func sqlalchemyNode(name string) *pyast.ClassDef {
622+
return &pyast.ClassDef{
623+
Name: name,
624+
Bases: []*pyast.Node{
625+
{
626+
Node: &pyast.Node_Attribute{
627+
Attribute: &pyast.Attribute{
628+
Value: &pyast.Node{
629+
Node: &pyast.Node_Name{
630+
Name: &pyast.Name{Id: "hello"},
631+
},
632+
},
633+
Attr: "Base",
634+
},
635+
},
636+
},
637+
},
638+
}
639+
}
640+
641+
func sqlalchemyBaseNode() *pyast.ClassDef {
642+
return &pyast.ClassDef{
643+
Name: "Base",
644+
Bases: []*pyast.Node{
645+
{
646+
Node: &pyast.Node_Attribute{
647+
Attribute: &pyast.Attribute{
648+
Value: &pyast.Node{
649+
Node: &pyast.Node_Name{
650+
Name: &pyast.Name{Id: "sqlalchemy.orm"},
651+
},
652+
},
653+
Attr: "DeclarativeBase",
654+
},
655+
},
656+
},
657+
},
658+
}
659+
}
660+
621661
func fieldNode(f Field) *pyast.Node {
622662
return &pyast.Node{
623663
Node: &pyast.Node_AnnAssign{
@@ -630,6 +670,14 @@ func fieldNode(f Field) *pyast.Node {
630670
}
631671
}
632672

673+
func fieldPassNode() *pyast.Node {
674+
return &pyast.Node{
675+
Node: &pyast.Node_Pass{
676+
Pass: &pyast.Pass{},
677+
},
678+
}
679+
}
680+
633681
func typeRefNode(base string, parts ...string) *pyast.Node {
634682
n := poet.Name(base)
635683
for _, p := range parts {
@@ -753,6 +801,71 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
753801
return &pyast.Node{Node: &pyast.Node_Module{Module: mod}}
754802
}
755803

804+
func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
805+
mod := moduleNode(ctx.SqlcVersion, "")
806+
std, pkg := i.ormImportSpecs()
807+
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
808+
809+
for _, e := range ctx.Enums {
810+
def := &pyast.ClassDef{
811+
Name: e.Name,
812+
Bases: []*pyast.Node{
813+
poet.Name("str"),
814+
poet.Attribute(poet.Name("enum"), "Enum"),
815+
},
816+
}
817+
if e.Comment != "" {
818+
def.Body = append(def.Body, &pyast.Node{
819+
Node: &pyast.Node_Expr{
820+
Expr: &pyast.Expr{
821+
Value: poet.Constant(e.Comment),
822+
},
823+
},
824+
})
825+
}
826+
for _, c := range e.Constants {
827+
def.Body = append(def.Body, assignNode(c.Name, poet.Constant(c.Value)))
828+
}
829+
mod.Body = append(mod.Body, &pyast.Node{
830+
Node: &pyast.Node_ClassDef{
831+
ClassDef: def,
832+
},
833+
})
834+
}
835+
836+
// generate Base class
837+
def := sqlalchemyBaseNode()
838+
def.Body = append(def.Body, fieldPassNode())
839+
mod.Body = append(mod.Body, &pyast.Node{
840+
Node: &pyast.Node_ClassDef{
841+
ClassDef: def,
842+
},
843+
})
844+
845+
for _, m := range ctx.Orms {
846+
def := sqlalchemyNode(m.Name)
847+
if m.Comment != "" {
848+
def.Body = append(def.Body, &pyast.Node{
849+
Node: &pyast.Node_Expr{
850+
Expr: &pyast.Expr{
851+
Value: poet.Constant(m.Comment),
852+
},
853+
},
854+
})
855+
}
856+
for _, f := range m.Fields {
857+
def.Body = append(def.Body, fieldNode(f))
858+
}
859+
mod.Body = append(mod.Body, &pyast.Node{
860+
Node: &pyast.Node_ClassDef{
861+
ClassDef: def,
862+
},
863+
})
864+
}
865+
866+
return &pyast.Node{Node: &pyast.Node_Module{Module: mod}}
867+
}
868+
756869
func querierClassDef() *pyast.ClassDef {
757870
return &pyast.ClassDef{
758871
Name: "Querier",
@@ -1077,6 +1190,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10771190
type pyTmplCtx struct {
10781191
SqlcVersion string
10791192
Models []Struct
1193+
Orms []Struct
10801194
Queries []Query
10811195
Enums []Enum
10821196
SourceName string
@@ -1116,6 +1230,7 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
11161230

11171231
tctx := pyTmplCtx{
11181232
Models: models,
1233+
Orms: models,
11191234
Queries: queries,
11201235
Enums: enums,
11211236
SqlcVersion: req.SqlcVersion,
@@ -1127,6 +1242,13 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
11271242
tctx.SourceName = "models.py"
11281243
output["models.py"] = string(result.Python)
11291244

1245+
// sqlalchemy models
1246+
if conf.EmitSQLAlchemyModels {
1247+
tctx.SourceName = "orm.py"
1248+
ormResult := pyprint.Print(buildOrmTree(&tctx, i), pyprint.Options{})
1249+
output["orm.py"] = string(ormResult.Python)
1250+
}
1251+
11301252
files := map[string]struct{}{}
11311253
for _, q := range queries {
11321254
files[q.SourceName] = struct{}{}

internal/imports.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ func queryValueUses(name string, qv QueryValue) bool {
7878
func (i *importer) Imports(fileName string) []string {
7979
if fileName == "models.py" {
8080
return i.modelImports()
81+
} else if fileName == "orm.py" {
82+
return i.ormImports()
8183
}
8284
return i.queryImports(fileName)
8385
}
@@ -119,6 +121,41 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
119121
return std, pkg
120122
}
121123

124+
func (i *importer) ormImportSpecs() (map[string]importSpec, map[string]importSpec) {
125+
modelUses := func(name string) bool {
126+
for _, model := range i.Models {
127+
if structUses(name, model) {
128+
return true
129+
}
130+
}
131+
return false
132+
}
133+
134+
std := stdImports(modelUses)
135+
if i.C.EmitSQLAlchemyModels {
136+
std["sqlalchemy"] = importSpec{Module: "sqlalchemy"}
137+
}
138+
if len(i.Enums) > 0 {
139+
std["enum"] = importSpec{Module: "enum"}
140+
}
141+
142+
pkg := make(map[string]importSpec)
143+
144+
for _, o := range i.Settings.Overrides {
145+
if pyTypeIsSet(o) {
146+
mod, _, found := strings.Cut(o.CodeType, ".")
147+
if !found {
148+
continue
149+
}
150+
if modelUses(o.CodeType) {
151+
pkg[mod] = importSpec{Module: mod}
152+
}
153+
}
154+
}
155+
156+
return std, pkg
157+
}
158+
122159
func (i *importer) modelImports() []string {
123160
std, pkg := i.modelImportSpecs()
124161
importLines := []string{
@@ -129,6 +166,16 @@ func (i *importer) modelImports() []string {
129166
return importLines
130167
}
131168

169+
func (i *importer) ormImports() []string {
170+
std, pkg := i.ormImportSpecs()
171+
importLines := []string{
172+
buildImportBlock(std),
173+
"",
174+
buildImportBlock(pkg),
175+
}
176+
return importLines
177+
}
178+
132179
func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map[string]importSpec) {
133180
queryUses := func(name string) bool {
134181
for _, q := range i.Queries {

0 commit comments

Comments
 (0)