From b6a2adc14a7cb52aa4e950827af181ec6ad37d22 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 8 Mar 2026 10:46:47 +0200 Subject: [PATCH 1/5] Reduce data copies in connection receive path Replace getvalue() with getbuffer() memoryview in _read_frame_header and frame body extraction to avoid full-buffer copies. Add _reset_buffer() helper using getbuffer()[pos:] instead of read() to reduce allocations. Wrap memoryview usage in try/finally to ensure release before mutation. Increase in_buffer_size from 4096 to 16384 to reduce recv() call overhead. --- cassandra/connection.py | 827 +++++++++++++++++++++++++++++----------- 1 file changed, 600 insertions(+), 227 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 72b273ec37..126ffd0b08 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -40,15 +40,32 @@ else: from queue import Queue, Empty # noqa -from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion +from cassandra import ( + ConsistencyLevel, + AuthenticationFailed, + OperationTimedOut, + ProtocolVersion, +) from cassandra.marshal import int32_pack -from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, - StartupMessage, ErrorMessage, CredentialsMessage, - QueryMessage, ResultMessage, ProtocolHandler, - InvalidRequestException, SupportedMessage, - AuthResponseMessage, AuthChallengeMessage, - AuthSuccessMessage, ProtocolException, - RegisterMessage, ReviseRequestMessage) +from cassandra.protocol import ( + ReadyMessage, + AuthenticateMessage, + OptionsMessage, + StartupMessage, + ErrorMessage, + CredentialsMessage, + QueryMessage, + ResultMessage, + ProtocolHandler, + InvalidRequestException, + SupportedMessage, + AuthResponseMessage, + AuthChallengeMessage, + AuthSuccessMessage, + ProtocolException, + RegisterMessage, + ReviseRequestMessage, +) from cassandra.segment import SegmentCodec, CrcException from cassandra.util import OrderedDict from cassandra.shard_info import ShardingInfo @@ -66,7 +83,9 @@ try: import lz4 except ImportError: - log.debug("lz4 package could not be imported. LZ4 Compression will not be available") + log.debug( + "lz4 package could not be imported. LZ4 Compression will not be available" + ) pass else: # The compress and decompress functions we need were moved from the lz4 to @@ -105,7 +124,9 @@ def lz4_decompress(byts): try: import snappy except ImportError: - log.debug("snappy package could not be imported. Snappy Compression will not be available") + log.debug( + "snappy package could not be imported. Snappy Compression will not be available" + ) pass else: # work around apparently buggy snappy decompress @@ -113,11 +134,15 @@ def decompress(byts): if byts == '\x00': return '' return snappy.decompress(byts) + locally_supported_compressions['snappy'] = (snappy.compress, decompress) -DRIVER_NAME, DRIVER_VERSION = 'ScyllaDB Python Driver', sys.modules['cassandra'].__version__ +DRIVER_NAME, DRIVER_VERSION = ( + "ScyllaDB Python Driver", + sys.modules["cassandra"].__version__, +) -PROTOCOL_VERSION_MASK = 0x7f +PROTOCOL_VERSION_MASK = 0x7F HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_TO_CLIENT = 0x80 @@ -172,7 +197,6 @@ def resolve(self): class EndPointFactory(object): - cluster = None def configure(self, cluster): @@ -211,8 +235,11 @@ def resolve(self): return self._address, self._port def __eq__(self, other): - return isinstance(other, DefaultEndPoint) and \ - self.address == other.address and self.port == other.port + return ( + isinstance(other, DefaultEndPoint) + and self.address == other.address + and self.port == other.port + ) def __hash__(self): return hash((self.address, self.port)) @@ -228,7 +255,6 @@ def __repr__(self): class DefaultEndPointFactory(EndPointFactory): - port = None """ If no port is discovered in the row, this is the default port @@ -241,6 +267,7 @@ def __init__(self, port=None): def create(self, row): # TODO next major... move this class so we don't need this kind of hack from cassandra.metadata import _NodeInfo + addr = _NodeInfo.get_broadcast_rpc_address(row) port = _NodeInfo.get_broadcast_rpc_port(row) if port is None: @@ -248,9 +275,7 @@ def create(self, row): # create the endpoint with the translated address # TODO next major, create a TranslatedEndPoint type - return DefaultEndPoint( - self.cluster.address_translator.translate(addr), - port) + return DefaultEndPoint(self.cluster.address_translator.translate(addr), port) @total_ordering @@ -279,41 +304,55 @@ def ssl_options(self): def resolve(self): try: - resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, - socket.AF_UNSPEC, socket.SOCK_STREAM) + resolved_addresses = socket.getaddrinfo( + self._proxy_address, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) except socket.gaierror: - log.debug('Could not resolve sni proxy hostname "%s" ' - 'with port %d' % (self._proxy_address, self._port)) + log.debug( + 'Could not resolve sni proxy hostname "%s" ' + "with port %d" % (self._proxy_address, self._port) + ) raise # round-robin pick - self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)] + self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[ + self._index % len(resolved_addresses) + ] self._index += 1 return self._resolved_address, self._port def __eq__(self, other): - return (isinstance(other, SniEndPoint) and - self.address == other.address and self.port == other.port and - self._server_name == other._server_name) + return ( + isinstance(other, SniEndPoint) + and self.address == other.address + and self.port == other.port + and self._server_name == other._server_name + ) def __hash__(self): return hash((self.address, self.port, self._server_name)) def __lt__(self, other): - return ((self.address, self.port, self._server_name) < - (other.address, other.port, self._server_name)) + return (self.address, self.port, self._server_name) < ( + other.address, + other.port, + self._server_name, + ) def __str__(self): return str("%s:%d:%s" % (self.address, self.port, self._server_name)) def __repr__(self): - return "<%s: %s:%d:%s>" % (self.__class__.__name__, - self.address, self.port, self._server_name) + return "<%s: %s:%d:%s>" % ( + self.__class__.__name__, + self.address, + self.port, + self._server_name, + ) class SniEndPointFactory(EndPointFactory): - def __init__(self, proxy_address, port, node_domain=None): self._proxy_address = proxy_address self._port = port @@ -323,7 +362,11 @@ def create(self, row): host_id = row.get("host_id") if host_id is None: raise ValueError("No host_id to create the SniEndPoint") - address = "{}.{}".format(host_id, self._node_domain) if self._node_domain else str(host_id) + address = ( + "{}.{}".format(host_id, self._node_domain) + if self._node_domain + else str(host_id) + ) return SniEndPoint(self._proxy_address, str(address), self._port) def create_from_sni(self, sni): @@ -342,7 +385,9 @@ class ClientRoutesEndPointFactory(EndPointFactory): client_routes_handler: _ClientRoutesHandler default_port: int - def __init__(self, client_routes_handler: _ClientRoutesHandler, default_port: int = None) -> None: + def __init__( + self, client_routes_handler: _ClientRoutesHandler, default_port: int = None + ) -> None: """ :param client_routes_handler: _ClientRoutesHandler instance to lookup routes :param default_port: Default port if none found in row @@ -358,13 +403,18 @@ def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint': (route lookup) and DNS resolution happen later in resolve(). """ from cassandra.metadata import _NodeInfo + host_id = row.get("host_id") if host_id is None: raise ValueError("No host_id to create ClientRoutesEndPoint") addr = _NodeInfo.get_broadcast_rpc_address(row) - port = _NodeInfo.get_broadcast_rpc_port(row) or _NodeInfo.get_broadcast_port(row) or self.default_port + port = ( + _NodeInfo.get_broadcast_rpc_port(row) + or _NodeInfo.get_broadcast_port(row) + or self.default_port + ) return ClientRoutesEndPoint( host_id=host_id, @@ -399,8 +449,10 @@ def resolve(self): return self.address, None def __eq__(self, other): - return (isinstance(other, UnixSocketEndPoint) and - self._unix_socket_path == other._unix_socket_path) + return ( + isinstance(other, UnixSocketEndPoint) + and self._unix_socket_path == other._unix_socket_path + ) def __hash__(self): return hash(self._unix_socket_path) @@ -430,7 +482,13 @@ class ClientRoutesEndPoint(EndPoint): _original_address: str _original_port: int - def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None: + def __init__( + self, + host_id: uuid.UUID, + handler: _ClientRoutesHandler, + original_address: str, + original_port: int = None, + ) -> None: """ :param host_id: Host UUID for route lookup :param handler: _ClientRoutesHandler instance @@ -466,23 +524,30 @@ def resolve(self) -> Tuple[str, int]: return result def __eq__(self, other): - return (isinstance(other, ClientRoutesEndPoint) and - self._host_id == other._host_id and - self._original_address == other._original_address) + return ( + isinstance(other, ClientRoutesEndPoint) + and self._host_id == other._host_id + and self._original_address == other._original_address + ) def __hash__(self): return hash((self._host_id, self._original_address)) def __lt__(self, other): - return ((self._host_id, self._original_address) < - (other._host_id, other._original_address)) + return (self._host_id, self._original_address) < ( + other._host_id, + other._original_address, + ) def __str__(self): return str("%s (host_id=%s)" % (self._original_address, self._host_id)) def __repr__(self): return "<%s: host_id=%s, original_addr=%s>" % ( - self.__class__.__name__, self._host_id, self._original_address) + self.__class__.__name__, + self._host_id, + self._original_address, + ) class _Frame(object): @@ -496,16 +561,25 @@ def __init__(self, version, flags, stream, opcode, body_offset, end_pos): def __eq__(self, other): # facilitates testing if isinstance(other, _Frame): - return (self.version == other.version and - self.flags == other.flags and - self.stream == other.stream and - self.opcode == other.opcode and - self.body_offset == other.body_offset and - self.end_pos == other.end_pos) + return ( + self.version == other.version + and self.flags == other.flags + and self.stream == other.stream + and self.opcode == other.opcode + and self.body_offset == other.body_offset + and self.end_pos == other.end_pos + ) return NotImplemented def __str__(self): - return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) + return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format( + self.version, + self.flags, + self.stream, + self.opcode, + self.body_offset, + self.end_pos - self.body_offset, + ) NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) @@ -530,6 +604,7 @@ class ConnectionShutdown(ConnectionException): """ Raised when a connection has been marked as defunct or has been closed. """ + pass @@ -537,6 +612,7 @@ class ProtocolVersionUnsupported(ConnectionException): """ Server rejected startup message due to unsupported protocol version """ + def __init__(self, endpoint, startup_version): msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version) super(ProtocolVersionUnsupported, self).__init__(msg, endpoint) @@ -548,6 +624,7 @@ class ConnectionBusy(Exception): An attempt was made to send a message through a :class:`.Connection` that was already at the max number of in-flight operations. """ + pass @@ -555,12 +632,14 @@ class ProtocolError(Exception): """ Communication did not match the protocol that this driver expects. """ + pass class CrcMismatchException(ConnectionException): pass + class ContinuousPagingSession(object): def __init__(self, stream_id, decoder, row_factory, connection, state): self.stream_id = stream_id @@ -635,9 +714,15 @@ def maybe_request_more(self): max_queue_size = self._state.max_queue_size num_in_flight = self._state.num_pages_requested - self._state.num_pages_received space_in_queue = max_queue_size - len(self._page_queue) - num_in_flight - log.debug("Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", - self.stream_id, self.connection.host, space_in_queue, self._state.num_pages_requested, - self._state.num_pages_received, num_in_flight) + log.debug( + "Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", + self.stream_id, + self.connection.host, + space_in_queue, + self._state.num_pages_requested, + self._state.num_pages_received, + num_in_flight, + ) if space_in_queue >= max_queue_size / 2: self.update_next_pages(space_in_queue) @@ -645,37 +730,64 @@ def maybe_request_more(self): def update_next_pages(self, num_next_pages): try: self._state.num_pages_requested += num_next_pages - log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host) + log.debug( + "Updating backpressure for session %s from %s", + self.stream_id, + self.connection.host, + ) with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, - self.stream_id, - next_pages=num_next_pages), - self.connection.get_request_id(), - self._on_backpressure_response) + self.connection.send_msg( + ReviseRequestMessage( + ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, + self.stream_id, + next_pages=num_next_pages, + ), + self.connection.get_request_id(), + self._on_backpressure_response, + ) except ConnectionShutdown as ex: - log.debug("Failed to update backpressure for session %s from %s, connection is shutdown", - self.stream_id, self.connection.host) + log.debug( + "Failed to update backpressure for session %s from %s, connection is shutdown", + self.stream_id, + self.connection.host, + ) self.on_error(ex) def _on_backpressure_response(self, response): if isinstance(response, ResultMessage): log.debug("Paging session %s backpressure updated.", self.stream_id) else: - log.error("Failed updating backpressure for session %s from %s: %s", self.stream_id, self.connection.host, - response.to_exception() if hasattr(response, 'to_exception') else response) + log.error( + "Failed updating backpressure for session %s from %s: %s", + self.stream_id, + self.connection.host, + response.to_exception() + if hasattr(response, "to_exception") + else response, + ) self.on_error(response) def cancel(self): try: - log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host) + log.debug( + "Canceling paging session %s from %s", + self.stream_id, + self.connection.host, + ) with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, - self.stream_id), - self.connection.get_request_id(), - self._on_cancel_response) + self.connection.send_msg( + ReviseRequestMessage( + ReviseRequestMessage.RevisionType.PAGING_CANCEL, self.stream_id + ), + self.connection.get_request_id(), + self._on_cancel_response, + ) except ConnectionShutdown: - log.debug("Failed to cancel session %s from %s, connection is shutdown", - self.stream_id, self.connection.host) + log.debug( + "Failed to cancel session %s from %s, connection is shutdown", + self.stream_id, + self.connection.host, + ) with self._condition: self._stop = True @@ -685,8 +797,14 @@ def _on_cancel_response(self, response): if isinstance(response, ResultMessage): log.debug("Paging session %s canceled.", self.stream_id) else: - log.error("Failed canceling streaming session %s from %s: %s", self.stream_id, self.connection.host, - response.to_exception() if hasattr(response, 'to_exception') else response) + log.error( + "Failed canceling streaming session %s from %s: %s", + self.stream_id, + self.connection.host, + response.to_exception() + if hasattr(response, "to_exception") + else response, + ) self.released = True @@ -698,6 +816,7 @@ def wrapper(self, *args, **kwargs): return f(self, *args, **kwargs) except Exception as exc: self.defunct(exc) + return wrapper @@ -710,6 +829,7 @@ class _ConnectionIOBuffer(object): protocol V5 and checksumming, the data is read, validated and copied to another cql frame buffer. """ + _io_buffer = None _cql_frame_buffer = None _connection = None @@ -725,8 +845,9 @@ def io_buffer(self): @property def cql_frame_buffer(self): - return self._cql_frame_buffer if self.is_checksumming_enabled else \ - self._io_buffer + return ( + self._cql_frame_buffer if self.is_checksumming_enabled else self._io_buffer + ) def set_checksumming_buffer(self): self.reset_io_buffer() @@ -738,7 +859,7 @@ def is_checksumming_enabled(self): @property def has_consumed_segment(self): - return self._segment_consumed; + return self._segment_consumed def readable_io_bytes(self): return self.io_buffer.tell() @@ -746,14 +867,28 @@ def readable_io_bytes(self): def readable_cql_frame_bytes(self): return self.cql_frame_buffer.tell() + @staticmethod + def _reset_buffer(buf): + """ + Reset a BytesIO buffer by discarding consumed data. + + Uses ``getbuffer()[pos:]`` (a zero-copy memoryview slice) instead of + ``.read()`` which would first allocate an intermediate ``bytes`` object. + The ``BytesIO()`` constructor still copies the data into its own backing + store either way, so the net saving is one temporary ``bytes`` allocation + on the hot receive path. + """ + pos = buf.tell() + new_buf = io.BytesIO(buf.getbuffer()[pos:]) + new_buf.seek(0, 2) # 2 == SEEK_END + return new_buf + def reset_io_buffer(self): - self._io_buffer = io.BytesIO(self._io_buffer.read()) - self._io_buffer.seek(0, 2) # 2 == SEEK_END + self._io_buffer = self._reset_buffer(self._io_buffer) def reset_cql_frame_buffer(self): if self.is_checksumming_enabled: - self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read()) - self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END + self._cql_frame_buffer = self._reset_buffer(self._cql_frame_buffer) else: self.reset_io_buffer() @@ -771,23 +906,31 @@ def _align(value: int, total_shards: int): return value + total_shards - shift def generate(self, shard_id: int, total_shards: int): - start = self._align(random.randrange(self.start_port, self.end_port), total_shards) + shard_id + start = ( + self._align(random.randrange(self.start_port, self.end_port), total_shards) + + shard_id + ) beginning = self._align(self.start_port, total_shards) + shard_id - available_ports = itertools.chain(range(start, self.end_port, total_shards), - range(beginning, start, total_shards)) + available_ports = itertools.chain( + range(start, self.end_port, total_shards), + range(beginning, start, total_shards), + ) for port in available_ports: yield port -DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH) +DefaultShardAwarePortGenerator = ShardAwarePortGenerator( + DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH +) class Connection(object): - CALLBACK_ERR_THREAD_THRESHOLD = 100 - in_buffer_size = 4096 + # 16 KiB recv buffer reduces the number of syscalls when reading + # large result sets, at a modest per-connection memory cost. + in_buffer_size = 16384 out_buffer_size = 4096 cql_version = None @@ -875,14 +1018,32 @@ def _iobuf(self): # backward compatibility, to avoid any change in the reactors return self._io_buffer.io_buffer - def __init__(self, host='127.0.0.1', port=9042, authenticator=None, - ssl_options=None, sockopts=None, compression: Union[bool, str] = True, - cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, - user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, - ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, - on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None): + def __init__( + self, + host="127.0.0.1", + port=9042, + authenticator=None, + ssl_options=None, + sockopts=None, + compression: Union[bool, str] = True, + cql_version=None, + protocol_version=ProtocolVersion.MAX_SUPPORTED, + is_control_connection=False, + user_type_map=None, + connect_timeout=None, + allow_beta_protocol_version=False, + no_compact=False, + ssl_context=None, + owning_pool=None, + shard_id=None, + total_shards=None, + on_orphaned_stream_released=None, + application_info: Optional[ApplicationInfoBase] = None, + ): # TODO next major rename host to endpoint and remove port kwarg. - self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + self.endpoint = ( + host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + ) self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else {} @@ -984,15 +1145,29 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs): raise conn.last_error elif not conn.connected_event.is_set(): conn.close() - raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) + raise OperationTimedOut( + "Timed out creating connection (%s seconds)" % timeout + ) else: return conn def _build_ssl_context_from_options(self): # Extract a subset of names from self.ssl_options which apply to SSLContext creation - ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers'] - opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options} + ssl_context_opt_names = [ + "ssl_version", + "cert_reqs", + "check_hostname", + "keyfile", + "certfile", + "ca_certs", + "ciphers", + ] + opts = { + k: self.ssl_options.get(k, None) + for k in ssl_context_opt_names + if k in self.ssl_options + } # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always # being explicit @@ -1019,13 +1194,22 @@ def _wrap_socket_from_context(self): # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts # of it that don't involve building an SSLContext under the covers) - wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname'] - opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options} + wrap_socket_opt_names = [ + "server_side", + "do_handshake_on_connect", + "suppress_ragged_eofs", + "server_hostname", + ] + opts = { + k: self.ssl_options.get(k, None) + for k in wrap_socket_opt_names + if k in self.ssl_options + } # PYTHON-1186: set the server_hostname only if the SSLContext has # check_hostname enabled and it is not already provided by the EndPoint ssl options #opts['server_hostname'] = self.endpoint.address - if (self.ssl_context.check_hostname and 'server_hostname' not in opts): + if self.ssl_context.check_hostname and "server_hostname" not in opts: server_hostname = self.endpoint.address opts['server_hostname'] = server_hostname @@ -1033,13 +1217,20 @@ def _wrap_socket_from_context(self): def _initiate_connection(self, sockaddr): if self.features.shard_id is not None: - for port in DefaultShardAwarePortGenerator.generate(self.features.shard_id, self.total_shards): + for port in DefaultShardAwarePortGenerator.generate( + self.features.shard_id, self.total_shards + ): try: self._socket.bind(('', port)) break except Exception as ex: log.debug("port=%d couldn't bind cause: %s", port, str(ex)) - log.debug('connection (%r) port=%d should be shard_id=%d', id(self), port, port % self.total_shards) + log.debug( + "connection (%r) port=%d should be shard_id=%d", + id(self), + port, + port % self.total_shards, + ) self._socket.connect(sockaddr) @@ -1055,9 +1246,13 @@ def _get_socket_addresses(self): if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX: return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)] - addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM) + addresses = socket.getaddrinfo( + address, port, self.endpoint.socket_family, socket.SOCK_STREAM + ) if not addresses: - raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,)) + raise ConnectionException( + "getaddrinfo returned empty list for %s" % (self.endpoint,) + ) return addresses @@ -1065,7 +1260,7 @@ def _connect_socket(self): sockerr = None addresses = self._get_socket_addresses() port = None - for (af, socktype, proto, _, sockaddr) in addresses: + for af, socktype, proto, _, sockaddr in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) if self.ssl_context: @@ -1093,8 +1288,11 @@ def _connect_socket(self): sockerr = err if sockerr: - raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % - ([a[4] for a in addresses], sockerr.strerror or sockerr)) + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" + % ([a[4] for a in addresses], sockerr.strerror or sockerr), + ) if self.sockopts: for args in self.sockopts: @@ -1107,7 +1305,9 @@ def _enable_compression(self): def _enable_checksumming(self): self._io_buffer.set_checksumming_buffer() self._is_checksumming_enabled = True - self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression + self._segment_codec = ( + segment_codec_lz4 if self.compressor else segment_codec_no_compression + ) log.debug("Enabling protocol checksumming on connection (%s).", id(self)) def close(self): @@ -1122,11 +1322,16 @@ def defunct(self, exc): exc_info = sys.exc_info() # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message if any(exc_info): - log.debug("Defuncting connection (%s) to %s:", - id(self), self.endpoint, exc_info=exc_info) + log.debug( + "Defuncting connection (%s) to %s:", + id(self), + self.endpoint, + exc_info=exc_info, + ) else: - log.debug("Defuncting connection (%s) to %s: %s", - id(self), self.endpoint, exc) + log.debug( + "Defuncting connection (%s) to %s: %s", id(self), self.endpoint, exc + ) self.last_error = exc self.close() @@ -1154,9 +1359,13 @@ def try_callback(cb): try: cb(new_exc) except Exception: - log.warning("Ignoring unhandled exception while erroring requests for a " - "failed connection (%s) to host %s:", - id(self), self.endpoint, exc_info=True) + log.warning( + "Ignoring unhandled exception while erroring requests for a " + "failed connection (%s) to host %s:", + id(self), + self.endpoint, + exc_info=True, + ) # run first callback from this thread to ensure pool state before leaving cb, _, _ = requests.popitem()[1] @@ -1171,6 +1380,7 @@ def try_callback(cb): def err_all_callbacks(): for cb, _, _ in requests.values(): try_callback(cb) + if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() else: @@ -1201,7 +1411,15 @@ def handle_pushed(self, response): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): + def send_msg( + self, + msg, + request_id, + cb, + encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, + result_metadata=None, + ): if self.is_defunct: msg = "Connection to %s is defunct" % self.endpoint if self.last_error: @@ -1218,8 +1436,13 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, - allow_beta_protocol_version=self.allow_beta_protocol_version) + msg = encoder( + msg, + request_id, + self.protocol_version, + compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version, + ) if self._is_checksumming_enabled: buffer = io.BytesIO() @@ -1260,9 +1483,11 @@ def wait_for_responses(self, *msgs, **kwargs): self.in_flight += available for i, request_id in enumerate(request_ids): - self.send_msg(msgs[messages_sent + i], - request_id, - partial(waiter.got_response, index=messages_sent + i)) + self.send_msg( + msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i), + ) messages_sent += available if messages_sent == len(msgs): @@ -1288,8 +1513,8 @@ def register_watcher(self, event_type, callback, register_timeout=None): """ self._push_watchers[event_type].add(callback) self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) + RegisterMessage(event_list=[event_type]), timeout=register_timeout + ) def register_watchers(self, type_callback_dict, register_timeout=None): """ @@ -1299,7 +1524,8 @@ def register_watchers(self, type_callback_dict, register_timeout=None): self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) + timeout=register_timeout, + ) def control_conn_disposed(self): self.is_control_connection = False @@ -1307,19 +1533,30 @@ def control_conn_disposed(self): @defunct_on_error def _read_frame_header(self): - buf = self._io_buffer.cql_frame_buffer.getvalue() - pos = len(buf) + cql_buf = self._io_buffer.cql_frame_buffer + pos = cql_buf.tell() if pos: - version = buf[0] & PROTOCOL_VERSION_MASK - if version not in ProtocolVersion.SUPPORTED_VERSIONS: - raise ProtocolError("This version of the driver does not support protocol version %d" % version) - # this frame header struct is everything after the version byte - header_size = frame_header_v3.size + 1 - if pos >= header_size: - flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1) - if body_len < 0: - raise ProtocolError("Received negative body length: %r" % body_len) - self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) + buf = cql_buf.getbuffer() + try: + version = buf[0] & PROTOCOL_VERSION_MASK + if version not in ProtocolVersion.SUPPORTED_VERSIONS: + raise ProtocolError( + "This version of the driver does not support protocol version %d" + % version + ) + # this frame header struct is everything after the version byte + header_size = frame_header_v3.size + 1 + if pos >= header_size: + flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1) + if body_len < 0: + raise ProtocolError( + "Received negative body length: %r" % body_len + ) + self._current_frame = _Frame( + version, flags, stream, op, header_size, body_len + header_size + ) + finally: + del buf # release memoryview before any buffer mutation return pos @defunct_on_error @@ -1328,7 +1565,9 @@ def _process_segment_buffer(self): if readable_bytes >= self._segment_codec.header_length_with_crc: try: self._io_buffer.io_buffer.seek(0) - segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer) + segment_header = self._segment_codec.decode_header( + self._io_buffer.io_buffer + ) if readable_bytes >= segment_header.segment_length: segment = self._segment_codec.decode(self._iobuf, segment_header) @@ -1351,7 +1590,10 @@ def process_io_buffer(self): self._process_segment_buffer() self._io_buffer.reset_io_buffer() - if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment: + if ( + self._is_checksumming_enabled + and not self._io_buffer.has_consumed_segment + ): # We couldn't read an entire segment from the io buffer, so return # control to allow more bytes to be read off the wire return @@ -1362,7 +1604,10 @@ def process_io_buffer(self): pos = self._io_buffer.readable_cql_frame_bytes() if not self._current_frame or pos < self._current_frame.end_pos: - if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + if ( + self._is_checksumming_enabled + and self._io_buffer.readable_io_bytes() + ): # We have a multi-segments message and we need to read more # data to complete the current cql frame continue @@ -1373,8 +1618,16 @@ def process_io_buffer(self): return else: frame = self._current_frame - self._io_buffer.cql_frame_buffer.seek(frame.body_offset) - msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) + # Use memoryview to avoid intermediate allocation, then + # convert to bytes. Explicitly release the memoryview + # before any buffer mutation (seek / reset). + cql_buf = self._io_buffer.cql_frame_buffer + buf = cql_buf.getbuffer() + try: + msg = bytes(buf[frame.body_offset : frame.end_pos]) + finally: + del buf # release memoryview before buffer mutation + cql_buf.seek(frame.end_pos) self.process_msg(frame, msg) self._io_buffer.reset_cql_frame_buffer() self._current_frame = None @@ -1413,11 +1666,23 @@ def process_msg(self, header, body): return try: - response = decoder(header.version, self.features, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor, result_metadata) + response = decoder( + header.version, + self.features, + self.user_type_map, + stream_id, + header.flags, + header.opcode, + body, + self.decompressor, + result_metadata, + ) except Exception as exc: - log.exception("Error decoding response from Cassandra. " - "%s; buffer: %r", header, self._iobuf.getvalue()) + log.exception( + "Error decoding response from Cassandra. %s; buffer: %r", + header, + self._iobuf.getvalue(), + ) if callback is not None: callback(exc) self.defunct(exc) @@ -1429,7 +1694,11 @@ def process_msg(self, header, body): if 'unsupported protocol version' in response.message: self.is_unsupported_proto_version = True else: - log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) + log.error( + "Closing connection %s due to protocol error: %s", + self, + response.summary_msg(), + ) self.defunct(response) if callback is not None: callback(response) @@ -1463,8 +1732,14 @@ def remove_continuous_paging_session(self, stream_id): @defunct_on_error def _send_options_message(self): - log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint) - self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) + log.debug( + "Sending initial options message for new connection (%s) to %s", + id(self), + self.endpoint, + ) + self.send_msg( + OptionsMessage(), self.get_request_id(), self._handle_options_response + ) @defunct_on_error def _handle_options_response(self, options_response): @@ -1476,14 +1751,20 @@ def _handle_options_response(self, options_response): if isinstance(options_response, ConnectionException): raise options_response else: - log.error("Did not get expected SupportedMessage response; " - "instead, got: %s", options_response) - raise ConnectionException("Did not get expected SupportedMessage " - "response; instead, got: %s" - % (options_response,)) - - log.debug("Received options response on new connection (%s) from %s", - id(self), self.endpoint) + log.error( + "Did not get expected SupportedMessage response; instead, got: %s", + options_response, + ) + raise ConnectionException( + "Did not get expected SupportedMessage " + "response; instead, got: %s" % (options_response,) + ) + + log.debug( + "Received options response on new connection (%s) from %s", + id(self), + self.endpoint, + ) supported_cql_versions = options_response.cql_versions remote_supported_compressions = options_response.options['COMPRESSION'] self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0] @@ -1498,21 +1779,25 @@ def _handle_options_response(self, options_response): raise ProtocolError( "cql_version %r is not supported by remote (w/ native " "protocol). Supported versions: %r" - % (self.cql_version, supported_cql_versions)) + % (self.cql_version, supported_cql_versions) + ) else: self.cql_version = supported_cql_versions[0] self._compressor = None compression_type = None if self.compression: - overlap = (set(locally_supported_compressions.keys()) & - set(remote_supported_compressions)) + overlap = set(locally_supported_compressions.keys()) & set( + remote_supported_compressions + ) if len(overlap) == 0: if locally_supported_compressions: - log.error("No available compression types supported on both ends." - " locally supported: %r. remotely supported: %r", - locally_supported_compressions.keys(), - remote_supported_compressions) + log.error( + "No available compression types supported on both ends." + " locally supported: %r. remotely supported: %r", + locally_supported_compressions.keys(), + remote_supported_compressions, + ) else: compression_type = None if isinstance(self.compression, str): @@ -1520,7 +1805,8 @@ def _handle_options_response(self, options_response): if self.compression not in remote_supported_compressions: raise ProtocolError( "The requested compression type (%s) is not supported by the Cassandra server at %s" - % (self.compression, self.endpoint)) + % (self.compression, self.endpoint) + ) compression_type = self.compression else: # our locally supported compressions are ordered to prefer @@ -1532,26 +1818,38 @@ def _handle_options_response(self, options_response): # If snappy compression is selected with v5+checksumming, the connection # will fail with OTO. Only lz4 is supported - if (compression_type == 'snappy' and - ProtocolVersion.has_checksumming_support(self.protocol_version)): - log.debug("Snappy compression is not supported with protocol version %s and " - "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version) + if ( + compression_type == "snappy" + and ProtocolVersion.has_checksumming_support(self.protocol_version) + ): + log.debug( + "Snappy compression is not supported with protocol version %s and " + "checksumming. Consider installing lz4. Disabling compression.", + self.protocol_version, + ) compression_type = None else: # set the decompressor here, but set the compressor only after # a successful Ready message self._compression_type = compression_type - self._compressor, self.decompressor = \ + self._compressor, self.decompressor = ( locally_supported_compressions[compression_type] + ) - self._send_startup_message(compression_type, no_compact=self.no_compact, extra_options=options) + self._send_startup_message( + compression_type, no_compact=self.no_compact, extra_options=options + ) @defunct_on_error - def _send_startup_message(self, compression=None, no_compact=False, extra_options=None): + def _send_startup_message( + self, compression=None, no_compact=False, extra_options=None + ): log.debug("Sending StartupMessage on %s", self) - opts = {'DRIVER_NAME': DRIVER_NAME, - 'DRIVER_VERSION': DRIVER_VERSION, - **extra_options} + opts = { + "DRIVER_NAME": DRIVER_NAME, + 'DRIVER_VERSION': DRIVER_VERSION, + **extra_options, + } if compression: opts['COMPRESSION'] = compression if no_compact: @@ -1567,12 +1865,18 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): if isinstance(startup_response, ReadyMessage): if self.authenticator: - log.warning("An authentication challenge was not sent, " - "this is suspicious because the driver expects " - "authentication (configured authenticator = %s)", - self.authenticator.__class__.__name__) - - log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) + log.warning( + "An authentication challenge was not sent, " + "this is suspicious because the driver expects " + "authentication (configured authenticator = %s)", + self.authenticator.__class__.__name__, + ) + + log.debug( + "Got ReadyMessage on new connection (%s) from %s", + id(self), + self.endpoint, + ) self._enable_compression() if ProtocolVersion.has_checksumming_support(self.protocol_version): @@ -1580,13 +1884,20 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): - log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, startup_response.authenticator) + log.debug( + "Got AuthenticateMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + startup_response.authenticator, + ) if self.authenticator is None: - log.error("Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " - "consider using TransitionalModePlainTextAuthProvider " - "if DSE authentication is configured with transitional mode" % (self.host,)) + log.error( + "Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " + "consider using TransitionalModePlainTextAuthProvider " + "if DSE authentication is configured with transitional mode" + % (self.host,) + ) raise AuthenticationFailed('Remote end requires authentication') self._enable_compression() @@ -1600,24 +1911,38 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.send_msg(cm, self.get_request_id(), cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) - self.authenticator.server_authenticator_class = startup_response.authenticator + self.authenticator.server_authenticator_class = ( + startup_response.authenticator + ) initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response - self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), - self._handle_auth_response) + self.send_msg( + AuthResponseMessage(initial_response), + self.get_request_id(), + self._handle_auth_response, + ) elif isinstance(startup_response, ErrorMessage): - log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, startup_response.summary_msg()) + log.debug( + "Received ErrorMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + startup_response.summary_msg(), + ) if did_authenticate: raise AuthenticationFailed( - "Failed to authenticate to %s: %s" % - (self.endpoint, startup_response.summary_msg())) + "Failed to authenticate to %s: %s" + % (self.endpoint, startup_response.summary_msg()) + ) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" - % (self.endpoint, startup_response.summary_msg())) + % (self.endpoint, startup_response.summary_msg()) + ) elif isinstance(startup_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the startup handshake", (self.endpoint)) + log.debug( + "Connection to %s was closed during the startup handshake", + (self.endpoint), + ) raise startup_response else: msg = "Unexpected response during Connection setup: %r" @@ -1641,13 +1966,21 @@ def _handle_auth_response(self, auth_response): log.debug("Responding to auth challenge on %s", self) self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): - log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, auth_response.summary_msg()) + log.debug( + "Received ErrorMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + auth_response.summary_msg(), + ) raise AuthenticationFailed( - "Failed to authenticate to %s: %s" % - (self.endpoint, auth_response.summary_msg())) + "Failed to authenticate to %s: %s" + % (self.endpoint, auth_response.summary_msg()) + ) elif isinstance(auth_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the authentication process", self.endpoint) + log.debug( + "Connection to %s was closed during the authentication process", + self.endpoint, + ) raise auth_response else: msg = "Unexpected response during Connection authentication to %s: %r" @@ -1658,8 +1991,9 @@ def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: return - query = QueryMessage(query='USE "%s"' % (keyspace,), - consistency_level=ConsistencyLevel.ONE) + query = QueryMessage( + query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE + ) try: result = self.wait_for_response(query) except InvalidRequestException as ire: @@ -1667,7 +2001,8 @@ def set_keyspace_blocking(self, keyspace): raise ire.to_exception() except Exception as exc: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.endpoint) + "Problem while setting keyspace: %r" % (exc,), self.endpoint + ) self.defunct(conn_exc) raise conn_exc @@ -1675,7 +2010,8 @@ def set_keyspace_blocking(self, keyspace): self.keyspace = keyspace else: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.endpoint) + "Problem while setting keyspace: %r" % (result,), self.endpoint + ) self.defunct(conn_exc) raise conn_exc @@ -1712,8 +2048,9 @@ def set_keyspace_async(self, keyspace, callback): callback(self, None) return - query = QueryMessage(query='USE "%s"' % (keyspace,), - consistency_level=ConsistencyLevel.ONE) + query = QueryMessage( + query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE + ) def process_result(result): if isinstance(result, ResultMessage): @@ -1722,8 +2059,15 @@ def process_result(result): elif isinstance(result, InvalidRequestException): callback(self, result.to_exception()) else: - callback(self, self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.endpoint))) + callback( + self, + self.defunct( + ConnectionException( + "Problem while setting keyspace: %r" % (result,), + self.endpoint, + ) + ), + ) # We've incremented self.in_flight above, so we "have permission" to # acquire a new request id @@ -1745,12 +2089,17 @@ def __str__(self): elif self.is_closed: status = " (closed)" - return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status) + return "<%s(%r) %s%s>" % ( + self.__class__.__name__, + id(self), + self.endpoint, + status, + ) + __repr__ = __str__ class ResponseWaiter(object): - def __init__(self, connection, num_responses, fail_on_error): self.connection = connection self.pending = num_responses @@ -1805,14 +2154,23 @@ def __init__(self, connection, owner): self._event = Event() self.connection = connection self.owner = owner - log.debug("Sending options message heartbeat on idle connection (%s) %s", - id(connection), connection.endpoint) + log.debug( + "Sending options message heartbeat on idle connection (%s) %s", + id(connection), + connection.endpoint, + ) with connection.lock: if connection.in_flight < connection.max_request_id: connection.in_flight += 1 - connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + connection.send_msg( + OptionsMessage(), + connection.get_request_id(), + self._options_callback, + ) else: - self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") + self._exception = Exception( + "Failed to send heartbeat because connection 'in_flight' exceeds threshold" + ) self._event.set() def wait(self, timeout): @@ -1821,23 +2179,29 @@ def wait(self, timeout): if self._exception: raise self._exception else: - raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint) + raise OperationTimedOut( + "Connection heartbeat timeout after %s seconds" % (timeout,), + self.connection.endpoint, + ) def _options_callback(self, response): if isinstance(response, SupportedMessage): - log.debug("Received options response on connection (%s) from %s", - id(self.connection), self.connection.endpoint) + log.debug( + "Received options response on connection (%s) from %s", + id(self.connection), + self.connection.endpoint, + ) else: if isinstance(response, ConnectionException): self._exception = response else: - self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" - % (response,)) + self._exception = ConnectionException( + "Received unexpected response to OptionsMessage: %s" % (response,) + ) self._event.set() class ConnectionHeartbeat(Thread): - def __init__(self, interval_sec, get_connection_holders, timeout): Thread.__init__(self, name="Connection heartbeat") self._interval = interval_sec @@ -1858,7 +2222,9 @@ def run(self): futures = [] failed_connections = [] try: - for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: + for connections, owner in [ + (o.get_connections(), o) for o in self._get_connection_holders() + ]: for connection in connections: self._raise_if_stopped() if not (connection.is_defunct or connection.is_closed): @@ -1866,14 +2232,20 @@ def run(self): try: futures.append(HeartbeatFuture(connection, owner)) except Exception as e: - log.warning("Failed sending heartbeat message on connection (%s) to %s", - id(connection), connection.endpoint) + log.warning( + "Failed sending heartbeat message on connection (%s) to %s", + id(connection), + connection.endpoint, + ) failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: - log.debug("Cannot send heartbeat message on connection (%s) to %s", - id(connection), connection.endpoint) + log.debug( + "Cannot send heartbeat message on connection (%s) to %s", + id(connection), + connection.endpoint, + ) # make sure the owner sees this defunt/closed connection owner.return_connection(connection) self._raise_if_stopped() @@ -1891,8 +2263,11 @@ def run(self): connection.in_flight -= 1 connection.reset_idle() except Exception as e: - log.warning("Heartbeat failed for connection (%s) to %s", - id(connection), connection.endpoint) + log.warning( + "Heartbeat failed for connection (%s) to %s", + id(connection), + connection.endpoint, + ) failed_connections.append((f.connection, f.owner, e)) timeout = self._timeout - (time.time() - start_time) @@ -1922,7 +2297,6 @@ def _raise_if_stopped(self): class Timer(object): - canceled = False def __init__(self, timeout, callback): @@ -1947,7 +2321,6 @@ def finish(self, time_now): class TimerManager(object): - def __init__(self): self._queue = [] self._new_timers = [] From 626148cbf1954c4bca6993fa0e7d735223c8f5b5 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 8 Mar 2026 10:46:53 +0200 Subject: [PATCH 2/5] Add BytesReader to replace BytesIO in decode_message Introduce a lightweight BytesReader class that operates directly on bytes data without BytesIO overhead. Materializes memoryview to bytes once in __init__ instead of checking on every read(). Includes remaining_buffer() method for zero-copy handoff to Cython parsers. --- cassandra/protocol.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..a2c82436f5 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -53,6 +53,43 @@ class NotSupportedError(Exception): class InternalError(Exception): pass + +class BytesReader: + """ + Lightweight reader for bytes data without BytesIO overhead. + Provides the same read() interface but operates directly on a + bytes object, avoiding internal buffer copies. + + Unlike io.BytesIO.read(n), read(n) raises EOFError when fewer than + n bytes remain. This is intentional: protocol parsing should fail + fast on truncated or malformed frames rather than silently returning + short data. + """ + __slots__ = ('_data', '_pos', '_size') + + def __init__(self, data): + # Materialize memoryview up front so read() never needs to check + self._data = bytes(data) if isinstance(data, memoryview) else data + self._pos = 0 + self._size = len(self._data) + + def read(self, n=-1): + if n < 0: + result = self._data[self._pos:] + self._pos = self._size + else: + end = self._pos + n + if end > self._size: + raise EOFError("Cannot read past the end of the buffer") + result = self._data[self._pos:end] + self._pos = end + return result + + def remaining_buffer(self): + """Return (underlying_bytes, current_position) for zero-copy handoff.""" + return self._data, self._pos + + ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) HEADER_DIRECTION_TO_CLIENT = 0x80 @@ -1155,7 +1192,8 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre body = decompressor(body) flags ^= COMPRESSED_FLAG - body = io.BytesIO(body) + # Use lightweight BytesReader instead of io.BytesIO to avoid buffer copy + body = BytesReader(body) if flags & TRACING_FLAG: trace_id = UUID(bytes=body.read(16)) flags ^= TRACING_FLAG From b1f566f5f3a63fd3335c36a58ac6c53e01307d72 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 8 Mar 2026 10:47:00 +0200 Subject: [PATCH 3/5] Eliminate copy in Cython row parser handoff (OPT-1) Add offset parameter to BytesIOReader so it can start reading from the middle of an existing buffer, avoiding the full-body copy at the Python-to-Cython boundary. Update row_parser.pyx to use f.remaining_buffer() for zero-copy handoff with hasattr fallback. Track _initial_offset for error recovery. --- cassandra/bytesio.pxd | 1 + cassandra/bytesio.pyx | 12 +++++++++--- cassandra/row_parser.pyx | 10 ++++++++-- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/cassandra/bytesio.pxd b/cassandra/bytesio.pxd index d52d3fa8fe..40edcd996d 100644 --- a/cassandra/bytesio.pxd +++ b/cassandra/bytesio.pxd @@ -17,4 +17,5 @@ cdef class BytesIOReader: cdef char *buf_ptr cdef Py_ssize_t pos cdef Py_ssize_t size + cdef Py_ssize_t _initial_offset cdef char *read(self, Py_ssize_t n = ?) except NULL diff --git a/cassandra/bytesio.pyx b/cassandra/bytesio.pyx index 1a57911fcf..4244e74d7c 100644 --- a/cassandra/bytesio.pyx +++ b/cassandra/bytesio.pyx @@ -16,12 +16,18 @@ cdef class BytesIOReader: """ This class provides efficient support for reading bytes from a 'bytes' buffer, by returning char * values directly without allocating intermediate objects. + + An optional offset allows reading from the middle of an existing buffer, + avoiding a copy when only a suffix of the bytes is needed. """ - def __init__(self, bytes buf): + def __init__(self, bytes buf, Py_ssize_t offset=0): + if offset < 0 or offset > len(buf): + raise ValueError("offset %d out of range for buffer of length %d" % (offset, len(buf))) self.buf = buf - self.size = len(buf) - self.buf_ptr = self.buf + self._initial_offset = offset + self.size = len(buf) - offset + self.buf_ptr = self.buf + offset cdef char *read(self, Py_ssize_t n = -1) except NULL: """Read at most size bytes from the file diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..5a99dfac36 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -35,13 +35,19 @@ def make_recv_results_rows(ColumnParser colparser): desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, [ColDesc(md[0], md[1], md[2]) for md in column_metadata], make_deserializers(self.column_types), protocol_version) - reader = BytesIOReader(f.read()) + # Zero-copy handoff: reuse the underlying bytes buffer at its current + # position instead of copying via f.read(). + if hasattr(f, 'remaining_buffer'): + buf_data, buf_offset = f.remaining_buffer() + reader = BytesIOReader(buf_data, buf_offset) + else: + reader = BytesIOReader(f.read()) try: self.parsed_rows = colparser.parse_rows(reader, desc) except Exception as e: # Use explicitly the TupleRowParser to display better error messages for column decoding failures rowparser = TupleRowParser() - reader.buf_ptr = reader.buf + reader.buf_ptr = reader.buf + reader._initial_offset reader.pos = 0 rowcount = read_int(reader) for i in range(rowcount): From 8def949fe8921e3ce353ff63c2c2b46a159c5e65 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Sun, 8 Mar 2026 10:47:06 +0200 Subject: [PATCH 4/5] Add tests for BytesReader and BytesIOReader offset parameter 13 BytesReader tests covering read operations, remaining_buffer(), memoryview materialization, empty data, and EOFError handling. 9 BytesIOReader tests covering offset initialization, boundary conditions, read behavior with offset, and error cases. --- tests/unit/test_bytesio_reader.py | 67 ++++++++ tests/unit/test_connection.py | 42 ++++- tests/unit/test_protocol.py | 271 +++++++++++++++++++++++++----- 3 files changed, 337 insertions(+), 43 deletions(-) create mode 100644 tests/unit/test_bytesio_reader.py diff --git a/tests/unit/test_bytesio_reader.py b/tests/unit/test_bytesio_reader.py new file mode 100644 index 0000000000..bb9e8eb781 --- /dev/null +++ b/tests/unit/test_bytesio_reader.py @@ -0,0 +1,67 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import pytest + +try: + from cassandra.bytesio import BytesIOReader + + has_cython = True +except ImportError: + has_cython = False + + +@pytest.mark.skipif(not has_cython, reason="Cython extensions not compiled") +class BytesIOReaderTest(unittest.TestCase): + """Tests for the Cython BytesIOReader, including the offset parameter. + + Note: BytesIOReader.read() is a cdef method, so it cannot be called + directly from Python. Reading with an offset is exercised through the + end-to-end decode_message test in test_protocol.py which goes through + the Cython row parser path (remaining_buffer -> BytesIOReader(buf, offset)). + """ + + def test_construct_no_offset(self): + # Should not raise + reader = BytesIOReader(b"\x00\x01\x02\x03\x04\x05") + + def test_construct_with_zero_offset(self): + reader = BytesIOReader(b"hello world", 0) + + def test_construct_with_offset(self): + reader = BytesIOReader(b"header_row_data", 7) + + def test_construct_offset_at_end(self): + data = b"abcdef" + reader = BytesIOReader(data, len(data)) + + def test_construct_negative_offset_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", -1) + + def test_construct_offset_past_end_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", 6) + + def test_construct_offset_way_past_end_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"hello", 100) + + def test_construct_empty_buffer_zero_offset(self): + reader = BytesIOReader(b"", 0) + + def test_construct_empty_buffer_nonzero_offset_raises(self): + with self.assertRaises(ValueError): + BytesIOReader(b"", 1) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..b88b5bee2d 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,8 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator, + _ConnectionIOBuffer) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -571,3 +572,42 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class ResetBufferTest(unittest.TestCase): + """Tests for _ConnectionIOBuffer._reset_buffer static method.""" + + def test_preserves_remaining_data(self): + buf = BytesIO() + buf.write(b"already_consumed_new_data") + buf.seek(17) # position after "already_consumed_" + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"new_data") + # Cursor is at SEEK_END, ready for further writes + self.assertEqual(result.tell(), len(b"new_data")) + + def test_empty_remaining(self): + buf = BytesIO() + buf.write(b"all_consumed") + buf.seek(12) + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"") + self.assertEqual(result.tell(), 0) + + def test_nothing_consumed(self): + buf = BytesIO() + buf.write(b"all_remaining") + buf.seek(0) + result = _ConnectionIOBuffer._reset_buffer(buf) + self.assertEqual(result.getvalue(), b"all_remaining") + # Cursor is at SEEK_END, ready for further writes + self.assertEqual(result.tell(), len(b"all_remaining")) + + def test_new_buffer_is_writable(self): + buf = BytesIO() + buf.write(b"head_tail") + buf.seek(5) + result = _ConnectionIOBuffer._reset_buffer(buf) + result.seek(0, 2) # seek to end + result.write(b"_more") + self.assertEqual(result.getvalue(), b"tail_more") diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..3e46d82a9c 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -13,15 +13,26 @@ # limitations under the License. import unittest +from io import BytesIO from unittest.mock import Mock from cassandra import ProtocolVersion, UnsupportedOperation from cassandra.protocol import ( - PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, - _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, - _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + PrepareMessage, + QueryMessage, + ExecuteMessage, + UnsupportedOperation, + _PAGING_OPTIONS_FLAG, + _WITH_SERIAL_CONSISTENCY_FLAG, + _PAGE_SIZE_FLAG, + _WITH_PAGING_STATE_FLAG, + BatchMessage, + BytesReader, + ProtocolHandler, + SupportedMessage, + ReadyMessage, + write_stringmultimap, ) from cassandra.query import BatchType from cassandra.marshal import uint32_unpack @@ -30,7 +41,6 @@ class MessageTest(unittest.TestCase): - def test_prepare_message(self): """ Test to check the appropriate calls are made @@ -57,16 +67,26 @@ def test_execute_message(self): io = Mock() message.send_body(io, 4) - self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) + self._check_calls( + io, [(b"\x00\x01",), (b"1",), (b"\x00\x04",), (b"\x01",), (b"\x00\x00",)] + ) io.reset_mock() message.result_metadata_id = 'foo' message.send_body(io, 5) - self._check_calls(io, [(b'\x00\x01',), (b'1',), - (b'\x00\x03',), (b'foo',), - (b'\x00\x04',), - (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) + self._check_calls( + io, + [ + (b"\x00\x01",), + (b"1",), + (b'\x00\x03',), + (b"foo",), + (b'\x00\x04',), + (b'\x00\x00\x00\x01',), + (b"\x00\x00",), + ], + ) def test_query_message(self): """ @@ -82,11 +102,16 @@ def test_query_message(self): io = Mock() message.send_body(io, 4) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) + self._check_calls( + io, [(b"\x00\x00\x00\x01",), (b"a",), (b"\x00\x03",), (b"\x00",)] + ) io.reset_mock() message.send_body(io, 5) - self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) + self._check_calls( + io, + [(b"\x00\x00\x00\x01",), (b"a",), (b"\x00\x03",), (b"\x00\x00\x00\x00",)], + ) def _check_calls(self, io, expected): assert tuple(c[1] for c in io.write.mock_calls) == tuple(expected) @@ -118,13 +143,16 @@ def test_prepare_flag_with_keyspace(self): for version in ProtocolVersion.SUPPORTED_VERSIONS: if ProtocolVersion.uses_keyspace_flag(version): message.send_body(io, version) - self._check_calls(io, [ - (b'\x00\x00\x00\x01',), - (b'a',), - (b'\x00\x00\x00\x01',), - (b'\x00\x02',), - (b'ks',), - ]) + self._check_calls( + io, + [ + (b'\x00\x00\x00\x01',), + (b'a',), + (b'\x00\x00\x00\x01',), + (b'\x00\x02',), + (b'ks',), + ], + ) else: with pytest.raises(UnsupportedOperation): message.send_body(io, version) @@ -150,42 +178,201 @@ def test_keyspace_written_with_length(self): QueryMessage('a', consistency_level=3, keyspace='ks').send_body( io, protocol_version=5 ) - self._check_calls(io, base_expected + [ - (b'\x00\x02',), # length of keyspace string - (b'ks',), - ]) + self._check_calls( + io, + base_expected + + [ + (b'\x00\x02',), # length of keyspace string + (b'ks',), + ], + ) io.reset_mock() QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( io, protocol_version=5 ) - self._check_calls(io, base_expected + [ - (b'\x00\x08',), # length of keyspace string - (b'keyspace',), - ]) + self._check_calls( + io, + base_expected + + [ + (b'\x00\x08',), # length of keyspace string + (b'keyspace',), + ], + ) def test_batch_message_with_keyspace(self): self.maxDiff = None io = Mock(name='io') batch = BatchMessage( batch_type=BatchType.LOGGED, - queries=((False, 'stmt a', ('param a',)), - (False, 'stmt b', ('param b',)), - (False, 'stmt c', ('param c',)) - ), + queries=( + (False, "stmt a", ("param a",)), + (False, 'stmt b', ('param b',)), + (False, "stmt c", ("param c",)), + ), consistency_level=3, - keyspace='ks' + keyspace="ks", ) batch.send_body(io, protocol_version=5) - self._check_calls(io, - ((b'\x00',), (b'\x00\x03',), (b'\x00',), - (b'\x00\x00\x00\x06',), (b'stmt a',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), - (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), - (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), - (b'\x00\x03',), - (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) + self._check_calls( + io, + ( + (b"\x00",), + (b"\x00\x03",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt a",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param a",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt b",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param b",), + (b"\x00",), + (b"\x00\x00\x00\x06",), + (b"stmt c",), + (b"\x00\x01",), + (b"\x00\x00\x00\x07",), + ("param c",), + (b"\x00\x03",), + (b"\x00\x00\x00\x80",), + (b"\x00\x02",), + (b"ks",), + ), + ) + + +class BytesReaderTest(unittest.TestCase): + """Tests for the BytesReader class used in decode_message.""" + + def test_read_exact(self): + r = BytesReader(b"abcdef") + self.assertEqual(r.read(3), b"abc") + self.assertEqual(r.read(3), b"def") + + def test_read_sequential(self): + r = BytesReader(b"\x00\x01\x02\x03") + self.assertEqual(r.read(1), b"\x00") + self.assertEqual(r.read(2), b"\x01\x02") + self.assertEqual(r.read(1), b"\x03") + + def test_read_zero_bytes(self): + r = BytesReader(b"abc") + self.assertEqual(r.read(0), b"") + self.assertEqual(r.read(3), b"abc") + + def test_read_all_no_args(self): + r = BytesReader(b"hello") + self.assertEqual(r.read(), b"hello") + + def test_read_all_negative(self): + r = BytesReader(b"hello") + self.assertEqual(r.read(-1), b"hello") + + def test_read_all_after_partial(self): + r = BytesReader(b"hello world") + r.read(6) + self.assertEqual(r.read(), b"world") + + def test_read_past_end_raises(self): + r = BytesReader(b"abc") + with self.assertRaises(EOFError): + r.read(4) + + def test_read_past_end_after_partial(self): + r = BytesReader(b"abc") + r.read(2) + with self.assertRaises(EOFError): + r.read(2) + + def test_empty_data(self): + r = BytesReader(b"") + self.assertEqual(r.read(), b"") + self.assertEqual(r.read(0), b"") + with self.assertRaises(EOFError): + r.read(1) + + def test_memoryview_input(self): + data = b"hello world" + r = BytesReader(memoryview(data)) + result = r.read(5) + self.assertIsInstance(result, bytes) + self.assertEqual(result, b"hello") + + def test_return_type_is_bytes(self): + r = BytesReader(b"\x00\x01\x02") + result = r.read(3) + self.assertIsInstance(result, bytes) + + def test_remaining_buffer(self): + r = BytesReader(b"header_row_data") + r.read(7) # consume "header_" + buf, pos = r.remaining_buffer() + self.assertEqual(buf, b"header_row_data") + self.assertEqual(pos, 7) + self.assertEqual(buf[pos:], b"row_data") + + def test_remaining_buffer_at_start(self): + r = BytesReader(b"all_data") + buf, pos = r.remaining_buffer() + self.assertEqual(pos, 0) + self.assertEqual(buf, b"all_data") + + +class DecodeMessageTest(unittest.TestCase): + """ + End-to-end tests for ProtocolHandler.decode_message using BytesReader. + + These verify that real message types round-trip through the decode path + that now uses BytesReader instead of io.BytesIO. + """ + + def _decode(self, opcode, body): + return ProtocolHandler.decode_message( + protocol_version=ProtocolVersion.MAX_SUPPORTED, + protocol_features=None, + user_type_map={}, + stream_id=0, + flags=0, + opcode=opcode, + body=body, + decompressor=None, + result_metadata=None, + ) + + def test_ready_message_empty_body(self): + """ReadyMessage has an empty body (opcode 0x02).""" + msg = self._decode(0x02, b"") + self.assertIsInstance(msg, ReadyMessage) + self.assertEqual(msg.stream_id, 0) + self.assertIsNone(msg.trace_id) + self.assertIsNone(msg.custom_payload) + + def test_supported_message_with_body(self): + """SupportedMessage reads a stringmultimap from body (opcode 0x06).""" + buf = BytesIO() + write_stringmultimap( + buf, + { + "CQL_VERSION": ["3.4.5"], + "COMPRESSION": ["lz4", "snappy"], + }, ) + body = buf.getvalue() + msg = self._decode(0x06, body) + self.assertIsInstance(msg, SupportedMessage) + self.assertEqual(msg.cql_versions, ["3.4.5"]) + self.assertEqual(msg.options["COMPRESSION"], ["lz4", "snappy"]) + + def test_decode_with_memoryview_body(self): + """decode_message should accept a memoryview body (BytesReader materializes it).""" + buf = BytesIO() + write_stringmultimap(buf, {"CQL_VERSION": ["3.0.0"]}) + body = memoryview(buf.getvalue()) + msg = self._decode(0x06, body) + self.assertIsInstance(msg, SupportedMessage) + self.assertEqual(msg.cql_versions, ["3.0.0"]) From 213c0e855d5666b4b87d0fb8c5532b779db7d606 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Mon, 9 Mar 2026 09:47:40 +0200 Subject: [PATCH 5/5] Add isolated decode_message benchmark for measuring copy-reduction impact Standalone benchmark (no cluster required) that constructs synthetic RESULT/ROWS wire-format bodies and measures ProtocolHandler.decode_message() throughput across 8 scenarios (small to 16MB, narrow to 20-column wide). Pins to a single CPU core via sched_setaffinity for consistent results. Supports both Cython and pure-Python paths, with --cprofile option. --- benchmarks/decode_benchmark.py | 579 +++++++++++++++++++++++++++++++++ 1 file changed, 579 insertions(+) create mode 100644 benchmarks/decode_benchmark.py diff --git a/benchmarks/decode_benchmark.py b/benchmarks/decode_benchmark.py new file mode 100644 index 0000000000..7117e36fbf --- /dev/null +++ b/benchmarks/decode_benchmark.py @@ -0,0 +1,579 @@ +#!/usr/bin/env python3 +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Isolated benchmark for ProtocolHandler.decode_message(). + +Measures the throughput of decoding synthetic RESULT/ROWS messages of +varying sizes. Does NOT require a live Cassandra/Scylla cluster. + +Run on both ``master`` and the ``remove_copies`` branch to compare: + + python benchmarks/decode_benchmark.py + python benchmarks/decode_benchmark.py --scenarios small_100,large_5k_1KB + python benchmarks/decode_benchmark.py --cython-only --iterations 20 + python benchmarks/decode_benchmark.py --cprofile medium_1k_1KB +""" + +from __future__ import print_function + +import argparse +import gc +import os +import statistics +import struct +import sys +import time + +# --------------------------------------------------------------------------- +# Pin to a single CPU core for consistent results +# --------------------------------------------------------------------------- +try: + os.sched_setaffinity(0, {0}) +except (AttributeError, OSError): + # sched_setaffinity is Linux-only; silently skip on other platforms + pass + +# --------------------------------------------------------------------------- +# Make sure the driver package is importable from the repo root +# --------------------------------------------------------------------------- +_benchdir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(_benchdir, "..")) + +from io import BytesIO +from cassandra.marshal import int32_pack, int64_pack, double_pack +from cassandra.protocol import ( + write_int, + write_short, + write_string, + write_value, + ProtocolHandler, + _ProtocolHandler, + HAVE_CYTHON, +) + +# --------------------------------------------------------------------------- +# CQL type codes (native protocol v4) +# --------------------------------------------------------------------------- +TYPE_BIGINT = 0x0002 +TYPE_BLOB = 0x0003 +TYPE_DOUBLE = 0x0007 +TYPE_INT = 0x0009 +TYPE_VARCHAR = 0x000D + +# Metadata flag +_FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + +# ResultMessage kind +_RESULT_KIND_ROWS = 0x0002 + +# ResultMessage opcode +_OPCODE_RESULT = 0x08 + + +# ====================================================================== +# Synthetic message construction +# ====================================================================== + + +def _build_rows_body(columns, row_values_fn, row_count): + """ + Build the raw bytes for a RESULT/ROWS message body. + + Parameters + ---------- + columns : list of (name: str, type_code: int) + Column definitions. + row_values_fn : callable() -> list[bytes|None] + Returns one row of pre-encoded cell values each time it is called. + row_count : int + Number of rows to encode. + + Returns + ------- + bytes + Complete RESULT body ready for ``decode_message()``. + """ + buf = BytesIO() + + # kind = ROWS + write_int(buf, _RESULT_KIND_ROWS) + + # --- metadata --- + write_int(buf, _FLAGS_GLOBAL_TABLES_SPEC) + write_int(buf, len(columns)) + write_string(buf, "ks") + write_string(buf, "tbl") + for col_name, type_code in columns: + write_string(buf, col_name) + write_short(buf, type_code) + + # --- rows --- + write_int(buf, row_count) + for _ in range(row_count): + for cell in row_values_fn(): + write_value(buf, cell) + + return buf.getvalue() + + +# ====================================================================== +# Scenario definitions +# ====================================================================== + + +def _make_text(size): + """Return a UTF-8 encoded bytes value of exactly *size* bytes.""" + return b"x" * size + + +def _scenario_small_100(): + """100 rows, 3 cols (text 50B, int, bigint) ~3 KB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(50) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 100) + + +def _scenario_medium_1k_256B(): + """1000 rows, 3 cols (text 256B, int, bigint) ~273 KB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(256) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_medium_1k_1KB(): + """1000 rows, 3 cols (text 1024B, int, bigint) ~1 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(1024) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_large_5k_1KB(): + """5000 rows, 3 cols (text 1024B, int, bigint) ~5 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(1024) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 5000) + + +def _scenario_large_1k_4KB(): + """1000 rows, 3 cols (text 4096B, int, bigint) ~4 MB""" + columns = [ + ("col_text", TYPE_VARCHAR), + ("col_int", TYPE_INT), + ("col_bigint", TYPE_BIGINT), + ] + text_val = _make_text(4096) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + row = lambda: [text_val, int_val, bigint_val] + return _build_rows_body(columns, row, 1000) + + +def _scenario_wide_5k_doubles(): + """5000 rows, 10 cols (10x double) ~586 KB""" + columns = [("col_d%d" % i, TYPE_DOUBLE) for i in range(10)] + vals = [double_pack(1.0 + i * 0.1) for i in range(10)] + row = lambda: list(vals) + return _build_rows_body(columns, row, 5000) + + +def _scenario_wide_1k_20cols(): + """1000 rows, 20 cols (10x text 64B, 5x int, 3x bigint, 2x double) ~850 KB""" + columns = [] + for i in range(10): + columns.append(("col_text%d" % i, TYPE_VARCHAR)) + for i in range(5): + columns.append(("col_int%d" % i, TYPE_INT)) + for i in range(3): + columns.append(("col_bigint%d" % i, TYPE_BIGINT)) + for i in range(2): + columns.append(("col_double%d" % i, TYPE_DOUBLE)) + + text_val = _make_text(64) + int_val = int32_pack(42) + bigint_val = int64_pack(123456789) + double_val = double_pack(3.14159) + + def row(): + cells = [] + for _ in range(10): + cells.append(text_val) + for _ in range(5): + cells.append(int_val) + for _ in range(3): + cells.append(bigint_val) + for _ in range(2): + cells.append(double_val) + return cells + + return _build_rows_body(columns, row, 1000) + + +def _scenario_blob_1k_16KB(): + """1000 rows, 2 cols (int, 16 KB blob) ~16 MB""" + columns = [ + ("col_int", TYPE_INT), + ("col_blob", TYPE_BLOB), + ] + int_val = int32_pack(42) + blob_val = os.urandom(16384) + row = lambda: [int_val, blob_val] + return _build_rows_body(columns, row, 1000) + + +SCENARIOS = { + "small_100": ("100 rows, 3 cols (text 50B, int, bigint)", _scenario_small_100), + "medium_1k_256B": ( + "1000 rows, 3 cols (text 256B, int, bigint)", + _scenario_medium_1k_256B, + ), + "medium_1k_1KB": ( + "1000 rows, 3 cols (text 1024B, int, bigint)", + _scenario_medium_1k_1KB, + ), + "large_5k_1KB": ( + "5000 rows, 3 cols (text 1024B, int, bigint)", + _scenario_large_5k_1KB, + ), + "large_1k_4KB": ( + "1000 rows, 3 cols (text 4096B, int, bigint)", + _scenario_large_1k_4KB, + ), + "wide_5k_doubles": ("5000 rows, 10 cols (10x double)", _scenario_wide_5k_doubles), + "wide_1k_20cols": ( + "1000 rows, 20 cols (10x text64, 5x int, ...)", + _scenario_wide_1k_20cols, + ), + "blob_1k_16KB": ("1000 rows, 2 cols (int, 16 KB blob)", _scenario_blob_1k_16KB), +} + +# Ordered list so output is deterministic +SCENARIO_ORDER = [ + "small_100", + "medium_1k_256B", + "medium_1k_1KB", + "large_5k_1KB", + "large_1k_4KB", + "wide_5k_doubles", + "wide_1k_20cols", + "blob_1k_16KB", +] + + +# ====================================================================== +# Benchmark runner +# ====================================================================== + + +def _decode(handler, body): + """Call decode_message with the standard benchmark parameters.""" + return handler.decode_message( + protocol_version=4, + protocol_features=None, + user_type_map={}, + stream_id=0, + flags=0, + opcode=_OPCODE_RESULT, + body=body, + decompressor=None, + result_metadata=None, + ) + + +def _run_iterations(handler, body, iterations, warmup): + """ + Run *warmup* + *iterations* decode calls, return list of elapsed + times (seconds) for the measured iterations only. + """ + # Warm-up: let JIT / caches settle + for _ in range(warmup): + _decode(handler, body) + + gc.disable() + try: + times = [] + for _ in range(iterations): + t0 = time.perf_counter() + _decode(handler, body) + t1 = time.perf_counter() + times.append(t1 - t0) + finally: + gc.enable() + return times + + +def _format_time(seconds): + """Human-readable time string. Always reports in microseconds for + consistent cross-scenario comparison.""" + return "%.1f us" % (seconds * 1e6) + + +def _format_throughput(body_size, seconds): + """MB/s throughput string.""" + mb = body_size / (1024 * 1024) + return "%.1f MB/s" % (mb / seconds) + + +def _report(label, times, body_size, row_count): + """Print a single result line.""" + t_min = min(times) + t_med = statistics.median(times) + t_mean = statistics.mean(times) + rows_per_sec = row_count / t_med + + if rows_per_sec >= 1e6: + rps_str = "%.2fM rows/s" % (rows_per_sec / 1e6) + elif rows_per_sec >= 1e3: + rps_str = "%.0fK rows/s" % (rows_per_sec / 1e3) + else: + rps_str = "%.0f rows/s" % rows_per_sec + + print( + " %-14s min=%s median=%s mean=%s (%s, %s)" + % ( + label, + _format_time(t_min), + _format_time(t_med), + _format_time(t_mean), + _format_throughput(body_size, t_med), + rps_str, + ) + ) + + +def _extract_row_count(scenario_name): + """Infer the row count from the scenario name for rows/s reporting.""" + mapping = { + "small_100": 100, + "medium_1k_256B": 1000, + "medium_1k_1KB": 1000, + "large_5k_1KB": 5000, + "large_1k_4KB": 1000, + "wide_5k_doubles": 5000, + "wide_1k_20cols": 1000, + "blob_1k_16KB": 1000, + } + return mapping.get(scenario_name, 0) + + +def run_benchmark( + scenarios, iterations, warmup, cython_only, python_only, cprofile_scenario +): + """ + Run the benchmark for each requested scenario. + """ + print("=" * 78) + print("Decode Benchmark") + print("=" * 78) + print(" Cython available : %s" % HAVE_CYTHON) + print(" Iterations : %d (+ %d warmup)" % (iterations, warmup)) + print(" CPU pinned : %s" % _is_pinned()) + print() + + handlers = [] + if not python_only: + if HAVE_CYTHON: + handlers.append(("Cython", ProtocolHandler)) + elif not cython_only: + print(" [NOTE] Cython extensions not available, skipping Cython path\n") + if not cython_only: + handlers.append(("Python", _ProtocolHandler)) + + if not handlers: + print( + "ERROR: no handlers selected (Cython not available and --cython-only set)" + ) + sys.exit(1) + + profiler = None + if cprofile_scenario: + import cProfile + + profiler = cProfile.Profile() + + for name in scenarios: + desc, builder = SCENARIOS[name] + body = builder() + body_size = len(body) + row_count = _extract_row_count(name) + + print("Scenario: %s (%s, %s body)" % (name, desc, _format_size(body_size))) + + for label, handler in handlers: + if profiler and name == cprofile_scenario: + profiler.enable() + + times = _run_iterations(handler, body, iterations, warmup) + + if profiler and name == cprofile_scenario: + profiler.disable() + + _report(label + ":", times, body_size, row_count) + + print() + + if profiler: + print("-" * 78) + print("cProfile results for scenario '%s':" % cprofile_scenario) + print("-" * 78) + import pstats + + stats = pstats.Stats(profiler) + stats.strip_dirs() + stats.sort_stats("cumulative") + stats.print_stats(30) + + +def _format_size(nbytes): + """Human-readable byte size.""" + if nbytes >= 1024 * 1024: + return "%.1f MB" % (nbytes / (1024 * 1024)) + elif nbytes >= 1024: + return "%.1f KB" % (nbytes / 1024) + else: + return "%d B" % nbytes + + +def _is_pinned(): + """Check if the process is pinned to a single CPU core.""" + try: + affinity = os.sched_getaffinity(0) + return len(affinity) == 1 + except (AttributeError, OSError): + return False + + +# ====================================================================== +# CLI +# ====================================================================== + + +def main(): + parser = argparse.ArgumentParser( + description="Isolated decode_message benchmark (no cluster required)" + ) + parser.add_argument( + "--iterations", + "-n", + type=int, + default=10, + help="Number of timed iterations per scenario (default: 10)", + ) + parser.add_argument( + "--warmup", + "-w", + type=int, + default=3, + help="Number of warmup iterations (default: 3)", + ) + parser.add_argument( + "--scenarios", + "-s", + type=str, + default=None, + help="Comma-separated list of scenarios to run (default: all). " + "Available: %s" % ", ".join(SCENARIO_ORDER), + ) + parser.add_argument( + "--cython-only", + action="store_true", + default=False, + help="Only benchmark the Cython (fast) path", + ) + parser.add_argument( + "--python-only", + action="store_true", + default=False, + help="Only benchmark the pure-Python path", + ) + parser.add_argument( + "--cprofile", + type=str, + default=None, + metavar="SCENARIO", + help="Enable cProfile for the named scenario and print top-30 stats", + ) + parser.add_argument( + "--list", + action="store_true", + default=False, + help="List available scenarios and exit", + ) + + args = parser.parse_args() + + if args.list: + print("Available scenarios:") + for name in SCENARIO_ORDER: + desc, builder = SCENARIOS[name] + print(" %-20s %s" % (name, desc)) + sys.exit(0) + + if args.cython_only and args.python_only: + parser.error("--cython-only and --python-only are mutually exclusive") + + if args.scenarios: + selected = [s.strip() for s in args.scenarios.split(",")] + for s in selected: + if s not in SCENARIOS: + parser.error("Unknown scenario: %s" % s) + else: + selected = list(SCENARIO_ORDER) + + if args.cprofile and args.cprofile not in selected: + parser.error( + "--cprofile scenario '%s' is not in the selected scenarios" % args.cprofile + ) + + run_benchmark( + scenarios=selected, + iterations=args.iterations, + warmup=args.warmup, + cython_only=args.cython_only, + python_only=args.python_only, + cprofile_scenario=args.cprofile, + ) + + +if __name__ == "__main__": + main()