Skip to content

Commit

Permalink
Support all valid frame clauses (#406)
Browse files Browse the repository at this point in the history
* fix: support frame clauses using range, groups

* chore: update primer refs
  • Loading branch information
tconbeer authored Apr 14, 2023
1 parent 01eab66 commit f360c1e
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ All notable changes to this project will be documented in this file.
See also [this discussion](https://github.com/tconbeer/sqlfmt/discussions/317). Thank you
[@dave-connors-3](https://github.com/dave-connors-3) and
[@alrocar](https://github.com/alrocar)!
- sqlfmt now supports all [Postgres frame clauses](https://www.postgresql.org/docs/current/sql-expressions.html#SYNTAX-WINDOW-FUNCTIONS), not just those that start with `rows between`. ([#404](https://github.com/tconbeer/sqlfmt/issues/404))

## [0.17.1] - 2023-04-12

Expand Down
9 changes: 8 additions & 1 deletion src/sqlfmt/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
r"then",
r"else",
r"partition\s+by",
r"rows\s+between",
r"values",
# in pg, RETURNING can be the last clause of
# a DELETE statement
Expand All @@ -157,6 +156,14 @@
+ group(r"\W", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD),
),
Rule(
name="frame_clause",
priority=1305,
pattern=group(r"(range|rows|groups)\s+")
+ group(r"(between\s+)?((unbounded|\d+)\s+(preceding|following)|current\s+row)")
+ group(r"\W", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD),
),
Rule(
# BQ arrays use an offset(n) function for
# indexing that we do not want to match. This
Expand Down
2 changes: 1 addition & 1 deletion src/sqlfmt_primer/primer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="rittman",
git_url="https://github.com/tconbeer/rittman_ra_data_warehouse.git",
git_ref="f44544f", # sqlfmt 54b8edd
git_ref="418af64", # sqlfmt cd38a6c
expected_changed=0,
expected_unchanged=307,
expected_errored=4, # true mismatching brackets
Expand Down
16 changes: 14 additions & 2 deletions tests/data/fast/unformatted/103_window_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ select
row_number() over () as c,
count(case when a is null then 1 end) over (partition by user_id, date_trunc('year', performed_at)) as d,
first_value(a ignore nulls) over (partition by user_id order by performed_at desc rows between unbounded preceding and unbounded following) as e,
count(*) filter (WHERE a is null) over (partition by user_id, date_trunc('year', performed_at)) as f
count(*) filter (WHERE a is null) over (partition by user_id, date_trunc('year', performed_at)) as f,
first_value(a ignore nulls) over (partition by user_id order by performed_at desc range between unbounded preceding and unbounded following) as g,
last_value(a ignore nulls) over (partition by user_id order by performed_at asc rows 5 preceding exclude current row) as h
from
my_table
)))))__SQLFMT_OUTPUT__(((((
Expand All @@ -22,5 +24,15 @@ select
) as e,
count(*) filter (where a is null) over (
partition by user_id, date_trunc('year', performed_at)
) as f
) as f,
first_value(a ignore nulls) over (
partition by user_id
order by performed_at desc
range between unbounded preceding and unbounded following
) as g,
last_value(a ignore nulls) over (
partition by user_id
order by performed_at asc
rows 5 preceding exclude current row
) as h
from my_table
16 changes: 14 additions & 2 deletions tests/data/unformatted/103_window_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ select
row_number() over () as c,
count(case when a is null then 1 end) over (partition by user_id, date_trunc('year', performed_at)) as d,
first_value(a ignore nulls) over (partition by user_id order by performed_at desc rows between unbounded preceding and unbounded following) as e,
count(*) filter (WHERE a is null) over (partition by user_id, date_trunc('year', performed_at)) as f
count(*) filter (WHERE a is null) over (partition by user_id, date_trunc('year', performed_at)) as f,
first_value(a ignore nulls) over (partition by user_id order by performed_at desc range between unbounded preceding and unbounded following) as g,
last_value(a ignore nulls) over (partition by user_id order by performed_at asc rows 5 preceding exclude current row) as h
from
my_table
)))))__SQLFMT_OUTPUT__(((((
Expand All @@ -22,5 +24,15 @@ select
) as e,
count(*) filter (where a is null) over (
partition by user_id, date_trunc('year', performed_at)
) as f
) as f,
first_value(a ignore nulls) over (
partition by user_id
order by performed_at desc
range between unbounded preceding and unbounded following
) as g,
last_value(a ignore nulls) over (
partition by user_id
order by performed_at asc
rows 5 preceding exclude current row
) as h
from my_table
25 changes: 25 additions & 0 deletions tests/unit_tests/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,31 @@ def test_regex_anti_match(
assert match is None, f"{rule_name} regex should not match {value}"


@pytest.mark.parametrize(
"ruleset,rule_name,value,matched_value",
[
(MAIN, "frame_clause", "rows between unbounded preceding", "rows "),
(MAIN, "frame_clause", "rows unbounded preceding", "rows "),
(MAIN, "frame_clause", "rows 1 preceding", "rows "),
(MAIN, "frame_clause", "range between current row", "range "),
(MAIN, "frame_clause", "range 1 following", "range "),
(MAIN, "frame_clause", "range current row", "range "),
(MAIN, "frame_clause", "groups between 1 preceding", "groups "),
],
)
def test_regex_partial_match(
ruleset: List[Rule], rule_name: str, value: str, matched_value: str
) -> None:
rule = get_rule(ruleset, rule_name)
match = rule.program.match(value)
assert match is not None, f"{rule_name} regex doesn't match {value}"
start, end = match.span(1)

assert (
value[start:end] == matched_value
), f"{rule_name} regex doesn't exactly match {matched_value}"


def test_regex_should_not_match_empty_string() -> None:
rules = itertools.chain.from_iterable(ALL_RULESETS)
for rule in rules:
Expand Down

0 comments on commit f360c1e

Please sign in to comment.