2727from threading import Thread , Event , RLock , Condition
2828import time
2929import ssl
30+ import weakref
31+
3032
3133if 'gevent.monkey' in sys .modules :
3234 from gevent .queue import Queue , Empty
4244 AuthResponseMessage , AuthChallengeMessage ,
4345 AuthSuccessMessage , ProtocolException ,
4446 RegisterMessage , ReviseRequestMessage )
47+ from cassandra .segment import SegmentCodec , CrcException
4548from cassandra .util import OrderedDict
4649
4750
4851log = logging .getLogger (__name__ )
4952
53+ segment_codec_no_compression = SegmentCodec ()
54+ segment_codec_lz4 = None
55+
5056# We use an ordered dictionary and specifically add lz4 before
5157# snappy so that lz4 will be preferred. Changing the order of this
5258# will change the compression preferences for the driver.
@@ -88,6 +94,7 @@ def lz4_decompress(byts):
8894 return lz4_block .decompress (byts [3 ::- 1 ] + byts [4 :])
8995
9096 locally_supported_compressions ['lz4' ] = (lz4_compress , lz4_decompress )
97+ segment_codec_lz4 = SegmentCodec (lz4_compress , lz4_decompress )
9198
9299try :
93100 import snappy
@@ -426,6 +433,10 @@ class ProtocolError(Exception):
426433 pass
427434
428435
436+ class CrcMismatchException (ConnectionException ):
437+ pass
438+
439+
429440class ContinuousPagingState (object ):
430441 """
431442 A class for specifying continuous paging state, only supported starting with DSE_V2.
@@ -601,6 +612,55 @@ def int_from_buf_item(i):
601612 int_from_buf_item = ord
602613
603614
615+ class _ConnectionIOBuffer (object ):
616+ """
617+ Abstraction class to ease the use of the different connection io buffers. With
618+ protocol V5 and checksumming, the data is read, validated and copied to another
619+ cql frame buffer.
620+ """
621+ _io_buffer = None
622+ _cql_frame_buffer = None
623+ _connection = None
624+
625+ def __init__ (self , connection ):
626+ self ._io_buffer = io .BytesIO ()
627+ self ._connection = weakref .proxy (connection )
628+
629+ @property
630+ def io_buffer (self ):
631+ return self ._io_buffer
632+
633+ @property
634+ def cql_frame_buffer (self ):
635+ return self ._cql_frame_buffer if self .is_checksumming_enabled else \
636+ self ._io_buffer
637+
638+ def set_checksumming_buffer (self ):
639+ self .reset_io_buffer ()
640+ self ._cql_frame_buffer = io .BytesIO ()
641+
642+ @property
643+ def is_checksumming_enabled (self ):
644+ return self ._connection ._is_checksumming_enabled
645+
646+ def readable_io_bytes (self ):
647+ return self .io_buffer .tell ()
648+
649+ def readable_cql_frame_bytes (self ):
650+ return self .cql_frame_buffer .tell ()
651+
652+ def reset_io_buffer (self ):
653+ self ._io_buffer = io .BytesIO (self ._io_buffer .read ())
654+ self ._io_buffer .seek (0 , 2 ) # 2 == SEEK_END
655+
656+ def reset_cql_frame_buffer (self ):
657+ if self .is_checksumming_enabled :
658+ self ._cql_frame_buffer = io .BytesIO (self ._cql_frame_buffer .read ())
659+ self ._cql_frame_buffer .seek (0 , 2 ) # 2 == SEEK_END
660+ else :
661+ self .reset_io_buffer ()
662+
663+
604664class Connection (object ):
605665
606666 CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -656,7 +716,6 @@ class Connection(object):
656716
657717 allow_beta_protocol_version = False
658718
659- _iobuf = None
660719 _current_frame = None
661720
662721 _socket = None
@@ -667,6 +726,13 @@ class Connection(object):
667726 _check_hostname = False
668727 _product_type = None
669728
729+ _is_checksumming_enabled = False
730+
731+ @property
732+ def _iobuf (self ):
733+ # backward compatibility, to avoid any change in the reactors
734+ return self ._io_buffer .io_buffer
735+
670736 def __init__ (self , host = '127.0.0.1' , port = 9042 , authenticator = None ,
671737 ssl_options = None , sockopts = None , compression = True ,
672738 cql_version = None , protocol_version = ProtocolVersion .MAX_SUPPORTED , is_control_connection = False ,
@@ -690,7 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
690756 self .no_compact = no_compact
691757 self ._push_watchers = defaultdict (set )
692758 self ._requests = {}
693- self ._iobuf = io . BytesIO ( )
759+ self ._io_buffer = _ConnectionIOBuffer ( self )
694760 self ._continuous_paging_sessions = {}
695761 self ._socket_writable = True
696762
@@ -831,6 +897,12 @@ def _connect_socket(self):
831897 for args in self .sockopts :
832898 self ._socket .setsockopt (* args )
833899
900+ def _enable_checksumming (self ):
901+ self ._io_buffer .set_checksumming_buffer ()
902+ self ._is_checksumming_enabled = True
903+ self ._segment_codec = segment_codec_lz4 if self .compressor else segment_codec_no_compression
904+ log .debug ("Enabling protocol checksumming on connection (%s)." , id (self ))
905+
834906 def close (self ):
835907 raise NotImplementedError ()
836908
@@ -933,7 +1005,14 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
9331005 # queue the decoder function with the request
9341006 # this allows us to inject custom functions per request to encode, decode messages
9351007 self ._requests [request_id ] = (cb , decoder , result_metadata )
936- msg = encoder (msg , request_id , self .protocol_version , compressor = self .compressor , allow_beta_protocol_version = self .allow_beta_protocol_version )
1008+ msg = encoder (msg , request_id , self .protocol_version , compressor = self .compressor ,
1009+ allow_beta_protocol_version = self .allow_beta_protocol_version )
1010+
1011+ if self ._is_checksumming_enabled :
1012+ buffer = io .BytesIO ()
1013+ self ._segment_codec .encode (buffer , msg )
1014+ msg = buffer .getvalue ()
1015+
9371016 self .push (msg )
9381017 return len (msg )
9391018
@@ -1012,7 +1091,7 @@ def control_conn_disposed(self):
10121091
10131092 @defunct_on_error
10141093 def _read_frame_header (self ):
1015- buf = self ._iobuf .getvalue ()
1094+ buf = self ._io_buffer . cql_frame_buffer .getvalue ()
10161095 pos = len (buf )
10171096 if pos :
10181097 version = int_from_buf_item (buf [0 ]) & PROTOCOL_VERSION_MASK
@@ -1028,29 +1107,51 @@ def _read_frame_header(self):
10281107 self ._current_frame = _Frame (version , flags , stream , op , header_size , body_len + header_size )
10291108 return pos
10301109
1031- def _reset_frame (self ):
1032- self ._iobuf = io .BytesIO (self ._iobuf .read ())
1033- self ._iobuf .seek (0 , 2 ) # io.SEEK_END == 2 (constant not present in 2.6)
1034- self ._current_frame = None
1110+ @defunct_on_error
1111+ def _process_segment_buffer (self ):
1112+ readable_bytes = self ._io_buffer .readable_io_bytes ()
1113+ if readable_bytes >= self ._segment_codec .header_length_with_crc :
1114+ try :
1115+ self ._io_buffer .io_buffer .seek (0 )
1116+ segment_header = self ._segment_codec .decode_header (self ._io_buffer .io_buffer )
1117+ if readable_bytes >= segment_header .segment_length :
1118+ segment = self ._segment_codec .decode (self ._iobuf , segment_header )
1119+ self ._io_buffer .cql_frame_buffer .write (segment .payload )
1120+ else :
1121+ # not enough data to read the segment
1122+ self ._io_buffer .io_buffer .seek (0 , 2 )
1123+ except CrcException as exc :
1124+ # re-raise an exception that inherits from ConnectionException
1125+ raise CrcMismatchException (str (exc ), self .endpoint )
10351126
10361127 def process_io_buffer (self ):
10371128 while True :
1129+ if self ._is_checksumming_enabled :
1130+ self ._process_segment_buffer ()
1131+ self ._io_buffer .reset_io_buffer ()
1132+
10381133 if not self ._current_frame :
10391134 pos = self ._read_frame_header ()
10401135 else :
1041- pos = self ._iobuf . tell ()
1136+ pos = self ._io_buffer . readable_cql_frame_bytes ()
10421137
10431138 if not self ._current_frame or pos < self ._current_frame .end_pos :
1139+ if self ._is_checksumming_enabled and self ._io_buffer .readable_io_bytes ():
1140+ # We have a multi-segments message and we need to read more
1141+ # data to complete the current cql frame
1142+ continue
1143+
10441144 # we don't have a complete header yet or we
10451145 # already saw a header, but we don't have a
10461146 # complete message yet
10471147 return
10481148 else :
10491149 frame = self ._current_frame
1050- self ._iobuf .seek (frame .body_offset )
1051- msg = self ._iobuf .read (frame .end_pos - frame .body_offset )
1150+ self ._io_buffer . cql_frame_buffer .seek (frame .body_offset )
1151+ msg = self ._io_buffer . cql_frame_buffer .read (frame .end_pos - frame .body_offset )
10521152 self .process_msg (frame , msg )
1053- self ._reset_frame ()
1153+ self ._io_buffer .reset_cql_frame_buffer ()
1154+ self ._current_frame = None
10541155
10551156 @defunct_on_error
10561157 def process_msg (self , header , body ):
@@ -1185,11 +1286,19 @@ def _handle_options_response(self, options_response):
11851286 compression_type = k
11861287 break
11871288
1188- # set the decompressor here, but set the compressor only after
1189- # a successful Ready message
1190- self ._compression_type = compression_type
1191- self ._compressor , self .decompressor = \
1192- locally_supported_compressions [compression_type ]
1289+ # If snappy compression is selected with v5+checksumming, the connection
1290+ # will fail with OTO. Only lz4 is supported
1291+ if (compression_type == 'snappy' and
1292+ ProtocolVersion .has_checksumming_support (self .protocol_version )):
1293+ log .debug ("Snappy compression is not supported with protocol version %s and "
1294+ "checksumming. Consider installing lz4. Disabling compression." , self .protocol_version )
1295+ compression_type = None
1296+ else :
1297+ # set the decompressor here, but set the compressor only after
1298+ # a successful Ready message
1299+ self ._compression_type = compression_type
1300+ self ._compressor , self .decompressor = \
1301+ locally_supported_compressions [compression_type ]
11931302
11941303 self ._send_startup_message (compression_type , no_compact = self .no_compact )
11951304
@@ -1210,6 +1319,7 @@ def _send_startup_message(self, compression=None, no_compact=False):
12101319 def _handle_startup_response (self , startup_response , did_authenticate = False ):
12111320 if self .is_defunct :
12121321 return
1322+
12131323 if isinstance (startup_response , ReadyMessage ):
12141324 if self .authenticator :
12151325 log .warning ("An authentication challenge was not sent, "
@@ -1220,6 +1330,10 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
12201330 log .debug ("Got ReadyMessage on new connection (%s) from %s" , id (self ), self .endpoint )
12211331 if self ._compressor :
12221332 self .compressor = self ._compressor
1333+
1334+ if ProtocolVersion .has_checksumming_support (self .protocol_version ):
1335+ self ._enable_checksumming ()
1336+
12231337 self .connected_event .set ()
12241338 elif isinstance (startup_response , AuthenticateMessage ):
12251339 log .debug ("Got AuthenticateMessage on new connection (%s) from %s: %s" ,
0 commit comments