@@ -618,6 +618,46 @@ func pydanticNode(name string) *pyast.ClassDef {
618618 }
619619}
620620
621+ func sqlalchemyNode (name string ) * pyast.ClassDef {
622+ return & pyast.ClassDef {
623+ Name : name ,
624+ Bases : []* pyast.Node {
625+ {
626+ Node : & pyast.Node_Attribute {
627+ Attribute : & pyast.Attribute {
628+ Value : & pyast.Node {
629+ Node : & pyast.Node_Name {
630+ Name : & pyast.Name {Id : "hello" },
631+ },
632+ },
633+ Attr : "Base" ,
634+ },
635+ },
636+ },
637+ },
638+ }
639+ }
640+
641+ func sqlalchemyBaseNode () * pyast.ClassDef {
642+ return & pyast.ClassDef {
643+ Name : "Base" ,
644+ Bases : []* pyast.Node {
645+ {
646+ Node : & pyast.Node_Attribute {
647+ Attribute : & pyast.Attribute {
648+ Value : & pyast.Node {
649+ Node : & pyast.Node_Name {
650+ Name : & pyast.Name {Id : "sqlalchemy.orm" },
651+ },
652+ },
653+ Attr : "DeclarativeBase" ,
654+ },
655+ },
656+ },
657+ },
658+ }
659+ }
660+
621661func fieldNode (f Field ) * pyast.Node {
622662 return & pyast.Node {
623663 Node : & pyast.Node_AnnAssign {
@@ -630,6 +670,14 @@ func fieldNode(f Field) *pyast.Node {
630670 }
631671}
632672
673+ func fieldPassNode () * pyast.Node {
674+ return & pyast.Node {
675+ Node : & pyast.Node_Pass {
676+ Pass : & pyast.Pass {},
677+ },
678+ }
679+ }
680+
633681func typeRefNode (base string , parts ... string ) * pyast.Node {
634682 n := poet .Name (base )
635683 for _ , p := range parts {
@@ -753,6 +801,71 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
753801 return & pyast.Node {Node : & pyast.Node_Module {Module : mod }}
754802}
755803
804+ func buildOrmTree (ctx * pyTmplCtx , i * importer ) * pyast.Node {
805+ mod := moduleNode (ctx .SqlcVersion , "" )
806+ std , pkg := i .ormImportSpecs ()
807+ mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
808+
809+ for _ , e := range ctx .Enums {
810+ def := & pyast.ClassDef {
811+ Name : e .Name ,
812+ Bases : []* pyast.Node {
813+ poet .Name ("str" ),
814+ poet .Attribute (poet .Name ("enum" ), "Enum" ),
815+ },
816+ }
817+ if e .Comment != "" {
818+ def .Body = append (def .Body , & pyast.Node {
819+ Node : & pyast.Node_Expr {
820+ Expr : & pyast.Expr {
821+ Value : poet .Constant (e .Comment ),
822+ },
823+ },
824+ })
825+ }
826+ for _ , c := range e .Constants {
827+ def .Body = append (def .Body , assignNode (c .Name , poet .Constant (c .Value )))
828+ }
829+ mod .Body = append (mod .Body , & pyast.Node {
830+ Node : & pyast.Node_ClassDef {
831+ ClassDef : def ,
832+ },
833+ })
834+ }
835+
836+ // generate Base class
837+ def := sqlalchemyBaseNode ()
838+ def .Body = append (def .Body , fieldPassNode ())
839+ mod .Body = append (mod .Body , & pyast.Node {
840+ Node : & pyast.Node_ClassDef {
841+ ClassDef : def ,
842+ },
843+ })
844+
845+ for _ , m := range ctx .Orms {
846+ def := sqlalchemyNode (m .Name )
847+ if m .Comment != "" {
848+ def .Body = append (def .Body , & pyast.Node {
849+ Node : & pyast.Node_Expr {
850+ Expr : & pyast.Expr {
851+ Value : poet .Constant (m .Comment ),
852+ },
853+ },
854+ })
855+ }
856+ for _ , f := range m .Fields {
857+ def .Body = append (def .Body , fieldNode (f ))
858+ }
859+ mod .Body = append (mod .Body , & pyast.Node {
860+ Node : & pyast.Node_ClassDef {
861+ ClassDef : def ,
862+ },
863+ })
864+ }
865+
866+ return & pyast.Node {Node : & pyast.Node_Module {Module : mod }}
867+ }
868+
756869func querierClassDef () * pyast.ClassDef {
757870 return & pyast.ClassDef {
758871 Name : "Querier" ,
@@ -1077,6 +1190,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10771190type pyTmplCtx struct {
10781191 SqlcVersion string
10791192 Models []Struct
1193+ Orms []Struct
10801194 Queries []Query
10811195 Enums []Enum
10821196 SourceName string
@@ -1116,6 +1230,7 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
11161230
11171231 tctx := pyTmplCtx {
11181232 Models : models ,
1233+ Orms : models ,
11191234 Queries : queries ,
11201235 Enums : enums ,
11211236 SqlcVersion : req .SqlcVersion ,
@@ -1127,6 +1242,13 @@ func Generate(_ context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenRes
11271242 tctx .SourceName = "models.py"
11281243 output ["models.py" ] = string (result .Python )
11291244
1245+ // sqlalchemy models
1246+ if conf .EmitSQLAlchemyModels {
1247+ tctx .SourceName = "orm.py"
1248+ ormResult := pyprint .Print (buildOrmTree (& tctx , i ), pyprint.Options {})
1249+ output ["orm.py" ] = string (ormResult .Python )
1250+ }
1251+
11301252 files := map [string ]struct {}{}
11311253 for _ , q := range queries {
11321254 files [q .SourceName ] = struct {}{}
0 commit comments