Skip to content

Commit 12e127f

Browse files
authored
Model.batch_get: add guard-rails (pynamodb#1184)
For models with a range key, fail if: - item is a `str` ("accidental" iterable) - item is an iterable with != 2 items
1 parent 0cf2e94 commit 12e127f

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

pynamodb/models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,23 @@ def batch_get(
357357
keys_to_get = []
358358
item = items.pop()
359359
if range_key_attribute:
360-
hash_key, range_key = cls._serialize_keys(item[0], item[1]) # type: ignore
360+
if isinstance(item, str):
361+
raise ValueError(f'Invalid key value {item!r}: '
362+
'expected non-str iterable with exactly 2 elements (hash key, range key)')
363+
try:
364+
hash_key, range_key = item
365+
except (TypeError, ValueError):
366+
raise ValueError(f'Invalid key value {item!r}: '
367+
'expected iterable with exactly 2 elements (hash key, range key)')
368+
hash_key_ser, range_key_ser = cls._serialize_keys(hash_key, range_key)
361369
keys_to_get.append({
362-
hash_key_attribute.attr_name: hash_key,
363-
range_key_attribute.attr_name: range_key
370+
hash_key_attribute.attr_name: hash_key_ser,
371+
range_key_attribute.attr_name: range_key_ser,
364372
})
365373
else:
366-
hash_key = cls._serialize_keys(item)[0]
374+
hash_key_ser, _ = cls._serialize_keys(item)
367375
keys_to_get.append({
368-
hash_key_attribute.attr_name: hash_key
376+
hash_key_attribute.attr_name: hash_key_ser
369377
})
370378

371379
while keys_to_get:

tests/test_model.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import base64
55
import json
66
import copy
7+
import re
78
from datetime import datetime
89
from datetime import timedelta
910
from datetime import timezone
@@ -1845,7 +1846,6 @@ def test_batch_get(self):
18451846
}
18461847
self.assertEqual(params, req.call_args[0][1])
18471848

1848-
18491849
with patch(PATCH_METHOD) as req:
18501850
item_keys = [('hash-{}'.format(x), '{}'.format(x)) for x in range(10)]
18511851
item_keys_copy = list(item_keys)
@@ -1906,6 +1906,56 @@ def fake_batch_get(*batch_args):
19061906
for item in UserModel.batch_get(item_keys):
19071907
self.assertIsNotNone(item)
19081908

1909+
def test_batch_get__range_key(self):
1910+
with patch(PATCH_METHOD) as req:
1911+
req.return_value = {
1912+
'UnprocessedKeys': {},
1913+
'Responses': {
1914+
'UserModel': [],
1915+
}
1916+
}
1917+
items = [(f'hash-{x}', f'range-{x}') for x in range(10)]
1918+
_ = list(UserModel.batch_get(items))
1919+
1920+
actual_keys = req.call_args[0][1]['RequestItems']['UserModel']['Keys']
1921+
actual_keys.sort(key=json.dumps)
1922+
assert actual_keys == [
1923+
{'user_name': {'S': f'hash-{x}'}, 'user_id': {'S': f'range-{x}'}}
1924+
for x in range(10)
1925+
]
1926+
1927+
def test_batch_get__range_key__invalid__string(self):
1928+
with patch(PATCH_METHOD) as req:
1929+
req.return_value = {
1930+
'UnprocessedKeys': {},
1931+
'Responses': {
1932+
'UserModel': [],
1933+
}
1934+
}
1935+
with pytest.raises(
1936+
ValueError,
1937+
match=re.escape(
1938+
"Invalid key value 'ab': expected non-str iterable with exactly 2 elements (hash key, range key)"
1939+
)
1940+
):
1941+
_ = list(UserModel.batch_get(['ab']))
1942+
1943+
def test_batch_get__range_key__invalid__3_elements(self):
1944+
with patch(PATCH_METHOD) as req:
1945+
req.return_value = {
1946+
'UnprocessedKeys': {},
1947+
'Responses': {
1948+
'UserModel': [],
1949+
}
1950+
}
1951+
with pytest.raises(
1952+
ValueError,
1953+
match=re.escape(
1954+
"Invalid key value ('a', 'b', 'c'): expected iterable with exactly 2 elements (hash key, range key)"
1955+
)
1956+
):
1957+
_ = list(UserModel.batch_get([('a', 'b', 'c')]))
1958+
19091959
def test_batch_write(self):
19101960
"""
19111961
Model.batch_write

0 commit comments

Comments
 (0)