Skip to content

Commit 2c498ff

Browse files
committed
Allow specification of alternate return formats
1 parent c1c0b41 commit 2c498ff

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

cassandra/cluster.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
WriteTimeoutErrorMessage,
1414
UnavailableErrorMessage,
1515
OverloadedErrorMessage,
16-
IsBootstrappingErrorMessage)
16+
IsBootstrappingErrorMessage, named_tuple_factory,
17+
dict_factory)
1718
from cassandra.metadata import Metadata
1819
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
1920
ExponentialReconnectionPolicy, HostDistance,
@@ -430,6 +431,19 @@ class Session(object):
430431
keyspace = None
431432
is_shutdown = False
432433

434+
row_factory = staticmethod(named_tuple_factory)
435+
"""
436+
The format to return row results in. By default, each
437+
returned row will be a named tuple. You can alternatively
438+
use any of the following:
439+
440+
- :func:`cassandra.decoder.tuple_factory`
441+
- :func:`cassandra.decoder.named_tuple_factory`
442+
- :func:`cassandra.decoder.dict_factory`
443+
- :func:`cassandra.decoder.ordered_dict_factory`
444+
445+
"""
446+
433447
_lock = None
434448
_pools = None
435449
_load_balancer = None
@@ -798,9 +812,14 @@ def _refresh_schema(self, connection, keyspace=None, table=None):
798812

799813
if ks_query:
800814
ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query)
815+
ks_result = dict_factory(*ks_result.results)
816+
cf_result = dict_factory(*cf_result.results)
817+
col_result = dict_factory(*col_result.results)
801818
else:
802819
ks_result = None
803820
cf_result, col_result = connection.wait_for_responses(cf_query, col_query)
821+
cf_result = dict_factory(*cf_result.results)
822+
col_result = dict_factory(*col_result.results)
804823

805824
self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
806825

@@ -817,12 +836,14 @@ def _refresh_node_list_and_token_map(self, connection):
817836
peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
818837
local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
819838
peers_result, local_result = connection.wait_for_responses(peers_query, local_query)
839+
peers_result = dict_factory(*peers_result.results)
820840

821841
partitioner = None
822842
token_map = {}
823843

824844
if local_result.results:
825-
local_row = local_result.results[0]
845+
local_rows = dict_factory(*(local_result.results))
846+
local_row = local_rows[0]
826847
cluster_name = local_row["cluster_name"]
827848
self._cluster.metadata.cluster_name = cluster_name
828849

@@ -836,7 +857,7 @@ def _refresh_node_list_and_token_map(self, connection):
836857
token_map[host] = tokens
837858

838859
found_hosts = set()
839-
for row in peers_result.results:
860+
for row in peers_result:
840861
addr = row.get("rpc_address")
841862

842863
# TODO handle ipv6 equivalent
@@ -909,14 +930,15 @@ def wait_for_schema_agreement(self, connection=None):
909930
peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
910931
local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
911932
peers_result, local_result = connection.wait_for_responses(peers_query, local_query)
933+
peers_result = dict_factory(*peers_result.results)
912934

913935
versions = set()
914936
if local_result.results:
915-
local_row = local_result.results[0]
937+
local_row = dict_factory(*local_result.results)[0]
916938
if local_row.get("schema_version"):
917939
versions.add(local_row.get("schema_version"))
918940

919-
for row in peers_result.results:
941+
for row in peers_result:
920942
if not row.get("rpc_address") or not row.get("schema_version"):
921943
continue
922944

@@ -1074,6 +1096,7 @@ class ResponseFuture(object):
10741096

10751097
def __init__(self, session, message, query):
10761098
self.session = session
1099+
self.row_factory = session.row_factory
10771100
self.message = message
10781101
self.query = query
10791102

@@ -1149,7 +1172,10 @@ def _set_result(self, response):
11491172
self.session.cluster.control_connection,
11501173
self)
11511174
else:
1152-
self._set_final_result(getattr(response, 'results', None))
1175+
results = getattr(response, 'results', None)
1176+
if results:
1177+
results = self.row_factory(*results)
1178+
self._set_final_result(results)
11531179
elif isinstance(response, ErrorMessage):
11541180
retry_policy = self.query.retry_policy
11551181
if not retry_policy:

cassandra/decoder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from collections import namedtuple, OrderedDict
1718
import socket
1819
try:
1920
from cStringIO import StringIO
@@ -46,6 +47,23 @@ def warn(msg):
4647
print msg
4748

4849

50+
def tuple_factory(colnames, rows):
51+
return rows
52+
53+
54+
def named_tuple_factory(colnames, rows):
55+
Row = namedtuple('Row', colnames)
56+
return [Row(*row) for row in rows]
57+
58+
59+
def dict_factory(colnames, rows):
60+
return [dict(zip(colnames, row)) for row in rows]
61+
62+
63+
def ordered_dict_factory(colnames, rows):
64+
return [OrderedDict(zip(colnames, row)) for row in rows]
65+
66+
4967
class PreparedResult:
5068
def __init__(self, queryid, param_metadata):
5169
self.queryid = queryid
@@ -393,8 +411,8 @@ def recv_results_rows(cls, f):
393411
rows = [cls.recv_row(f, len(colspecs)) for x in xrange(rowcount)]
394412
colnames = [c[2] for c in colspecs]
395413
coltypes = [c[3] for c in colspecs]
396-
return [dict(zip(colnames, [ctype.from_binary(val) for ctype, val in zip(coltypes, row)]))
397-
for row in rows]
414+
return (colnames, [tuple(ctype.from_binary(val) for ctype, val in zip(coltypes, row))
415+
for row in rows])
398416

399417
@classmethod
400418
def recv_results_prepared(cls, f):

cassandra/metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def rebuild_schema(self, keyspace, table, ks_results, cf_results, col_results):
6363
cf_def_rows = defaultdict(list)
6464
col_def_rows = defaultdict(lambda: defaultdict(list))
6565

66-
for row in cf_results.results:
66+
for row in cf_results:
6767
cf_def_rows[row["keyspace_name"]].append(row)
6868

69-
for row in col_results.results:
69+
for row in col_results:
7070
ksname = row["keyspace_name"]
7171
cfname = row["columnfamily_name"]
7272
col_def_rows[ksname][cfname].append(row)
@@ -75,7 +75,7 @@ def rebuild_schema(self, keyspace, table, ks_results, cf_results, col_results):
7575
if not table:
7676
# ks_results is not None
7777
added_keyspaces = set()
78-
for row in ks_results.results:
78+
for row in ks_results:
7979
keyspace_meta = self._build_keyspace_metadata(row)
8080
for table_row in cf_def_rows.get(keyspace_meta.name, []):
8181
table_meta = self._build_table_metadata(

example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def main():
1515
s = c.connect()
1616

1717
rows = s.execute("SELECT keyspace_name FROM system.schema_keyspaces")
18-
if KEYSPACE in [row.values()[0] for row in rows]:
18+
if KEYSPACE in [row[0] for row in rows]:
1919
print "dropping existing keyspace..."
2020
s.execute("DROP KEYSPACE " + KEYSPACE)
2121

@@ -55,7 +55,7 @@ def main():
5555
log.exeception()
5656

5757
for row in rows:
58-
print '\t'.join(row.values())
58+
print '\t'.join(row)
5959

6060
s.execute("DROP KEYSPACE " + KEYSPACE)
6161

0 commit comments

Comments
 (0)