Skip to content

Commit 5104cdf

Browse files
committed
perf: buffer accumulation in BatchMessage.send_body()
Replace per-write_value()/write_byte()/write_short() calls in BatchMessage.send_body() with buffer accumulation (list.append + b"".join + single f.write()), reducing f.write() calls from Q*(4 + 2*P) + footer to 1 for Q queries with P params each. Benchmark results (Python 3.14, Cython .so, 50K iters, best of 3, quiet machine): Scenario Before After Speedup 10 queries x 2 params (128D vec) 8364 ns 4475 ns 1.87x 10 queries x 2 params (768D vec) 8081 ns 5516 ns 1.47x 50 queries x 2 params (128D vec) 32368 ns 16271 ns 1.99x 10 queries x 10 text params 19138 ns 9051 ns 2.11x 50 queries x 10 text params 86845 ns 40020 ns 2.17x 10 unprepared x 2 params 8666 ns 4252 ns 2.04x Also updates test_batch_message_with_keyspace to use BytesIO for byte-level verification (compatible with single-write output). Adds 7 batch-specific unit tests covering prepared, unprepared, mixed, empty, many-query, NULL/UNSET, and vector parameter scenarios. Includes benchmark script benchmarks/bench_batch_send_body.py.
1 parent ac64459 commit 5104cdf

3 files changed

Lines changed: 290 additions & 27 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Benchmark: BatchMessage.send_body() for vector and scalar workloads.
17+
18+
Measures the actual loaded module's BatchMessage.send_body() method.
19+
Run this before and after optimization to compare.
20+
21+
Usage:
22+
# Build baseline .so, then:
23+
python benchmarks/bench_batch_send_body.py
24+
# Apply optimization, rebuild .so, then:
25+
python benchmarks/bench_batch_send_body.py
26+
"""
27+
28+
import io
29+
import struct
30+
import time
31+
import timeit
32+
import sys
33+
import os
34+
35+
# Ensure the repo root is importable
36+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
37+
38+
import cassandra.protocol
39+
from cassandra.protocol import BatchMessage
40+
from cassandra.query import BatchType
41+
from cassandra.marshal import int32_pack
42+
43+
44+
# ---------------------------------------------------------------------------
45+
# Scenario builders
46+
# ---------------------------------------------------------------------------
47+
48+
49+
def make_batch_vector_queries(num_queries, dim):
50+
"""Batch of prepared INSERT with (int32_key, float_vector) params."""
51+
vector_bytes = struct.pack(f">{dim}f", *([0.1] * dim))
52+
key_bytes = int32_pack(42)
53+
return [
54+
(True, b"\x01\x02\x03\x04\x05\x06\x07\x08", [key_bytes, vector_bytes])
55+
for _ in range(num_queries)
56+
]
57+
58+
59+
def make_batch_scalar_queries(num_queries, num_params, param_size=20):
60+
"""Batch of prepared INSERT with N text columns of param_size bytes."""
61+
params = [b"\x41" * param_size for _ in range(num_params)]
62+
return [
63+
(True, b"\x01\x02\x03\x04\x05\x06\x07\x08", list(params))
64+
for _ in range(num_queries)
65+
]
66+
67+
68+
def make_batch_unprepared_queries(num_queries, num_params, param_size=20):
69+
"""Batch of unprepared INSERT statements."""
70+
stmt = "INSERT INTO ks.tbl (k, v) VALUES (?, ?)"
71+
params = [b"\x41" * param_size for _ in range(num_params)]
72+
return [(False, stmt, list(params)) for _ in range(num_queries)]
73+
74+
75+
# ---------------------------------------------------------------------------
76+
# Config
77+
# ---------------------------------------------------------------------------
78+
79+
PROTO_VERSION = 4
80+
ITERATIONS = 50_000
81+
REPEATS = 3
82+
83+
SCENARIOS = [
84+
("10 queries x 2 params (128D vec)", make_batch_vector_queries(10, 128)),
85+
("10 queries x 2 params (768D vec)", make_batch_vector_queries(10, 768)),
86+
("50 queries x 2 params (128D vec)", make_batch_vector_queries(50, 128)),
87+
("10 queries x 10 text params", make_batch_scalar_queries(10, 10, 20)),
88+
("50 queries x 10 text params", make_batch_scalar_queries(50, 10, 20)),
89+
("10 unprepared x 2 params", make_batch_unprepared_queries(10, 2, 20)),
90+
]
91+
92+
93+
# ---------------------------------------------------------------------------
94+
# Benchmark
95+
# ---------------------------------------------------------------------------
96+
97+
98+
def bench_batch(queries, iterations, repeats):
99+
"""Benchmark BatchMessage.send_body(), return best ns/call."""
100+
msg = BatchMessage(
101+
batch_type=BatchType.LOGGED,
102+
queries=queries,
103+
consistency_level=1,
104+
timestamp=1234567890123456,
105+
)
106+
f = io.BytesIO()
107+
108+
def run():
109+
f.seek(0)
110+
f.truncate()
111+
msg.send_body(f, PROTO_VERSION)
112+
113+
t = timeit.repeat(run, number=iterations, repeat=repeats, timer=time.process_time)
114+
return min(t) / iterations * 1e9
115+
116+
117+
def main():
118+
is_cython = cassandra.protocol.__file__.endswith(".so")
119+
print(f"Python: {sys.version.split()[0]}")
120+
print(f"Module: {cassandra.protocol.__file__}")
121+
print(f"Cython: {'YES (.so loaded)' if is_cython else 'NO (pure Python .py)'}")
122+
print(f"Config: proto v{PROTO_VERSION}, {ITERATIONS:,} iters, best of {REPEATS}")
123+
print()
124+
print(f"{'Scenario':45s} {'ns/call':>10s} {'bytes':>8s}")
125+
print(f"{'-' * 45} {'-' * 10} {'-' * 8}")
126+
127+
for label, queries in SCENARIOS:
128+
# Measure output size
129+
msg = BatchMessage(
130+
batch_type=BatchType.LOGGED,
131+
queries=queries,
132+
consistency_level=1,
133+
timestamp=1234567890123456,
134+
)
135+
f = io.BytesIO()
136+
msg.send_body(f, PROTO_VERSION)
137+
nbytes = len(f.getvalue())
138+
139+
ns = bench_batch(queries, ITERATIONS, REPEATS)
140+
print(f"{label:45s} {ns:8.1f} {nbytes:>6d}")
141+
142+
print()
143+
144+
145+
if __name__ == "__main__":
146+
main()

cassandra/protocol.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -923,21 +923,36 @@ def __init__(self, batch_type, queries, consistency_level,
923923
self.keyspace = keyspace
924924

925925
def send_body(self, f, protocol_version):
926-
write_byte(f, self.batch_type.value)
927-
write_short(f, len(self.queries))
926+
# Buffer accumulation: collect all bytes and write once.
927+
_i32 = int32_pack
928+
_u16 = uint16_pack
929+
_u8 = uint8_pack
930+
parts = [_u8(self.batch_type.value), _u16(len(self.queries))]
931+
_p = parts.append
928932
for prepared, string_or_query_id, params in self.queries:
929933
if not prepared:
930-
write_byte(f, 0)
931-
write_longstring(f, string_or_query_id)
934+
_p(_u8(0))
935+
if isinstance(string_or_query_id, str):
936+
string_or_query_id = string_or_query_id.encode('utf8')
937+
_p(_i32(len(string_or_query_id)))
938+
_p(string_or_query_id)
932939
else:
933-
write_byte(f, 1)
934-
write_short(f, len(string_or_query_id))
935-
f.write(string_or_query_id)
936-
write_short(f, len(params))
940+
_p(_u8(1))
941+
_p(_u16(len(string_or_query_id)))
942+
_p(string_or_query_id)
943+
_p(_u16(len(params)))
937944
for param in params:
938-
write_value(f, param)
945+
if param is None:
946+
_p(_i32(-1))
947+
elif param is _UNSET_VALUE:
948+
_p(_i32(-2))
949+
else:
950+
if isinstance(param, str):
951+
param = param.encode('utf8')
952+
_p(_i32(len(param)))
953+
_p(param)
939954

940-
write_consistency_level(f, self.consistency_level)
955+
_p(_u16(self.consistency_level))
941956
flags = 0
942957
if self.serial_consistency_level:
943958
flags |= _WITH_SERIAL_CONSISTENCY_FLAG
@@ -951,18 +966,24 @@ def send_body(self, f, protocol_version):
951966
"Keyspaces may only be set on queries with protocol version "
952967
"5 or higher. Consider setting Cluster.protocol_version to 5.")
953968
if ProtocolVersion.uses_int_query_flags(protocol_version):
954-
write_int(f, flags)
969+
_p(_i32(flags))
955970
else:
956-
write_byte(f, flags)
971+
_p(_u8(flags))
957972

958973
if self.serial_consistency_level:
959-
write_consistency_level(f, self.serial_consistency_level)
974+
_p(_u16(self.serial_consistency_level))
960975
if self.timestamp is not None:
961-
write_long(f, self.timestamp)
976+
_p(uint64_pack(self.timestamp))
962977

963978
if ProtocolVersion.uses_keyspace_flag(protocol_version):
964979
if self.keyspace is not None:
965-
write_string(f, self.keyspace)
980+
ks = self.keyspace
981+
if isinstance(ks, str):
982+
ks = ks.encode('utf8')
983+
_p(_u16(len(ks)))
984+
_p(ks)
985+
986+
f.write(b"".join(parts))
966987

967988

968989
known_event_types = frozenset((

tests/unit/test_protocol.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_keyspace_written_with_length(self):
170170

171171
def test_batch_message_with_keyspace(self):
172172
self.maxDiff = None
173-
io = Mock(name='io')
173+
buf = io.BytesIO()
174174
batch = BatchMessage(
175175
batch_type=BatchType.LOGGED,
176176
queries=((False, 'stmt a', ('param a',)),
@@ -180,18 +180,27 @@ def test_batch_message_with_keyspace(self):
180180
consistency_level=3,
181181
keyspace='ks'
182182
)
183-
batch.send_body(io, protocol_version=5)
184-
self._check_calls(io,
185-
((b'\x00',), (b'\x00\x03',), (b'\x00',),
186-
(b'\x00\x00\x00\x06',), (b'stmt a',),
187-
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',),
188-
(b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',),
189-
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',),
190-
(b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',),
191-
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',),
192-
(b'\x00\x03',),
193-
(b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',))
183+
batch.send_body(buf, protocol_version=5)
184+
expected = (
185+
b'\x00' # batch type LOGGED
186+
b'\x00\x03' # 3 queries
187+
b'\x00' # not prepared
188+
b'\x00\x00\x00\x06' b'stmt a' # longstring 'stmt a'
189+
b'\x00\x01' # 1 param
190+
b'\x00\x00\x00\x07' b'param a' # write_value 'param a'
191+
b'\x00' # not prepared
192+
b'\x00\x00\x00\x06' b'stmt b'
193+
b'\x00\x01'
194+
b'\x00\x00\x00\x07' b'param b'
195+
b'\x00'
196+
b'\x00\x00\x00\x06' b'stmt c'
197+
b'\x00\x01'
198+
b'\x00\x00\x00\x07' b'param c'
199+
b'\x00\x03' # consistency level
200+
b'\x00\x00\x00\x80' # flags (keyspace)
201+
b'\x00\x02' b'ks' # keyspace
194202
)
203+
self.assertEqual(buf.getvalue(), expected)
195204

196205
class WriteQueryParamsBufferAccumulationTest(unittest.TestCase):
197206
"""
@@ -373,3 +382,90 @@ def test_single_unset_param(self):
373382
raw = self._execute_msg_bytes(msg, protocol_version=4)
374383
self.assertIn(expected, raw)
375384

385+
# -- BatchMessage buffer accumulation tests ---------------------------
386+
387+
@staticmethod
388+
def _batch_msg_bytes(queries, protocol_version=4, **kwargs):
389+
"""Serialize a BatchMessage and return the raw bytes."""
390+
msg = BatchMessage(batch_type=BatchType.LOGGED, queries=queries,
391+
consistency_level=1, **kwargs)
392+
buf = io.BytesIO()
393+
msg.send_body(buf, protocol_version)
394+
return buf.getvalue()
395+
396+
def test_batch_prepared_queries_with_params(self):
397+
"""Batch of prepared queries with byte params serializes correctly."""
398+
queries = [
399+
(True, b'\x01\x02\x03\x04', [b'val1', b'val2']),
400+
(True, b'\x01\x02\x03\x04', [b'val3', None]),
401+
]
402+
raw = self._batch_msg_bytes(queries)
403+
self.assertIn(b'val1', raw)
404+
self.assertIn(b'val2', raw)
405+
self.assertIn(b'val3', raw)
406+
self.assertIn(int32_pack(-1), raw) # NULL
407+
408+
def test_batch_unprepared_queries(self):
409+
"""Batch of unprepared (string) queries serializes correctly."""
410+
queries = [
411+
(False, 'INSERT INTO t (k) VALUES (?)', [b'\x01']),
412+
(False, 'INSERT INTO t (k) VALUES (?)', [b'\x02']),
413+
]
414+
raw = self._batch_msg_bytes(queries)
415+
self.assertIn(b'INSERT INTO t (k) VALUES (?)', raw)
416+
417+
def test_batch_mixed_prepared_unprepared(self):
418+
"""Batch mixing prepared and unprepared queries."""
419+
queries = [
420+
(False, 'SELECT 1', []),
421+
(True, b'\xab\xcd', [b'data']),
422+
]
423+
raw = self._batch_msg_bytes(queries)
424+
self.assertIn(b'SELECT 1', raw)
425+
self.assertIn(b'data', raw)
426+
427+
def test_batch_empty_queries(self):
428+
"""Batch with zero queries."""
429+
raw = self._batch_msg_bytes([])
430+
self.assertIn(uint16_pack(0), raw)
431+
432+
def test_batch_many_queries(self):
433+
"""Batch with 50 queries to exercise accumulation at scale."""
434+
queries = [
435+
(True, b'\x01\x02', [b'param_%03d' % i])
436+
for i in range(50)
437+
]
438+
raw = self._batch_msg_bytes(queries)
439+
self.assertIn(uint16_pack(50), raw)
440+
for i in range(50):
441+
self.assertIn(b'param_%03d' % i, raw)
442+
443+
def test_batch_null_and_unset_params(self):
444+
"""Batch params with NULL and UNSET values."""
445+
queries = [
446+
(True, b'\x01', [None, _UNSET_VALUE, b'ok']),
447+
]
448+
raw = self._batch_msg_bytes(queries, protocol_version=4)
449+
self.assertIn(int32_pack(-1), raw) # NULL
450+
self.assertIn(int32_pack(-2), raw) # UNSET
451+
self.assertIn(b'ok', raw)
452+
453+
def test_batch_vector_params(self):
454+
"""Batch with large vector params (simulating bulk vector INSERT)."""
455+
vector = struct.pack('128f', *([0.5] * 128))
456+
queries = [
457+
(True, b'\x01\x02', [int32_pack(i), vector])
458+
for i in range(10)
459+
]
460+
raw = self._batch_msg_bytes(queries)
461+
# 10 copies of the vector should appear
462+
count = 0
463+
start = 0
464+
while True:
465+
idx = raw.find(vector, start)
466+
if idx == -1:
467+
break
468+
count += 1
469+
start = idx + 1
470+
self.assertEqual(count, 10)
471+

0 commit comments

Comments
 (0)