@@ -152,46 +152,42 @@ def group_arrays(tlist):
152152
153153@recurse (sql .Identifier )
154154def group_operator (tlist ):
155- I_CYCLE = (sql .SquareBrackets , sql .Parenthesis , sql .Function ,
155+ ttypes = T_NUMERICAL + T_STRING + T_NAME
156+ clss = (sql .SquareBrackets , sql .Parenthesis , sql .Function ,
156157 sql .Identifier , sql .Operation )
157- # wilcards wouldn't have operations next to them
158- T_CYCLE = T_NUMERICAL + T_STRING + T_NAME
159- func = lambda tk : imt (tk , i = I_CYCLE , t = T_CYCLE )
160158
161- tidx , token = tlist .token_next_by (t = (T .Operator , T .Wildcard ))
162- while token :
163- pidx , prev_ = tlist .token_prev (tidx )
164- nidx , next_ = tlist .token_next (tidx )
159+ def match (token ):
160+ return imt (token , t = (T .Operator , T .Wildcard ))
165161
166- if func (prev_ ) and func (next_ ):
167- token .ttype = T .Operator
168- tlist .group_tokens (sql .Operation , pidx , nidx )
169- tidx = pidx
162+ def valid (token ):
163+ return imt (token , i = clss , t = ttypes )
164+
165+ def post (tlist , pidx , tidx , nidx ):
166+ tlist [tidx ].ttype = T .Operator
167+ return pidx , nidx
170168
171- tidx , token = tlist . token_next_by ( t = ( T . Operator , T . Wildcard ), idx = tidx )
169+ _group ( tlist , sql . Operation , match , valid , valid , post , extend = False )
172170
173171
174- @recurse (sql .IdentifierList )
175172def group_identifier_list (tlist ):
176- M_ROLE = T .Keyword , ('null' , 'role' )
177- M_COMMA = T .Punctuation , ','
173+ m_role = T .Keyword , ('null' , 'role' )
174+ m_comma = T .Punctuation , ','
175+ clss = (sql .Function , sql .Case , sql .Identifier , sql .Comparison ,
176+ sql .IdentifierList , sql .Operation )
177+ ttypes = (T_NUMERICAL + T_STRING + T_NAME +
178+ (T .Keyword , T .Comment , T .Wildcard ))
178179
179- I_IDENT_LIST = (sql .Function , sql .Case , sql .Identifier , sql .Comparison ,
180- sql .IdentifierList , sql .Operation )
181- T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
182- (T .Keyword , T .Comment , T .Wildcard ))
180+ def match (token ):
181+ return imt (token , m = m_comma )
183182
184- func = lambda t : imt (t , i = I_IDENT_LIST , m = M_ROLE , t = T_IDENT_LIST )
183+ def func (token ):
184+ return imt (token , i = clss , m = m_role , t = ttypes )
185185
186- tidx , token = tlist .token_next_by (m = M_COMMA )
187- while token :
188- pidx , prev_ = tlist .token_prev (tidx )
189- nidx , next_ = tlist .token_next (tidx )
186+ def post (tlist , pidx , tidx , nidx ):
187+ return pidx , nidx
190188
191- if func (prev_ ) and func (next_ ):
192- tlist .group_tokens (sql .IdentifierList , pidx , nidx , extend = True )
193- tidx = pidx
194- tidx , token = tlist .token_next_by (m = M_COMMA , idx = tidx )
189+ _group (tlist , sql .IdentifierList , match ,
190+ valid_left = func , valid_right = func , post = post , extend = True )
195191
196192
197193@recurse (sql .Comment )
@@ -309,3 +305,25 @@ def group(stmt):
309305 ]:
310306 func (stmt )
311307 return stmt
308+
309+
310+ def _group (tlist , cls , match ,
311+ valid_left = lambda t : True ,
312+ valid_right = lambda t : True ,
313+ post = None ,
314+ extend = True ):
315+ """Groups together tokens that are joined by a middle token. ie. x < y"""
316+ for token in list (tlist ):
317+ if token .is_group () and not isinstance (token , cls ):
318+ _group (token , cls , match , valid_left , valid_right , post , extend )
319+ continue
320+ if not match (token ):
321+ continue
322+
323+ tidx = tlist .token_index (token )
324+ pidx , prev_ = tlist .token_prev (tidx )
325+ nidx , next_ = tlist .token_next (tidx )
326+
327+ if valid_left (prev_ ) and valid_right (next_ ):
328+ from_idx , to_idx = post (tlist , pidx , tidx , nidx )
329+ tlist .group_tokens (cls , from_idx , to_idx , extend = extend )
0 commit comments