Skip to content

Commit bf26160

Browse files
committed
Merge pull request andialbrecht#177 from darikg/brackets
Better square bracket / array index handling
2 parents 15b0cb9 + acdebef commit bf26160

5 files changed

Lines changed: 125 additions & 52 deletions

File tree

sqlparse/engine/grouping.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,21 @@ def _group_left_right(tlist, ttype, value, cls,
5151
ttype, value)
5252

5353

54+
def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value):
55+
depth = 1
56+
for tok in tlist.tokens[idx:]:
57+
if tok.match(start_ttype, start_value):
58+
depth += 1
59+
elif tok.match(end_ttype, end_value):
60+
depth -= 1
61+
if depth == 1:
62+
return tok
63+
return None
64+
65+
5466
def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
5567
cls, include_semicolon=False, recurse=False):
56-
def _find_matching(i, tl, stt, sva, ett, eva):
57-
depth = 1
58-
for n in xrange(i, len(tl.tokens)):
59-
t = tl.tokens[n]
60-
if t.match(stt, sva):
61-
depth += 1
62-
elif t.match(ett, eva):
63-
depth -= 1
64-
if depth == 1:
65-
return t
66-
return None
68+
6769
[_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
6870
cls, include_semicolon) for sgroup in tlist.get_sublists()
6971
if recurse]
@@ -157,16 +159,17 @@ def _consume_cycle(tl, i):
157159
lambda y: (y.match(T.Punctuation, '.')
158160
or y.ttype in (T.Operator,
159161
T.Wildcard,
160-
T.ArrayIndex,
161-
T.Name)),
162+
T.Name)
163+
or isinstance(y, sql.SquareBrackets)),
162164
lambda y: (y.ttype in (T.String.Symbol,
163165
T.Name,
164166
T.Wildcard,
165-
T.ArrayIndex,
166167
T.Literal.String.Single,
167168
T.Literal.Number.Integer,
168169
T.Literal.Number.Float)
169-
or isinstance(y, (sql.Parenthesis, sql.Function)))))
170+
or isinstance(y, (sql.Parenthesis,
171+
sql.SquareBrackets,
172+
sql.Function)))))
170173
for t in tl.tokens[i:]:
171174
# Don't take whitespaces into account.
172175
if t.ttype is T.Whitespace:
@@ -275,9 +278,48 @@ def group_identifier_list(tlist):
275278
tcomma = next_
276279

277280

278-
def group_parenthesis(tlist):
279-
_group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')',
280-
sql.Parenthesis)
281+
def group_brackets(tlist):
282+
"""Group parentheses () or square brackets []
283+
284+
This is just like _group_matching, but complicated by the fact that
285+
round brackets can contain square bracket groups and vice versa
286+
"""
287+
288+
if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)):
289+
idx = 1
290+
else:
291+
idx = 0
292+
293+
# Find the first opening bracket
294+
token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
295+
296+
while token:
297+
start_val = token.value # either '(' or '['
298+
if start_val == '(':
299+
end_val = ')'
300+
group_class = sql.Parenthesis
301+
else:
302+
end_val = ']'
303+
group_class = sql.SquareBrackets
304+
305+
tidx = tlist.token_index(token)
306+
307+
# Find the corresponding closing bracket
308+
end = _find_matching(tidx, tlist, T.Punctuation, start_val,
309+
T.Punctuation, end_val)
310+
311+
if end is None:
312+
idx = tidx + 1
313+
else:
314+
group = tlist.group_tokens(group_class,
315+
tlist.tokens_between(token, end))
316+
317+
# Check for nested bracket groups within this group
318+
group_brackets(group)
319+
idx = tlist.token_index(group) + 1
320+
321+
# Find the next opening bracket
322+
token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
281323

282324

283325
def group_comments(tlist):
@@ -393,7 +435,7 @@ def align_comments(tlist):
393435
def group(tlist):
394436
for func in [
395437
group_comments,
396-
group_parenthesis,
438+
group_brackets,
397439
group_functions,
398440
group_where,
399441
group_case,

sqlparse/lexer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ class Lexer(object):
194194
(r"'(''|\\\\|\\'|[^'])*'", tokens.String.Single),
195195
# not a real string literal in ANSI SQL:
196196
(r'(""|".*?[^\\]")', tokens.String.Symbol),
197-
(r'(?<=[\w\]])(\[[^\]]*?\])', tokens.Punctuation.ArrayIndex),
198-
(r'(\[[^\]]+\])', tokens.Name),
197+
# sqlite names can be escaped with [square brackets]. left bracket
198+
# cannot be preceded by word character or a right bracket --
199+
# otherwise it's probably an array index
200+
(r'(?<![\w\])])(\[[^\]]+\])', tokens.Name),
199201
(r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword),
200202
(r'END(\s+IF|\s+LOOP)?\b', tokens.Keyword),
201203
(r'NOT NULL\b', tokens.Keyword),

sqlparse/sql.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,12 @@ def get_ordering(self):
511511
return ordering.value.upper()
512512

513513
def get_array_indices(self):
514-
"""Returns an iterator of index expressions as strings"""
514+
"""Returns an iterator of index token lists"""
515515

516-
# Use [1:-1] index to discard the square brackets
517-
return (tok.value[1:-1] for tok in self.tokens
518-
if tok.ttype in T.ArrayIndex)
516+
for tok in self.tokens:
517+
if isinstance(tok, SquareBrackets):
518+
# Use [1:-1] index to discard the square brackets
519+
yield tok.tokens[1:-1]
519520

520521

521522
class IdentifierList(TokenList):
@@ -542,6 +543,15 @@ def _groupable_tokens(self):
542543
return self.tokens[1:-1]
543544

544545

546+
class SquareBrackets(TokenList):
547+
"""Tokens between square brackets"""
548+
549+
__slots__ = ('value', 'ttype', 'tokens')
550+
551+
@property
552+
def _groupable_tokens(self):
553+
return self.tokens[1:-1]
554+
545555
class Assignment(TokenList):
546556
"""An assignment like 'var := val;'"""
547557
__slots__ = ('value', 'ttype', 'tokens')

sqlparse/tokens.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __repr__(self):
5757
String = Literal.String
5858
Number = Literal.Number
5959
Punctuation = Token.Punctuation
60-
ArrayIndex = Punctuation.ArrayIndex
6160
Operator = Token.Operator
6261
Comparison = Operator.Comparison
6362
Wildcard = Token.Wildcard

tests/test_parse.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_single_quotes_with_linebreaks(): # issue118
215215
assert p[0].ttype is T.String.Single
216216

217217

218-
def test_array_indexed_column():
218+
def test_sqlite_identifiers():
219219
# Make sure we still parse sqlite style escapes
220220
p = sqlparse.parse('[col1],[col2]')[0].tokens
221221
assert (len(p) == 1
@@ -227,39 +227,59 @@ def test_array_indexed_column():
227227
types = [tok.ttype for tok in p.flatten()]
228228
assert types == [T.Name, T.Operator, T.Name]
229229

230+
231+
def test_simple_1d_array_index():
230232
p = sqlparse.parse('col[1]')[0].tokens
231-
assert (len(p) == 1
232-
and tuple(p[0].get_array_indices()) == ('1',)
233-
and p[0].get_name() == 'col')
233+
assert len(p) == 1
234+
assert p[0].get_name() == 'col'
235+
indices = list(p[0].get_array_indices())
236+
assert (len(indices) == 1 # 1-dimensional index
237+
and len(indices[0]) == 1 # index is single token
238+
and indices[0][0].value == '1')
234239

235-
p = sqlparse.parse('col[1][1:5] as mycol')[0].tokens
236-
assert (len(p) == 1
237-
and tuple(p[0].get_array_indices()) == ('1', '1:5')
238-
and p[0].get_name() == 'mycol'
239-
and p[0].get_real_name() == 'col')
240-
241-
p = sqlparse.parse('col[1][other_col]')[0].tokens
242-
assert len(p) == 1 and tuple(p[0].get_array_indices()) == ('1', 'other_col')
243-
244-
sql = 'SELECT col1, my_1d_array[2] as alias1, my_2d_array[2][5] as alias2'
245-
p = sqlparse.parse(sql)[0].tokens
246-
assert len(p) == 3 and isinstance(p[2], sqlparse.sql.IdentifierList)
247-
ids = list(p[2].get_identifiers())
248-
assert (ids[0].get_name() == 'col1'
249-
and tuple(ids[0].get_array_indices()) == ()
250-
and ids[1].get_name() == 'alias1'
251-
and ids[1].get_real_name() == 'my_1d_array'
252-
and tuple(ids[1].get_array_indices()) == ('2',)
253-
and ids[2].get_name() == 'alias2'
254-
and ids[2].get_real_name() == 'my_2d_array'
255-
and tuple(ids[2].get_array_indices()) == ('2', '5'))
240+
241+
def test_2d_array_index():
242+
p = sqlparse.parse('col[x][(y+1)*2]')[0].tokens
243+
assert len(p) == 1
244+
assert p[0].get_name() == 'col'
245+
assert len(list(p[0].get_array_indices())) == 2 # 2-dimensional index
246+
247+
248+
def test_array_index_function_result():
249+
p = sqlparse.parse('somefunc()[1]')[0].tokens
250+
assert len(p) == 1
251+
assert len(list(p[0].get_array_indices())) == 1
252+
253+
254+
def test_schema_qualified_array_index():
255+
p = sqlparse.parse('schem.col[1]')[0].tokens
256+
assert len(p) == 1
257+
assert p[0].get_parent_name() == 'schem'
258+
assert p[0].get_name() == 'col'
259+
assert list(p[0].get_array_indices())[0][0].value == '1'
260+
261+
262+
def test_aliased_array_index():
263+
p = sqlparse.parse('col[1] x')[0].tokens
264+
assert len(p) == 1
265+
assert p[0].get_alias() == 'x'
266+
assert p[0].get_real_name() == 'col'
267+
assert list(p[0].get_array_indices())[0][0].value == '1'
268+
269+
270+
def test_array_literal():
271+
# See issue #176
272+
p = sqlparse.parse('ARRAY[%s, %s]')[0]
273+
assert len(p.tokens) == 2
274+
assert len(list(p.flatten())) == 7
256275

257276

258277
def test_typed_array_definition():
259278
# array indices aren't grouped with builtins, but make sure we can extract
260279
# indentifer names
261280
p = sqlparse.parse('x int, y int[], z int')[0]
262-
names = [x.get_name() for x in p.get_sublists()]
281+
names = [x.get_name() for x in p.get_sublists()
282+
if isinstance(x, sqlparse.sql.Identifier)]
263283
assert names == ['x', 'y', 'z']
264284

265285

0 commit comments

Comments
 (0)