Skip to content

Commit

Permalink
feat: add safety check, close #328 (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer authored Jan 6, 2023
1 parent ecf16da commit e207ea6
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 6 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Features

- by default, sqlfmt now runs an additional safety check that parses the formatted output to ensure it contains all of the same content as the raw input. This incurs a slight (~20%) performance penalty. To bypass this safety check, you can use the command line option `--fast`, the corresponding TOML or environment variable config, or pass `Mode(fast=True)` to any API method. The safety check is automatically bypassed if sqlfmt is run with the `--check` or `--diff` options. If the safety check fails, the CLI will include an error in the report, and the `format_string` API will raise a `SqlfmtEquivalenceError`, which is a subclass of `SqlfmtError`.

## [0.14.3] - 2023-01-05

### Formatting Changes + Bug Fixes

- fixed a bug where very long lines could raise `RecursionError` ([#343](https://github.com/tconbeer/sqlfmt/issues/343) - thank you [@kcem-flyr](https://github.com/kcem-flyr)!).

## [0.14.2] - 2022-12-12
Expand Down
63 changes: 59 additions & 4 deletions src/sqlfmt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from functools import partial
from glob import glob
from itertools import zip_longest
from pathlib import Path
from typing import (
Awaitable,
Expand All @@ -18,10 +19,12 @@

from tqdm import tqdm

from sqlfmt.analyzer import Analyzer
from sqlfmt.cache import Cache, check_cache, clear_cache, load_cache, write_cache
from sqlfmt.exception import SqlfmtError
from sqlfmt.exception import SqlfmtEquivalenceError, SqlfmtError
from sqlfmt.formatter import QueryFormatter
from sqlfmt.mode import Mode
from sqlfmt.mode import Mode as Mode
from sqlfmt.query import Query
from sqlfmt.report import STDIN_PATH, Report, SqlFormatResult

T = TypeVar("T")
Expand All @@ -31,13 +34,21 @@
def format_string(source_string: str, mode: Mode) -> str:
"""
Takes a raw query string and a mode as input, returns the formatted query
as a string, or raises a SqlfmtError if the string cannot be formatted
as a string, or raises a SqlfmtError if the string cannot be formatted.
If mode.fast is False, also performs a safety check to ensure no tokens
are dropped from the original input.
"""
analyzer = mode.dialect.initialize_analyzer(line_length=mode.line_length)
raw_query = analyzer.parse_query(source_string=source_string)
formatter = QueryFormatter(mode)
formatted_query = formatter.format(raw_query)
return str(formatted_query)
result = str(formatted_query)

if not mode.fast and not mode.check and not mode.diff:
_perform_safety_check(analyzer, raw_query, result)

return result


def run(
Expand Down Expand Up @@ -246,3 +257,47 @@ def _read_path_or_stdin(path: Path) -> str:
with open(path, "r") as f:
source = f.read()
return source


def _perform_safety_check(analyzer: Analyzer, raw_query: Query, result: str) -> None:
"""
Raises a SqlfmtEquivalenceError if re-lexing
the result produces a different set of tokens than
the original.
"""
result_query = analyzer.parse_query(source_string=result)
filtered_raw_tokens = [
token.type for token in raw_query.tokens if token.type.is_equivalent_in_output
]
filtered_result_tokens = [
token.type
for token in result_query.tokens
if token.type.is_equivalent_in_output
]

try:
assert filtered_raw_tokens == filtered_result_tokens
except AssertionError:
raw_len = len(filtered_raw_tokens)
result_len = len(filtered_result_tokens)
mismatch_pos = 0
mismatch_raw = ""
mismatch_res = ""

for i, (raw, res) in enumerate(
zip_longest(filtered_raw_tokens, filtered_result_tokens)
):
if raw is not res:
mismatch_pos = i
mismatch_raw = str(raw)
mismatch_res = str(res)
break

raise SqlfmtEquivalenceError(
"There was a problem formatting your query that "
"caused the safety check to fail. Please open an "
f"issue. Raw query was {raw_len} tokens; formatted "
f"query was {result_len} tokens. First mismatching "
f"token at position {mismatch_pos}: raw: {mismatch_raw}; "
f"result: {mismatch_res}."
)
13 changes: 13 additions & 0 deletions src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@
"called. To exclude multiple globs, repeat the --exclude option."
),
)
@click.option(
"--fast/--safe",
envvar="SQLFMT_FAST",
default=False,
help=(
"By default, sqlfmt re-processes the output it produces in "
"order to run a safety check and ensure that all tokens from "
"the input are present in the output. This can add 15-20% to "
"the processing time for new files. To disable this safety "
"check, use the --fast option. To force the safety check, "
"use --safe."
),
)
@click.option(
"--single-process",
envvar="SQLFMT_SINGLE_PROCESS",
Expand Down
9 changes: 9 additions & 0 deletions src/sqlfmt/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class SqlfmtSegmentError(SqlfmtError):
pass


class SqlfmtEquivalenceError(SqlfmtError):
"""
Raised during the safety check if the result query does
not lex to the same tokens as the raw query
"""

pass


class SqlfmtControlFlowException(Exception):
"""
Generic exception for exceptions used to manage control
Expand Down
3 changes: 2 additions & 1 deletion src/sqlfmt/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Mode:
"""
A Mode is a container for all sqlfmt config, including formatting config and
report config
report config. For more info on each option, see cli.py
"""

SQL_EXTENSIONS: List[str] = field(default_factory=lambda: [".sql", ".sql.jinja"])
Expand All @@ -19,6 +19,7 @@ class Mode:
check: bool = False
diff: bool = False
exclude: List[str] = field(default_factory=list)
fast: bool = False
single_process: bool = False
no_jinjafmt: bool = False
reset_cache: bool = False
Expand Down
7 changes: 7 additions & 0 deletions src/sqlfmt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def is_always_lowercased(self) -> bool:
TokenType.SET_OPERATOR,
]

@cached_property
def is_equivalent_in_output(self) -> bool:
return self not in [
TokenType.NEWLINE,
TokenType.COMMENT,
]


class Token(NamedTuple):
"""
Expand Down
31 changes: 30 additions & 1 deletion tests/unit_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from sqlfmt.api import (
_format_many,
_perform_safety_check,
_update_source_files,
format_string,
get_matching_paths,
initialize_progress_bar,
run,
)
from sqlfmt.exception import SqlfmtBracketError, SqlfmtError
from sqlfmt.exception import SqlfmtBracketError, SqlfmtEquivalenceError, SqlfmtError
from sqlfmt.mode import Mode


Expand Down Expand Up @@ -319,3 +320,31 @@ def test_initialize_disabled_progress_bar(no_progressbar_mode: Mode) -> None:
assert progress_callback is not None
progress_callback("foo") # type: ignore
assert progress_bar.format_dict.get("n") == 0


def test_perform_safety_check(default_mode: Mode) -> None:
source_string = "select 1, 2, 3\n"

analyzer = default_mode.dialect.initialize_analyzer(
line_length=default_mode.line_length
)
raw_query = analyzer.parse_query(source_string)

with pytest.raises(SqlfmtEquivalenceError) as excinfo:
# drops last token
_perform_safety_check(analyzer, raw_query, "select 1, 2, \n")

assert "Raw query was 6 tokens; formatted query was 5 tokens." in str(excinfo.value)

with pytest.raises(SqlfmtEquivalenceError) as excinfo:
# changes a token
_perform_safety_check(analyzer, raw_query, "select a, 2, 3\n")

assert (
"First mismatching token at position 1: raw: TokenType.NUMBER; "
"result: TokenType.NAME." in str(excinfo.value)
)

# does not raise
_perform_safety_check(analyzer, raw_query, "select\n 1, 2, 3\n")
_perform_safety_check(analyzer, raw_query, "select\n-- new comment\n 1, 2, 3\n")
9 changes: 9 additions & 0 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,12 @@ def test_preformatted_no_progressbar(
args = f"{preformatted_dir.as_posix()} --check --no-progressbar"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0


@pytest.mark.parametrize("option", ["--fast", "--safe"])
def test_preformatted_fast_safe(
sqlfmt_runner: CliRunner, preformatted_dir: Path, option: str
) -> None:
args = f"{preformatted_dir.as_posix()} --check {option}"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0

0 comments on commit e207ea6

Please sign in to comment.