Skip to content

Commit d850f81

Browse files
committed
(improvement)change test_policy_performance tests to be a benchmark test
1. Move to regular Pytest, within the performance subdir, as part of a benchmark module 2. Renamed tests/integration/standard/column_encryption/test_policies.py -> test_encrypted_policies.py - we had two test_policies.py which conflicted when trying to run all unit tests. Example run: (scylla-driver) ykaul@ykaul:~/github/python-driver$ SCYLLA_VERSION=release:2025.4.2 PROTOCOL_VRESION=4 pytest -s -m benchmark ============================================================================================================ test session starts ============================================================================================================= platform linux -- Python 3.14.2, pytest-8.4.2, pluggy-1.6.0 rootdir: /home/ykaul/github/python-driver configfile: pyproject.toml collected 1798 items / 1792 deselected / 6 selected tests/performance/test_policy_performance.py Pinned to CPU 0 DCAware | 100000 | 0.1176 | 850 Kops/s . RackAware | 100000 | 0.1774 | 563 Kops/s . TokenAware(DCAware) | 100000 | 0.6666 | 150 Kops/s . TokenAware(RackAware) | 100000 | 0.7195 | 138 Kops/s . Default(DCAware) | 100000 | 0.1481 | 675 Kops/s . HostFilter(DCAware) | 100000 | 0.2416 | 413 Kops/s . Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 47b64aa commit d850f81

4 files changed

Lines changed: 245 additions & 216 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ log_level = "DEBUG"
122122
log_date_format = "%Y-%m-%d %H:%M:%S"
123123
xfail_strict = true
124124
addopts = "-rf"
125+
markers = [
126+
"benchmark: marks tests as performance benchmarks (deselect with '-m \"not benchmark\"')",
127+
]
125128

126129
[tool.setuptools_scm]
127130
version_file = "cassandra/_version.py"

tests/integration/standard/column_encryption/test_policies.py renamed to tests/integration/standard/column_encryption/test_encrypted_policies.py

File renamed without changes.
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import time
2+
import uuid
3+
import struct
4+
import os
5+
import statistics
6+
from unittest.mock import Mock
7+
import pytest
8+
9+
"A micro-benchmark for performance of policies"
10+
11+
from cassandra.policies import (
12+
DCAwareRoundRobinPolicy,
13+
RackAwareRoundRobinPolicy,
14+
TokenAwarePolicy,
15+
DefaultLoadBalancingPolicy,
16+
HostFilterPolicy
17+
)
18+
from cassandra.pool import Host
19+
from cassandra.cluster import SimpleConvictionPolicy
20+
21+
# Mock for Connection/EndPoint since Host expects it
22+
class MockEndPoint(object):
23+
__slots__ = ('address',)
24+
25+
def __init__(self, address):
26+
self.address = address
27+
def __str__(self):
28+
return self.address
29+
30+
class MockStatement(object):
31+
__slots__ = ('routing_key', 'keyspace', 'table')
32+
33+
def __init__(self, routing_key, keyspace="ks", table="tbl"):
34+
self.routing_key = routing_key
35+
self.keyspace = keyspace
36+
self.table = table
37+
38+
def is_lwt(self):
39+
return False
40+
41+
class MockTokenMap(object):
42+
__slots__ = ('token_class', 'get_replicas_func')
43+
def __init__(self, get_replicas_func):
44+
self.token_class = Mock()
45+
self.token_class.from_key = lambda k: k
46+
self.get_replicas_func = get_replicas_func
47+
48+
def get_replicas(self, keyspace, token):
49+
return self.get_replicas_func(keyspace, token)
50+
51+
class MockTablets(object):
52+
__slots__ = ()
53+
def get_tablet_for_key(self, keyspace, table, key):
54+
return None
55+
56+
class MockMetadata(object):
57+
__slots__ = ('_tablets', 'token_map', 'get_replicas_func', 'hosts_by_address')
58+
def __init__(self, get_replicas_func, hosts_by_address):
59+
self._tablets = MockTablets()
60+
self.token_map = MockTokenMap(get_replicas_func)
61+
self.get_replicas_func = get_replicas_func
62+
self.hosts_by_address = hosts_by_address
63+
64+
def can_support_partitioner(self):
65+
return True
66+
67+
def get_replicas(self, keyspace, key):
68+
return self.get_replicas_func(keyspace, key)
69+
70+
def get_host(self, addr):
71+
return self.hosts_by_address.get(addr)
72+
73+
class MockCluster(object):
74+
__slots__ = ('metadata',)
75+
def __init__(self, metadata):
76+
self.metadata = metadata
77+
78+
@pytest.fixture(scope="module")
79+
def benchmark_setup():
80+
"""Setup test data that will be shared across all benchmark tests"""
81+
if hasattr(os, 'sched_setaffinity'):
82+
try:
83+
# Pin to the first available CPU
84+
cpu = list(os.sched_getaffinity(0))[0]
85+
os.sched_setaffinity(0, {cpu})
86+
print(f"Pinned to CPU {cpu}")
87+
except Exception as e:
88+
print(f"Could not pin CPU: {e}")
89+
90+
# 1. Topology: 5 DCs, 3 Racks/DC, 3 Nodes/Rack = 45 Nodes
91+
hosts = []
92+
hosts_map = {} # host_id -> Host
93+
replicas_map = {} # routing_key -> list of replica hosts
94+
95+
# Deterministic generation
96+
dcs = ['dc{}'.format(i) for i in range(5)]
97+
racks = ['rack{}'.format(i) for i in range(3)]
98+
nodes_per_rack = 3
99+
100+
ip_counter = 0
101+
subnet_counter = 0
102+
for dc in dcs:
103+
for rack in racks:
104+
subnet_counter += 1
105+
for node_idx in range(nodes_per_rack):
106+
ip_counter += 1
107+
address = "127.0.{}.{}".format(subnet_counter, node_idx + 1)
108+
h_id = uuid.UUID(int=ip_counter)
109+
h = Host(MockEndPoint(address), SimpleConvictionPolicy, host_id=h_id)
110+
h.set_location_info(dc, rack)
111+
hosts.append(h)
112+
hosts_map[h_id] = h
113+
114+
# 2. Queries: 100,000 deterministic queries
115+
query_count = 100000
116+
queries = []
117+
# We'll use simple packed integers as routing keys
118+
for i in range(query_count):
119+
key = struct.pack('>I', i)
120+
queries.append(MockStatement(routing_key=key))
121+
122+
# Pre-calculate replicas for TokenAware:
123+
# Deterministically pick 3 replicas based on the key index
124+
# This simulates the metadata.get_replicas behavior
125+
# We pick index i, i+1, i+2 mod 45
126+
replicas = []
127+
for r in range(3):
128+
idx = (i + r) % len(hosts)
129+
replicas.append(hosts[idx])
130+
replicas_map[key] = replicas
131+
132+
return {
133+
'hosts': hosts,
134+
'hosts_map': hosts_map,
135+
'replicas_map': replicas_map,
136+
'queries': queries,
137+
'query_count': query_count,
138+
}
139+
140+
141+
def _get_replicas_side_effect(replicas_map, keyspace, key):
142+
return replicas_map.get(key, [])
143+
144+
145+
def _setup_cluster_mock(hosts, replicas_map):
146+
hosts_by_address = {}
147+
for host in hosts:
148+
addr = getattr(host, 'address', None)
149+
if addr is None and getattr(host, 'endpoint', None) is not None:
150+
addr = getattr(host.endpoint, 'address', None)
151+
if addr is not None:
152+
hosts_by_address[addr] = host
153+
154+
get_replicas_func = lambda ks, key: _get_replicas_side_effect(replicas_map, ks, key)
155+
metadata = MockMetadata(get_replicas_func, hosts_by_address)
156+
return MockCluster(metadata)
157+
158+
159+
def _run_benchmark(name, policy, setup_data):
160+
"""Run a benchmark for a given policy"""
161+
hosts = setup_data['hosts']
162+
queries = setup_data['queries']
163+
replicas_map = setup_data['replicas_map']
164+
165+
# Setup
166+
cluster = _setup_cluster_mock(hosts, replicas_map)
167+
policy.populate(cluster, hosts)
168+
169+
# Warmup
170+
for _ in range(100):
171+
list(policy.make_query_plan(working_keyspace="ks", query=queries[0]))
172+
173+
# Run multiple iterations to reduce noise
174+
iterations = 5
175+
timings = []
176+
177+
for _ in range(iterations):
178+
start_time = time.perf_counter()
179+
for q in queries:
180+
# We consume the iterator to ensure full plan generation cost is paid
181+
for _ in policy.make_query_plan(working_keyspace="ks", query=q):
182+
pass
183+
end_time = time.perf_counter()
184+
timings.append(end_time - start_time)
185+
186+
# Use median to filter outliers
187+
duration = statistics.median(timings)
188+
189+
count = len(queries)
190+
ops_per_sec = count / duration
191+
kops = int(ops_per_sec / 1000)
192+
193+
print(f"\n{name:<30} | {count:<10} | {duration:<10.4f} | {kops:<10} Kops/s")
194+
return ops_per_sec
195+
196+
197+
@pytest.mark.benchmark
198+
def test_dc_aware(benchmark_setup):
199+
"""Benchmark DCAwareRoundRobinPolicy"""
200+
# Local DC = dc0, 1 remote host per DC
201+
policy = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
202+
_run_benchmark("DCAware", policy, benchmark_setup)
203+
204+
205+
@pytest.mark.benchmark
206+
def test_rack_aware(benchmark_setup):
207+
"""Benchmark RackAwareRoundRobinPolicy"""
208+
# Local DC = dc0, Local Rack = rack0, 1 remote host per DC
209+
policy = RackAwareRoundRobinPolicy(local_dc='dc0', local_rack='rack0', used_hosts_per_remote_dc=1)
210+
_run_benchmark("RackAware", policy, benchmark_setup)
211+
212+
213+
@pytest.mark.benchmark
214+
def test_token_aware_wrapping_dc_aware(benchmark_setup):
215+
"""Benchmark TokenAwarePolicy wrapping DCAwareRoundRobinPolicy"""
216+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
217+
policy = TokenAwarePolicy(child, shuffle_replicas=False) # False for strict determinism in test if needed
218+
_run_benchmark("TokenAware(DCAware)", policy, benchmark_setup)
219+
220+
221+
@pytest.mark.benchmark
222+
def test_token_aware_wrapping_rack_aware(benchmark_setup):
223+
"""Benchmark TokenAwarePolicy wrapping RackAwareRoundRobinPolicy"""
224+
child = RackAwareRoundRobinPolicy(local_dc='dc0', local_rack='rack0', used_hosts_per_remote_dc=1)
225+
policy = TokenAwarePolicy(child, shuffle_replicas=False)
226+
_run_benchmark("TokenAware(RackAware)", policy, benchmark_setup)
227+
228+
229+
@pytest.mark.benchmark
230+
def test_default_wrapping_dc_aware(benchmark_setup):
231+
"""Benchmark DefaultLoadBalancingPolicy wrapping DCAwareRoundRobinPolicy"""
232+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
233+
policy = DefaultLoadBalancingPolicy(child)
234+
_run_benchmark("Default(DCAware)", policy, benchmark_setup)
235+
236+
237+
@pytest.mark.benchmark
238+
def test_host_filter_wrapping_dc_aware(benchmark_setup):
239+
"""Benchmark HostFilterPolicy wrapping DCAwareRoundRobinPolicy"""
240+
child = DCAwareRoundRobinPolicy(local_dc='dc0', used_hosts_per_remote_dc=1)
241+
policy = HostFilterPolicy(child_policy=child, predicate=lambda host: host.rack != 'rack2')
242+
_run_benchmark("HostFilter(DCAware)", policy, benchmark_setup)

0 commit comments

Comments
 (0)