Skip to content

Commit

Permalink
fix: prevent line splitting on between's and, closes #124
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Feb 8, 2022
1 parent 1de65be commit 3d956f3
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 2 deletions.
29 changes: 29 additions & 0 deletions src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,35 @@ def is_multiplication_star(self) -> bool:
in (TokenType.UNTERM_KEYWORD, TokenType.COMMA, TokenType.DOT)
)

@property
def is_the_between_operator(self) -> bool:
"""
True if this node is a WORD_OPERATOR with the value "between"
"""
return self.token.type == TokenType.WORD_OPERATOR and self.value == "between"

@cached_property
def is_the_and_after_the_between_operator(self) -> bool:
"""
True if this node is a BOOLEAN_OPERATOR with the value "and" immediately
following a "between" operator
"""
if self.token.type != TokenType.BOOLEAN_OPERATOR or self.value != "and":
return False
else:
prev = self.previous_node
while prev and prev.depth >= self.depth:
if prev.depth == self.depth and prev.is_the_between_operator:
return True
elif (
prev.depth == self.depth
and prev.token.type == TokenType.BOOLEAN_OPERATOR
):
break
else:
prev = prev.previous_node
return False

@property
def is_low_priority_merge_operator(self) -> bool:
return self.token.type in (TokenType.BOOLEAN_OPERATOR, TokenType.ON)
Expand Down
3 changes: 2 additions & 1 deletion src/sqlfmt/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def maybe_split_before(self, node: Node) -> bool:
):
return True
# split before any operator unless the previous node is a closing
# bracket or statement
# bracket or statement, or it is the "and" following a "between"
elif (
node.is_operator
and node.previous_node
and not node.previous_node.is_closing_bracket
and not node.is_the_and_after_the_between_operator
):
return True
else:
Expand Down
37 changes: 37 additions & 0 deletions tests/data/unformatted/111_chained_boolean_between.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
select
radio,
mcc,
net as mnc,
area as lac,
cell % 65536 as cid,
cell / 65536 as rnc,
cell as long_cid,
lon,
lat
from
towershift
where
radio != 'CDMA'
and mcc between 200 and 799 and net between 1 and 999 and area between 0 and 65535
and cell between 0 and 268435455 and lon between -180 and 180
and lat between -90 and 90
)))))__SQLFMT_OUTPUT__(((((
select
radio,
mcc,
net as mnc,
area as lac,
cell % 65536 as cid,
cell / 65536 as rnc,
cell as long_cid,
lon,
lat
from towershift
where
radio != 'CDMA'
and mcc between 200 and 799
and net between 1 and 999
and area between 0 and 65535
and cell between 0 and 268435455
and lon between -180 and 180
and lat between -90 and 90
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
where
radio != 'CDMA'
and mcc between 200 and 799
and net between 1+1+1+1+1+1+1+1+1+1+1+1+1+1+1 and 999
and area between smallest_area and biggest_area
and cell between (select min(cell_number) from numbers) and (
select max(cell_number) from numbers
) and lon between -180 and 180
or lat between -90 and 90
2 changes: 1 addition & 1 deletion tests/functional_tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_end_to_end_check_unformatted(
result = sqlfmt_runner.invoke(sqlfmt_main, args=args)

assert result
assert "13 files" in result.stderr
assert "14 files" in result.stderr
assert "failed formatting check" in result.stderr

if "-q" in options or "--quiet" in options:
Expand Down
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"unformatted/108_test_block.sql",
"unformatted/109_lateral_flatten.sql",
"unformatted/110_other_identifiers.sql",
"unformatted/111_chained_boolean_between.sql",
"unformatted/200_base_model.sql",
"unformatted/300_jinjafmt.sql",
],
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,39 @@ def test_from_token_raises_bracket_error_on_jinja_block_end() -> None:
)
with pytest.raises(SqlfmtBracketError):
_ = Node.from_token(t, previous_node=None)


@pytest.mark.parametrize(
"token,result", [("between", True), ("BETWEEN", True), ("like", False)]
)
def test_is_the_between_operator(token: str, result: bool) -> None:
t = Token(
type=TokenType.WORD_OPERATOR,
prefix="",
token=token,
spos=0,
epos=7,
)
n = Node.from_token(t, previous_node=None)
assert n.is_the_between_operator is result


def test_is_the_and_after_the_between_operator(default_mode: Mode) -> None:
source_string, _ = read_test_data(
"unit_tests/test_line/test_is_the_and_after_the_between_operator.sql"
)
q = default_mode.dialect.initialize_analyzer(
line_length=default_mode.line_length
).parse_query(source_string=source_string)

and_nodes = [
node for node in q.nodes if node.token.type == TokenType.BOOLEAN_OPERATOR
]
other_nodes = [
node for node in q.nodes if node.token.type != TokenType.BOOLEAN_OPERATOR
]
boolean_ands = and_nodes[::2]
between_ands = and_nodes[1::2]
assert all([not n.is_the_and_after_the_between_operator for n in boolean_ands])
assert all([n.is_the_and_after_the_between_operator for n in between_ands])
assert all([not n.is_the_and_after_the_between_operator for n in other_nodes])
25 changes: 25 additions & 0 deletions tests/unit_tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,28 @@ def test_jinja_block_split(splitter: LineSplitter) -> None:
actual_result = "".join([str(line).lstrip() for line in split_lines])

assert actual_result == expected_result


def test_split_at_and(splitter: LineSplitter) -> None:
source_string = "select 1 where a between b and c and d between e and f and a < b\n"
raw_query = splitter.mode.dialect.initialize_analyzer(
splitter.mode.line_length
).parse_query(source_string)

split_lines: List[Line] = []
for raw_line in raw_query.lines:
split_lines.extend(splitter.maybe_split(raw_line))

actual_result = [str(line) for line in split_lines]
expected_result = [
"select\n",
" 1\n",
"where\n",
" a\n",
" between b and c\n",
" and d\n",
" between e and f\n",
" and a\n",
" < b\n",
]
assert actual_result == expected_result

0 comments on commit 3d956f3

Please sign in to comment.