Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix(compiler): robustly strip psql meta commands without breaking SQL
Replace naive line-based removal with a single-pass state machine that correctly distinguishes psql meta-commands from backslashes in SQL code, literals, and comments.

The previous implementation would incorrectly strip any line starting with a backslash, breaking valid SQL containing:
- Backslashes in string literals (E'\\n', escape sequences)
- Meta-command text in comments or documentation
- Dollar-quoted function bodies with backslash content

Changes:
- Track parsing state for single quotes, dollar quotes, and block comments
- Only remove backslash commands at true line starts outside any literal context
- Properly handle escaped quotes (''), nested block comments (/* /* */ */)
- Support dollar-quoted tags with identifiers ($tag$...$tag$)
- Add comprehensive test suite covering:
  * All documented psql meta-commands (\connect, \set, \d*, etc.)
  * String literals with backslashes and nested quotes
  * Dollar-quoted blocks with various tag formats
  * Nested block comments containing meta-command text
  * Edge cases: empty input, whitespace-only, missing newlines

Performance improvements:
- Pre-allocate output buffer with strings.Builder.Grow()
- Single pass eliminates redundant string operations
- Reduces allocations by avoiding intermediate line slice
  • Loading branch information
ignat980 committed Dec 11, 2025
commit 2181f98b87296c3cc4c98035c2435d0cae71e0f9
152 changes: 144 additions & 8 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package compiler

import (
"bufio"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -60,16 +59,153 @@ func (c *Compiler) parseCatalog(schemas []string) error {
}

func removePsqlMetaCommands(contents string) string {
s := bufio.NewScanner(strings.NewReader(contents))
var lines []string
for s.Scan() {
line := s.Text()
if strings.HasPrefix(line, `\`) {
if contents == "" {
return contents
}
var out strings.Builder
out.Grow(len(contents))

lineStart := true
inSingle := false
inDollar := false
var dollarTag string
blockDepth := 0
n := len(contents)
for i := 0; ; {
if lineStart && !inSingle && blockDepth == 0 && !inDollar {
start := i
for i < n {
c := contents[i]
if c == ' ' || c == '\t' || c == '\r' {
i++
continue
}
break
}
if i < n && contents[i] == '\\' {
for i < n && contents[i] != '\n' {
i++
}
if i < n && contents[i] == '\n' {
out.WriteByte('\n')
i++
}
lineStart = true
continue
}
if start < i {
out.WriteString(contents[start:i])
}
if i >= n {
break
}
}
if i >= n {
break
}
c := contents[i]
if inSingle {
out.WriteByte(c)
if c == '\'' {
if i+1 < n && contents[i+1] == '\'' {
out.WriteByte(contents[i+1])
i += 2
lineStart = false
continue
}
inSingle = false
}
if c == '\n' {
lineStart = true
} else {
lineStart = false
}
i++
continue
}
lines = append(lines, line)
if inDollar {
if strings.HasPrefix(contents[i:], dollarTag) {
out.WriteString(dollarTag)
i += len(dollarTag)
inDollar = false
lineStart = false
continue
}
out.WriteByte(c)
if c == '\n' {
lineStart = true
} else {
lineStart = false
}
i++
continue
}
if blockDepth > 0 {
if c == '/' && i+1 < n && contents[i+1] == '*' {
blockDepth++
out.WriteString("/*")
i += 2
lineStart = false
continue
}
if c == '*' && i+1 < n && contents[i+1] == '/' {
blockDepth--
out.WriteString("*/")
i += 2
lineStart = false
continue
}
out.WriteByte(c)
if c == '\n' {
lineStart = true
} else {
lineStart = false
}
i++
continue
}
switch c {
case '\'':
inSingle = true
out.WriteByte(c)
lineStart = false
i++
continue
case '$':
tagEnd := i + 1
for tagEnd < n && isDollarTagChar(contents[tagEnd]) {
tagEnd++
}
if tagEnd < n && contents[tagEnd] == '$' {
dollarTag = contents[i : tagEnd+1]
inDollar = true
out.WriteString(dollarTag)
i = tagEnd + 1
lineStart = false
continue
}
case '/':
if i+1 < n && contents[i+1] == '*' {
blockDepth = 1
out.WriteString("/*")
i += 2
lineStart = false
continue
}
}
out.WriteByte(c)
if c == '\n' {
lineStart = true
} else {
lineStart = false
}
i++
}
return strings.Join(lines, "\n")
return out.String()
}

func isDollarTagChar(b byte) bool {
return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}

func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
Expand Down
159 changes: 159 additions & 0 deletions internal/compiler/psql_meta_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package compiler

import (
"fmt"
"strings"
"testing"
)

var allPsqlMetaCommands = []string{
`\a`, `\bind`, `\bind_named`, `\c`, `\connect`, `\C`, `\cd`, `\close_prepared`, `\conninfo`, `\copy`,
`\copyright`, `\crosstabview`, `\d`, `\da`, `\dA`, `\dAc`, `\dAf`, `\dAo`, `\dAp`, `\db`,
`\dc`, `\dconfig`, `\dC`, `\dd`, `\dD`, `\ddp`, `\dE`, `\di`, `\dm`, `\ds`,
`\dt`, `\dv`, `\des`, `\det`, `\deu`, `\dew`, `\df`, `\dF`, `\dFd`, `\dFp`,
`\dFt`, `\dg`, `\dl`, `\dL`, `\dn`, `\do`, `\dO`, `\dp`, `\dP`, `\drds`,
`\drg`, `\dRp`, `\dRs`, `\dT`, `\du`, `\dx`, `\dX`, `\dy`, `\e`, `\edit`,
`\echo`, `\ef`, `\encoding`, `\ev`, `\f`, `\g`, `\gdesc`, `\getenv`, `\gexec`, `\gset`,
`\gx`, `\h`, `\help`, `\H`, `\html`, `\i`, `\include`, `\if`, `\elif`, `\else`,
`\endif`, `\ir`, `\include_relative`, `\l`, `\list`, `\lo_export`, `\lo_import`, `\lo_list`, `\lo_unlink`, `\o`,
`\out`, `\p`, `\print`, `\parse`, `\password`, `\prompt`, `\pset`, `\q`, `\quit`, `\qecho`,
`\r`, `\reset`, `\restrict`, `\s`, `\set`, `\setenv`, `\sf`, `\sv`, `\startpipeline`, `\sendpipeline`,
`\syncpipeline`, `\endpipeline`, `\flushrequest`, `\flush`, `\getresults`, `\t`, `\T`, `\timing`, `\unrestrict`, `\unset`,
`\w`, `\write`, `\warn`, `\watch`, `\x`, `\z`, `\!`, `\?`, `\;`,
}

func TestRemovePsqlMetaCommands_TableDriven(t *testing.T) {
inDoubleQuoted := "CREATE TABLE \"foo\\bar\" (id int);\nSELECT \"foo\\bar\"." +
"id FROM \"foo\\bar\";\n"
inValidSQL := "CREATE TABLE t (id int);\nINSERT INTO t VALUES (1);\n"
inWhitespaceOnly := " \t "
inNoTrailingNewline := "SELECT 1"
inBackslashNotAtStart := "SELECT '\\not_meta' AS col;\n SELECT '\\still_not_meta';\n"
inDoubleSingleQuotes := "INSERT INTO t VALUES ('It''s fine');\n"

tests := []struct {
name string
in string
want string
}{
{
name: "RemovesTopLevelMetaCommands",
in: "CREATE TABLE public.authors();\n\\connect test\n \\set ON_ERROR_STOP on\nSELECT 1;\n",
want: "CREATE TABLE public.authors();\n\n\nSELECT 1;\n",
},
{
name: "IgnoresBackslashesInStrings",
in: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\\connect nope\n",
want: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\n",
},
{
name: "PreservesDollarQuotedBlocks",
in: "DO $$\n\\this_should_stay\n$$;\n\\connect other\n",
want: "DO $$\n\\this_should_stay\n$$;\n\n",
},
{
name: "IgnoresBlockComments",
in: "/*\n\\comment_not_meta\n*/\n\\set x 1\nSELECT 1;\n",
want: "/*\n\\comment_not_meta\n*/\n\nSELECT 1;\n",
},
{
name: "LeavesValidSqlUntouched",
in: inValidSQL,
want: inValidSQL,
},
{
name: "HandlesEmptyInput",
in: "",
want: "",
},
{
name: "PreservesWhitespaceOnlyInput",
in: inWhitespaceOnly,
want: inWhitespaceOnly,
},
{
name: "PreservesFinalLineWithoutNewline",
in: inNoTrailingNewline,
want: inNoTrailingNewline,
},
{
name: "BackslashInDoubleQuotedIdentifier",
in: inDoubleQuoted,
want: inDoubleQuoted,
},
{
name: "BackslashNotAtLineStart",
in: inBackslashNotAtStart,
want: inBackslashNotAtStart,
},
{
name: "DoubleSingleQuotesRemain",
in: inDoubleSingleQuotes,
want: inDoubleSingleQuotes,
},
{
name: "MetaCommandTextInsideLiteral",
in: `INSERT INTO logs VALUES ('Remember to run \connect later');
SELECT E'\n\connect\n' as literal;` + "\n",
want: `INSERT INTO logs VALUES ('Remember to run \connect later');
SELECT E'\n\connect\n' as literal;` + "\n",
},
{
name: "BlockCommentsPreserveMetaText",
in: `/* outer block begins
/* nested: run \connect test_db for interactive work */
documenting with \connect text shouldn't strip SQL
*/
SELECT 1;
/* Change instructions:
\connect reporting

Reason: run maintenance scripts as reporting user.
*/
\connect should_go
`,
want: `/* outer block begins
/* nested: run \connect test_db for interactive work */
documenting with \connect text shouldn't strip SQL
*/
SELECT 1;
/* Change instructions:
\connect reporting

Reason: run maintenance scripts as reporting user.
*/

`,
},
{
name: "DollarTagWithIdentifier",
in: "DO $foo$\n\\inside\n$foo$;\n\\set should_go\n",
want: "DO $foo$\n\\inside\n$foo$;\n\n",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := removePsqlMetaCommands(tc.in)
if got != tc.want {
t.Fatalf("unexpected output after stripping meta commands:\nwant=%q\ngot =%q", tc.want, got)
}
})
}

t.Run("CoversDocumentedMetaCommands", func(t *testing.T) {
for _, cmd := range allPsqlMetaCommands {
t.Run(fmt.Sprintf("strip_%s", strings.TrimPrefix(cmd, `\`)), func(t *testing.T) {
input := fmt.Sprintf("%s -- meta command\nSELECT 42;\n", cmd)
got := removePsqlMetaCommands(input)

if strings.Contains(got, cmd+" -- meta command") {
t.Fatalf("meta command %q line was not removed", cmd)
}
if !strings.Contains(got, "SELECT 42;") {
t.Fatalf("SQL content was unexpectedly removed for %q", cmd)
}
})
}
})
}