Skip to content

Commit

Permalink
Improve Support for Numeric Constants (#325)
Browse files Browse the repository at this point in the history
* feat: improve numeric constant handling, close #321

* chore: update primer refs
  • Loading branch information
tconbeer authored Nov 21, 2022
1 parent f37bc5c commit bd08b42
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file.
- sqlfmt now resets the SQL depth of a query after encountering an `{% endmacro %}`, `{% endtest %}`, `{% endcall %}`, or `{% endmaterialization %}` tag.
- sqlfmt now supports `create warehouse` and `alter warehouse` statements ([#312](https://github.com/tconbeer/sqlfmt/issues/312), [#299](https://github.com/tconbeer/sqlfmt/issues/312)).
- sqlfmt now supports `alter function` and `drop function` statements ([#310](https://github.com/tconbeer/sqlfmt/issues/310), [#311](https://github.com/tconbeer/sqlfmt/issues/311)), and Snowflake's `create external function` statements ([#322](https://github.com/tconbeer/sqlfmt/issues/322)).
- sqlfmt better supports numeric constants (number literals), including those using scientific notation (e.g., `1.5e-9`) and the unary `+` or `-` operators (e.g., `+3`), and is now smarter about when the `-` symbol is the unary negative or binary subtraction operator. ([#321](https://github.com/tconbeer/sqlfmt/issues/321) - thank you [@liaopeiyuan](https://github.com/liaopeiyuan)!).
- fixed a bug where we added extra whitespace to the end of empty comment lines ([#319](https://github.com/tconbeer/sqlfmt/issues/319) - thank you [@eherde](https://github.com/eherde)!).
- fixed a bug where we could have unsafely run *black* against jinja that contained Python keywords and their safe alternatives (e.g., `return(return_())`).

Expand Down
41 changes: 41 additions & 0 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,47 @@ def handle_set_operator(
analyzer.pos = token.epos


def handle_number(analyzer: "Analyzer", source_string: str, match: re.Match) -> None:
"""
We don't know if a token like "-3" or "+4" is properly a unary operator,
or a poorly-spaced binary operator, so we have to check the previous
node.
"""
first_char = source_string[match.span(1)[0] : match.span(1)[0] + 1]
if first_char in ["+", "-"] and analyzer.previous_node:
prev_token, _ = get_previous_token(analyzer.previous_node)
if prev_token and prev_token.type in (
TokenType.NUMBER,
TokenType.NAME,
TokenType.QUOTED_NAME,
TokenType.STATEMENT_END,
TokenType.BRACKET_CLOSE,
):
# This is a binary operator. Create a new match for only the
# operator token
op_prog = re.compile(r"\s*(\+|-)")
op_match = op_prog.match(source_string, pos=analyzer.pos)
assert op_match, "Internal error! Could not match symbol of binary operator"
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=op_match,
token_type=TokenType.OPERATOR,
)
# we don't have to handle the rest of the number; this
# will get called again by analyzer.lex
return

# in all other cases, this is a number with/out a unary operator, and we lex it
# as a single token
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.NUMBER,
)


def handle_nonreserved_keyword(
analyzer: "Analyzer",
source_string: str,
Expand Down
6 changes: 3 additions & 3 deletions src/sqlfmt/rules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@
name="number",
priority=350,
pattern=group(
r"-?\d+\.?\d*",
r"-?\.\d+",
r"(\+|-)?\d+(\.\d*)?(e(\+|-)?\d+)?",
r"(\+|-)?\.\d+(e(\+|-)?\d+)?",
),
action=partial(actions.add_node_to_buffer, token_type=TokenType.NUMBER),
action=actions.handle_number,
),
Rule(
name="semicolon",
Expand Down
4 changes: 2 additions & 2 deletions src/sqlfmt_primer/primer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="gitlab",
git_url="https://github.com/tconbeer/gitlab-analytics-sqlfmt.git",
git_ref="b53ed29", # sqlfmt 6aa20a1
git_ref="52dc37d", # sqlfmt a0e5254
expected_changed=4,
expected_unchanged=2413,
expected_errored=0,
Expand All @@ -39,7 +39,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="rittman",
git_url="https://github.com/tconbeer/rittman_ra_data_warehouse.git",
git_ref="5d838dd", # sqlfmt b792a79
git_ref="dd47b23", # sqlfmt a0e5254
expected_changed=0,
expected_unchanged=307,
expected_errored=4, # true mismatching brackets
Expand Down
23 changes: 23 additions & 0 deletions tests/data/unformatted/125_numeric_constants.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
select 1, 1.0, +1, +1.0, -1, -1.0, 1e9, 1e+9, 1e-9, -1e9, -1e+9, +1e-9, -1e-9, 1.6e9, 1.6e-9, 1., 1.e9, 1.e-9, -1., -1.e-9
)))))__SQLFMT_OUTPUT__(((((
select
1,
1.0,
+1,
+1.0,
-1,
-1.0,
1e9,
1e+9,
1e-9,
-1e9,
-1e+9,
+1e-9,
-1e-9,
1.6e9,
1.6e-9,
1.,
1.e9,
1.e-9,
-1.,
-1.e-9
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"unformatted/122_values.sql",
"unformatted/123_spark_keywords.sql",
"unformatted/124_bq_compound_types.sql",
"unformatted/125_numeric_constants.sql",
"unformatted/200_base_model.sql",
"unformatted/201_basic_snapshot.sql",
"unformatted/202_unpivot_macro.sql",
Expand Down
27 changes: 27 additions & 0 deletions tests/unit_tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,30 @@ def test_handle_closing_angle_bracket(default_analyzer: Analyzer) -> None:
assert array_line.nodes[-3].is_closing_bracket
assert array_line.nodes[-4].is_closing_bracket
assert all([line.nodes[1].is_operator for line in query.lines[2:]])


def test_handle_number_unary(default_analyzer: Analyzer) -> None:
source_string = """
select
+1,
-2,
-1 + -2,
"""
query = default_analyzer.parse_query(source_string=source_string.lstrip())
numbers = [str(n).strip() for n in query.nodes if n.token.type == TokenType.NUMBER]
assert numbers == ["+1", "-2", "-1", "-2"]


def test_handle_number_binary(default_analyzer: Analyzer) -> None:
source_string = """
select
1 +1,
1 -1,
-1+2,
something-2,
(something)+2,
case when true then foo else bar end+2
"""
query = default_analyzer.parse_query(source_string=source_string.lstrip())
numbers = [str(n).strip() for n in query.nodes if n.token.type == TokenType.NUMBER]
assert numbers == ["1", "1", "1", "1", "-1", "2", "2", "2", "2"]
5 changes: 5 additions & 0 deletions tests/unit_tests/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def get_rule(ruleset: List[Rule], rule_name: str) -> Rule:
(CORE, "star", "*"),
(CORE, "number", "145.8"),
(CORE, "number", "-.58"),
(CORE, "number", "+145.8"),
(CORE, "number", "+.58"),
(CORE, "number", "1e9"),
(CORE, "number", "1e-9"),
(CORE, "number", "1.55e-9"),
(CORE, "bracket_open", "["),
(CORE, "bracket_close", ")"),
(CORE, "double_colon", "::"),
Expand Down

0 comments on commit bd08b42

Please sign in to comment.