@@ -623,15 +623,8 @@ func sqlalchemyNode(name string) *pyast.ClassDef {
623623 Name : name ,
624624 Bases : []* pyast.Node {
625625 {
626- Node : & pyast.Node_Attribute {
627- Attribute : & pyast.Attribute {
628- Value : & pyast.Node {
629- Node : & pyast.Node_Name {
630- Name : & pyast.Name {Id : "hello" },
631- },
632- },
633- Attr : "Base" ,
634- },
626+ Node : & pyast.Node_Name {
627+ Name : & pyast.Name {Id : "Base" },
635628 },
636629 },
637630 },
@@ -643,15 +636,8 @@ func sqlalchemyBaseNode() *pyast.ClassDef {
643636 Name : "Base" ,
644637 Bases : []* pyast.Node {
645638 {
646- Node : & pyast.Node_Attribute {
647- Attribute : & pyast.Attribute {
648- Value : & pyast.Node {
649- Node : & pyast.Node_Name {
650- Name : & pyast.Name {Id : "sqlalchemy.orm" },
651- },
652- },
653- Attr : "DeclarativeBase" ,
654- },
639+ Node : & pyast.Node_Name {
640+ Name : & pyast.Name {Id : "DeclarativeBase" },
655641 },
656642 },
657643 },
@@ -670,6 +656,18 @@ func fieldNode(f Field) *pyast.Node {
670656 }
671657}
672658
659+ func ormFieldNode (f Field ) * pyast.Node {
660+ return & pyast.Node {
661+ Node : & pyast.Node_AnnAssign {
662+ AnnAssign : & pyast.AnnAssign {
663+ Target : & pyast.Name {Id : f .Name },
664+ Annotation : subscriptNode ("Mapped" , f .Type .Annotation ()),
665+ Comment : f .Comment ,
666+ },
667+ },
668+ }
669+ }
670+
673671func fieldPassNode () * pyast.Node {
674672 return & pyast.Node {
675673 Node : & pyast.Node_Pass {
@@ -716,6 +714,7 @@ func buildImportGroup(specs map[string]importSpec) *pyast.Node {
716714 var body []* pyast.Node
717715 for _ , spec := range buildImportBlock2 (specs ) {
718716 if len (spec .Names ) > 0 && spec .Names [0 ] != "" {
717+ // e.g. from sqlalchemy import create_engine
719718 imp := & pyast.ImportFrom {
720719 Module : spec .Module ,
721720 }
@@ -834,6 +833,7 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
834833 }
835834
836835 // generate Base class
836+ // Base class is a dependency for all orm models
837837 def := sqlalchemyBaseNode ()
838838 def .Body = append (def .Body , fieldPassNode ())
839839 mod .Body = append (mod .Body , & pyast.Node {
@@ -842,7 +842,7 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
842842 },
843843 })
844844
845- for _ , m := range ctx .Orms {
845+ for _ , m := range ctx .Models {
846846 def := sqlalchemyNode (m .Name )
847847 if m .Comment != "" {
848848 def .Body = append (def .Body , & pyast.Node {
@@ -853,8 +853,11 @@ func buildOrmTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
853853 },
854854 })
855855 }
856+ // add table name as class attribute
857+ def .Body = append (def .Body , assignNode ("__tablename__" , poet .Constant (m .Table .Name )))
858+
856859 for _ , f := range m .Fields {
857- def .Body = append (def .Body , fieldNode (f ))
860+ def .Body = append (def .Body , ormFieldNode (f ))
858861 }
859862 mod .Body = append (mod .Body , & pyast.Node {
860863 Node : & pyast.Node_ClassDef {
@@ -1190,7 +1193,6 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
11901193type pyTmplCtx struct {
11911194 SqlcVersion string
11921195 Models []Struct
1193- Orms []Struct
11941196 Queries []Query
11951197 Enums []Enum
11961198 SourceName string
@@ -1230,7 +1232,6 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
12301232
12311233 tctx := pyTmplCtx {
12321234 Models : models ,
1233- Orms : models ,
12341235 Queries : queries ,
12351236 Enums : enums ,
12361237 SqlcVersion : req .SqlcVersion ,
0 commit comments