Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown,
ConnectionHeartbeat, ProtocolVersionUnsupported,
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
SniEndPointFactory, ConnectionBusy, locally_supported_compressions)
SniEndPointFactory, ConnectionBusy, locally_supported_compressions,
SSLSessionCache)
from cassandra.cqltypes import UserType
import cassandra.cqltypes as types
from cassandra.encoder import Encoder
Expand Down Expand Up @@ -876,6 +877,26 @@
.. versionadded:: 3.17.0
"""

ssl_session_cache = None
"""
An optional :class:`.connection.SSLSessionCache` instance used to enable TLS
session resumption (via session tickets or PSK) for all connections managed
by this cluster.

When :attr:`~Cluster.ssl_context` or :attr:`~Cluster.ssl_options` are set,
a cache is created automatically so that reconnections to the same host can
skip the full TLS handshake. Set this to :const:`None` explicitly to
disable session caching.

Note: TLS 1.2 sessions are cached immediately after connect. TLS 1.3
sessions are cached after the CQL handshake completes (Ready / AuthSuccess),
because session tickets are sent asynchronously by the server.

Note: only the stdlib ``ssl`` reactor paths are supported (asyncore, libev,
gevent, asyncio). Twisted and Eventlet connections use pyOpenSSL and are
not covered by this cache.
"""

sockopts = None
"""
An optional list of tuples which will be used as arguments to
Expand Down Expand Up @@ -1217,7 +1238,8 @@
metadata_request_timeout: Optional[float] = None,
column_encryption_policy=None,
application_info:Optional[ApplicationInfoBase]=None,
client_routes_config:Optional[ClientRoutesConfig]=None
client_routes_config:Optional[ClientRoutesConfig]=None,
ssl_session_cache=_NOT_SET
):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
Expand Down Expand Up @@ -1461,6 +1483,30 @@

self.ssl_options = ssl_options
self.ssl_context = ssl_context

# Auto-create a session cache when TLS is enabled, unless the caller
# explicitly passed ssl_session_cache (including None to opt out).
if ssl_session_cache is _NOT_SET:
if ssl_context is not None or ssl_options is not None:
self.ssl_session_cache = SSLSessionCache()
else:
self.ssl_session_cache = None
else:
self.ssl_session_cache = ssl_session_cache

# Warn when the session cache won't be used because the connection
# class uses pyOpenSSL instead of the stdlib ssl module.
if self.ssl_session_cache is not None:
uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)
if uses_twisted or uses_eventlet:
log.warning(
"ssl_session_cache is set but the connection class %s uses "
"pyOpenSSL, which does not support stdlib ssl session "
"resumption. The cache will have no effect.",
self.connection_class.__name__,
)

self.sockopts = sockopts
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
Expand Down Expand Up @@ -1706,6 +1752,7 @@
kwargs_dict.setdefault('sockopts', self.sockopts)
kwargs_dict.setdefault('ssl_options', self.ssl_options)
kwargs_dict.setdefault('ssl_context', self.ssl_context)
kwargs_dict.setdefault('ssl_session_cache', self.ssl_session_cache)
kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types)
Expand Down Expand Up @@ -4340,7 +4387,7 @@
self._scheduled_tasks.discard(task)
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)

Check failure on line 4390 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test libev (3.12)

cannot schedule new futures after shutdown
future.add_done_callback(self._log_if_failed)
else:
self._queue.put_nowait((run_at, i, task))
Expand Down
106 changes: 103 additions & 3 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import socket
import struct
import sys
from threading import Thread, Event, RLock, Condition
from threading import Thread, Event, Lock, RLock, Condition
import time
import ssl
import uuid
Expand Down Expand Up @@ -783,6 +783,45 @@ def generate(self, shard_id: int, total_shards: int):
DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)


class SSLSessionCache(object):
"""
A thread-safe cache of ``ssl.SSLSession`` objects, keyed by connection TLS
identity.

When TLS is enabled, the driver stores the negotiated session after each
successful handshake and reuses it for subsequent connections to the same
host, enabling TLS session resumption (tickets / PSK) without any extra
configuration.

This cache is created automatically by :class:`.Cluster` when
``ssl_context`` or ``ssl_options`` are set. Pass ``ssl_session_cache=None``
to :class:`.Cluster` to opt out.

Note: only the stdlib ``ssl`` module is supported. Twisted and Eventlet
connections use pyOpenSSL, which has a different session API and is not
covered by this cache.
"""

def __init__(self):
self._lock = Lock()
self._cache = {}

def get(self, key):
"""
Return the cached ``ssl.SSLSession`` for ``key``, or ``None`` if none
is stored yet.
"""
with self._lock:
return self._cache.get(key)

def set(self, key, session):
"""
Store ``session`` for ``key``.
"""
with self._lock:
self._cache[key] = session


class Connection(object):

CALLBACK_ERR_THREAD_THRESHOLD = 100
Expand All @@ -803,6 +842,7 @@ class Connection(object):
endpoint = None
ssl_options = None
ssl_context = None
_ssl_session_cache = None
last_error = None

# The current number of operations that are in flight. More precisely,
Expand Down Expand Up @@ -880,13 +920,15 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None):
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None,
ssl_session_cache=None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)

self.authenticator = authenticator
self.ssl_options = ssl_options.copy() if ssl_options else {}
self.ssl_context = ssl_context
self._ssl_session_cache = ssl_session_cache
self.sockopts = sockopts
self.compression = compression
self.cql_version = cql_version
Expand Down Expand Up @@ -1029,7 +1071,25 @@ def _wrap_socket_from_context(self):
server_hostname = self.endpoint.address
opts['server_hostname'] = server_hostname

return self.ssl_context.wrap_socket(self._socket, **opts)
ssl_sock = self.ssl_context.wrap_socket(self._socket, **opts)

# Restore a previously cached session to enable TLS session resumption
# (session tickets / PSK). The session must be set *after*
# wrap_socket() (which only creates the SSLSocket) but *before*
# connect(), because connect() triggers the actual TLS handshake
# (via do_handshake_on_connect, which defaults to True).
# _initiate_connection, called after this method returns, performs
# the connect().
if self._ssl_session_cache is not None:
cached_session = self._ssl_session_cache.get(
self._ssl_session_cache_key())
if cached_session is not None:
try:
ssl_sock.session = cached_session
except (AttributeError, ssl.SSLError):
log.debug("Could not restore TLS session for %s", self.endpoint)

return ssl_sock

def _initiate_connection(self, sockaddr):
if self.features.shard_id is not None:
Expand All @@ -1043,6 +1103,30 @@ def _initiate_connection(self, sockaddr):

self._socket.connect(sockaddr)

def _cache_tls_session_if_needed(self):
"""
Store the current TLS session in the cache (if any) so that future
connections to the same endpoint can resume it.
"""
if self._ssl_session_cache is not None and self.ssl_context is not None:
session = getattr(self._socket, 'session', None)
if session is not None:
self._ssl_session_cache.set(self._ssl_session_cache_key(), session)

def _ssl_session_cache_key(self):
"""
Return a cache key that matches the TLS peer identity.

``server_hostname`` is included so that SNI-based connections routed
through the same proxy address/port do not overwrite each other's
sessions.
"""
return (
self.endpoint.address,
self.endpoint.port,
self.ssl_options.get('server_hostname') if self.ssl_options else None,
)

# PYTHON-1331
#
# Allow implementations specific to an event loop to add additional behaviours
Expand Down Expand Up @@ -1074,6 +1158,16 @@ def _connect_socket(self):
self._initiate_connection(sockaddr)
self._socket.settimeout(None)

# Cache the negotiated TLS session for future resumption.
# For TLS 1.2 the session is available right after connect().
# For TLS 1.3 the server sends the session ticket
# asynchronously after the first application-data exchange,
# so socket.session may still be None here; a second
# attempt is made in _cache_tls_session_if_needed() after
# the CQL handshake completes (see _handle_startup_response
# and _handle_auth_response).
self._cache_tls_session_if_needed()

local_addr = self._socket.getsockname()
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)

Expand Down Expand Up @@ -1578,6 +1672,9 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
if ProtocolVersion.has_checksumming_support(self.protocol_version):
self._enable_checksumming()

# TLS 1.3: the session ticket is sent after the first
# application-data exchange, so try caching it now.
self._cache_tls_session_if_needed()
self.connected_event.set()
elif isinstance(startup_response, AuthenticateMessage):
log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s",
Expand Down Expand Up @@ -1634,6 +1731,9 @@ def _handle_auth_response(self, auth_response):
self.authenticator.on_authentication_success(auth_response.token)
if self._compressor:
self.compressor = self._compressor
# TLS 1.3: the session ticket is sent after the first
# application-data exchange, so try caching it now.
self._cache_tls_session_if_needed()
self.connected_event.set()
elif isinstance(auth_response, AuthChallengeMessage):
response = self.authenticator.evaluate_challenge(auth_response.challenge)
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
from cassandra.connection import SSLSessionCache
from cassandra.pool import Host
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
Expand Down Expand Up @@ -634,3 +635,95 @@ def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self):
)

patched_logger.warning.assert_not_called()


class TestSSLSessionCacheAutoCreation(unittest.TestCase):

def test_cache_created_when_ssl_context_set(self):
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx)
assert isinstance(cluster.ssl_session_cache, SSLSessionCache)

def test_cache_created_when_ssl_options_set(self):
cluster = Cluster(contact_points=['127.0.0.1'], ssl_options={'ca_certs': '/dev/null'})
assert isinstance(cluster.ssl_session_cache, SSLSessionCache)

def test_no_cache_when_tls_not_enabled(self):
cluster = Cluster(contact_points=['127.0.0.1'])
assert cluster.ssl_session_cache is None

def test_explicit_none_disables_cache(self):
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx,
ssl_session_cache=None)
assert cluster.ssl_session_cache is None

def test_explicit_custom_cache_used(self):
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
custom = SSLSessionCache()
cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx,
ssl_session_cache=custom)
assert cluster.ssl_session_cache is custom

def test_cache_passed_to_connection_factory(self):
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
endpoint = Mock(address='127.0.0.1')
with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory:
cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx)
cluster.connection_factory(endpoint)

assert factory.call_args.kwargs['ssl_session_cache'] is cluster.ssl_session_cache

def test_warning_for_eventlet_connection_class(self):
"""A warning is logged when ssl_session_cache is set with EventletConnection."""
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

# Create a real class so issubclass() works throughout Cluster.__init__
from cassandra.connection import Connection as BaseConn
class FakeEventletConnection(BaseConn):
pass

with patch('cassandra.cluster.EventletConnection', FakeEventletConnection, create=True), \
patch('cassandra.cluster.log') as patched_logger:
Cluster(contact_points=['127.0.0.1'], ssl_context=ctx,
connection_class=FakeEventletConnection)

# At least one warning about pyOpenSSL
warning_calls = [c for c in patched_logger.warning.call_args_list
if 'pyOpenSSL' in str(c)]
assert len(warning_calls) == 1

def test_warning_for_twisted_connection_class(self):
"""A warning is logged when ssl_session_cache is set with TwistedConnection."""
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

from cassandra.connection import Connection as BaseConn
class FakeTwistedConnection(BaseConn):
pass

with patch('cassandra.cluster.TwistedConnection', FakeTwistedConnection, create=True), \
patch('cassandra.cluster.log') as patched_logger:
Cluster(contact_points=['127.0.0.1'], ssl_context=ctx,
connection_class=FakeTwistedConnection)

warning_calls = [c for c in patched_logger.warning.call_args_list
if 'pyOpenSSL' in str(c)]
assert len(warning_calls) == 1
Loading
Loading