Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Nested Set and Call Blocks #340

Merged
merged 8 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Formatting Changes + Bug Fixes

- fixed a bug where nested `{% set %}` and `{% call %}` blocks would cause a parsing error ([#338](https://github.com/tconbeer/sqlfmt/issues/338) - thank you [@AndrewLane](https://github.com/AndrewLane)!).

## [0.14.1] - 2022-12-06

### Formatting Changes + Bug Fixes

- sqlfmt now supports `is [not] distinct from` as a word operator ([#327](https://github.com/tconbeer/sqlfmt/issues/327) - thank you [@IgnorantWalking](https://github.com/IgnorantWalking), [@kadekillary](https://github.com/kadekillary)!).
- fixed a bug where jinja `{% call %}` blocks that called a macro that wasn't `statement` caused a parsing error ([#335](https://github.com/tconbeer/sqlfmt/issues/327) - thank you [@AndrewLane](https://github.com/AndrewLane)!).
- fixed a bug where jinja `{% call %}` blocks that called a macro that wasn't `statement` caused a parsing error ([#335](https://github.com/tconbeer/sqlfmt/issues/335) - thank you [@AndrewLane](https://github.com/AndrewLane)!).

### Performance

Expand Down
244 changes: 100 additions & 144 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import re
from typing import TYPE_CHECKING, Callable, List, Optional, Type
from typing import TYPE_CHECKING, Callable, List, Optional

from sqlfmt.comment import Comment
from sqlfmt.exception import (
SqlfmtBracketError,
SqlfmtControlFlowException,
StopJinjaLexing,
)
from sqlfmt.exception import SqlfmtBracketError, StopRulesetLexing
from sqlfmt.line import Line
from sqlfmt.node import Node, get_previous_token
from sqlfmt.rule import MAYBE_WHITESPACES, Rule
Expand All @@ -23,19 +19,6 @@ def group(*choices: str) -> str:
return f"({'|'.join(choices)})"


# return a human-readable token from a regex pattern
def simplify_jinja_regex(pattern: str) -> str:
replacements = [
("\\{", "{"),
("\\}", "}"),
("-?", ""),
("\\s*", " "),
]
for old, new in replacements:
pattern = pattern.replace(old, new)
return pattern


def raise_sqlfmt_bracket_error(
_: "Analyzer", source_string: str, match: re.Match
) -> None:
Expand Down Expand Up @@ -110,10 +93,10 @@ def add_jinja_comment_to_buffer(
"""
Create a COMMENT token from the match, then create a Comment
from that token and append it to the Analyzer's buffer; raise
StopJinjaLexing to revert to SQL lexing
StopRulesetLexing to revert to SQL lexing
"""
add_comment_to_buffer(analyzer, source_string, match)
raise StopJinjaLexing
raise StopRulesetLexing


def handle_newline(
Expand Down Expand Up @@ -345,162 +328,135 @@ def lex_ruleset(
source_string: str,
_: re.Match,
new_ruleset: List["Rule"],
stop_exception: Type[SqlfmtControlFlowException],
) -> None:
"""
Makes a nested call to analyzer.lex, with the new ruleset activated.
"""
analyzer.push_rules(new_ruleset)
try:
analyzer.lex(source_string)
except stop_exception:
except StopRulesetLexing:
analyzer.pop_rules()


def handle_jinja_data_block(
def handle_jinja_block_start(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
end_rule_name: str,
) -> None:
"""
A set block, like {% set my_var %}data{% endset %} should be parsed
as a single DATA token... the data between the two tags need not be
sql or python, and should not be formatted.
"""
# find the ending tag
end_rule = analyzer.get_rule(rule_name=end_rule_name)
end_match = end_rule.program.search(source_string, pos=analyzer.pos)
if end_match is None:
spos, epos = match.span(1)
raw_token = source_string[spos:epos]
raise SqlfmtBracketError(
f"Encountered unterminated Jinja set or call block '{raw_token}' at "
f"position {spos}. Expected end tag: "
f"{simplify_jinja_regex(end_rule.pattern)}"
)
# the data token is everything between the start and end tags, inclusive
data_spos = match.span(1)[0]
data_epos = end_match.span(1)[1]
data_token = Token(
type=TokenType.DATA,
prefix=source_string[analyzer.pos : data_spos],
token=source_string[data_spos:data_epos],
spos=data_spos,
epos=data_epos,
)
data_node = analyzer.node_manager.create_node(
token=data_token, previous_node=analyzer.previous_node
Lex tags like {% if ... %} and {% for ... %} that open a jinja block
"""
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_START,
)
analyzer.node_buffer.append(data_node)
analyzer.pos = data_epos
raise StopJinjaLexing
raise StopRulesetLexing


def handle_jinja_block(
def handle_jinja_block_keyword(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
start_name: str,
end_name: str,
other_names: List[str],
end_reset_sql_depth: bool = False,
) -> None:
"""
An if block, like {% if cond %}code{% else %}other_code{% endif %}
needs special handling, since the depth of the jinja tags is determined
by the code they contain.
Lex tags like {% elif ... %} and {% else %} that continue an open jinja block
"""
if analyzer.previous_node:
try:
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
except IndexError:
# {% if foo %}{% else %} is allowed, but then previous
# node won't have any open jinja blocks yet.
# when creating the node, we check to make sure these
# match
start_tag = analyzer.previous_node

previous_node = start_tag.previous_node

add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_KEYWORD,
previous_node=previous_node,
override_analyzer_prev_node=True,
)
raise StopRulesetLexing

else:
raise_sqlfmt_bracket_error(analyzer, source_string, match)


def handle_jinja_data_block_start(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
new_ruleset: Optional[List[Rule]],
raises: bool = True,
) -> None:
"""
Lex tags like {% set foo %} and {% call my_macro %} that open a jinja block
that can contain arbitrary data.

This can get called from the JINJA ruleset, in which case we need to
raise an additional StopRulesetLexing after the JINJA_DATA segment
is fully lexed.
"""
# for some jinja blocks, we need to reset the state after each branch
previous_node = analyzer.previous_node
# add the start tag to the buffer
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_START,
)

# configure the block parser
start_rule = analyzer.get_rule(rule_name=start_name)
end_rule = analyzer.get_rule(rule_name=end_name)
other_rules = [analyzer.get_rule(rule_name=r) for r in other_names]
patterns = [start_rule.pattern, end_rule.pattern] + [r.pattern for r in other_rules]
program = re.compile(
MAYBE_WHITESPACES + group(*patterns), re.IGNORECASE | re.DOTALL
if new_ruleset is None:
new_ruleset = analyzer.rules
lex_ruleset(
analyzer,
source_string,
match,
new_ruleset=new_ruleset,
)
if raises:
raise StopRulesetLexing

while True:
# search ahead for the next matching control tag
next_tag_match = program.search(source_string, analyzer.pos)
if not next_tag_match:

raise SqlfmtBracketError(
f"Encountered unterminated Jinja block at position"
f" {match.span(0)[0]}. Expected end tag: "
f"{simplify_jinja_regex(end_rule.pattern)}"
)
# otherwise, if the tag matches, lex everything up to that token
# using the ruleset that was active before jinja
next_tag_pos = next_tag_match.span(0)[0]
jinja_rules = analyzer.pop_rules()
analyzer.stops.append(next_tag_pos)
def handle_jinja_block_end(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
reset_sql_depth: bool = False,
) -> None:
"""
Lex tags like {% endif %} and {% endfor %} that close an open jinja block
"""
if analyzer.previous_node:
try:
analyzer.lex(source_string)
except StopIteration:
analyzer.stops.pop()

analyzer.push_rules(jinja_rules)
# it is possible for the next_tag_match found above to have already been lexed.
# but if it hasn't, we need to process it
if analyzer.pos == next_tag_pos:
# if this is another start tag, we have nested jinja blocks,
# so we recurse a level deeper
if start_rule.program.match(source_string, analyzer.pos):
try:
handle_jinja_block(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
start_name=start_name,
end_name=end_name,
other_names=other_names,
)
except StopJinjaLexing:
continue
# if this the tag that ends the block, add it to the
# buffer
elif end_rule.program.match(source_string, analyzer.pos):
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
token_type=TokenType.JINJA_BLOCK_END,
)
if end_reset_sql_depth and analyzer.previous_node:
if previous_node:
analyzer.previous_node.open_brackets = (
previous_node.open_brackets.copy()
)
else:
analyzer.previous_node.open_brackets = []
break
# otherwise, this is an elif or else statement; we add it to
# the buffer, but with the previous node set to the node before
# the if statement (to reset the depth)
else:
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
token_type=TokenType.JINJA_BLOCK_KEYWORD,
previous_node=previous_node,
override_analyzer_prev_node=True,
)
else:
continue

raise StopJinjaLexing
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
except IndexError:
# {% if foo %}{% else %} is allowed, but then previous
# node won't have any open jinja blocks yet.
# when creating the node, we check to make sure these
# match
start_tag = analyzer.previous_node

add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_END,
)

if reset_sql_depth:
analyzer.previous_node.open_brackets = start_tag.open_brackets.copy()

raise StopRulesetLexing

else:
# No open jinja blocks or none that match this token
raise_sqlfmt_bracket_error(analyzer, source_string=source_string, match=match)


def handle_jinja(
Expand All @@ -523,7 +479,7 @@ def handle_jinja(
end_name=end_name,
token_type=token_type,
)
raise StopJinjaLexing
raise StopRulesetLexing


def handle_potentially_nested_tokens(
Expand Down
4 changes: 0 additions & 4 deletions src/sqlfmt/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class Analyzer:
comment_buffer: List[Comment] = field(default_factory=list)
line_buffer: List[Line] = field(default_factory=list)
rule_stack: List[List[Rule]] = field(default_factory=list)
stops: List[int] = field(default_factory=list)
pos: int = 0

@property
Expand Down Expand Up @@ -125,9 +124,6 @@ def lex_one(self, source_string: str) -> None:

Mutates the analyzer's buffers and pos
"""
if self.stops and self.pos >= self.stops[-1]:
raise StopIteration

for rule in self.rules:
match = rule.program.match(source_string, self.pos)
if match:
Expand Down
11 changes: 1 addition & 10 deletions src/sqlfmt/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,10 @@ class InlineCommentException(SqlfmtControlFlowException):
pass


class StopJinjaLexing(SqlfmtControlFlowException):
"""
Raised by the Analyzer or one of its actions to indicate
that further lexing should use the main ruleset
"""

pass


class StopRulesetLexing(SqlfmtControlFlowException):
"""
Raised by the Analyzer or one of its actions to indicate
that further lexing should use the main ruleset
that further lexing should use the previous ruleset
"""

pass
Expand Down
Loading