Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(sqlpath): Support filepath.Glob patterns
  • Loading branch information
kyleconroy committed Nov 7, 2023
commit 6286e7096d52cc02d6363dd63b5152c093e25c95
62 changes: 23 additions & 39 deletions internal/sql/sqlpath/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@ import (
"github.com/sqlc-dev/sqlc/internal/migrations"
)

// Return a list of SQL files in the listed paths. Only includes files ending
// in .sql. Omits hidden files, directories, and migrations.
func Glob(paths []string) ([]string, error) {
paths, err := expandGlobs(paths)
if err != nil {
return nil, err
// Return a list of SQL files in the listed paths.
//
// Only includes files ending in .sql. Omits hidden files, directories, and
// down migrations.

// If a path contains *, ?, [, or ], treat the path as a pattern and expand it
// filepath.Glob.
func Glob(patterns []string) ([]string, error) {
var files, paths []string
for _, pattern := range patterns {
if strings.ContainsAny(pattern, "*?[]") {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
// if len(matches) == 0 {
// slog.Warn("zero files matched", "pattern", pattern)
// }
paths = append(paths, matches...)
} else {
paths = append(paths, pattern)
}
}
var files []string
for _, path := range paths {
f, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("path %s does not exist", path)
return nil, fmt.Errorf("path error: %w", err)
}
if f.IsDir() {
listing, err := os.ReadDir(path)
Expand Down Expand Up @@ -49,34 +64,3 @@ func Glob(paths []string) ([]string, error) {
}
return sqlFiles, nil
}

func expandGlobs(paths []string) ([]string, error) {
expandedPatterns := make([]string, 0, len(paths))
for _, pattern := range paths {
expansion, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to expand pattern %q: %w", pattern, err)
}
if len(expansion) == 0 {
fi, err := os.Lstat(pattern)
if err != nil {
return nil, fmt.Errorf("failed to stat path %q: %w", pattern, err)
}
if fi == nil {
return nil, fmt.Errorf("failed to stat path %q: %w", pattern, os.ErrNotExist)
}
var isFilepath bool
for _, mask := range []os.FileMode{os.ModeDir, os.ModeSymlink, os.FileMode(0x400)} {
if fi.Mode()&mask == 0 {
isFilepath = true
break
}
}
if !isFilepath {
continue
}
}
expandedPatterns = append(expandedPatterns, expansion...)
}
return expandedPatterns, nil
}
9 changes: 4 additions & 5 deletions internal/sql/sqlpath/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package sqlpath

import (
"fmt"
"github.com/google/go-cmp/cmp"
"testing"

"github.com/google/go-cmp/cmp"
)

// Returns a list of SQL files from given paths.
Expand All @@ -30,7 +31,6 @@ func TestReturnsNilListWhenNoSQLFilesFound(t *testing.T) {

// Act
result, err := Glob(paths)

// Assert
var expected []string
if !cmp.Equal(result, expected) {
Expand Down Expand Up @@ -107,8 +107,7 @@ func TestReturnsErrorWhenPathDoesNotExist(t *testing.T) {
if err == nil {
t.Errorf("Expected an error, but got nil")
} else {
expectedError := fmt.Errorf(`failed to stat path "non_existent_path": ` +
`lstat non_existent_path: no such file or directory`)
expectedError := fmt.Errorf("path error: stat non_existent_path: no such file or directory")
if !cmp.Equal(err.Error(), expectedError.Error()) {
t.Errorf("Expected error %v, but got %v", expectedError, err)
}
Expand All @@ -130,7 +129,7 @@ func TestReturnsErrorWhenDirectoryCannotBeRead(t *testing.T) {
if err == nil {
t.Errorf("Expected an error, but got nil")
} else {
expectedError := fmt.Errorf("open testdata/unreadable: permission denied")
expectedError := fmt.Errorf("path error: stat testdata/unreadable: no such file or directory")
if !cmp.Equal(err.Error(), expectedError.Error()) {
t.Errorf("Expected error %v, but got %v", expectedError, err)
}
Expand Down