Skip to content

Commit 48821ef

Browse files
committed
feat: with Mapped class annotation
1 parent 42d649c commit 48821ef

File tree

5 files changed

+49
-32
lines changed

5 files changed

+49
-32
lines changed

_examples/gen/sqlc/orm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# Code generated by sqlc. DO NOT EDIT.
22
# versions:
33
# sqlc v1.20.0
4-
import sqlalchemy
4+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
55
from typing import Optional
66

77

8-
class Base(sqlalchemy.orm.DeclarativeBase):
8+
class Base(DeclarativeBase):
99
pass
1010

1111

12-
class Author(hello.Base):
13-
id: int
14-
name: str
15-
age: int
16-
bio: Optional[str]
17-
is_active: bool
12+
class Author(Base):
13+
__tablename__ = "authors"
14+
id: Mapped[int]
15+
name: Mapped[str]
16+
age: Mapped[int]
17+
bio: Mapped[Optional[str]]
18+
is_active: Mapped[bool]

internal/gen.go

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -623,15 +623,8 @@ func sqlalchemyNode(name string) *pyast.ClassDef {
623623
Name: name,
624624
Bases: []*pyast.Node{
625625
{
626-
Node: &pyast.Node_Attribute{
627-
Attribute: &pyast.Attribute{
628-
Value: &pyast.Node{
629-
Node: &pyast.Node_Name{
630-
Name: &pyast.Name{Id: "hello"},
631-
},
632-
},
633-
Attr: "Base",
634-
},
626+
Node: &pyast.Node_Name{
627+
Name: &pyast.Name{Id: "Base"},
635628
},
636629
},
637630
},
@@ -643,15 +636,8 @@ func sqlalchemyBaseNode() *pyast.ClassDef {
643636
Name: "Base",
644637
Bases: []*pyast.Node{
645638
{
646-
Node: &pyast.Node_Attribute{
647-
Attribute: &pyast.Attribute{
648-
Value: &pyast.Node{
649-
Node: &pyast.Node_Name{
650-
Name: &pyast.Name{Id: "sqlalchemy.orm"},
651-
},
652-
},
653-
Attr: "DeclarativeBase",
654-
},
639+
Node: &pyast.Node_Name{
640+
Name: &pyast.Name{Id: "DeclarativeBase"},
655641
},
656642
},
657643
},
@@ -670,6 +656,18 @@ func fieldNode(f Field) *pyast.Node {
670656
}
671657
}
672658

659+
func ormFieldNode(f Field) *pyast.Node {
660+
return &pyast.Node{
661+
Node: &pyast.Node_AnnAssign{
662+
AnnAssign: &pyast.AnnAssign{
663+
Target: &pyast.Name{Id: f.Name},
664+
Annotation: subscriptNode("Mapped", f.Type.Annotation()),
665+
Comment: f.Comment,
666+
},
667+
},
668+
}
669+
}
670+
673671
func fieldPassNode() *pyast.Node {
674672
return &pyast.Node{
675673
Node: &pyast.Node_Pass{
@@ -716,6 +714,7 @@ func buildImportGroup(specs map[string]importSpec) *pyast.Node {
716714
var body []*pyast.Node
717715
for _, spec := range buildImportBlock2(specs) {
718716
if len(spec.Names) > 0 && spec.Names[0] != "" {
717+
// e.g. from sqlalchemy import create_engine
719718
imp := &pyast.ImportFrom{
720719
Module: spec.Module,
721720
}
@@ -834,6 +833,7 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
834833
}
835834

836835
// generate Base class
836+
// Base class is a dependency for all orm models
837837
def := sqlalchemyBaseNode()
838838
def.Body = append(def.Body, fieldPassNode())
839839
mod.Body = append(mod.Body, &pyast.Node{
@@ -842,7 +842,7 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
842842
},
843843
})
844844

845-
for _, m := range ctx.Orms {
845+
for _, m := range ctx.Models {
846846
def := sqlalchemyNode(m.Name)
847847
if m.Comment != "" {
848848
def.Body = append(def.Body, &pyast.Node{
@@ -853,8 +853,11 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
853853
},
854854
})
855855
}
856+
// add table name as class attribute
857+
def.Body = append(def.Body, assignNode("__tablename__", poet.Constant(m.Table.Name)))
858+
856859
for _, f := range m.Fields {
857-
def.Body = append(def.Body, fieldNode(f))
860+
def.Body = append(def.Body, ormFieldNode(f))
858861
}
859862
mod.Body = append(mod.Body, &pyast.Node{
860863
Node: &pyast.Node_ClassDef{
@@ -1190,7 +1193,6 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
11901193
type pyTmplCtx struct {
11911194
SqlcVersion string
11921195
Models []Struct
1193-
Orms []Struct
11941196
Queries []Query
11951197
Enums []Enum
11961198
SourceName string
@@ -1230,7 +1232,6 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
12301232

12311233
tctx := pyTmplCtx{
12321234
Models: models,
1233-
Orms: models,
12341235
Queries: queries,
12351236
Enums: enums,
12361237
SqlcVersion: req.SqlcVersion,

internal/imports.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
121121
return std, pkg
122122
}
123123

124+
// ormImportSpecs returns the standard and package imports for the ORM.
124125
func (i *importer) ormImportSpecs() (map[string]importSpec, map[string]importSpec) {
125126
modelUses := func(name string) bool {
126127
for _, model := range i.Models {
@@ -133,7 +134,9 @@ func (i *importer) ormImportSpecs() (map[string]importSpec, map[string]importSpe
133134

134135
std := stdImports(modelUses)
135136
if i.C.EmitSQLAlchemyModels {
136-
std["sqlalchemy"] = importSpec{Module: "sqlalchemy"}
137+
std["sqlalchemy.orm.DeclarativeBase"] = importSpec{Module: "sqlalchemy.orm", Name: "DeclarativeBase"}
138+
std["sqlalchemy.orm.Mapped"] = importSpec{Module: "sqlalchemy.orm", Name: "Mapped"}
139+
std["sqlalchemy.orm.mapped_column"] = importSpec{Module: "sqlalchemy.orm", Name: "mapped_column"}
137140
}
138141
if len(i.Enums) > 0 {
139142
std["enum"] = importSpec{Module: "enum"}

internal/printer/printer.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ func (w *writer) printNode(node *ast.Node, indent int32) {
130130
}
131131
}
132132

133+
func (w *writer) printAnnAssignWithMapped(aa *ast.AnnAssign, indent int32) {
134+
if aa.Comment != "" {
135+
w.print("# ")
136+
w.print(aa.Comment)
137+
w.print("\n")
138+
w.printIndent(indent)
139+
}
140+
w.printName(aa.Target, indent)
141+
w.print(": ")
142+
w.printNode(aa.Annotation, indent)
143+
}
144+
133145
func (w *writer) printAnnAssign(aa *ast.AnnAssign, indent int32) {
134146
if aa.Comment != "" {
135147
w.print("# ")

sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://dist/sqlc-gen-python-orm.wasm
6-
sha256: 2b6fe52d6ad5eb67cf783f5e57cd6cea3ce298115ca1ee7245215a41681dbdc5
6+
sha256: 25d1a43b594104740b8c368956944bfe5482d6ada35a3b2faf7dfd968917c9c9
77
sql:
88
- schema: "_examples/schema.sql"
99
queries: "_examples/query.sql"

0 commit comments

Comments
 (0)