Skip to content

Commit 74b3464

Browse files
committed
Re-Write grouping functions
1 parent af9b82e commit 74b3464

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

sqlparse/engine/grouping.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -152,46 +152,42 @@ def group_arrays(tlist):
152152

153153
@recurse(sql.Identifier)
154154
def group_operator(tlist):
155-
I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
155+
ttypes = T_NUMERICAL + T_STRING + T_NAME
156+
clss = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
156157
sql.Identifier, sql.Operation)
157-
# wilcards wouldn't have operations next to them
158-
T_CYCLE = T_NUMERICAL + T_STRING + T_NAME
159-
func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE)
160158

161-
tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard))
162-
while token:
163-
pidx, prev_ = tlist.token_prev(tidx)
164-
nidx, next_ = tlist.token_next(tidx)
159+
def match(token):
160+
return imt(token, t=(T.Operator, T.Wildcard))
165161

166-
if func(prev_) and func(next_):
167-
token.ttype = T.Operator
168-
tlist.group_tokens(sql.Operation, pidx, nidx)
169-
tidx = pidx
162+
def valid(token):
163+
return imt(token, i=clss, t=ttypes)
164+
165+
def post(tlist, pidx, tidx, nidx):
166+
tlist[tidx].ttype = T.Operator
167+
return pidx, nidx
170168

171-
tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=tidx)
169+
_group(tlist, sql.Operation, match, valid, valid, post, extend=False)
172170

173171

174-
@recurse(sql.IdentifierList)
175172
def group_identifier_list(tlist):
176-
M_ROLE = T.Keyword, ('null', 'role')
177-
M_COMMA = T.Punctuation, ','
173+
m_role = T.Keyword, ('null', 'role')
174+
m_comma = T.Punctuation, ','
175+
clss = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
176+
sql.IdentifierList, sql.Operation)
177+
ttypes = (T_NUMERICAL + T_STRING + T_NAME +
178+
(T.Keyword, T.Comment, T.Wildcard))
178179

179-
I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
180-
sql.IdentifierList, sql.Operation)
181-
T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
182-
(T.Keyword, T.Comment, T.Wildcard))
180+
def match(token):
181+
return imt(token, m=m_comma)
183182

184-
func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST)
183+
def func(token):
184+
return imt(token, i=clss, m=m_role, t=ttypes)
185185

186-
tidx, token = tlist.token_next_by(m=M_COMMA)
187-
while token:
188-
pidx, prev_ = tlist.token_prev(tidx)
189-
nidx, next_ = tlist.token_next(tidx)
186+
def post(tlist, pidx, tidx, nidx):
187+
return pidx, nidx
190188

191-
if func(prev_) and func(next_):
192-
tlist.group_tokens(sql.IdentifierList, pidx, nidx, extend=True)
193-
tidx = pidx
194-
tidx, token = tlist.token_next_by(m=M_COMMA, idx=tidx)
189+
_group(tlist, sql.IdentifierList, match,
190+
valid_left=func, valid_right=func, post=post, extend=True)
195191

196192

197193
@recurse(sql.Comment)
@@ -309,3 +305,25 @@ def group(stmt):
309305
]:
310306
func(stmt)
311307
return stmt
308+
309+
310+
def _group(tlist, cls, match,
311+
valid_left=lambda t: True,
312+
valid_right=lambda t: True,
313+
post=None,
314+
extend=True):
315+
"""Groups together tokens that are joined by a middle token. ie. x < y"""
316+
for token in list(tlist):
317+
if token.is_group() and not isinstance(token, cls):
318+
_group(token, cls, match, valid_left, valid_right, post, extend)
319+
continue
320+
if not match(token):
321+
continue
322+
323+
tidx = tlist.token_index(token)
324+
pidx, prev_ = tlist.token_prev(tidx)
325+
nidx, next_ = tlist.token_next(tidx)
326+
327+
if valid_left(prev_) and valid_right(next_):
328+
from_idx, to_idx = post(tlist, pidx, tidx, nidx)
329+
tlist.group_tokens(cls, from_idx, to_idx, extend=extend)

0 commit comments

Comments
 (0)