Skip to content

Commit af9b82e

Browse files
committed
Reorder grouping code and func call order
Remove repeated for-each/for grouping
1 parent 56b28dc commit af9b82e

File tree

1 file changed

+66
-67
lines changed

1 file changed

+66
-67
lines changed

sqlparse/engine/grouping.py

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,11 @@
99
from sqlparse import tokens as T
1010
from sqlparse.utils import recurse, imt
1111

12-
M_ROLE = (T.Keyword, ('null', 'role'))
13-
M_SEMICOLON = (T.Punctuation, ';')
14-
M_COMMA = (T.Punctuation, ',')
15-
1612
T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
1713
T_STRING = (T.String, T.String.Single, T.String.Symbol)
1814
T_NAME = (T.Name, T.Name.Placeholder)
1915

2016

21-
def _group_left_right(tlist, m, cls,
22-
valid_left=lambda t: t is not None,
23-
valid_right=lambda t: t is not None,
24-
semicolon=False):
25-
"""Groups together tokens that are joined by a middle token. ie. x < y"""
26-
for token in list(tlist):
27-
if token.is_group() and not isinstance(token, cls):
28-
_group_left_right(token, m, cls, valid_left, valid_right,
29-
semicolon)
30-
continue
31-
if not token.match(*m):
32-
continue
33-
34-
tidx = tlist.token_index(token)
35-
pidx, prev_ = tlist.token_prev(tidx)
36-
nidx, next_ = tlist.token_next(tidx)
37-
38-
if valid_left(prev_) and valid_right(next_):
39-
if semicolon:
40-
# only overwrite if a semicolon present.
41-
snidx, _ = tlist.token_next_by(m=M_SEMICOLON, idx=nidx)
42-
nidx = snidx or nidx
43-
# Luckily, this leaves the position of `token` intact.
44-
tlist.group_tokens(cls, pidx, nidx, extend=True)
45-
46-
4717
def _group_matching(tlist, cls):
4818
"""Groups Tokens that have beginning and end."""
4919
opens = []
@@ -69,6 +39,18 @@ def _group_matching(tlist, cls):
6939
tlist.group_tokens(cls, oidx, cidx)
7040

7141

42+
def group_brackets(tlist):
43+
_group_matching(tlist, sql.SquareBrackets)
44+
45+
46+
def group_parenthesis(tlist):
47+
_group_matching(tlist, sql.Parenthesis)
48+
49+
50+
def group_case(tlist):
51+
_group_matching(tlist, sql.Case)
52+
53+
7254
def group_if(tlist):
7355
_group_matching(tlist, sql.If)
7456

@@ -77,16 +59,54 @@ def group_for(tlist):
7759
_group_matching(tlist, sql.For)
7860

7961

80-
def group_foreach(tlist):
81-
_group_matching(tlist, sql.For)
82-
83-
8462
def group_begin(tlist):
8563
_group_matching(tlist, sql.Begin)
8664

8765

66+
def _group_left_right(tlist, m, cls,
67+
valid_left=lambda t: t is not None,
68+
valid_right=lambda t: t is not None,
69+
semicolon=False):
70+
"""Groups together tokens that are joined by a middle token. ie. x < y"""
71+
for token in list(tlist):
72+
if token.is_group() and not isinstance(token, cls):
73+
_group_left_right(token, m, cls, valid_left, valid_right,
74+
semicolon)
75+
continue
76+
if not token.match(*m):
77+
continue
78+
79+
tidx = tlist.token_index(token)
80+
pidx, prev_ = tlist.token_prev(tidx)
81+
nidx, next_ = tlist.token_next(tidx)
82+
83+
if valid_left(prev_) and valid_right(next_):
84+
if semicolon:
85+
# only overwrite if a semicolon present.
86+
m_semicolon = T.Punctuation, ';'
87+
snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
88+
nidx = snidx or nidx
89+
# Luckily, this leaves the position of `token` intact.
90+
tlist.group_tokens(cls, pidx, nidx, extend=True)
91+
92+
93+
def group_typecasts(tlist):
94+
_group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier)
95+
96+
97+
def group_period(tlist):
98+
lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier),
99+
t=(T.Name, T.String.Symbol,))
100+
101+
rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function),
102+
t=(T.Name, T.String.Symbol, T.Wildcard))
103+
104+
_group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier,
105+
valid_left=lfunc, valid_right=rfunc)
106+
107+
88108
def group_as(tlist):
89-
lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.value == 'NULL'
109+
lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.normalized == 'NULL'
90110
rfunc = lambda tk: not imt(tk, t=(T.DML, T.DDL))
91111
_group_left_right(tlist, (T.Keyword, 'AS'), sql.Identifier,
92112
valid_left=lfunc, valid_right=rfunc)
@@ -109,10 +129,6 @@ def group_comparison(tlist):
109129
valid_left=func, valid_right=func)
110130

111131

112-
def group_case(tlist):
113-
_group_matching(tlist, sql.Case)
114-
115-
116132
@recurse(sql.Identifier)
117133
def group_identifier(tlist):
118134
T_IDENT = (T.String.Symbol, T.Name)
@@ -123,17 +139,6 @@ def group_identifier(tlist):
123139
tidx, token = tlist.token_next_by(t=T_IDENT, idx=tidx)
124140

125141

126-
def group_period(tlist):
127-
lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier),
128-
t=(T.Name, T.String.Symbol,))
129-
130-
rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function),
131-
t=(T.Name, T.String.Symbol, T.Wildcard))
132-
133-
_group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier,
134-
valid_left=lfunc, valid_right=rfunc)
135-
136-
137142
def group_arrays(tlist):
138143
tidx, token = tlist.token_next_by(i=sql.SquareBrackets)
139144
while token:
@@ -168,6 +173,9 @@ def group_operator(tlist):
168173

169174
@recurse(sql.IdentifierList)
170175
def group_identifier_list(tlist):
176+
M_ROLE = T.Keyword, ('null', 'role')
177+
M_COMMA = T.Punctuation, ','
178+
171179
I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
172180
sql.IdentifierList, sql.Operation)
173181
T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
@@ -186,14 +194,6 @@ def group_identifier_list(tlist):
186194
tidx, token = tlist.token_next_by(m=M_COMMA, idx=tidx)
187195

188196

189-
def group_brackets(tlist):
190-
_group_matching(tlist, sql.SquareBrackets)
191-
192-
193-
def group_parenthesis(tlist):
194-
_group_matching(tlist, sql.Parenthesis)
195-
196-
197197
@recurse(sql.Comment)
198198
def group_comments(tlist):
199199
tidx, token = tlist.token_next_by(t=T.Comment)
@@ -237,10 +237,6 @@ def group_aliased(tlist):
237237
tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
238238

239239

240-
def group_typecasts(tlist):
241-
_group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier)
242-
243-
244240
@recurse(sql.Function)
245241
def group_functions(tlist):
246242
has_create = False
@@ -286,11 +282,17 @@ def align_comments(tlist):
286282
def group(stmt):
287283
for func in [
288284
group_comments,
285+
286+
# _group_matching
289287
group_brackets,
290288
group_parenthesis,
289+
group_case,
290+
group_if,
291+
group_for,
292+
group_begin,
293+
291294
group_functions,
292295
group_where,
293-
group_case,
294296
group_period,
295297
group_arrays,
296298
group_identifier,
@@ -301,12 +303,9 @@ def group(stmt):
301303
group_aliased,
302304
group_assignment,
303305
group_comparison,
306+
304307
align_comments,
305308
group_identifier_list,
306-
group_if,
307-
group_for,
308-
group_foreach,
309-
group_begin,
310309
]:
311310
func(stmt)
312311
return stmt

0 commit comments

Comments
 (0)