Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Nested Set and Call Blocks #340

Merged
merged 8 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: handle jinja block keywords with new action
  • Loading branch information
tconbeer committed Dec 11, 2022
commit ecea382dfbd461f31b46b26e65ac1d0dc0ac6420
182 changes: 26 additions & 156 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@ def group(*choices: str) -> str:
return f"({'|'.join(choices)})"


# return a human-readable token from a regex pattern
def simplify_jinja_regex(pattern: str) -> str:
replacements = [
("\\{", "{"),
("\\}", "}"),
("-?", ""),
("\\s*", " "),
]
for old, new in replacements:
pattern = pattern.replace(old, new)
return pattern


def raise_sqlfmt_bracket_error(
_: "Analyzer", source_string: str, match: re.Match
) -> None:
Expand Down Expand Up @@ -352,167 +339,49 @@ def lex_ruleset(
analyzer.pop_rules()


def handle_jinja_data_block(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
end_rule_name: str,
) -> None:
"""
A set block, like {% set my_var %}data{% endset %} should be parsed
as a single DATA token... the data between the two tags need not be
sql or python, and should not be formatted.
"""
# find the ending tag
end_rule = analyzer.get_rule(rule_name=end_rule_name)
end_match = end_rule.program.search(source_string, pos=analyzer.pos)
if end_match is None:
spos, epos = match.span(1)
raw_token = source_string[spos:epos]
raise SqlfmtBracketError(
f"Encountered unterminated Jinja set or call block '{raw_token}' at "
f"position {spos}. Expected end tag: "
f"{simplify_jinja_regex(end_rule.pattern)}"
)
# the data token is everything between the start and end tags, inclusive
data_spos = match.span(1)[0]
data_epos = end_match.span(1)[1]
data_token = Token(
type=TokenType.DATA,
prefix=source_string[analyzer.pos : data_spos],
token=source_string[data_spos:data_epos],
spos=data_spos,
epos=data_epos,
)
data_node = analyzer.node_manager.create_node(
token=data_token, previous_node=analyzer.previous_node
)
analyzer.node_buffer.append(data_node)
analyzer.pos = data_epos
raise StopRulesetLexing


def handle_jinja_block(
def handle_jinja_block_start(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
start_name: str,
end_name: str,
other_names: List[str],
end_reset_sql_depth: bool = False,
) -> None:
"""
An if block, like {% if cond %}code{% else %}other_code{% endif %}
needs special handling, since the depth of the jinja tags is determined
by the code they contain.
Lex tags like {% if ... %} and {% for ... %} that open a jinja block
"""
# for some jinja blocks, we need to reset the state after each branch
previous_node = analyzer.previous_node
# add the start tag to the buffer
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_START,
)

# configure the block parser
start_rule = analyzer.get_rule(rule_name=start_name)
end_rule = analyzer.get_rule(rule_name=end_name)
other_rules = [analyzer.get_rule(rule_name=r) for r in other_names]
patterns = [start_rule.pattern, end_rule.pattern] + [r.pattern for r in other_rules]
program = re.compile(
MAYBE_WHITESPACES + group(*patterns), re.IGNORECASE | re.DOTALL
)

while True:
# search ahead for the next matching control tag
next_tag_match = program.search(source_string, analyzer.pos)
if not next_tag_match:

raise SqlfmtBracketError(
f"Encountered unterminated Jinja block at position"
f" {match.span(0)[0]}. Expected end tag: "
f"{simplify_jinja_regex(end_rule.pattern)}"
)
# otherwise, if the tag matches, lex everything up to that token
# using the ruleset that was active before jinja
next_tag_pos = next_tag_match.span(0)[0]
jinja_rules = analyzer.pop_rules()
analyzer.stops.append(next_tag_pos)
try:
analyzer.lex(source_string)
except StopIteration:
analyzer.stops.pop()

analyzer.push_rules(jinja_rules)
# it is possible for the next_tag_match found above to have already been lexed.
# but if it hasn't, we need to process it
if analyzer.pos == next_tag_pos:
# if this is another start tag, we have nested jinja blocks,
# so we recurse a level deeper
if start_rule.program.match(source_string, analyzer.pos):
try:
handle_jinja_block(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
start_name=start_name,
end_name=end_name,
other_names=other_names,
)
except StopRulesetLexing:
continue
# if this the tag that ends the block, add it to the
# buffer
elif end_rule.program.match(source_string, analyzer.pos):
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
token_type=TokenType.JINJA_BLOCK_END,
)
if end_reset_sql_depth and analyzer.previous_node:
if previous_node:
analyzer.previous_node.open_brackets = (
previous_node.open_brackets.copy()
)
else:
analyzer.previous_node.open_brackets = []
break
# otherwise, this is an elif or else statement; we add it to
# the buffer, but with the previous node set to the node before
# the if statement (to reset the depth)
else:
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=next_tag_match,
token_type=TokenType.JINJA_BLOCK_KEYWORD,
previous_node=previous_node,
override_analyzer_prev_node=True,
)
else:
continue

raise StopRulesetLexing


def handle_jinja_block_start(
def handle_jinja_block_keyword(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
start_rule_names: List[str],
) -> None:
"""
Lex tags like {% if ... %} and {% for ... %} that open a jinja block
Lex tags like {% elif ... %} and {% else %} that continue an open jinja block
"""
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_START,
)
raise StopRulesetLexing
if analyzer.previous_node and analyzer.previous_node.open_jinja_blocks:
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
previous_node = start_tag.previous_node
start_rules = [analyzer.get_rule(name) for name in start_rule_names]
matches = [rule.program.match(start_tag.value) for rule in start_rules]
if any(matches):
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
match=match,
token_type=TokenType.JINJA_BLOCK_KEYWORD,
previous_node=previous_node,
override_analyzer_prev_node=True,
)
raise StopRulesetLexing

raise_sqlfmt_bracket_error(analyzer, source_string, match)


def handle_jinja_data_block_start(
Expand Down Expand Up @@ -548,16 +417,17 @@ def handle_jinja_block_end(
analyzer: "Analyzer",
source_string: str,
match: re.Match,
start_rule_name: str,
start_rule_names: List[str],
reset_sql_depth: bool = False,
) -> None:
"""
Lex tags like {% endif %} and {% endfor %} that close an open jinja block
"""
if analyzer.previous_node and analyzer.previous_node.open_jinja_blocks:
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
start_rule = analyzer.get_rule(start_rule_name)
if start_rule.program.match(start_tag.value):
start_rules = [analyzer.get_rule(name) for name in start_rule_names]
matches = [rule.program.match(start_tag.value) for rule in start_rules]
if any(matches):
add_node_to_buffer(
analyzer=analyzer,
source_string=source_string,
Expand Down
53 changes: 33 additions & 20 deletions src/sqlfmt/rules/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
pattern=group(JINJA_SHARED_PATTERNS["endset"]),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_set_block_start",
start_rule_names=["jinja_set_block_start"],
reset_sql_depth=True,
),
),
Expand All @@ -49,7 +49,7 @@
pattern=group(JINJA_SHARED_PATTERNS["endcall"]),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_call_block_start",
start_rule_names=["jinja_call_block_start"],
reset_sql_depth=True,
),
),
Expand Down Expand Up @@ -93,33 +93,46 @@
name="jinja_if_block_start",
priority=200,
pattern=group(r"\{%-?\s*if.*?-?%\}"),
action=partial(
actions.handle_jinja_block,
start_name="jinja_if_block_start",
end_name="jinja_if_block_end",
other_names=[
"jinja_elif_block_start",
"jinja_else_block_start",
],
),
action=actions.handle_jinja_block_start,
),
Rule(
name="jinja_elif_block_start",
priority=201,
pattern=group(r"\{%-?\s*elif\s+\w+.*?-?%\}"),
action=actions.raise_sqlfmt_bracket_error,
action=partial(
actions.handle_jinja_block_keyword,
start_rule_names=[
"jinja_if_block_start",
"jinja_elif_block_start",
"jinja_else_block_start",
],
),
),
Rule(
name="jinja_else_block_start",
priority=202,
pattern=group(r"\{%-?\s*else\s*-?%\}"),
action=actions.raise_sqlfmt_bracket_error,
action=partial(
actions.handle_jinja_block_keyword,
start_rule_names=[
"jinja_if_block_start",
"jinja_elif_block_start",
"jinja_else_block_start",
],
),
),
Rule(
name="jinja_if_block_end",
priority=203,
pattern=group(r"\{%-?\s*endif\s*-?%\}"),
action=actions.raise_sqlfmt_bracket_error,
action=partial(
actions.handle_jinja_block_end,
start_rule_names=[
"jinja_if_block_start",
"jinja_elif_block_start",
"jinja_else_block_start",
],
),
),
Rule(
name="jinja_for_block_start",
Expand All @@ -132,7 +145,7 @@
priority=211,
pattern=group(r"\{%-?\s*endfor\s*-?%\}"),
action=partial(
actions.handle_jinja_block_end, start_rule_name="jinja_for_block_start"
actions.handle_jinja_block_end, start_rule_names=["jinja_for_block_start"]
),
),
Rule(
Expand All @@ -147,7 +160,7 @@
pattern=group(r"\{%-?\s*endmacro\s*-?%\}"),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_macro_block_start",
start_rule_names=["jinja_macro_block_start"],
reset_sql_depth=True,
),
),
Expand All @@ -163,7 +176,7 @@
pattern=group(r"\{%-?\s*endtest\s*-?%\}"),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_test_block_start",
start_rule_names=["jinja_test_block_start"],
reset_sql_depth=True,
),
),
Expand All @@ -179,7 +192,7 @@
pattern=group(r"\{%-?\s*endsnapshot\s*-?%\}"),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_snapshot_block_start",
start_rule_names=["jinja_snapshot_block_start"],
reset_sql_depth=True,
),
),
Expand All @@ -195,7 +208,7 @@
pattern=group(r"\{%-?\s*endmaterialization\s*-?%\}"),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_materialization_block_start",
start_rule_names=["jinja_materialization_block_start"],
reset_sql_depth=True,
),
),
Expand Down Expand Up @@ -226,7 +239,7 @@
pattern=group(JINJA_SHARED_PATTERNS["endcall"]),
action=partial(
actions.handle_jinja_block_end,
start_rule_name="jinja_call_block_start",
start_rule_names=["jinja_call_statement_block_start"],
reset_sql_depth=True,
),
),
Expand Down
Loading