Skip to content

Commit

Permalink
refactor: code style
Browse files Browse the repository at this point in the history
  • Loading branch information
Bin-Huang committed Aug 5, 2022
1 parent d4eb1c1 commit 34624fa
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 189 deletions.
32 changes: 16 additions & 16 deletions generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ func {{.Name}}({{.Params}}) *{{.Struct}} {
{{ end }}
`

func generateCode(pkgName string, importResuts []ResultImport, results []Result) (string, error) {
// GenerateCode generate constructors code
func GenerateCode(pkgName string, importInfos []ImportInfo, structInfos []StructInfo) (string, error) {
// remove duplicate imports
importResuts = UniqueImports(importResuts)
importInfos = filterUniqueImports(importInfos)

// generate code with template
t, err := template.New("").Parse(templ)
Expand All @@ -53,21 +54,21 @@ func generateCode(pkgName string, importResuts []ResultImport, results []Result)
}
data := o{
"PkgName": pkgName,
"Imports": importResuts,
"Imports": importInfos,
"Constructor": []o{},
}
constructors := []o{}
for _, result := range results {
for _, structInfo := range structInfos {
params := []string{}
fields := []string{}
for _, field := range result.Fields {
for _, field := range structInfo.Fields {
params = append(params, fmt.Sprintf("%v %v", toLowerCamel(field.Name), field.Type))
fields = append(fields, fmt.Sprintf("%v: %v,", field.Name, toLowerCamel(field.Name)))
}
constructors = append(constructors, o{
"Name": "New" + strcase.ToCamel(result.StructName),
"Struct": result.StructName,
"Init": result.Init,
"Name": "New" + strcase.ToCamel(structInfo.StructName),
"Struct": structInfo.StructName,
"Init": structInfo.Init,
"Params": strings.Join(params, ", "),
"Fields": strings.Join(fields, "\n"),
})
Expand All @@ -80,15 +81,14 @@ func generateCode(pkgName string, importResuts []ResultImport, results []Result)
}

// format code
buf, err := FormatSource(buffer.Bytes())
buf, err := formatCode(buffer.Bytes())
if err != nil {
return "", err
}
return string(buf), nil
}

// FormatSource ...
func FormatSource(source []byte) ([]byte, error) {
func formatCode(source []byte) ([]byte, error) {
output, err := imports.Process("", source, &imports.Options{
AllErrors: true,
Comments: true,
Expand All @@ -102,17 +102,17 @@ func FormatSource(source []byte) ([]byte, error) {
if bytes.Equal(source, output) {
return output, nil
}
return FormatSource(output)
return formatCode(output)
}

// UniqueImports remove duplicate imports, return unqiue imports
func UniqueImports(imports []ResultImport) []ResultImport {
hash := map[string]ResultImport{}
// filterUniqueImports remove duplicate imports, return unqiue imports
func filterUniqueImports(imports []ImportInfo) []ImportInfo {
hash := map[string]ImportInfo{}
for _, importInfo := range imports {
key := fmt.Sprintf("%v|%v", importInfo.Name, importInfo.Path)
hash[key] = importInfo
}
ret := []ResultImport{}
ret := []ImportInfo{}
for _, importInfo := range hash {
ret = append(ret, importInfo)
}
Expand Down
176 changes: 4 additions & 172 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
package main

import (
"bufio"
"fmt"
"io/ioutil"
"os"
"strings"
"time"

"go/ast"
"go/types"

"go/parser"
"go/token"

"golang.org/x/tools/go/packages"
)

var fset = token.NewFileSet()

func main() {
pkg, err := GetPackageInfo(".")
if err != nil {
Expand All @@ -29,8 +17,8 @@ func main() {
if isGeneratedRecently(genFilename) {
return
}
allImports := []ResultImport{}
allResults := []Result{}
allImports := []ImportInfo{}
allResults := []StructInfo{}
for _, filename := range pkg.GoFiles {
has, err := IncludeMakeMark(filename)
if err != nil {
Expand All @@ -39,7 +27,7 @@ func main() {
if !has {
continue
}
results, imports, err := ParseFile(filename)
results, imports, err := ParseCodeFile(filename)
if err != nil {
panic(err)
}
Expand All @@ -49,7 +37,7 @@ func main() {
allImports = append(allImports, imports...)
allResults = append(allResults, results...)
}
code, err := generateCode(pkg.Name, allImports, allResults)
code, err := GenerateCode(pkg.Name, allImports, allResults)
if err != nil {
panic(err)
}
Expand All @@ -60,162 +48,6 @@ func main() {
fmt.Printf("make-constructor: %v: wrote %v\n", pkg.PkgPath, genFilename)
}

// GetPackageInfo get the Go package information in the dir
func GetPackageInfo(dir string) (*packages.Package, error) {
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedFiles,
Tests: false,
}, dir)
if err != nil {
return nil, fmt.Errorf("failed to load packages: %w", err)
}
if len(pkgs) == 0 {
return nil, fmt.Errorf("cannot find any package in %v", dir)
}
return pkgs[0], nil
}

// IncludeMakeMark ...
func IncludeMakeMark(filepath string) (bool, error) {
file, err := os.Open(filepath)
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if IsMakeComment(line) {
return true, nil
}
}
return false, nil
}

// IsMakeComment ...
func IsMakeComment(s string) bool {
s = strings.TrimSpace(s)
return strings.HasPrefix(s, "//go:generate") && strings.Contains(s, "make-constructor")
}

// IsInitModeEnable check if this struct enable the init mode
func IsInitModeEnable(s string) bool {
return strings.Contains(s, "init")
}

// BuildAST ...
func BuildAST(filename string) (*ast.File, error) {
astFile, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("failed to build AST from file(%v): %w", filename, err)
}
return astFile, nil
}

// ResultField ...
type ResultField struct {
Name string
Type string
}

// ResultImport ...
type ResultImport struct {
Name string
Path string
}

// Result ...
type Result struct {
StructName string
Init bool
Fields []ResultField
}

// ParseFile ...
func ParseFile(filename string) ([]Result, []ResultImport, error) {
results := []Result{}
imports := []ResultImport{}
astFile, err := BuildAST(filename)
if err != nil {
return results, imports, err
}
for _, decl := range astFile.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}

var initMode bool
if genDecl.Tok == token.TYPE {
needGen := false
for _, doc := range genDecl.Doc.List {
if IsMakeComment(doc.Text) {
needGen = true
initMode = IsInitModeEnable(doc.Text)
break
}
}
if !needGen {
continue
}
}

for _, spec := range genDecl.Specs {
importSpec, ok := spec.(*ast.ImportSpec)
if ok {
var name string
if importSpec.Name != nil {
name = importSpec.Name.Name
}
imports = append(imports, ResultImport{
Name: name,
Path: importSpec.Path.Value,
})
continue
}

typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
resultFields := []ResultField{}
for _, field := range structType.Fields.List {
fieldType := types.ExprString(field.Type)
var fieldName string
if len(field.Names) > 0 {
fieldName = field.Names[0].Name
} else {
// handle embeded struct cases just like this:
// type Foo struct {
// pkg.Struct,
// }
items := strings.Split(fieldType, ".")
fieldName = items[len(items)-1]
// handle pointer cases just like this:
// type Foo struct {
// *pkg.Struct,
// }
fieldName = strings.TrimPrefix(fieldName, "*")
}
resultFields = append(resultFields, ResultField{
Type: fieldType,
Name: fieldName,
})
}
results = append(results, Result{
StructName: typeSpec.Name.Name,
Fields: resultFields,
Init: initMode,
})
}
}
return results, imports, nil
}

func isGeneratedRecently(genFilename string) bool {
stat, err := os.Stat(genFilename)
if err != nil {
Expand Down
Loading

0 comments on commit 34624fa

Please sign in to comment.