Skip to content

Commit

Permalink
Improve unsupported DDL parsing; squash revealed bugs (#329)
Browse files Browse the repository at this point in the history
* fix #326: do not match unsupported ddl followed by open parens

* refactor: parse unsupported DDL more granularly instead of with one big DATA token

* chore: update changelog

* fix: catch SqlfmtSegmentError and handle appropriately, do not use finally; fixes issue where we could drop lines entirely

* fix: stop using tail recursion for create_segments_from_lines

* chore: update primer refs

* refactor: simplify create_segments, add tests
  • Loading branch information
tconbeer authored Nov 30, 2022
1 parent bd08b42 commit 768f949
Show file tree
Hide file tree
Showing 27 changed files with 1,104 additions and 170 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ All notable changes to this project will be documented in this file.
- sqlfmt now supports `alter function` and `drop function` statements ([#310](https://github.com/tconbeer/sqlfmt/issues/310), [#311](https://github.com/tconbeer/sqlfmt/issues/311)), and Snowflake's `create external function` statements ([#322](https://github.com/tconbeer/sqlfmt/issues/322)).
- sqlfmt better supports numeric constants (number literals), including those using scientific notation (e.g., `1.5e-9`) and the unary `+` or `-` operators (e.g., `+3`), and is now smarter about when the `-` symbol is the unary negative or binary subtraction operator. ([#321](https://github.com/tconbeer/sqlfmt/issues/321) - thank you [@liaopeiyuan](https://github.com/liaopeiyuan)!).
- fixed a bug where we added extra whitespace to the end of empty comment lines ([#319](https://github.com/tconbeer/sqlfmt/issues/319) - thank you [@eherde](https://github.com/eherde)!).
- fixed an bug where wrapping unsupported DDL in jinja would cause a parsing error ([#326](https://github.com/tconbeer/sqlfmt/issues/326) - thank you [@ETG-msimons](https://github.com/ETG-msimons)!). Also improved parsing of unsupported DDL and made false positives less likely.
- fixed a bug where we could have unsafely run *black* against jinja that contained Python keywords and their safe alternatives (e.g., `return(return_())`).
- fixed a bug where we deleted some extra whitespace lines (and in very rare cases, nonblank lines)
- fixed a bug where Python recursion limits could cause incorrect formatting in rare cases

## [0.13.0] - 2022-11-01

Expand Down
34 changes: 0 additions & 34 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,40 +327,6 @@ def handle_nonreserved_keyword(
action(analyzer, source_string, match)


def handle_possible_unsupported_ddl(
analyzer: "Analyzer", source_string: str, match: re.Match
) -> None:
"""
Checks to see if we're at depth 0; if so, then lex this token as DATA,
otherwise try to match the current position against the ordinary
name rule
"""
token = Token.from_match(source_string, match, TokenType.DATA)
node = analyzer.node_manager.create_node(
token=token, previous_node=analyzer.previous_node
)
if node.depth[0] == 0:
analyzer.node_buffer.append(node)
analyzer.pos = token.epos
else:
# this looks like unsupported ddl/sql, but we're inside a query already, so
# it's probably just an ordinary name
name_rule = analyzer.get_rule(rule_name="name")
name_match = name_rule.program.match(source_string, pos=analyzer.pos)
assert name_match, (
"Internal Error! Please open an issue."
"An error occurred when lexing unsupported SQL"
f"at position {match.span(1)[0]}:\n"
f"{source_string[slice(*match.span(1))]}"
)
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=name_match,
token_type=TokenType.NAME,
)


def lex_ruleset(
analyzer: "Analyzer",
source_string: str,
Expand Down
21 changes: 18 additions & 3 deletions src/sqlfmt/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Line:
previous_node: Optional[Node] # last node of prior line, if any
nodes: List[Node] = field(default_factory=list)
comments: List[Comment] = field(default_factory=list)
formatting_disabled: bool = False
formatting_disabled: List[Token] = field(default_factory=list)

def __str__(self) -> str:
return self._calc_str
Expand Down Expand Up @@ -143,7 +143,7 @@ def from_nodes(
nodes=nodes,
comments=comments,
formatting_disabled=nodes[0].formatting_disabled
or nodes[-1].formatting_disabled,
+ nodes[-1].formatting_disabled,
)
else:
line = Line(
Expand All @@ -152,7 +152,7 @@ def from_nodes(
comments=comments,
formatting_disabled=previous_node.formatting_disabled
if previous_node
else False,
else [],
)

return line
Expand Down Expand Up @@ -349,3 +349,18 @@ def opens_new_bracket(self) -> bool:
return True
else:
return False

def starts_new_segment(self, prev_segment_depth: Tuple[int, int]) -> bool:
if self.depth <= prev_segment_depth or self.depth[1] < prev_segment_depth[1]:
# if this line starts with a closing bracket,
# we want to include that closing bracket
# in the same segment as the first line.
if (
self.closes_bracket_from_previous_line
or self.closes_simple_jinja_block_from_previous_line
or self.is_blank_line
) and self.depth == prev_segment_depth:
return False
else:
return True
return False
30 changes: 19 additions & 11 deletions src/sqlfmt/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,16 @@ def maybe_merge_lines(self, lines: List[Line]) -> List[Line]:
# indented lines
else:
only_segment = segments[0]
_, i = only_segment.head
merged_lines.extend(only_segment[: i + 1])
for segment in only_segment.split_after(i):
merged_lines.extend(self.maybe_merge_lines(segment))
finally:
return merged_lines
try:
_, i = only_segment.head
except SqlfmtSegmentError:
merged_lines.extend(only_segment)
else:
merged_lines.extend(only_segment[: i + 1])
for segment in only_segment.split_after(i):
merged_lines.extend(self.maybe_merge_lines(segment))

return merged_lines

def _fix_standalone_operators(self, segments: List[Segment]) -> List[Segment]:
"""
Expand Down Expand Up @@ -262,8 +266,8 @@ def _try_merge_operator_segments(
]
except CannotMergeException:
new_segments = self._maybe_merge_operators(segments, op_tiers)
finally:
return new_segments

return new_segments

def _maybe_stubbornly_merge(self, segments: List[Segment]) -> List[Segment]:
"""
Expand Down Expand Up @@ -330,7 +334,11 @@ def _stubbornly_merge(
"""
new_segments = prev_segments.copy()
prev_segment = new_segments.pop()
head, i = segment.head
try:
head, i = segment.head
except SqlfmtSegmentError:
new_segments.extend([prev_segment, segment])
return new_segments

# try to merge the first line of this segment with the previous segment
try:
Expand All @@ -355,5 +363,5 @@ def _stubbornly_merge(
except CannotMergeException:
# give up and just return the original segments
new_segments.extend([prev_segment, segment])
finally:
return new_segments

return new_segments
6 changes: 3 additions & 3 deletions src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class Node:
refer to open brackets (keywords and parens) or jinja blocks (e.g., {% if foo %})
that increase the syntax depth (and therefore printed indentation) of this Node
formatting_disabled: a boolean indicating that sqlfmt should print the raw token
instead of the formatted values for this Node
formatting_disabled: a list of FMT_OFF tokens that precede this node and prevent
it from being formatted
"""

token: Token
Expand All @@ -53,7 +53,7 @@ class Node:
value: str
open_brackets: List["Node"] = field(default_factory=list)
open_jinja_blocks: List["Node"] = field(default_factory=list)
formatting_disabled: bool = False
formatting_disabled: List[Token] = field(default_factory=list)

def __str__(self) -> str:
"""
Expand Down
48 changes: 40 additions & 8 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional

from sqlfmt.exception import SqlfmtBracketError
from sqlfmt.line import Line
Expand All @@ -25,11 +25,9 @@ def create_node(self, token: Token, previous_node: Optional[Node]) -> Node:
if previous_node is None:
open_brackets = []
open_jinja_blocks = []
formatting_disabled = False
else:
open_brackets = previous_node.open_brackets.copy()
open_jinja_blocks = previous_node.open_jinja_blocks.copy()
formatting_disabled = previous_node.formatting_disabled

# add the previous node to the list of open brackets or jinja blocks
if previous_node.is_unterm_keyword or previous_node.is_opening_bracket:
Expand Down Expand Up @@ -70,11 +68,7 @@ def create_node(self, token: Token, previous_node: Optional[Node]) -> Node:
prev_token, extra_whitespace = get_previous_token(previous_node)
prefix = self.whitespace(token, prev_token, extra_whitespace)
value = self.standardize_value(token)

if token.type in (TokenType.FMT_OFF, TokenType.DATA):
formatting_disabled = True
elif prev_token and prev_token.type in (TokenType.FMT_ON, TokenType.DATA):
formatting_disabled = False
formatting_disabled = self.disable_formatting(token, previous_node)

return Node(
token=token,
Expand Down Expand Up @@ -245,6 +239,44 @@ def standardize_value(self, token: Token) -> str:
else:
return token.token

def disable_formatting(
self, token: Token, previous_node: Optional[Node]
) -> List[Token]:
"""
Manage the formatting_disabled property for the node to be created from
the token and previous node.
"""
formatting_disabled = (
previous_node.formatting_disabled.copy() if previous_node else []
)

if token.type in (TokenType.FMT_OFF, TokenType.DATA):
formatting_disabled.append(token)

if (
formatting_disabled
and previous_node
and previous_node.token.type
in (
TokenType.FMT_ON,
TokenType.DATA,
)
):
formatting_disabled.pop()

# formatting can be disabled because of unsupported
# ddl. When we hit a semicolon we need to pop
# all of the formatting disabled tokens caused by ddl
# off the stack
if token.type == TokenType.SEMICOLON:
while (
formatting_disabled
and "fmt:" not in formatting_disabled[-1].token.lower()
):
formatting_disabled.pop()

return formatting_disabled

def append_newline(self, line: Line) -> None:
"""
Create a new NEWLINE token and append it to the end of line
Expand Down
113 changes: 55 additions & 58 deletions src/sqlfmt/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
ALTER_WAREHOUSE,
CREATE_FUNCTION,
CREATE_WAREHOUSE,
NEWLINE,
SQL_COMMENT,
SQL_QUOTED_EXP,
group,
)
from sqlfmt.rules.core import CORE as CORE
Expand Down Expand Up @@ -235,61 +232,61 @@
name="unsupported_ddl",
priority=2999,
pattern=group(
group(
r"alter",
r"attach\s+rls\s+policy",
r"cache\s+table",
r"clear\s+cache",
r"cluster",
r"comment",
r"copy",
r"create",
r"deallocate",
r"declare",
r"describe",
r"desc\s+datashare",
r"desc\s+identity\s+provider",
r"delete",
r"detach\s+rls\s+policy",
r"discard",
r"do",
r"drop",
r"execute",
r"export",
r"fetch",
r"get",
r"handler",
r"import\s+foreign\s+schema",
r"import\s+table",
# snowflake: "insert into" or "insert overwrite into"
# snowflake: has insert() function
# spark: "insert overwrite" without the trailing "into"
# redshift/pg: "insert into" only
# bigquery: bare "insert" is okay
r"insert(\s+overwrite)?(\s+into)?(?!\()",
r"list",
r"lock",
r"merge",
r"move",
# prepare transaction statements are simple enough
# so we'll allow them
r"prepare(?!\s+transaction)",
r"put",
r"reassign\s+owned",
r"remove",
r"rename\s+table",
r"repair",
r"security\s+label",
r"select\s+into",
r"truncate",
r"unload",
r"update",
r"validate",
)
+ rf"\b({SQL_COMMENT}|{SQL_QUOTED_EXP}|[^'`\"$;])*?"
r"alter",
r"attach\s+rls\s+policy",
r"cache\s+table",
r"clear\s+cache",
r"cluster",
r"comment",
r"copy",
r"create",
r"deallocate",
r"declare",
r"describe",
r"desc\s+datashare",
r"desc\s+identity\s+provider",
r"delete",
r"detach\s+rls\s+policy",
r"discard",
r"do",
r"drop",
r"execute",
r"export",
r"fetch",
r"get",
r"handler",
r"import\s+foreign\s+schema",
r"import\s+table",
# snowflake: "insert into" or "insert overwrite into"
# snowflake: has insert() function
# spark: "insert overwrite" without the trailing "into"
# redshift/pg: "insert into" only
# bigquery: bare "insert" is okay
r"insert(\s+overwrite)?(\s+into)?",
r"list",
r"lock",
r"merge",
r"move",
# prepare transaction statements are simple enough
# so we'll allow them
r"prepare(?!\s+transaction)",
r"put",
r"reassign\s+owned",
r"remove",
r"rename\s+table",
r"repair",
r"security\s+label",
r"select\s+into",
r"truncate",
r"unload",
r"update",
r"validate",
)
+ rf"{NEWLINE}*"
+ group(r";", r"$"),
action=actions.handle_possible_unsupported_ddl,
+ r"(?!\()"
+ group(r"\W", r"$"),
action=partial(
actions.handle_nonreserved_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.FMT_OFF),
),
),
]
Loading

0 comments on commit 768f949

Please sign in to comment.