Skip to content

Commit 895f021

Browse files
committed
Grouping of function/procedure calls.
1 parent 9917967 commit 895f021

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

sqlparse/engine/grouping.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,25 @@ def group_typecasts(tlist):
245245
_group_left_right(tlist, T.Punctuation, '::', Identifier)
246246

247247

248+
def group_functions(tlist):
249+
[group_functions(sgroup) for sgroup in tlist.get_sublists()
250+
if not isinstance(sgroup, Function)]
251+
idx = 0
252+
token = tlist.token_next_by_type(idx, T.Name)
253+
while token:
254+
next_ = tlist.token_next(token)
255+
if not isinstance(next_, Parenthesis):
256+
idx = tlist.token_index(token)+1
257+
else:
258+
func = tlist.group_tokens(Function,
259+
tlist.tokens_between(token, next_))
260+
idx = tlist.token_index(func)+1
261+
token = tlist.token_next_by_type(idx, T.Name)
262+
263+
248264
def group(tlist):
249265
for func in [group_parenthesis,
266+
group_functions,
250267
group_comments,
251268
group_where,
252269
group_case,

sqlparse/sql.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,17 @@ def get_cases(self):
455455
elif in_value:
456456
ret[-1][1].append(token)
457457
return ret
458+
459+
460+
class Function(TokenList):
461+
"""A function or procedure call."""
462+
463+
__slots__ = ('value', 'ttype', 'tokens')
464+
465+
def get_parameters(self):
466+
"""Return a list of parameters."""
467+
parenthesis = self.tokens[-1]
468+
for t in parenthesis.tokens:
469+
if isinstance(t, IdentifierList):
470+
return t.get_identifiers()
471+
return []

tests/test_grouping.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class TestGrouping(TestCaseBase):
1111

1212
def test_parenthesis(self):
13-
s ='x1 (x2 (x3) x2) foo (y2) bar'
13+
s ='select (select (x3) x2) and (y2) bar'
1414
parsed = sqlparse.parse(s)[0]
1515
self.ndiffAssertEqual(s, str(parsed))
1616
self.assertEqual(len(parsed.tokens), 9)
@@ -142,6 +142,13 @@ def test_comparsion_exclude(self):
142142
p = sqlparse.parse('(a+1)')[0]
143143
self.assert_(isinstance(p.tokens[0].tokens[1], Comparsion))
144144

145+
def test_function(self):
146+
p = sqlparse.parse('foo()')[0]
147+
self.assert_(isinstance(p.tokens[0], Function))
148+
p = sqlparse.parse('foo(null, bar)')[0]
149+
self.assert_(isinstance(p.tokens[0], Function))
150+
self.assertEqual(len(p.tokens[0].get_parameters()), 2)
151+
145152

146153
class TestStatement(TestCaseBase):
147154

0 commit comments

Comments
 (0)