@@ -51,19 +51,21 @@ def _group_left_right(tlist, ttype, value, cls,
5151 ttype , value )
5252
5353
54+ def _find_matching (idx , tlist , start_ttype , start_value , end_ttype , end_value ):
55+ depth = 1
56+ for tok in tlist .tokens [idx :]:
57+ if tok .match (start_ttype , start_value ):
58+ depth += 1
59+ elif tok .match (end_ttype , end_value ):
60+ depth -= 1
61+ if depth == 1 :
62+ return tok
63+ return None
64+
65+
5466def _group_matching (tlist , start_ttype , start_value , end_ttype , end_value ,
5567 cls , include_semicolon = False , recurse = False ):
56- def _find_matching (i , tl , stt , sva , ett , eva ):
57- depth = 1
58- for n in xrange (i , len (tl .tokens )):
59- t = tl .tokens [n ]
60- if t .match (stt , sva ):
61- depth += 1
62- elif t .match (ett , eva ):
63- depth -= 1
64- if depth == 1 :
65- return t
66- return None
68+
6769 [_group_matching (sgroup , start_ttype , start_value , end_ttype , end_value ,
6870 cls , include_semicolon ) for sgroup in tlist .get_sublists ()
6971 if recurse ]
@@ -157,16 +159,17 @@ def _consume_cycle(tl, i):
157159 lambda y : (y .match (T .Punctuation , '.' )
158160 or y .ttype in (T .Operator ,
159161 T .Wildcard ,
160- T .ArrayIndex ,
161- T . Name )),
162+ T .Name )
163+ or isinstance ( y , sql . SquareBrackets )),
162164 lambda y : (y .ttype in (T .String .Symbol ,
163165 T .Name ,
164166 T .Wildcard ,
165- T .ArrayIndex ,
166167 T .Literal .String .Single ,
167168 T .Literal .Number .Integer ,
168169 T .Literal .Number .Float )
169- or isinstance (y , (sql .Parenthesis , sql .Function )))))
170+ or isinstance (y , (sql .Parenthesis ,
171+ sql .SquareBrackets ,
172+ sql .Function )))))
170173 for t in tl .tokens [i :]:
171174 # Don't take whitespaces into account.
172175 if t .ttype is T .Whitespace :
@@ -275,9 +278,48 @@ def group_identifier_list(tlist):
275278 tcomma = next_
276279
277280
278- def group_parenthesis (tlist ):
279- _group_matching (tlist , T .Punctuation , '(' , T .Punctuation , ')' ,
280- sql .Parenthesis )
281+ def group_brackets (tlist ):
282+ """Group parentheses () or square brackets []
283+
284+ This is just like _group_matching, but complicated by the fact that
285+ round brackets can contain square bracket groups and vice versa
286+ """
287+
288+ if isinstance (tlist , (sql .Parenthesis , sql .SquareBrackets )):
289+ idx = 1
290+ else :
291+ idx = 0
292+
293+ # Find the first opening bracket
294+ token = tlist .token_next_match (idx , T .Punctuation , ['(' , '[' ])
295+
296+ while token :
297+ start_val = token .value # either '(' or '['
298+ if start_val == '(' :
299+ end_val = ')'
300+ group_class = sql .Parenthesis
301+ else :
302+ end_val = ']'
303+ group_class = sql .SquareBrackets
304+
305+ tidx = tlist .token_index (token )
306+
307+ # Find the corresponding closing bracket
308+ end = _find_matching (tidx , tlist , T .Punctuation , start_val ,
309+ T .Punctuation , end_val )
310+
311+ if end is None :
312+ idx = tidx + 1
313+ else :
314+ group = tlist .group_tokens (group_class ,
315+ tlist .tokens_between (token , end ))
316+
317+ # Check for nested bracket groups within this group
318+ group_brackets (group )
319+ idx = tlist .token_index (group ) + 1
320+
321+ # Find the next opening bracket
322+ token = tlist .token_next_match (idx , T .Punctuation , ['(' , '[' ])
281323
282324
283325def group_comments (tlist ):
@@ -393,7 +435,7 @@ def align_comments(tlist):
393435def group (tlist ):
394436 for func in [
395437 group_comments ,
396- group_parenthesis ,
438+ group_brackets ,
397439 group_functions ,
398440 group_where ,
399441 group_case ,
0 commit comments