Skip to content

Commit

Permalink
fix/#343/splitter recursion depth (#346)
Browse files Browse the repository at this point in the history
* fix #343: refactor splitter from recursion to iteration

* refactor: replace if previous_node with if previous_node is not None; eliminates ~50% of calls to node.__len__ and 15% of calls to node.__str__

* fix: add test for splitter edge case with no trailing newline
  • Loading branch information
tconbeer authored Jan 5, 2023
1 parent edcf372 commit 49d11af
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

- fixed a bug where very long lines could raise `RecursionError` ([#343](https://github.com/tconbeer/sqlfmt/issues/343) - thank you [@kcem-flyr](https://github.com/kcem-flyr)!).

## [0.14.2] - 2022-12-12

### Formatting Changes + Bug Fixes
Expand Down
6 changes: 3 additions & 3 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def handle_number(analyzer: "Analyzer", source_string: str, match: re.Match) ->
node.
"""
first_char = source_string[match.span(1)[0] : match.span(1)[0] + 1]
if first_char in ["+", "-"] and analyzer.previous_node:
if first_char in ["+", "-"] and analyzer.previous_node is not None:
prev_token, _ = get_previous_token(analyzer.previous_node)
if prev_token and prev_token.type in (
TokenType.NUMBER,
Expand Down Expand Up @@ -364,7 +364,7 @@ def handle_jinja_block_keyword(
"""
Lex tags like {% elif ... %} and {% else %} that continue an open jinja block
"""
if analyzer.previous_node:
if analyzer.previous_node is not None:
try:
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
except IndexError:
Expand Down Expand Up @@ -432,7 +432,7 @@ def handle_jinja_block_end(
"""
Lex tags like {% endif %} and {% endfor %} that close an open jinja block
"""
if analyzer.previous_node:
if analyzer.previous_node is not None:
try:
start_tag = analyzer.previous_node.open_jinja_blocks[-1]
except IndexError:
Expand Down
5 changes: 4 additions & 1 deletion src/sqlfmt/jinjafmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def format_line(self, line: Line) -> List[Line]:
chain(
*[
self.format_line(new_line)
for new_line in splitter.split_at_index(line, i)
for new_line in [
splitter.split_at_index(line, 0, i, line.comments),
splitter.split_at_index(line, i, -1, []),
]
]
)
)
Expand Down
16 changes: 10 additions & 6 deletions src/sqlfmt/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def open_brackets(self) -> List[Node]:
"""
if self.nodes:
return self.nodes[0].open_brackets
elif self.previous_node:
elif self.previous_node is not None:
return self.previous_node.open_brackets
else:
return []
Expand All @@ -63,7 +63,7 @@ def open_jinja_blocks(self) -> List[Node]:
"""
if self.nodes:
return self.nodes[0].open_jinja_blocks
elif self.previous_node:
elif self.previous_node is not None:
return self.previous_node.open_jinja_blocks
else:
return []
Expand Down Expand Up @@ -147,7 +147,7 @@ def from_nodes(
nodes=nodes,
comments=comments,
formatting_disabled=previous_node.formatting_disabled
if previous_node
if previous_node is not None
else [],
)

Expand Down Expand Up @@ -258,7 +258,11 @@ def closes_bracket_from_previous_line(self) -> bool:
that matches a bracket on a preceding line. False for unterminated
keywords or any lines with matched brackets
"""
if self.previous_node and self.previous_node.open_brackets and self.nodes:
if (
self.previous_node is not None
and self.previous_node.open_brackets
and self.nodes
):
explicit_brackets = [
b for b in self.previous_node.open_brackets if b.is_opening_bracket
]
Expand All @@ -276,7 +280,7 @@ def previous_line_has_open_jinja_blocks_not_keywords(self) -> bool:
after a jinja block keyword, like {% else %}/{% elif %}
"""
if (
self.previous_node
self.previous_node is not None
and self.previous_node.open_jinja_blocks
and not self.previous_node.open_jinja_blocks[-1].is_jinja_block_keyword
):
Expand All @@ -292,7 +296,7 @@ def closes_jinja_block_from_previous_line(self) -> bool:
"""
if (
self.nodes
and self.previous_node
and self.previous_node is not None
and self.previous_node.open_jinja_blocks
and (
self.previous_node.open_jinja_blocks[-1]
Expand Down
4 changes: 3 additions & 1 deletion src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def has_preceding_between_operator(self) -> bool:
"""
True if this node has a preceding "between" operator at the same depth
"""
prev = self.previous_node.previous_node if self.previous_node else None
prev = (
self.previous_node.previous_node if self.previous_node is not None else None
)
while prev and prev.depth >= self.depth:
if prev.depth == self.depth and prev.is_the_between_operator:
return True
Expand Down
8 changes: 5 additions & 3 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,17 @@ def disable_formatting(
the token and previous node.
"""
formatting_disabled = (
previous_node.formatting_disabled.copy() if previous_node else []
previous_node.formatting_disabled.copy()
if previous_node is not None
else []
)

if token.type in (TokenType.FMT_OFF, TokenType.DATA):
formatting_disabled.append(token)

if (
formatting_disabled
and previous_node
and previous_node is not None
and previous_node.token.type
in (
TokenType.FMT_ON,
Expand Down Expand Up @@ -290,7 +292,7 @@ def append_newline(self, line: Line) -> None:
if line.nodes:
previous_node = line.nodes[-1]
previous_token = line.nodes[-1].token
elif line.previous_node:
elif line.previous_node is not None:
previous_node = line.previous_node
previous_token = line.previous_node.token

Expand Down
72 changes: 43 additions & 29 deletions src/sqlfmt/splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Iterator, Tuple
from typing import List, Tuple

from sqlfmt.comment import Comment
from sqlfmt.line import Line
from sqlfmt.node import Node
from sqlfmt.node_manager import NodeManager
Expand All @@ -10,33 +11,48 @@
class LineSplitter:
node_manager: NodeManager

def maybe_split(self, line: Line) -> Iterator[Line]:
def maybe_split(self, line: Line) -> List[Line]:
"""
Evaluates a line for splitting. If line matches criteria for splitting,
yields new lines; otherwise yields original line
returns a list of new lines; otherwise returns a list of only the original line.
We used to do this recursively, but very long lines (with >500 splits) would
raise RecursionError.
"""

if line.formatting_disabled:
yield line
return
return [line]

new_lines: List[Line] = []
comments = line.comments
head = 0
always_split_after = never_split_after = False
for i, node in enumerate(line.nodes):
if node.is_newline:
# can't split just before a newline
yield line
break
if head == 0:
new_lines.append(line)
else:
new_lines.append(self.split_at_index(line, head, i, comments))
return new_lines
elif (
i > 0
i > head
and not never_split_after
and not node.formatting_disabled
and (always_split_after or self.maybe_split_before(node))
):
yield from self.split_at_index(line, i)
break
new_line = self.split_at_index(line, head, i, comments)
new_lines.append(new_line)
comments = [] # only first split gets original comments
head = i
# node now follows a new newline node, so we need to update
# its previous node (this can impact its depth)
node.previous_node = new_line.nodes[-1]

always_split_after, never_split_after = self.maybe_split_after(node)

new_lines.append(self.split_at_index(line, head, -1, comments))
return new_lines

def maybe_split_before(self, node: Node) -> bool:
"""
Return True if we should split before node
Expand Down Expand Up @@ -101,26 +117,24 @@ def maybe_split_after(self, node: Node) -> Tuple[bool, bool]:
else:
return False, False

def split_at_index(self, line: Line, index: int) -> Iterator[Line]:
def split_at_index(
self, line: Line, head: int, index: int, comments: List[Comment]
) -> Line:
"""
Split a line before nodes[index]. Recursively maybe_split
resulting lines. Yields new lines
Return a new line comprised of the nodes line[head:index], plus a newline node
"""
assert index > 0, "Cannot split at start of line!"
head, tail = line.nodes[:index], line.nodes[index:]
assert head[0] is not None, "Cannot split at start of line!"
if index == -1:
new_nodes = line.nodes[head:]
else:
assert index > head, "Cannot split at start of line!"
new_nodes = line.nodes[head:index]

head_line = Line.from_nodes(
previous_node=line.previous_node,
nodes=head,
comments=line.comments,
new_line = Line.from_nodes(
previous_node=new_nodes[0].previous_node,
nodes=new_nodes,
comments=comments,
)
self.node_manager.append_newline(head_line)
yield head_line
if not new_line.nodes[-1].is_newline:
self.node_manager.append_newline(new_line)

tail_line = Line.from_nodes(
previous_node=head_line.nodes[-1],
nodes=tail,
comments=[],
)
yield from self.maybe_split(tail_line)
return new_line
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- source: https://github.com/tconbeer/sqlfmt/issues/343
WITH associated_selector AS (SELECT DISTINCT recent.int_flight_id AS anon_1 FROM net_leg_metrics AS recent LEFT OUTER JOIN net_leg_metrics AS xdays ON recent.int_flight_id = xdays.int_flight_id AND xdays.scenario_date = (('2022-12-22')::DATE - interval '7 days')::DATE AND true LEFT OUTER JOIN (SELECT flyr_cabin_class_hierarchy.cabin_class AS cabin_class, flyr_cabin_class_hierarchy.hierarchy AS hierarchy FROM (VALUES ('Y', 1), ('W', 2), ('J', 3), ('F', 4)) AS flyr_cabin_class_hierarchy (cabin_class, hierarchy)) AS flyr_cabin_class_hierarchy ON recent.cabin_class = flyr_cabin_class_hierarchy.cabin_class WHERE recent.client_cabin_class IN ('Y') AND recent.scenario_date = ('2022-12-22')::DATE), recent_leg_od AS (SELECT mappings_table.int_od_id AS int_od_id, sum(mappings_table.bookings_leg_od) AS bookings_leg_od, sum(mappings_table.revenue_leg_od) AS revenue_leg_od FROM net_od_leg_mapping AS mappings_table WHERE mappings_table.scenario_date = ('2022-12-22')::DATE AND mappings_table.int_flight_id IN (SELECT associated_selector.anon_1 FROM associated_selector) GROUP BY mappings_table.int_od_id), xdays_leg_od AS (SELECT mappings_table.int_od_id AS int_od_id, sum(mappings_table.bookings_leg_od) AS bookings_leg_od, sum(mappings_table.revenue_leg_od) AS revenue_leg_od FROM net_od_leg_mapping AS mappings_table WHERE mappings_table.scenario_date = (('2022-12-22')::DATE - interval '7 days')::DATE AND mappings_table.int_flight_id IN (SELECT associated_selector.anon_1 FROM associated_selector) GROUP BY mappings_table.int_od_id) SELECT round(CAST((round(CAST(sum(xdays_leg_od.bookings_leg_od) AS NUMERIC), 0) / nullif(round(CAST(sum(xdays.bookings) AS NUMERIC), 0), 0)) * 100 AS NUMERIC), 0) AS x_day_leg_od_bookings_share_percent_to_x_day_bookings, mode() WITHIN GROUP (ORDER BY recent.comp_fare_cirrus_match_airline DESC) AS comp_fare_cirrus_match_airline, round(CAST(round(CAST(sum(recent.bookings) AS NUMERIC), 0) - round(CAST(sum(recent.bookings_baseline) AS NUMERIC), 0) AS NUMERIC), 0) AS net_bookings_diff_to_baseline, CASE WHEN (sum(xdays_leg_od.bookings_leg_od) = 0) THEN 0 ELSE round(CAST(sum(xdays_leg_od.revenue_leg_od) / sum(xdays_leg_od.bookings_leg_od) AS NUMERIC), 0) END AS x_day_leg_od_revenue_per_booking, round(CAST(sum(recent_leg_od.revenue_leg_od) - sum(xdays_leg_od.revenue_leg_od) AS NUMERIC), 0) AS x_day_leg_od_revenue_build, round(CAST(sum(recent.bookings) AS NUMERIC), 0) AS net_bookings, round(CAST(avg(recent.wtp_lac_price) AS NUMERIC), 0) AS wtp_fare, round(CAST(((sum(recent.final_revenue_expected) * 100) / nullif(sum(recent.final_revenue_baseline), 0) - 100) - ((sum(xdays.final_revenue_expected) * 100) / nullif(sum(xdays.final_revenue_baseline), 0) - 100) AS NUMERIC), 1) AS x_day_final_revenue_expected_build_percent_diff_to_baseline, round(CAST(sum(recent_leg_od.revenue_leg_od) AS NUMERIC), 0) AS leg_od_revenue, round(CAST(avg(recent.target_price_frm) AS NUMERIC), 0) AS target_fare_frm, round(CAST(CASE WHEN (sum(recent_leg_od.bookings_leg_od) = 0) THEN 0 ELSE sum(recent_leg_od.revenue_leg_od) / sum(recent_leg_od.bookings_leg_od) END - CASE WHEN (sum(xdays_leg_od.bookings_leg_od) = 0) THEN 0 ELSE sum(xdays_leg_od.revenue_leg_od) / sum(xdays_leg_od.bookings_leg_od) END AS NUMERIC), 0) AS x_day_leg_od_revenue_per_booking_build, sum(CASE WHEN (recent.target_price != recent.target_price_frm AND recent.wtp_lac_price != recent.target_price_frm) THEN 1 ELSE 0 END) AS number_of_impacted_subjects, round(CAST(sum(recent.revenue) AS NUMERIC), 0) AS revenue, round(CAST(avg(recent.target_price) AS NUMERIC), 0) AS target_fare, round(CAST(avg(recent.lowest_vff * fx_rates_for_lowest_vff.exchange_rate) AS NUMERIC), 0) AS lowest_vff, CASE WHEN (sum(recent_leg_od.bookings_leg_od) = 0) THEN 0 ELSE round(CAST(sum(recent_leg_od.revenue_leg_od) / sum(recent_leg_od.bookings_leg_od) AS NUMERIC), 0) END AS leg_od_revenue_per_booking, round(CAST(avg(recent.comp_fare_cirrus_match_fare * fx_rates_comp_fare_cirrus_match_currency.exchange_rate) AS NUMERIC), 0) AS comp_fare_cirrus_match_fare, max(recent.rt_market) AS rt_market, round(CAST(sum(xdays_leg_od.revenue_leg_od) AS NUMERIC), 0) AS x_day_leg_od_revenue, round(CAST(sum(recent.final_revenue_expected) AS NUMERIC), 0) AS final_revenue_expected, round(CAST(sum(recent_leg_od.bookings_leg_od) AS NUMERIC), 0) AS leg_od_bookings, round(CAST(sum(recent.revenue) - sum(recent.revenue_baseline) AS NUMERIC), 0) AS diff_to_erb, round(CAST(sum(recent.final_bookings_expected) AS NUMERIC), 0) AS final_net_bookings_expected, round(CAST(avg(recent.wtp_lac_price_frm) AS NUMERIC), 0) AS wtp_fare_frm, round(CAST((sum(xdays_leg_od.revenue_leg_od) / nullif(sum(xdays.revenue), 0)) * 100 AS NUMERIC), 0) AS x_day_leg_od_revenue_share_percent_to_x_day_revenue, round(CAST(avg(recent.wtp_lac_rank) AS NUMERIC), 1) AS wtp_lac_rank, round(CAST(round(CAST(sum(recent.final_bookings_expected) AS NUMERIC), 0) - round(CAST(sum(xdays.final_bookings_expected) AS NUMERIC), 0) AS NUMERIC), 0) AS x_day_final_net_bookings_expected_build, round(CAST((round(CAST(sum(recent_leg_od.bookings_leg_od) AS NUMERIC), 0) / nullif(round(CAST(sum(recent.bookings) AS NUMERIC), 0), 0)) * 100 AS NUMERIC), 0) AS leg_od_bookings_share_percent_to_bookings, round(CAST(round(CAST(sum(recent_leg_od.bookings_leg_od) AS NUMERIC), 0) - round(CAST(sum(xdays_leg_od.bookings_leg_od) AS NUMERIC), 0) AS NUMERIC), 0) AS x_day_leg_od_bookings_build, max(recent.client_cabin_class) AS cabin_class, round(CAST((sum(recent.final_revenue_expected) * 100) / nullif(sum(recent.final_revenue_baseline), 0) - 100 AS NUMERIC), 1) AS final_revenue_expected_percent_diff_to_baseline, count(*) AS number_of_subjects, round(CAST((sum(recent_leg_od.revenue_leg_od) / nullif(sum(recent.revenue), 0)) * 100 AS NUMERIC), 0) AS leg_od_revenue_share_percent_to_revenue, round(CAST(avg(CASE WHEN (recent.lowest_vff * fx_rates_for_lowest_vff.exchange_rate = recent.wtp_lac_price) THEN 100 ELSE 0 END) AS NUMERIC), 1) AS lowest_vff_share, round(CAST(avg(recent.wtp_lac_price) - avg(recent.lowest_vff * fx_rates_for_lowest_vff.exchange_rate) AS NUMERIC), 0) AS wtp_fare_diff_to_lowest_vff, round(CAST(sum(xdays_leg_od.bookings_leg_od) AS NUMERIC), 0) AS x_day_leg_od_bookings FROM net_od_metrics AS recent LEFT OUTER JOIN net_od_metrics AS xdays ON recent.int_od_id = xdays.int_od_id AND xdays.scenario_date = (('2022-12-22')::DATE - interval '7 days')::DATE AND true JOIN recent_leg_od ON recent_leg_od.int_od_id = recent.int_od_id LEFT OUTER JOIN xdays_leg_od ON xdays_leg_od.int_od_id = xdays.int_od_id LEFT OUTER JOIN (SELECT flyr_cabin_class_hierarchy.cabin_class AS cabin_class, flyr_cabin_class_hierarchy.hierarchy AS hierarchy FROM (VALUES ('Y', 1), ('W', 2), ('J', 3), ('F', 4)) AS flyr_cabin_class_hierarchy (cabin_class, hierarchy)) AS flyr_cabin_class_hierarchy ON recent.cabin_class = flyr_cabin_class_hierarchy.cabin_class LEFT OUTER JOIN fx_rates AS fx_rates_comp_fare_cirrus_match_currency ON fx_rates_comp_fare_cirrus_match_currency.from_currency = recent.comp_fare_cirrus_match_currency AND fx_rates_comp_fare_cirrus_match_currency.to_currency = 'USD' LEFT OUTER JOIN fx_rates AS fx_rates_for_lowest_vff ON fx_rates_for_lowest_vff.from_currency = recent.currency AND fx_rates_for_lowest_vff.to_currency = 'USD' WHERE true AND recent.scenario_date = ('2022-12-22')::DATE GROUP BY recent.rt_market, recent.client_cabin_class HAVING true ORDER BY round(CAST(sum(recent_leg_od.revenue_leg_od) AS NUMERIC), 0) DESC NULLS LAST, max(recent.dep_date) ASC, max(recent.origin) ASC, max(recent.destination) ASC, max(recent.dep_time) ASC, max(flyr_cabin_class_hierarchy.hierarchy) ASC LIMIT 25 OFFSET 0
2 changes: 1 addition & 1 deletion tests/unit_tests/test_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_simple_append_newline(simple_line: Line, node_manager: NodeManager) ->
# this line already ends with a newline
last_node = simple_line.nodes[-1]
assert last_node.token.type is TokenType.NEWLINE
assert last_node.previous_node
assert last_node.previous_node is not None
assert last_node.previous_node.token.type is not TokenType.NEWLINE

node_manager.append_newline(simple_line)
Expand Down
Loading

0 comments on commit 49d11af

Please sign in to comment.