Skip to content

Commit 1417977

Browse files
authored
Merge pull request apache#1091 from datastax/python-1258
PYTHON-1258: Implement protocol v5 checksumming
2 parents 2473aae + d992e81 commit 1417977

13 files changed

Lines changed: 640 additions & 39 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Not released
55
Features
66
--------
77
* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260)
8+
* Implement protocol v5 checksumming (PYTHON-1258)
89

910
Bug Fixes
1011
---------

cassandra/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ def has_continuous_paging_support(cls, version):
235235
def has_continuous_paging_next_pages(cls, version):
236236
return version >= cls.DSE_V2
237237

238+
@classmethod
239+
def has_checksumming_support(cls, version):
240+
return cls.V5 <= version < cls.DSE_V1
241+
238242

239243
class WriteType(object):
240244
"""

cassandra/connection.py

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from threading import Thread, Event, RLock, Condition
2828
import time
2929
import ssl
30+
import weakref
31+
3032

3133
if 'gevent.monkey' in sys.modules:
3234
from gevent.queue import Queue, Empty
@@ -42,11 +44,15 @@
4244
AuthResponseMessage, AuthChallengeMessage,
4345
AuthSuccessMessage, ProtocolException,
4446
RegisterMessage, ReviseRequestMessage)
47+
from cassandra.segment import SegmentCodec, CrcException
4548
from cassandra.util import OrderedDict
4649

4750

4851
log = logging.getLogger(__name__)
4952

53+
segment_codec_no_compression = SegmentCodec()
54+
segment_codec_lz4 = None
55+
5056
# We use an ordered dictionary and specifically add lz4 before
5157
# snappy so that lz4 will be preferred. Changing the order of this
5258
# will change the compression preferences for the driver.
@@ -88,6 +94,7 @@ def lz4_decompress(byts):
8894
return lz4_block.decompress(byts[3::-1] + byts[4:])
8995

9096
locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress)
97+
segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress)
9198

9299
try:
93100
import snappy
@@ -426,6 +433,10 @@ class ProtocolError(Exception):
426433
pass
427434

428435

436+
class CrcMismatchException(ConnectionException):
437+
pass
438+
439+
429440
class ContinuousPagingState(object):
430441
"""
431442
A class for specifying continuous paging state, only supported starting with DSE_V2.
@@ -601,6 +612,55 @@ def int_from_buf_item(i):
601612
int_from_buf_item = ord
602613

603614

615+
class _ConnectionIOBuffer(object):
616+
"""
617+
Abstraction class to ease the use of the different connection io buffers. With
618+
protocol V5 and checksumming, the data is read, validated and copied to another
619+
cql frame buffer.
620+
"""
621+
_io_buffer = None
622+
_cql_frame_buffer = None
623+
_connection = None
624+
625+
def __init__(self, connection):
626+
self._io_buffer = io.BytesIO()
627+
self._connection = weakref.proxy(connection)
628+
629+
@property
630+
def io_buffer(self):
631+
return self._io_buffer
632+
633+
@property
634+
def cql_frame_buffer(self):
635+
return self._cql_frame_buffer if self.is_checksumming_enabled else \
636+
self._io_buffer
637+
638+
def set_checksumming_buffer(self):
639+
self.reset_io_buffer()
640+
self._cql_frame_buffer = io.BytesIO()
641+
642+
@property
643+
def is_checksumming_enabled(self):
644+
return self._connection._is_checksumming_enabled
645+
646+
def readable_io_bytes(self):
647+
return self.io_buffer.tell()
648+
649+
def readable_cql_frame_bytes(self):
650+
return self.cql_frame_buffer.tell()
651+
652+
def reset_io_buffer(self):
653+
self._io_buffer = io.BytesIO(self._io_buffer.read())
654+
self._io_buffer.seek(0, 2) # 2 == SEEK_END
655+
656+
def reset_cql_frame_buffer(self):
657+
if self.is_checksumming_enabled:
658+
self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read())
659+
self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END
660+
else:
661+
self.reset_io_buffer()
662+
663+
604664
class Connection(object):
605665

606666
CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -656,7 +716,6 @@ class Connection(object):
656716

657717
allow_beta_protocol_version = False
658718

659-
_iobuf = None
660719
_current_frame = None
661720

662721
_socket = None
@@ -667,6 +726,13 @@ class Connection(object):
667726
_check_hostname = False
668727
_product_type = None
669728

729+
_is_checksumming_enabled = False
730+
731+
@property
732+
def _iobuf(self):
733+
# backward compatibility, to avoid any change in the reactors
734+
return self._io_buffer.io_buffer
735+
670736
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
671737
ssl_options=None, sockopts=None, compression=True,
672738
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
@@ -690,7 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
690756
self.no_compact = no_compact
691757
self._push_watchers = defaultdict(set)
692758
self._requests = {}
693-
self._iobuf = io.BytesIO()
759+
self._io_buffer = _ConnectionIOBuffer(self)
694760
self._continuous_paging_sessions = {}
695761
self._socket_writable = True
696762

@@ -831,6 +897,12 @@ def _connect_socket(self):
831897
for args in self.sockopts:
832898
self._socket.setsockopt(*args)
833899

900+
def _enable_checksumming(self):
901+
self._io_buffer.set_checksumming_buffer()
902+
self._is_checksumming_enabled = True
903+
self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression
904+
log.debug("Enabling protocol checksumming on connection (%s).", id(self))
905+
834906
def close(self):
835907
raise NotImplementedError()
836908

@@ -933,7 +1005,14 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
9331005
# queue the decoder function with the request
9341006
# this allows us to inject custom functions per request to encode, decode messages
9351007
self._requests[request_id] = (cb, decoder, result_metadata)
936-
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version)
1008+
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
1009+
allow_beta_protocol_version=self.allow_beta_protocol_version)
1010+
1011+
if self._is_checksumming_enabled:
1012+
buffer = io.BytesIO()
1013+
self._segment_codec.encode(buffer, msg)
1014+
msg = buffer.getvalue()
1015+
9371016
self.push(msg)
9381017
return len(msg)
9391018

@@ -1012,7 +1091,7 @@ def control_conn_disposed(self):
10121091

10131092
@defunct_on_error
10141093
def _read_frame_header(self):
1015-
buf = self._iobuf.getvalue()
1094+
buf = self._io_buffer.cql_frame_buffer.getvalue()
10161095
pos = len(buf)
10171096
if pos:
10181097
version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK
@@ -1028,29 +1107,51 @@ def _read_frame_header(self):
10281107
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
10291108
return pos
10301109

1031-
def _reset_frame(self):
1032-
self._iobuf = io.BytesIO(self._iobuf.read())
1033-
self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6)
1034-
self._current_frame = None
1110+
@defunct_on_error
1111+
def _process_segment_buffer(self):
1112+
readable_bytes = self._io_buffer.readable_io_bytes()
1113+
if readable_bytes >= self._segment_codec.header_length_with_crc:
1114+
try:
1115+
self._io_buffer.io_buffer.seek(0)
1116+
segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer)
1117+
if readable_bytes >= segment_header.segment_length:
1118+
segment = self._segment_codec.decode(self._iobuf, segment_header)
1119+
self._io_buffer.cql_frame_buffer.write(segment.payload)
1120+
else:
1121+
# not enough data to read the segment
1122+
self._io_buffer.io_buffer.seek(0, 2)
1123+
except CrcException as exc:
1124+
# re-raise an exception that inherits from ConnectionException
1125+
raise CrcMismatchException(str(exc), self.endpoint)
10351126

10361127
def process_io_buffer(self):
10371128
while True:
1129+
if self._is_checksumming_enabled:
1130+
self._process_segment_buffer()
1131+
self._io_buffer.reset_io_buffer()
1132+
10381133
if not self._current_frame:
10391134
pos = self._read_frame_header()
10401135
else:
1041-
pos = self._iobuf.tell()
1136+
pos = self._io_buffer.readable_cql_frame_bytes()
10421137

10431138
if not self._current_frame or pos < self._current_frame.end_pos:
1139+
if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
1140+
# We have a multi-segments message and we need to read more
1141+
# data to complete the current cql frame
1142+
continue
1143+
10441144
# we don't have a complete header yet or we
10451145
# already saw a header, but we don't have a
10461146
# complete message yet
10471147
return
10481148
else:
10491149
frame = self._current_frame
1050-
self._iobuf.seek(frame.body_offset)
1051-
msg = self._iobuf.read(frame.end_pos - frame.body_offset)
1150+
self._io_buffer.cql_frame_buffer.seek(frame.body_offset)
1151+
msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset)
10521152
self.process_msg(frame, msg)
1053-
self._reset_frame()
1153+
self._io_buffer.reset_cql_frame_buffer()
1154+
self._current_frame = None
10541155

10551156
@defunct_on_error
10561157
def process_msg(self, header, body):
@@ -1185,11 +1286,19 @@ def _handle_options_response(self, options_response):
11851286
compression_type = k
11861287
break
11871288

1188-
# set the decompressor here, but set the compressor only after
1189-
# a successful Ready message
1190-
self._compression_type = compression_type
1191-
self._compressor, self.decompressor = \
1192-
locally_supported_compressions[compression_type]
1289+
# If snappy compression is selected with v5+checksumming, the connection
1290+
# will fail with OTO. Only lz4 is supported
1291+
if (compression_type == 'snappy' and
1292+
ProtocolVersion.has_checksumming_support(self.protocol_version)):
1293+
log.debug("Snappy compression is not supported with protocol version %s and "
1294+
"checksumming. Consider installing lz4. Disabling compression.", self.protocol_version)
1295+
compression_type = None
1296+
else:
1297+
# set the decompressor here, but set the compressor only after
1298+
# a successful Ready message
1299+
self._compression_type = compression_type
1300+
self._compressor, self.decompressor = \
1301+
locally_supported_compressions[compression_type]
11931302

11941303
self._send_startup_message(compression_type, no_compact=self.no_compact)
11951304

@@ -1210,6 +1319,7 @@ def _send_startup_message(self, compression=None, no_compact=False):
12101319
def _handle_startup_response(self, startup_response, did_authenticate=False):
12111320
if self.is_defunct:
12121321
return
1322+
12131323
if isinstance(startup_response, ReadyMessage):
12141324
if self.authenticator:
12151325
log.warning("An authentication challenge was not sent, "
@@ -1220,6 +1330,10 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
12201330
log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint)
12211331
if self._compressor:
12221332
self.compressor = self._compressor
1333+
1334+
if ProtocolVersion.has_checksumming_support(self.protocol_version):
1335+
self._enable_checksumming()
1336+
12231337
self.connected_event.set()
12241338
elif isinstance(startup_response, AuthenticateMessage):
12251339
log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s",

cassandra/marshal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _make_packer(format_string):
2828
int8_pack, int8_unpack = _make_packer('>b')
2929
uint64_pack, uint64_unpack = _make_packer('>Q')
3030
uint32_pack, uint32_unpack = _make_packer('>I')
31+
uint32_le_pack, uint32_le_unpack = _make_packer('<I')
3132
uint16_pack, uint16_unpack = _make_packer('>H')
3233
uint8_pack, uint8_unpack = _make_packer('>B')
3334
float_pack, float_unpack = _make_packer('>f')

cassandra/protocol.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
UserAggregateDescriptor, SchemaTargetType)
3232
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
3333
uint8_pack, int8_unpack, uint64_pack, header_pack,
34-
v3_header_pack, uint32_pack)
34+
v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack)
3535
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
3636
CounterColumnType, DateType, DecimalType,
3737
DoubleType, FloatType, Int32Type,
@@ -1115,7 +1115,9 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
11151115
msg.send_body(body, protocol_version)
11161116
body = body.getvalue()
11171117

1118-
if compressor and len(body) > 0:
1118+
# With checksumming, the compression is done at the segment frame encoding
1119+
if (not ProtocolVersion.has_checksumming_support(protocol_version)
1120+
and compressor and len(body) > 0):
11191121
body = compressor(body)
11201122
flags |= COMPRESSED_FLAG
11211123

@@ -1155,7 +1157,8 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod
11551157
:param decompressor: optional decompression function to inflate the body
11561158
:return: a message decoded from the body and frame attributes
11571159
"""
1158-
if flags & COMPRESSED_FLAG:
1160+
if (not ProtocolVersion.has_checksumming_support(protocol_version) and
1161+
flags & COMPRESSED_FLAG):
11591162
if decompressor is None:
11601163
raise RuntimeError("No de-compressor available for compressed frame!")
11611164
body = decompressor(body)
@@ -1271,6 +1274,33 @@ def read_int(f):
12711274
return int32_unpack(f.read(4))
12721275

12731276

1277+
def read_uint_le(f, size=4):
1278+
"""
1279+
Read a sequence of little endian bytes and return an unsigned integer.
1280+
"""
1281+
1282+
if size == 4:
1283+
value = uint32_le_unpack(f.read(4))
1284+
else:
1285+
value = 0
1286+
for i in range(size):
1287+
value |= (read_byte(f) & 0xFF) << 8 * i
1288+
1289+
return value
1290+
1291+
1292+
def write_uint_le(f, i, size=4):
1293+
"""
1294+
Write an unsigned integer on a sequence of little endian bytes.
1295+
"""
1296+
if size == 4:
1297+
f.write(uint32_le_pack(i))
1298+
else:
1299+
for j in range(size):
1300+
shift = j * 8
1301+
write_byte(f, i >> shift & 0xFF)
1302+
1303+
12741304
def write_int(f, i):
12751305
f.write(int32_pack(i))
12761306

0 commit comments

Comments
 (0)