Skip to content

Commit

Permalink
chore: improve merger testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Jul 14, 2022
1 parent 2f60e22 commit 88617d2
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
103 changes: 103 additions & 0 deletions tests/data/unit_tests/test_merger/test_maybe_stubbornly_merge.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
-- try to merge the first line of this segment with the previous segment
count(
*
)
over (
partition by
foofoofoofoofoofoofoofoofoo
order by
foofoofoofoofoofoofoofoofoo asc
rows between unbounded preceding and unbounded following
),

-- try to add this segment to the last line of the previous segment
(
foofoofoofoofoofoofoofoofoo
+ barbarbarbarbarbarbarbarbar
+ bazbazbazbazbazbazbazbazbaz
)
::decimal(
18,
2
),

-- try to add just the first line of this segment to the last
-- line of the previous segment
sum(
case
when
foo
then
foo + bar + baz
when
bar
then
bar + baz + qux
when
baz
then
something_else_long
end
)
over (
partition by
foofoofoofoofoofoofoofoofoo
order by
foofoofoofoofoofoofoofoofoo asc
rows between unbounded preceding and unbounded following
),

-- give up and just return the original segments
a_very_very_long_cte_name_that_is_just_under_eighty_eight_characters_in_length_xxxxxxx
as (
select
1
)
)))))__SQLFMT_OUTPUT__(((((
-- try to merge the first line of this segment with the previous segment
count(*) over (
partition by
foofoofoofoofoofoofoofoofoo
order by
foofoofoofoofoofoofoofoofoo asc
rows between unbounded preceding and unbounded following
),

-- try to add this segment to the last line of the previous segment
(
foofoofoofoofoofoofoofoofoo
+ barbarbarbarbarbarbarbarbar
+ bazbazbazbazbazbazbazbazbaz
)::decimal(18, 2),

-- try to add just the first line of this segment to the last
-- line of the previous segment
sum(
case
when
foo
then
foo + bar + baz
when
bar
then
bar + baz + qux
when
baz
then
something_else_long
end
) over (
partition by
foofoofoofoofoofoofoofoofoo
order by
foofoofoofoofoofoofoofoofoo asc
rows between unbounded preceding and unbounded following
),

-- give up and just return the original segments
a_very_very_long_cte_name_that_is_just_under_eighty_eight_characters_in_length_xxxxxxx
as (
select
1
)
85 changes: 85 additions & 0 deletions tests/unit_tests/test_merger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import itertools
from typing import List

import pytest

from sqlfmt.exception import SqlfmtSegmentError
from sqlfmt.line import Line
from sqlfmt.merger import CannotMergeException, LineMerger
from sqlfmt.mode import Mode
Expand Down Expand Up @@ -421,3 +423,86 @@ def test_do_not_merge_across_query_dividers(merger: LineMerger, sep: str) -> Non
).parse_query(source_string)
merged_lines = merger.maybe_merge_lines(raw_query.lines)
assert raw_query.lines == merged_lines


@pytest.mark.parametrize(
"source_string,expected_result",
[
("case\nfoo\nend\n", True),
("count(\n*\n)\n", True),
("count(\n*\n)\n\n\n", True),
("\n\n\ncount(\n*\n)", True),
("as (\nselect\n1\n)\n", True),
("as foo\n", False),
("from\nfoo\nwhere\n", False),
],
)
def test_tail_closes_head(
merger: LineMerger, source_string: str, expected_result: bool
) -> None:
q = merger.mode.dialect.initialize_analyzer(merger.mode.line_length).parse_query(
source_string
)
assert merger._tail_closes_head(q.lines) == expected_result


@pytest.mark.parametrize(
"source_string,expected_idx",
[
("case\nfoo\nend\n", 0),
("\n\n\ncount(\n*\n)", 3),
("\n\n\n count(\n*\n)", 3),
("\n \n\n count(\n*\n)", 3),
],
)
def test_get_first_nonblank_line(
merger: LineMerger, source_string: str, expected_idx: int
) -> None:
q = merger.mode.dialect.initialize_analyzer(merger.mode.line_length).parse_query(
source_string
)
line, i = merger._get_first_nonblank_line(q.lines)
assert i == expected_idx
assert line == q.lines[i]


@pytest.mark.parametrize(
"source_string",
[
"",
"\n",
"\n\n\n \n\n",
],
)
def test_get_first_nonblank_line_raises(merger: LineMerger, source_string: str) -> None:
q = merger.mode.dialect.initialize_analyzer(merger.mode.line_length).parse_query(
source_string
)
with pytest.raises(SqlfmtSegmentError):
_, _ = merger._get_first_nonblank_line(q.lines)


def test_maybe_stubbornly_merge(merger: LineMerger) -> None:
source_string, expected_string = read_test_data(
"unit_tests/test_merger/test_maybe_stubbornly_merge.sql"
)
q = merger.mode.dialect.initialize_analyzer(merger.mode.line_length).parse_query(
source_string
)
segments = merger._split_into_segments(q.lines)
merged_segments = merger._maybe_stubbornly_merge(segments)
result_string = "".join(
[line.render_with_comments(88) for line in itertools.chain(*merged_segments)]
)
assert result_string == expected_string


def test_maybe_stubbornly_merge_single_segment(merger: LineMerger) -> None:
source_string = "select\na,\nb\n"
q = merger.mode.dialect.initialize_analyzer(merger.mode.line_length).parse_query(
source_string
)
segments = merger._split_into_segments(q.lines)
assert len(segments) == 1
merged_segments = merger._maybe_stubbornly_merge(segments)
assert merged_segments == segments

0 comments on commit 88617d2

Please sign in to comment.