Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def emit(self, record):

logging.getLogger('cassandra').addHandler(NullHandler())

__version_info__ = (3, 24, 5)
__version_info__ = (3, 24, 6)
__version__ = '.'.join(map(str, __version_info__))


Expand Down
78 changes: 43 additions & 35 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Connection pooling and host management.
"""

from concurrent.futures import Future
from functools import total_ordering
import logging
import socket
Expand Down Expand Up @@ -411,6 +411,7 @@ def __init__(self, host, host_distance, session):
# and are waiting until all requests time out or complete
# so that we can dispose of them.
self._trash = set()
self._shard_connections_futures = []

if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
Expand Down Expand Up @@ -537,9 +538,9 @@ def return_connection(self, connection, stream_was_orphaned=False):
if is_down:
self.shutdown()
else:
connection.close()
del self._connections[connection.shard_id]
with self._lock:
connection.close()
self._connections.pop(connection.shard_id, None)
if self._is_replacing:
return
self._is_replacing = True
Expand Down Expand Up @@ -568,23 +569,22 @@ def _replace(self, connection):
if self.is_shutdown:
return

log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
if connection.shard_id in self._connections.keys():
del self._connections[connection.shard_id]
if self.host.sharding_info:
self._connecting.add(connection.shard_id)
self._open_connection_to_missing_shard(connection.shard_id)
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
if connection.shard_id in self._connections.keys():
del self._connections[connection.shard_id]
if self.host.sharding_info:
self._connecting.add(connection.shard_id)
self._session.submit(self._open_connection_to_missing_shard, connection.shard_id)
else:
connection = self._session.cluster.connection_factory(self.host.endpoint, owning_pool=self)
if self._keyspace:
connection.set_keyspace_blocking(self._keyspace)
self._connections[connection.shard_id] = connection
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
connection = self._session.cluster.connection_factory(self.host.endpoint, owning_pool=self)
if self._keyspace:
connection.set_keyspace_blocking(self._keyspace)
self._connections[connection.shard_id] = connection
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
with self._lock:
self._is_replacing = False
self._stream_available_condition.notify()

Expand All @@ -597,11 +597,14 @@ def shutdown(self):
self.is_shutdown = True
self._stream_available_condition.notify_all()

if self._connections:
for c in self._connections.values():
log.debug("Closing connection (%s) to %s", id(c), self.host)
c.close()
self._connections = {}
for future in self._shard_connections_futures:
future.cancel()

if self._connections:
for connection in self._connections.values():
log.debug("Closing connection (%s) to %s", id(connection), self.host)
connection.close()
self._connections.clear()

self._close_excess_connections()

Expand All @@ -620,7 +623,7 @@ def _close_excess_connections(self):
if not self._excess_connections:
return
conns = self._excess_connections
self._excess_connections = set()
self._excess_connections.clear()

for c in conns:
log.debug("Closing excess connection (%s) to %s", id(c), self.host)
Expand Down Expand Up @@ -653,7 +656,9 @@ def _open_connection_to_missing_shard(self, shard_id):
if self.is_shutdown:
log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", id(conn), self.host)
conn.close()
elif conn.shard_id not in self._connections.keys() or self._connections[conn.shard_id].orphaned_threshold_reached:
return
old_conn = self._connections.get(conn.shard_id)
if old_conn is None or old_conn.orphaned_threshold_reached:
log.debug(
"New connection (%s) created to shard_id=%i on host %s",
id(conn),
Expand Down Expand Up @@ -698,7 +703,8 @@ def _open_connection_to_missing_shard(self, shard_id):
else:
self._trash.add(old_conn)
if self._keyspace:
self._connections[conn.shard_id].set_keyspace_blocking(self._keyspace)
if old_conn := self._connections.get(conn.shard_id):
old_conn.set_keyspace_blocking(self._keyspace)
num_missing_or_needing_replacement = self.num_missing_or_needing_replacement
log.debug(
"Connected to %s/%i shards on host %s (%i missing or needs replacement)",
Expand Down Expand Up @@ -750,9 +756,11 @@ def _open_connections_for_all_shards(self):
if self.is_shutdown:
return

for shard_id in range(self.host.sharding_info.shards_count):
self._connecting.add(shard_id)
self._session.submit(self._open_connection_to_missing_shard, shard_id)
for shard_id in range(self.host.sharding_info.shards_count):
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
if isinstance(future, Future):
self._connecting.add(shard_id)
self._shard_connections_futures.append(future)

def _set_keyspace_for_all_conns(self, keyspace, callback):
"""
Expand All @@ -779,15 +787,15 @@ def connection_finished_setting_keyspace(conn, error):
callback(self, errors)

self._keyspace = keyspace
for conn in self._connections.values():
for conn in list(self._connections.values()):
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)

def get_connections(self):
c = self._connections
return list(self._connections.values()) if c else []
connections = self._connections
return list(connections.values()) if connections else []

def get_state(self):
in_flights = [c.in_flight for c in self._connections.values()]
in_flights = [c.in_flight for c in list(self._connections.values())]
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}

@property
Expand All @@ -797,7 +805,7 @@ def num_missing_or_needing_replacement(self):

@property
def open_count(self):
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in self._connections.values()])
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in list(self._connections.values())])

@property
def _excess_connection_limit(self):
Expand Down
62 changes: 59 additions & 3 deletions tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
import logging
import time

from cassandra.shard_info import _ShardingInfo

try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import unittest # noqa
import unittest.mock as mock

from mock import Mock, NonCallableMagicMock
from mock import Mock, NonCallableMagicMock, MagicMock
from threading import Thread, Event, Lock

from cassandra.cluster import Session
Expand All @@ -26,6 +32,8 @@
from cassandra.pool import Host, NoConnectionsAvailable
from cassandra.policies import HostDistance, SimpleConvictionPolicy

LOGGER = logging.getLogger(__name__)


class _PoolTests(unittest.TestCase):
__test__ = False
Expand Down Expand Up @@ -79,7 +87,8 @@ def test_failed_wait_for_connection(self):
def test_successful_wait_for_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, lock=Lock())
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100,
lock=Lock())
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
Expand Down Expand Up @@ -266,3 +275,50 @@ class HostConnectionTests(_PoolTests):
PoolImpl = HostConnection
uses_single_connection = True

def test_fast_shutdown(self):
class MockSession(MagicMock):
is_shutdown = False
keyspace = "reprospace"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cluster = MagicMock()
self.cluster.executor = ThreadPoolExecutor(max_workers=2, initializer=self.executor_init)
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
self.cluster.connection_factory = self.mock_connection_factory
self.connection_counter = 0

def submit(self, fn, *args, **kwargs):
LOGGER.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
if not self.is_shutdown:
return self.cluster.executor.submit(fn, *args, **kwargs)

def mock_connection_factory(self, *args, **kwargs):
connection = MagicMock()
connection.is_shutdown = False
connection.is_defunct = False
connection.is_closed = False
connection.shard_id = self.connection_counter
self.connection_counter += 1
connection.sharding_info = _ShardingInfo(shard_id=1, shards_count=14,
partitioner="", sharding_algorithm="", sharding_ignore_msb=0)

return connection

def executor_init(self, *args):
time.sleep(0.5)
LOGGER.info("Future start: %s", args)

for attempt_num in range(20):
LOGGER.info("Testing fast shutdown %d / 20 times", attempt_num + 1)
host = MagicMock()
host.endpoint = "1.2.3.4"
session = MockSession()

pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
LOGGER.info("Initialized pool %s", pool)
LOGGER.info("Connections: %s", pool._connections)
time.sleep(0.5)
pool.shutdown()
time.sleep(3)
session.cluster.executor.shutdown()