Skip to content

Commit

Permalink
feat: add handling for any encoding, default to utf-8 (#385)
Browse files Browse the repository at this point in the history
* feat: add handling for any encoding, default to utf-8

* fix: change option name to inherit, not none

* fix: change approach for better newline handling on Windows

* fix: prevent key errors from utf aliases

* fix: support more utf aliases

---------

Co-authored-by: Ted Conbeer <[email protected]>
  • Loading branch information
tconbeer and Ted Conbeer authored Feb 24, 2023
1 parent e6f99d0 commit 16322af
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.tox
.venv
.vscode
build
dist
tests/.coverage
tests/.results
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

- sqlfmt now defaults to reading and writing files using the `utf-8` encoding. Previously, we used Python's default behavior of using the encoding from the host machine's locale. However, as `utf-8` becomes a de-facto standard, this was causing issues for some Windows users, whose locale was set to use older encodings. You can use the `--encoding` option to specify a different encoding. Setting encoding to `inherit`, e.g., `sqlfmt --encoding inherit foo.sql` will revert to the old behavior of using the host's locale. sqlfmt will detect and preserve a UTF BOM if it is present. If you specify `--encoding utf-8-sig`, sqlfmt will always write a UTF-8 BOM in the formatted file. ([#350](https://github.com/tconbeer/sqlfmt/issues/350), [#381]((https://github.com/tconbeer/sqlfmt/issues/381)), [#383]((https://github.com/tconbeer/sqlfmt/issues/383)) - thank you [@profesia-company](https://github.com/profesia-company), [@cmcnicoll](https://github.com/cmcnicoll), [@aersam](https://github.com/aersam), and [@ryanmeekins](https://github.com/ryanmeekins)!)

## [0.16.0] - 2023-01-27

### Formatting Changes + Bug Fixes
Expand Down
78 changes: 69 additions & 9 deletions src/sqlfmt/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import codecs
import concurrent.futures
import locale
import sys
from functools import partial
from glob import glob
Expand All @@ -9,6 +11,7 @@
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Expand All @@ -21,7 +24,7 @@

from sqlfmt.analyzer import Analyzer
from sqlfmt.cache import Cache, check_cache, clear_cache, load_cache, write_cache
from sqlfmt.exception import SqlfmtEquivalenceError, SqlfmtError
from sqlfmt.exception import SqlfmtEquivalenceError, SqlfmtError, SqlfmtUnicodeError
from sqlfmt.formatter import QueryFormatter
from sqlfmt.mode import Mode as Mode
from sqlfmt.query import Query
Expand Down Expand Up @@ -172,6 +175,8 @@ def _format_many(
source_path=path,
source_string="",
formatted_string="",
encoding="",
utf_bom="",
from_cache=True,
)
)
Expand Down Expand Up @@ -218,17 +223,23 @@ def _format_one(path: Path, mode: Mode) -> SqlFormatResult:
Runs format_string on the contents of a single file (found at path). Handles
potential user errors in formatted code, and returns a SqlfmtResult
"""
source = _read_path_or_stdin(path)
source, encoding, utf_bom = _read_path_or_stdin(path, mode)
try:
formatted = format_string(source, mode)
return SqlFormatResult(
source_path=path, source_string=source, formatted_string=formatted
source_path=path,
source_string=source,
formatted_string=formatted,
encoding=encoding,
utf_bom=utf_bom,
)
except SqlfmtError as e:
return SqlFormatResult(
source_path=path,
source_string=source,
formatted_string="",
encoding=encoding,
utf_bom=utf_bom,
exception=e,
)

Expand All @@ -241,22 +252,71 @@ def _update_source_files(results: Iterable[SqlFormatResult]) -> None:
"""
for res in results:
if res.has_changed and res.source_path != STDIN_PATH and res.formatted_string:
with open(res.source_path, "w") as f:
with open(res.source_path, "w", encoding=res.encoding) as f:
f.write(res.formatted_string)


def _read_path_or_stdin(path: Path) -> str:
def _read_path_or_stdin(path: Path, mode: Mode) -> Tuple[str, str, str]:
"""
If passed a Path, calls open() and read() and returns contents as a string.
If passed a Path, calls open() and read() and returns contents as
a tuple of strings. The first element is the contents of the file; the
second element is the encoding used to read the file; the third
element is either the utf BOM or an empty string.
If passed Path("-"), calls sys.stdin.read()
"""
encoding = (
(
locale.getpreferredencoding()
if mode.encoding.lower() == "inherit"
else mode.encoding
)
.lower()
.replace("-", "_")
)
bom_map: Dict[str, List[bytes]] = {
"utf": [codecs.BOM_UTF8],
"utf8": [codecs.BOM_UTF8],
"u8": [codecs.BOM_UTF8],
"utf16": [codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE],
"u16": [codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE],
"utf16le": [codecs.BOM_UTF16_LE],
"utf16be": [codecs.BOM_UTF16_BE],
"utf32": [codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE],
"u32": [codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE],
"utf32le": [codecs.BOM_UTF32_LE],
"utf32be": [codecs.BOM_UTF32_BE],
}
detected_bom = ""
if path == STDIN_PATH:
# todo: customize encoding of stdin
source = sys.stdin.read()
else:
with open(path, "r") as f:
source = f.read()
return source
try:
with open(path, "r", encoding=encoding) as f:
source = f.read()
if encoding.startswith("utf") and encoding != "utf_8_sig":
for b in [
bom.decode(encoding)
for bom in bom_map.get(encoding.replace("_", ""), [])
]:
if source.startswith(b):
detected_bom = b
source = source[len(b) :]
break

except UnicodeDecodeError as e:
raise SqlfmtUnicodeError(
f"Error reading file {path}\n"
f"File could not be decoded using {encoding}. "
f"Specifically, {repr(e.object)} at position {e.start} failed "
f"with: {e.reason}.\n"
"You can specify a different encoding by running sqlfmt "
"with the --encoding option. Or set --encoding to 'none' to "
"use the system default encoding. We suggest always using "
"utf-8 for all files."
)
return source, encoding, detected_bom


def _perform_safety_check(analyzer: Analyzer, raw_query: Query, result: str) -> None:
Expand Down
9 changes: 9 additions & 0 deletions src/sqlfmt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
"called. To exclude multiple globs, repeat the --exclude option."
),
)
@click.option(
"--encoding",
envvar="SQLFMT_ENCODING",
help=(
"The encoding to use when reading and writing .sql files. Defaults "
"to utf-8. Set to 'inherit' to read the system default encoding. utf "
"encodings will detect and preserve the BOM if one is present."
),
)
@click.option(
"--fast/--safe",
envvar="SQLFMT_FAST",
Expand Down
9 changes: 9 additions & 0 deletions src/sqlfmt/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ class SqlfmtConfigError(SqlfmtError):
pass


class SqlfmtUnicodeError(SqlfmtError):
"""
Raised while reading input if the input cannot be
decoded into a Python string
"""

pass


class SqlfmtParsingError(SqlfmtError):
"""
Raised during lexing if sqlfmt encounters a token that does
Expand Down
1 change: 1 addition & 0 deletions src/sqlfmt/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Mode:
check: bool = False
diff: bool = False
exclude: List[str] = field(default_factory=list)
encoding: str = "utf-8"
fast: bool = False
single_process: bool = False
no_jinjafmt: bool = False
Expand Down
2 changes: 2 additions & 0 deletions src/sqlfmt/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class SqlFormatResult:
source_path: Path
source_string: str
formatted_string: str
encoding: str
utf_bom: str
exception: Optional[SqlfmtError] = None
from_cache: bool = False

Expand Down
1 change: 1 addition & 0 deletions tests/data/fast/preformatted/006_has_bom.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"preformatted/002_select_from_where.sql",
"preformatted/003_literals.sql",
"preformatted/004_with_select.sql",
"preformatted/005_fmt_off.sql",
"preformatted/301_multiline_jinjafmt.sql",
"preformatted/400_create_table.sql",
"unformatted/100_select_case.sql",
Expand Down
56 changes: 55 additions & 1 deletion tests/unit_tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import codecs
import io
import os
from pathlib import Path
Expand All @@ -9,13 +10,19 @@
from sqlfmt.api import (
_format_many,
_perform_safety_check,
_read_path_or_stdin,
_update_source_files,
format_string,
get_matching_paths,
initialize_progress_bar,
run,
)
from sqlfmt.exception import SqlfmtBracketError, SqlfmtEquivalenceError, SqlfmtError
from sqlfmt.exception import (
SqlfmtBracketError,
SqlfmtEquivalenceError,
SqlfmtError,
SqlfmtUnicodeError,
)
from sqlfmt.mode import Mode


Expand Down Expand Up @@ -348,3 +355,50 @@ def test_perform_safety_check(default_mode: Mode) -> None:
# does not raise
_perform_safety_check(analyzer, raw_query, "select\n 1, 2, 3\n")
_perform_safety_check(analyzer, raw_query, "select\n-- new comment\n 1, 2, 3\n")


@pytest.mark.parametrize(
"encoding,bom",
[
("utf-8", b""),
("utf-8", codecs.BOM_UTF8),
("utf-8-sig", b""), # encoding with utf-8-sig will add a bom
("utf-16", b""),
("utf_16_be", codecs.BOM_UTF16_BE),
("utf_16_le", codecs.BOM_UTF16_LE),
("utf-32", b""),
("utf_32_be", codecs.BOM_UTF32_BE),
("utf_32_le", codecs.BOM_UTF32_LE),
("cp1250", b""),
("cp1252", b""),
("latin-1", b""),
("ascii", b""),
],
)
def test_read_path_or_stdin_many_encodings(
encoding: str, bom: bytes, tmp_path: Path
) -> None:
p = tmp_path / "q.sql"
# create a new file with the specified encoding and BOM
raw_query = "select\n\n\n1\n"
file_contents = bom + raw_query.encode(encoding)
with open(p, "wb") as f:
f.write(file_contents)

mode = Mode(encoding=encoding)
actual_source, actual_encoding, actual_bom = _read_path_or_stdin(p, mode)
assert actual_source == raw_query
assert actual_encoding == encoding.lower().replace("-", "_")
assert actual_bom == bom.decode(encoding)


def test_read_path_or_stdin_error(tmp_path: Path) -> None:
p = tmp_path / "q.sql"
with open(p, "w", encoding="utf-8") as f:
f.write("select 'ň' as ch")

mode = Mode(encoding="cp1250")
with pytest.raises(SqlfmtUnicodeError) as exc_info:
_, _, _ = _read_path_or_stdin(p, mode)

assert "cp1250" in str(exc_info.value)
34 changes: 30 additions & 4 deletions tests/unit_tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,36 @@ def small_cache(sample_paths: Dict[str, Path], sample_stat: Tuple[float, int]) -
@pytest.fixture
def results_for_caching(sample_paths: Dict[str, Path]) -> List[SqlFormatResult]:
results = [
SqlFormatResult(sample_paths["001"], "select 1\n", "select 1\n"),
SqlFormatResult(sample_paths["002"], "select 1\n", "", from_cache=True),
SqlFormatResult(sample_paths["003"], "select 'abc'\n", "select\n 'abc'\n"),
SqlFormatResult(sample_paths["900"], "!\n", "", SqlfmtError("oops")),
SqlFormatResult(
sample_paths["001"],
"select 1\n",
"select 1\n",
encoding="utf-8",
utf_bom="",
),
SqlFormatResult(
sample_paths["002"],
"select 1\n",
"",
encoding="utf-8",
utf_bom="",
from_cache=True,
),
SqlFormatResult(
sample_paths["003"],
"select 'abc'\n",
"select\n 'abc'\n",
encoding="utf-8",
utf_bom="",
),
SqlFormatResult(
sample_paths["900"],
"!\n",
"",
encoding="utf-8",
utf_bom="",
exception=SqlfmtError("oops"),
),
]
return results

Expand Down
30 changes: 28 additions & 2 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import locale
import re
import subprocess
import sys
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_preformatted_short_lines_env(
)
assert results.exit_code == 0
print(results.stderr)
assert "4 files formatted" in results.stderr
assert "5 files formatted" in results.stderr

# test that CLI flag overrides ENV VAR
args = f"{preformatted_dir.as_posix()} -l 88 --check"
Expand All @@ -104,7 +105,7 @@ def test_preformatted_short_lines_env(
)
assert results.exit_code == 0
print(results.stderr)
assert "5 files passed formatting check" in results.stderr
assert "6 files passed formatting check" in results.stderr


def test_unformatted_check(sqlfmt_runner: CliRunner, unformatted_dir: Path) -> None:
Expand Down Expand Up @@ -183,3 +184,28 @@ def test_preformatted_fast_safe(
args = f"{preformatted_dir.as_posix()} --check {option}"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0


def test_preformatted_utf_8_sig_encoding(
sqlfmt_runner: CliRunner, preformatted_dir: Path
) -> None:
args = f"{preformatted_dir.as_posix()} --check --encoding utf-8-sig"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
assert results.exit_code == 0


def test_preformatted_inherit_encoding(
sqlfmt_runner: CliRunner, preformatted_dir: Path
) -> None:
args = f"{preformatted_dir.as_posix()} --check --encoding inherit"
results = sqlfmt_runner.invoke(sqlfmt_main, args=args)
if locale.getpreferredencoding().lower().replace("-", "_") == "utf_8":
assert results.exit_code == 0
else:
# this directory includes a file that starts with a BOM. We'll
# get a weird symbol if decoded with anything other than utf-8,
# like cp-1252, which is the default on many Windows machines
assert results.exit_code == 2
assert results.stderr.startswith("1 file had errors")
assert "006_has_bom.sql" in results.stderr
assert "Could not parse SQL at position 1" in results.stderr
Loading

0 comments on commit 16322af

Please sign in to comment.