diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..70a34fcdc9 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -52,7 +52,8 @@ from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, - SniEndPointFactory, ConnectionBusy, locally_supported_compressions) + SniEndPointFactory, ConnectionBusy, locally_supported_compressions, + SSLSessionCache) from cassandra.cqltypes import UserType import cassandra.cqltypes as types from cassandra.encoder import Encoder @@ -876,6 +877,26 @@ def default_retry_policy(self, policy): .. versionadded:: 3.17.0 """ + ssl_session_cache = None + """ + An optional :class:`.connection.SSLSessionCache` instance used to enable TLS + session resumption (via session tickets or PSK) for all connections managed + by this cluster. + + When :attr:`~Cluster.ssl_context` or :attr:`~Cluster.ssl_options` are set, + a cache is created automatically so that reconnections to the same host can + skip the full TLS handshake. Set this to :const:`None` explicitly to + disable session caching. + + Note: TLS 1.2 sessions are cached immediately after connect. TLS 1.3 + sessions are cached after the CQL handshake completes (Ready / AuthSuccess), + because session tickets are sent asynchronously by the server. + + Note: only the stdlib ``ssl`` reactor paths are supported (asyncore, libev, + gevent, asyncio). Twisted and Eventlet connections use pyOpenSSL and are + not covered by this cache. + """ + sockopts = None """ An optional list of tuples which will be used as arguments to @@ -1217,7 +1238,8 @@ def __init__(self, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, application_info:Optional[ApplicationInfoBase]=None, - client_routes_config:Optional[ClientRoutesConfig]=None + client_routes_config:Optional[ClientRoutesConfig]=None, + ssl_session_cache=_NOT_SET ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1461,6 +1483,30 @@ def __init__(self, self.ssl_options = ssl_options self.ssl_context = ssl_context + + # Auto-create a session cache when TLS is enabled, unless the caller + # explicitly passed ssl_session_cache (including None to opt out). + if ssl_session_cache is _NOT_SET: + if ssl_context is not None or ssl_options is not None: + self.ssl_session_cache = SSLSessionCache() + else: + self.ssl_session_cache = None + else: + self.ssl_session_cache = ssl_session_cache + + # Warn when the session cache won't be used because the connection + # class uses pyOpenSSL instead of the stdlib ssl module. + if self.ssl_session_cache is not None: + uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection) + uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection) + if uses_twisted or uses_eventlet: + log.warning( + "ssl_session_cache is set but the connection class %s uses " + "pyOpenSSL, which does not support stdlib ssl session " + "resumption. The cache will have no effect.", + self.connection_class.__name__, + ) + self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait @@ -1706,6 +1752,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict): kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) kwargs_dict.setdefault('ssl_context', self.ssl_context) + kwargs_dict.setdefault('ssl_session_cache', self.ssl_session_cache) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) diff --git a/cassandra/connection.py b/cassandra/connection.py index 72b273ec37..0388b33bbf 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -22,7 +22,7 @@ import socket import struct import sys -from threading import Thread, Event, RLock, Condition +from threading import Thread, Event, Lock, RLock, Condition import time import ssl import uuid @@ -783,6 +783,45 @@ def generate(self, shard_id: int, total_shards: int): DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH) +class SSLSessionCache(object): + """ + A thread-safe cache of ``ssl.SSLSession`` objects, keyed by connection TLS + identity. + + When TLS is enabled, the driver stores the negotiated session after each + successful handshake and reuses it for subsequent connections to the same + host, enabling TLS session resumption (tickets / PSK) without any extra + configuration. + + This cache is created automatically by :class:`.Cluster` when + ``ssl_context`` or ``ssl_options`` are set. Pass ``ssl_session_cache=None`` + to :class:`.Cluster` to opt out. + + Note: only the stdlib ``ssl`` module is supported. Twisted and Eventlet + connections use pyOpenSSL, which has a different session API and is not + covered by this cache. + """ + + def __init__(self): + self._lock = Lock() + self._cache = {} + + def get(self, key): + """ + Return the cached ``ssl.SSLSession`` for ``key``, or ``None`` if none + is stored yet. + """ + with self._lock: + return self._cache.get(key) + + def set(self, key, session): + """ + Store ``session`` for ``key``. + """ + with self._lock: + self._cache[key] = session + + class Connection(object): CALLBACK_ERR_THREAD_THRESHOLD = 100 @@ -803,6 +842,7 @@ class Connection(object): endpoint = None ssl_options = None ssl_context = None + _ssl_session_cache = None last_error = None # The current number of operations that are in flight. More precisely, @@ -880,13 +920,15 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, - on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None): + on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None, + ssl_session_cache=None): # TODO next major rename host to endpoint and remove port kwarg. self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else {} self.ssl_context = ssl_context + self._ssl_session_cache = ssl_session_cache self.sockopts = sockopts self.compression = compression self.cql_version = cql_version @@ -1029,7 +1071,25 @@ def _wrap_socket_from_context(self): server_hostname = self.endpoint.address opts['server_hostname'] = server_hostname - return self.ssl_context.wrap_socket(self._socket, **opts) + ssl_sock = self.ssl_context.wrap_socket(self._socket, **opts) + + # Restore a previously cached session to enable TLS session resumption + # (session tickets / PSK). The session must be set *after* + # wrap_socket() (which only creates the SSLSocket) but *before* + # connect(), because connect() triggers the actual TLS handshake + # (via do_handshake_on_connect, which defaults to True). + # _initiate_connection, called after this method returns, performs + # the connect(). + if self._ssl_session_cache is not None: + cached_session = self._ssl_session_cache.get( + self._ssl_session_cache_key()) + if cached_session is not None: + try: + ssl_sock.session = cached_session + except (AttributeError, ssl.SSLError): + log.debug("Could not restore TLS session for %s", self.endpoint) + + return ssl_sock def _initiate_connection(self, sockaddr): if self.features.shard_id is not None: @@ -1043,6 +1103,30 @@ def _initiate_connection(self, sockaddr): self._socket.connect(sockaddr) + def _cache_tls_session_if_needed(self): + """ + Store the current TLS session in the cache (if any) so that future + connections to the same endpoint can resume it. + """ + if self._ssl_session_cache is not None and self.ssl_context is not None: + session = getattr(self._socket, 'session', None) + if session is not None: + self._ssl_session_cache.set(self._ssl_session_cache_key(), session) + + def _ssl_session_cache_key(self): + """ + Return a cache key that matches the TLS peer identity. + + ``server_hostname`` is included so that SNI-based connections routed + through the same proxy address/port do not overwrite each other's + sessions. + """ + return ( + self.endpoint.address, + self.endpoint.port, + self.ssl_options.get('server_hostname') if self.ssl_options else None, + ) + # PYTHON-1331 # # Allow implementations specific to an event loop to add additional behaviours @@ -1074,6 +1158,16 @@ def _connect_socket(self): self._initiate_connection(sockaddr) self._socket.settimeout(None) + # Cache the negotiated TLS session for future resumption. + # For TLS 1.2 the session is available right after connect(). + # For TLS 1.3 the server sends the session ticket + # asynchronously after the first application-data exchange, + # so socket.session may still be None here; a second + # attempt is made in _cache_tls_session_if_needed() after + # the CQL handshake completes (see _handle_startup_response + # and _handle_auth_response). + self._cache_tls_session_if_needed() + local_addr = self._socket.getsockname() log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr) @@ -1578,6 +1672,9 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): if ProtocolVersion.has_checksumming_support(self.protocol_version): self._enable_checksumming() + # TLS 1.3: the session ticket is sent after the first + # application-data exchange, so try caching it now. + self._cache_tls_session_if_needed() self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", @@ -1634,6 +1731,9 @@ def _handle_auth_response(self, auth_response): self.authenticator.on_authentication_success(auth_response.token) if self._compressor: self.compressor = self._compressor + # TLS 1.3: the session ticket is sent after the first + # application-data exchange, so try caching it now. + self._cache_tls_session_if_needed() self.connected_event.set() elif isinstance(auth_response, AuthChallengeMessage): response = self.authenticator.evaluate_challenge(auth_response.challenge) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 4942fd4d69..fbd51a331b 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -23,6 +23,7 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import SSLSessionCache from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -634,3 +635,95 @@ def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self): ) patched_logger.warning.assert_not_called() + + +class TestSSLSessionCacheAutoCreation(unittest.TestCase): + + def test_cache_created_when_ssl_context_set(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx) + assert isinstance(cluster.ssl_session_cache, SSLSessionCache) + + def test_cache_created_when_ssl_options_set(self): + cluster = Cluster(contact_points=['127.0.0.1'], ssl_options={'ca_certs': '/dev/null'}) + assert isinstance(cluster.ssl_session_cache, SSLSessionCache) + + def test_no_cache_when_tls_not_enabled(self): + cluster = Cluster(contact_points=['127.0.0.1']) + assert cluster.ssl_session_cache is None + + def test_explicit_none_disables_cache(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + ssl_session_cache=None) + assert cluster.ssl_session_cache is None + + def test_explicit_custom_cache_used(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + custom = SSLSessionCache() + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + ssl_session_cache=custom) + assert cluster.ssl_session_cache is custom + + def test_cache_passed_to_connection_factory(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + endpoint = Mock(address='127.0.0.1') + with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx) + cluster.connection_factory(endpoint) + + assert factory.call_args.kwargs['ssl_session_cache'] is cluster.ssl_session_cache + + def test_warning_for_eventlet_connection_class(self): + """A warning is logged when ssl_session_cache is set with EventletConnection.""" + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + # Create a real class so issubclass() works throughout Cluster.__init__ + from cassandra.connection import Connection as BaseConn + class FakeEventletConnection(BaseConn): + pass + + with patch('cassandra.cluster.EventletConnection', FakeEventletConnection, create=True), \ + patch('cassandra.cluster.log') as patched_logger: + Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + connection_class=FakeEventletConnection) + + # At least one warning about pyOpenSSL + warning_calls = [c for c in patched_logger.warning.call_args_list + if 'pyOpenSSL' in str(c)] + assert len(warning_calls) == 1 + + def test_warning_for_twisted_connection_class(self): + """A warning is logged when ssl_session_cache is set with TwistedConnection.""" + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + from cassandra.connection import Connection as BaseConn + class FakeTwistedConnection(BaseConn): + pass + + with patch('cassandra.cluster.TwistedConnection', FakeTwistedConnection, create=True), \ + patch('cassandra.cluster.log') as patched_logger: + Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + connection_class=FakeTwistedConnection) + + warning_calls = [c for c in patched_logger.warning.call_args_list + if 'pyOpenSSL' in str(c)] + assert len(warning_calls) == 1 diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..0e5429267c 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,9 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator, + SniEndPoint, + SSLSessionCache) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -571,3 +573,237 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class TestSSLSessionCache(unittest.TestCase): + + @staticmethod + def _key(address, port, server_hostname=None): + return (address, port, server_hostname) + + def test_get_returns_none_when_empty(self): + cache = SSLSessionCache() + assert cache.get(self._key('127.0.0.1', 9042)) is None + + def test_set_and_get(self): + cache = SSLSessionCache() + session = object() # stand-in for ssl.SSLSession + cache.set(self._key('127.0.0.1', 9042), session) + assert cache.get(self._key('127.0.0.1', 9042)) is session + + def test_different_keys_are_independent(self): + cache = SSLSessionCache() + s1 = object() + s2 = object() + cache.set(self._key('127.0.0.1', 9042), s1) + cache.set(self._key('127.0.0.2', 9042), s2) + assert cache.get(self._key('127.0.0.1', 9042)) is s1 + assert cache.get(self._key('127.0.0.2', 9042)) is s2 + assert cache.get(self._key('127.0.0.1', 9043)) is None + + def test_sni_keys_are_independent_for_same_proxy(self): + cache = SSLSessionCache() + s1 = object() + s2 = object() + + cache.set(self._key('proxy.example.com', 9042, 'node-a'), s1) + cache.set(self._key('proxy.example.com', 9042, 'node-b'), s2) + + assert cache.get(self._key('proxy.example.com', 9042, 'node-a')) is s1 + assert cache.get(self._key('proxy.example.com', 9042, 'node-b')) is s2 + + def test_overwrite_existing_entry(self): + cache = SSLSessionCache() + old = object() + new = object() + cache.set(self._key('127.0.0.1', 9042), old) + cache.set(self._key('127.0.0.1', 9042), new) + assert cache.get(self._key('127.0.0.1', 9042)) is new + + def test_thread_safety(self): + """Concurrent set/get operations must not raise.""" + import threading + cache = SSLSessionCache() + errors = [] + + def writer(addr_suffix): + try: + for i in range(200): + cache.set(self._key('127.0.0.%d' % addr_suffix, 9042), object()) + except Exception as e: + errors.append(e) + + def reader(addr_suffix): + try: + for i in range(200): + cache.get(self._key('127.0.0.%d' % addr_suffix, 9042)) + except Exception as e: + errors.append(e) + + threads = [] + for n in range(5): + threads.append(threading.Thread(target=writer, args=(n,))) + threads.append(threading.Thread(target=reader, args=(n,))) + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + + +class TestConnectionSSLSessionRestore(unittest.TestCase): + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_restores_cached_session(self, _send, _connect): + """_wrap_socket_from_context sets ssl_sock.session from cache.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + cached = Mock(name='cached_session') + cache = SSLSessionCache() + cache.set(('10.0.0.1', 9042, None), cached) + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = cache + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + assert mock_ssl_sock.session == cached + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_tolerates_missing_cache(self, _send, _connect): + """No error when _ssl_session_cache is None.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = None + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_handles_set_session_failure(self, _send, _connect): + """If setting session raises ssl.SSLError, it is silently ignored.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + type(mock_ssl_sock).session = property( + fget=lambda self: None, + fset=Mock(side_effect=_ssl.SSLError("bad session")), + ) + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + cache = SSLSessionCache() + cache.set(('10.0.0.1', 9042, None), Mock(name='bad_cached')) + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = cache + + # Should NOT raise + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_uses_sni_specific_cached_session(self, _send, _connect): + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + expected = Mock(name='node_b_session') + cache = SSLSessionCache() + cache.set(('proxy.example.com', 9042, 'node-a'), Mock(name='node_a_session')) + cache.set(('proxy.example.com', 9042, 'node-b'), expected) + + conn = Connection.__new__(Connection) + conn.endpoint = SniEndPoint('proxy.example.com', 'node-b', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {'server_hostname': 'node-b'} + conn._ssl_session_cache = cache + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + assert mock_ssl_sock.session == expected + + +class TestConnectionCacheTLSSession(unittest.TestCase): + + def _make_conn(self): + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = Mock() + conn._ssl_session_cache = SSLSessionCache() + conn._socket = Mock() + return conn + + def test_cache_tls_session_stores_session(self): + conn = self._make_conn() + fake_session = Mock(name='ssl_session') + conn._socket.session = fake_session + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042, None)) is fake_session + + def test_cache_tls_session_no_op_when_session_none(self): + conn = self._make_conn() + conn._socket.session = None + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042, None)) is None + + def test_cache_tls_session_no_op_when_cache_none(self): + conn = self._make_conn() + conn._ssl_session_cache = None + conn._socket.session = Mock() + + # Should not raise + conn._cache_tls_session_if_needed() + + def test_cache_tls_session_no_op_when_no_ssl_context(self): + conn = self._make_conn() + conn.ssl_context = None + conn._socket.session = Mock() + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042, None)) is None + + def test_cache_tls_session_uses_sni_specific_key(self): + conn = Connection.__new__(Connection) + conn.endpoint = SniEndPoint('proxy.example.com', 'node-b', 9042) + conn.ssl_context = Mock() + conn.ssl_options = {'server_hostname': 'node-b'} + conn._ssl_session_cache = SSLSessionCache() + conn._socket = Mock() + fake_session = Mock(name='ssl_session') + conn._socket.session = fake_session + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('proxy.example.com', 9042, 'node-b')) is fake_session + assert conn._ssl_session_cache.get(('proxy.example.com', 9042, 'node-a')) is None