diff --git a/.github/workflows/ci-pr-validation.yaml b/.github/workflows/ci-pr-validation.yaml index df2417de..dfe6d678 100644 --- a/.github/workflows/ci-pr-validation.yaml +++ b/.github/workflows/ci-pr-validation.yaml @@ -78,7 +78,7 @@ jobs: python3 -m pip install -U pip setuptools wheel requests python3 setup.py bdist_wheel WHEEL=$(find dist -name '*.whl') - pip3 install ${WHEEL}[avro] + pip3 install ${WHEEL}[avro,protobuf] - name: Run Oauth2 tests run: | diff --git a/pulsar/schema/__init__.py b/pulsar/schema/__init__.py index efa68066..e3fa49e8 100644 --- a/pulsar/schema/__init__.py +++ b/pulsar/schema/__init__.py @@ -22,3 +22,4 @@ from .schema import Schema, BytesSchema, StringSchema, JsonSchema from .schema_avro import AvroSchema +from .schema_protobuf import ProtobufNativeSchema diff --git a/pulsar/schema/schema_protobuf.py b/pulsar/schema/schema_protobuf.py new file mode 100644 index 00000000..1852cd97 --- /dev/null +++ b/pulsar/schema/schema_protobuf.py @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import base64 +import _pulsar + +from .schema import Schema + +try: + from google.protobuf import descriptor_pb2 + from google.protobuf.message import Message as ProtobufMessage + HAS_PROTOBUF = True +except ImportError: + HAS_PROTOBUF = False + + +def _collect_file_descriptors(file_descriptor, visited, file_descriptor_set): + """Recursively collect all FileDescriptorProto objects into file_descriptor_set.""" + if file_descriptor.name in visited: + return + for dep in file_descriptor.dependencies: + _collect_file_descriptors(dep, visited, file_descriptor_set) + visited.add(file_descriptor.name) + proto = descriptor_pb2.FileDescriptorProto() + file_descriptor.CopyToProto(proto) + file_descriptor_set.file.append(proto) + + +def _build_schema_definition(descriptor): + """ + Build the schema definition dict used by Java's ``ProtobufNativeSchemaData``. + + The returned mapping has these keys: + + .. code-block:: text + + fileDescriptorSet + rootMessageTypeName + rootFileDescriptorName + + ``fileDescriptorSet`` contains base64-encoded ``FileDescriptorSet`` bytes. + This mirrors ``ProtobufNativeSchemaUtils.serialize()`` in the Java client. + """ + file_descriptor_set = descriptor_pb2.FileDescriptorSet() + _collect_file_descriptors(descriptor.file, set(), file_descriptor_set) + file_descriptor_set_bytes = file_descriptor_set.SerializeToString() + return { + "fileDescriptorSet": base64.b64encode(file_descriptor_set_bytes).decode('utf-8'), + "rootMessageTypeName": descriptor.full_name, + "rootFileDescriptorName": descriptor.file.name, + } + + +if HAS_PROTOBUF: + class ProtobufNativeSchema(Schema): + """ + Schema for protobuf messages using the native protobuf binary encoding. + + The schema definition is stored as a JSON-encoded ProtobufNativeSchemaData + (fileDescriptorSet, rootMessageTypeName, rootFileDescriptorName), which is + compatible with the Java client's ProtobufNativeSchema. + + Parameters + ---------- + record_cls: + A generated protobuf message class (subclass of google.protobuf.message.Message). + + Example + ------- + .. code-block:: python + + import pulsar + from pulsar.schema import ProtobufNativeSchema + from my_proto_pb2 import MyMessage + + client = pulsar.Client('pulsar://localhost:6650') + schema = ProtobufNativeSchema(MyMessage) + producer = client.create_producer('my-topic', schema=schema) + consumer = client.subscribe('my-topic', 'my-sub', schema=schema) + + message = MyMessage() + message.field = 'value' + producer.send(message) + + received = consumer.receive(timeout_millis=5000) + typed_value = received.value() + consumer.acknowledge(received) + + assert isinstance(typed_value, MyMessage) + assert typed_value.field == 'value' + + consumer.close() + producer.close() + client.close() + """ + + def __init__(self, record_cls): + if not (isinstance(record_cls, type) and issubclass(record_cls, ProtobufMessage)): + raise TypeError( + f'record_cls must be a protobuf Message subclass, got {record_cls!r}' + ) + schema_definition = _build_schema_definition(record_cls.DESCRIPTOR) + super(ProtobufNativeSchema, self).__init__( + record_cls, _pulsar.SchemaType.PROTOBUF_NATIVE, schema_definition, 'PROTOBUF_NATIVE' + ) + + def encode(self, obj): + self._validate_object_type(obj) + return obj.SerializeToString() + + def decode(self, data): + return self._record_cls.FromString(data) + + def __str__(self): + return f'ProtobufNativeSchema({self._record_cls.__name__})' + +else: + class ProtobufNativeSchema(Schema): + def __init__(self, _record_cls=None): + raise Exception( + "protobuf library support was not found. " + "Install it with: pip install protobuf" + ) + + def encode(self, obj): + pass + + def decode(self, data): + pass diff --git a/setup.py b/setup.py index 76d929b2..370205df 100755 --- a/setup.py +++ b/setup.py @@ -76,14 +76,21 @@ def build_extension(self, ext): extras_require = {} +# protobuf schema dependencies +extras_require["protobuf"] = sorted( + { + "protobuf>=6.33.6", + } +) + # functions dependencies extras_require["functions"] = sorted( { - "protobuf>=3.6.1", "grpcio>=1.59.3", "apache-bookkeeper-client>=4.16.1", "prometheus_client", - "ratelimit" + "ratelimit", + *extras_require["protobuf"], } ) diff --git a/src/enums.cc b/src/enums.cc index 447d013c..7ee28ea1 100644 --- a/src/enums.cc +++ b/src/enums.cc @@ -115,7 +115,8 @@ void export_enums(py::module_& m) { .value("AVRO", pulsar::AVRO) .value("AUTO_CONSUME", pulsar::AUTO_CONSUME) .value("AUTO_PUBLISH", pulsar::AUTO_PUBLISH) - .value("KEY_VALUE", pulsar::KEY_VALUE); + .value("KEY_VALUE", pulsar::KEY_VALUE) + .value("PROTOBUF_NATIVE", pulsar::PROTOBUF_NATIVE); enum_(m, "InitialPosition", "Supported initial position") .value("Latest", InitialPositionLatest) diff --git a/tests/schema_test.py b/tests/schema_test.py index 9d031d15..42b65905 100755 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -18,6 +18,7 @@ # under the License. # +import base64 import math import requests from typing import List @@ -29,6 +30,67 @@ from enum import Enum import json from fastavro.schema import load_schema +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory + + +def _add_protobuf_field(message, name, number, field_type, type_name=None): + field = message.field.add() + field.name = name + field.number = number + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field.type = field_type + if type_name: + field.type_name = type_name + + +def _get_message_classes(pool, message_names): + if hasattr(message_factory, 'GetMessageClass'): + return tuple( + message_factory.GetMessageClass(pool.FindMessageTypeByName(message_name)) + for message_name in message_names + ) + factory = message_factory.MessageFactory(pool) + return tuple( + factory.GetPrototype(pool.FindMessageTypeByName(message_name)) + for message_name in message_names + ) + + +def _build_protobuf_test_messages(): + file_proto = descriptor_pb2.FileDescriptorProto() + file_proto.name = 'test_schema.proto' + file_proto.package = 'test' + file_proto.syntax = 'proto3' + + test_message = file_proto.message_type.add() + test_message.name = 'TestMessage' + _add_protobuf_field(test_message, 'name', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) + _add_protobuf_field(test_message, 'value', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) + + nested_message = file_proto.message_type.add() + nested_message.name = 'TestMessageWithNested' + _add_protobuf_field(nested_message, 'str_field', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) + _add_protobuf_field(nested_message, 'int_field', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) + _add_protobuf_field(nested_message, 'double_field', 3, descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE) + _add_protobuf_field( + nested_message, 'nested', 4, descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, '.test.TestInner' + ) + + inner_message = file_proto.message_type.add() + inner_message.name = 'TestInner' + _add_protobuf_field(inner_message, 'inner_str', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) + _add_protobuf_field(inner_message, 'inner_int', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT64) + + pool = descriptor_pool.DescriptorPool() + pool.AddSerializedFile(file_proto.SerializeToString()) + return _get_message_classes( + pool, + ('test.TestMessage', 'test.TestMessageWithNested', 'test.TestInner'), + ) + + +TestMessage, TestMessageWithNested, TestInner = _build_protobuf_test_messages() + class ExampleRecord(Record): str_field = String() @@ -1404,5 +1466,90 @@ def test_schema_type_promotion(self): client.close() +class ProtobufNativeSchemaTest(TestCase): + """Unit tests for ProtobufNativeSchema (no Pulsar broker required).""" + + def test_schema_type(self): + """Schema type must be PROTOBUF_NATIVE.""" + import _pulsar + schema = ProtobufNativeSchema(TestMessage) + self.assertEqual(schema.schema_info().schema_type(), _pulsar.SchemaType.PROTOBUF_NATIVE) + + def test_schema_definition_keys(self): + """Schema definition JSON must contain the three required keys.""" + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + self.assertIn('fileDescriptorSet', schema_def) + self.assertIn('rootMessageTypeName', schema_def) + self.assertIn('rootFileDescriptorName', schema_def) + + def test_schema_definition_values(self): + """rootMessageTypeName and rootFileDescriptorName must match the descriptor.""" + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + self.assertEqual(schema_def['rootMessageTypeName'], 'test.TestMessage') + self.assertEqual(schema_def['rootFileDescriptorName'], 'test_schema.proto') + + def test_file_descriptor_set_is_valid_base64_proto(self): + """fileDescriptorSet must be valid base64-encoded FileDescriptorSet bytes.""" + from google.protobuf import descriptor_pb2 + schema = ProtobufNativeSchema(TestMessage) + schema_def = json.loads(schema.schema_info().schema()) + raw = base64.b64decode(schema_def['fileDescriptorSet']) + fds = descriptor_pb2.FileDescriptorSet.FromString(raw) + file_names = [f.name for f in fds.file] + self.assertIn('test_schema.proto', file_names) + + def test_encode_decode_roundtrip(self): + """encode then decode must reproduce the original message.""" + schema = ProtobufNativeSchema(TestMessage) + original = TestMessage(name='hello', value=42) + encoded = schema.encode(original) + decoded = schema.decode(encoded) + self.assertEqual(decoded.name, 'hello') + self.assertEqual(decoded.value, 42) + + def test_encode_produces_protobuf_binary(self): + """Encoded bytes must be valid protobuf binary (parseable by the class directly).""" + schema = ProtobufNativeSchema(TestMessage) + msg = TestMessage(name='pulsar', value=100) + encoded = schema.encode(msg) + # Verify with protobuf's own parser + reparsed = TestMessage.FromString(encoded) + self.assertEqual(reparsed, msg) + + def test_encode_decode_nested_message(self): + """encode/decode round-trip works for messages containing nested message fields.""" + schema = ProtobufNativeSchema(TestMessageWithNested) + original = TestMessageWithNested( + str_field='test', + int_field=7, + double_field=3.14, + nested=TestInner(inner_str='inner', inner_int=999), + ) + decoded = schema.decode(schema.encode(original)) + self.assertEqual(decoded.str_field, 'test') + self.assertEqual(decoded.int_field, 7) + self.assertAlmostEqual(decoded.double_field, 3.14) + self.assertEqual(decoded.nested.inner_str, 'inner') + self.assertEqual(decoded.nested.inner_int, 999) + + def test_wrong_type_raises(self): + """Encoding an object of the wrong type must raise TypeError.""" + schema = ProtobufNativeSchema(TestMessage) + with self.assertRaises(TypeError): + schema.encode("not a protobuf message") + + def test_non_message_class_raises(self): + """Constructing with a non-Message class must raise TypeError.""" + with self.assertRaises(TypeError): + ProtobufNativeSchema(str) + + def test_schema_name(self): + """Schema name must be 'PROTOBUF_NATIVE'.""" + schema = ProtobufNativeSchema(TestMessage) + self.assertEqual(schema.schema_info().name(), 'PROTOBUF_NATIVE') + + if __name__ == '__main__': main()