Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unsupported DDL parsing; squash revealed bugs #329

Merged
merged 7 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: parse unsupported DDL more granularly instead of with one b…
…ig DATA token
  • Loading branch information
tconbeer committed Nov 28, 2022
commit a32e211144130788bcd24eb194f909ae0a642807
34 changes: 0 additions & 34 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,40 +327,6 @@ def handle_nonreserved_keyword(
action(analyzer, source_string, match)


def handle_possible_unsupported_ddl(
analyzer: "Analyzer", source_string: str, match: re.Match
) -> None:
"""
Checks to see if we're at depth 0; if so, then lex this token as DATA,
otherwise try to match the current position against the ordinary
name rule
"""
token = Token.from_match(source_string, match, TokenType.DATA)
node = analyzer.node_manager.create_node(
token=token, previous_node=analyzer.previous_node
)
if node.depth[0] == 0:
analyzer.node_buffer.append(node)
analyzer.pos = token.epos
else:
# this looks like unsupported ddl/sql, but we're inside a query already, so
# it's probably just an ordinary name
name_rule = analyzer.get_rule(rule_name="name")
name_match = name_rule.program.match(source_string, pos=analyzer.pos)
assert name_match, (
"Internal Error! Please open an issue."
"An error occurred when lexing unsupported SQL"
f"at position {match.span(1)[0]}:\n"
f"{source_string[slice(*match.span(1))]}"
)
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=name_match,
token_type=TokenType.NAME,
)


def lex_ruleset(
analyzer: "Analyzer",
source_string: str,
Expand Down
6 changes: 3 additions & 3 deletions src/sqlfmt/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Line:
previous_node: Optional[Node] # last node of prior line, if any
nodes: List[Node] = field(default_factory=list)
comments: List[Comment] = field(default_factory=list)
formatting_disabled: bool = False
formatting_disabled: List[Token] = field(default_factory=list)

def __str__(self) -> str:
return self._calc_str
Expand Down Expand Up @@ -143,7 +143,7 @@ def from_nodes(
nodes=nodes,
comments=comments,
formatting_disabled=nodes[0].formatting_disabled
or nodes[-1].formatting_disabled,
+ nodes[-1].formatting_disabled,
)
else:
line = Line(
Expand All @@ -152,7 +152,7 @@ def from_nodes(
comments=comments,
formatting_disabled=previous_node.formatting_disabled
if previous_node
else False,
else [],
)

return line
Expand Down
6 changes: 3 additions & 3 deletions src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class Node:
refer to open brackets (keywords and parens) or jinja blocks (e.g., {% if foo %})
that increase the syntax depth (and therefore printed indentation) of this Node

formatting_disabled: a boolean indicating that sqlfmt should print the raw token
instead of the formatted values for this Node
formatting_disabled: a list of FMT_OFF tokens that precede this node and prevent
it from being formatted
"""

token: Token
Expand All @@ -53,7 +53,7 @@ class Node:
value: str
open_brackets: List["Node"] = field(default_factory=list)
open_jinja_blocks: List["Node"] = field(default_factory=list)
formatting_disabled: bool = False
formatting_disabled: List[Token] = field(default_factory=list)

def __str__(self) -> str:
"""
Expand Down
44 changes: 36 additions & 8 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

from sqlfmt.exception import SqlfmtBracketError
from sqlfmt.line import Line
Expand All @@ -25,11 +25,9 @@ def create_node(self, token: Token, previous_node: Optional[Node]) -> Node:
if previous_node is None:
open_brackets = []
open_jinja_blocks = []
formatting_disabled = False
else:
open_brackets = previous_node.open_brackets.copy()
open_jinja_blocks = previous_node.open_jinja_blocks.copy()
formatting_disabled = previous_node.formatting_disabled

# add the previous node to the list of open brackets or jinja blocks
if previous_node.is_unterm_keyword or previous_node.is_opening_bracket:
Expand Down Expand Up @@ -70,11 +68,7 @@ def create_node(self, token: Token, previous_node: Optional[Node]) -> Node:
prev_token, extra_whitespace = get_previous_token(previous_node)
prefix = self.whitespace(token, prev_token, extra_whitespace)
value = self.standardize_value(token)

if token.type in (TokenType.FMT_OFF, TokenType.DATA):
formatting_disabled = True
elif prev_token and prev_token.type in (TokenType.FMT_ON, TokenType.DATA):
formatting_disabled = False
formatting_disabled = self.disable_formatting(token, previous_node)

return Node(
token=token,
Expand Down Expand Up @@ -245,6 +239,40 @@ def standardize_value(self, token: Token) -> str:
else:
return token.token

def disable_formatting(
self, token: Token, previous_node: Optional[Node]
) -> List[Token]:
"""
Manage the formatting_disabled property for the node to be created from
the token and previous node.
"""
formatting_disabled = (
previous_node.formatting_disabled.copy() if previous_node else []
)

if token.type in (TokenType.FMT_OFF, TokenType.DATA):
formatting_disabled.append(token)

if (
formatting_disabled
and previous_node
and previous_node.token.type
in (
TokenType.FMT_ON,
TokenType.DATA,
)
):
formatting_disabled.pop()

if (
formatting_disabled
and token.type == TokenType.SEMICOLON
and "fmt: off" not in formatting_disabled[-1].token.lower()
):
formatting_disabled.pop()

return formatting_disabled

def append_newline(self, line: Line) -> None:
"""
Create a new NEWLINE token and append it to the end of line
Expand Down
114 changes: 55 additions & 59 deletions src/sqlfmt/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
ALTER_WAREHOUSE,
CREATE_FUNCTION,
CREATE_WAREHOUSE,
NEWLINE,
SQL_COMMENT,
SQL_QUOTED_EXP,
group,
)
from sqlfmt.rules.core import CORE as CORE
Expand Down Expand Up @@ -235,62 +232,61 @@
name="unsupported_ddl",
priority=2999,
pattern=group(
group(
r"alter",
r"attach\s+rls\s+policy",
r"cache\s+table",
r"clear\s+cache",
r"cluster",
r"comment",
r"copy",
r"create",
r"deallocate",
r"declare",
r"describe",
r"desc\s+datashare",
r"desc\s+identity\s+provider",
r"delete",
r"detach\s+rls\s+policy",
r"discard",
r"do",
r"drop",
r"execute",
r"export",
r"fetch",
r"get",
r"handler",
r"import\s+foreign\s+schema",
r"import\s+table",
# snowflake: "insert into" or "insert overwrite into"
# snowflake: has insert() function
# spark: "insert overwrite" without the trailing "into"
# redshift/pg: "insert into" only
# bigquery: bare "insert" is okay
r"insert(\s+overwrite)?(\s+into)?",
r"list",
r"lock",
r"merge",
r"move",
# prepare transaction statements are simple enough
# so we'll allow them
r"prepare(?!\s+transaction)",
r"put",
r"reassign\s+owned",
r"remove",
r"rename\s+table",
r"repair",
r"security\s+label",
r"select\s+into",
r"truncate",
r"unload",
r"update",
r"validate",
)
+ r"(?!\()"
+ rf"\b({SQL_COMMENT}|{SQL_QUOTED_EXP}|[^'`\"$;w])*?"
r"alter",
r"attach\s+rls\s+policy",
r"cache\s+table",
r"clear\s+cache",
r"cluster",
r"comment",
r"copy",
r"create",
r"deallocate",
r"declare",
r"describe",
r"desc\s+datashare",
r"desc\s+identity\s+provider",
r"delete",
r"detach\s+rls\s+policy",
r"discard",
r"do",
r"drop",
r"execute",
r"export",
r"fetch",
r"get",
r"handler",
r"import\s+foreign\s+schema",
r"import\s+table",
# snowflake: "insert into" or "insert overwrite into"
# snowflake: has insert() function
# spark: "insert overwrite" without the trailing "into"
# redshift/pg: "insert into" only
# bigquery: bare "insert" is okay
r"insert(\s+overwrite)?(\s+into)?",
r"list",
r"lock",
r"merge",
r"move",
# prepare transaction statements are simple enough
# so we'll allow them
r"prepare(?!\s+transaction)",
r"put",
r"reassign\s+owned",
r"remove",
r"rename\s+table",
r"repair",
r"security\s+label",
r"select\s+into",
r"truncate",
r"unload",
r"update",
r"validate",
)
+ rf"{NEWLINE}*"
+ group(r";", r"$"),
action=actions.handle_possible_unsupported_ddl,
+ r"(?!\()"
+ group(r"\W", r"$"),
action=partial(
actions.handle_nonreserved_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.FMT_OFF),
),
),
]
13 changes: 13 additions & 0 deletions tests/data/unformatted/999_unsupported_ddl.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE PUBLICATION insert_only FOR TABLE mydata
WITH (publish = 'insert');
CREATE PUBLICATION production_publication FOR TABLE users, departments, TABLES IN SCHEMA production;
CREATE PUBLICATION users_filtered FOR TABLE users (user_id, firstname);
SELECT
1;
)))))__SQLFMT_OUTPUT__(((((
CREATE PUBLICATION insert_only FOR TABLE mydata
WITH (publish = 'insert');
CREATE PUBLICATION production_publication FOR TABLE users, departments, TABLES IN SCHEMA production;
CREATE PUBLICATION users_filtered FOR TABLE users (user_id, firstname);
select 1
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
select 0;
-- fmt: off
select 1;
-- fmt: off
select 2;
-- unsupported ddl
CREATE PUBLICATION mypublication FOR TABLE users, departments;
-- fmt: on
select 3;
-- fmt: on
select 4;
-- fmt: on
-- fmt: on
-- fmt: on
-- fmt: on
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"unformatted/408_alter_function_snowflake_examples.sql",
"unformatted/409_create_external_function.sql",
"unformatted/410_create_warehouse.sql",
"unformatted/999_unsupported_ddl.sql",
],
)
def test_formatting(p: str) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ def test_handle_unsupported_ddl(default_analyzer: Analyzer) -> None:
query = default_analyzer.parse_query(source_string=source_string.lstrip())
assert len(query.lines) == 3
first_create_line = query.lines[0]
assert len(first_create_line.nodes) == 3
assert first_create_line.nodes[0].token.type == TokenType.DATA
assert first_create_line.nodes[1].token.type == TokenType.SEMICOLON
assert len(first_create_line.nodes) == 9
assert first_create_line.nodes[0].token.type == TokenType.FMT_OFF
assert first_create_line.nodes[-2].token.type == TokenType.SEMICOLON

select_line = query.lines[1]
assert len(select_line.nodes) == 8
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_simple_line(
"\tvalue='with',\n"
"\topen_brackets=[],\n"
"\topen_jinja_blocks=[],\n"
"\tformatting_disabled=False\n"
"\tformatting_disabled=[]\n"
")"
)
assert repr(simple_line.nodes[0]) == expected_node_repr
Expand Down Expand Up @@ -405,7 +405,7 @@ def test_formatting_disabled(default_mode: Mode) -> None:
True, # --fmt: on
False, # where format is true
]
actual = [line.formatting_disabled for line in q.lines]
actual = [bool(line.formatting_disabled) for line in q.lines]
assert actual == expected


Expand Down
Loading