This class handles parsing from both regular InputStreams (with data copying) and ArrowBuf + * (with zero-copy slicing for large fields like app_metadata and body). + * + *
Small fields (descriptor, header) are always copied. Large fields (app_metadata, body) use
+ * zero-copy slicing when parsing from ArrowBuf.
+ */
+final class FlightDataParser {
+ private static final AtomicLong ALLOCATOR_ID = new AtomicLong();
+
+ // Protobuf wire format tags for FlightData fields
+ private static final int DESCRIPTOR_TAG =
+ (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int HEADER_TAG =
+ (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int BODY_TAG =
+ (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int APP_METADATA_TAG =
+ (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+
+ /** Base class for FlightData readers with common parsing logic. */
+ abstract static class FlightDataReader {
+ protected final BufferAllocator allocator;
+
+ protected FlightDescriptor descriptor;
+ protected MessageMetadataResult header;
+ protected ArrowBuf appMetadata;
+ protected ArrowBuf body;
+
+ FlightDataReader(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ /** Parses the FlightData and returns an ArrowMessage. */
+ final ArrowMessage toMessage() {
+ try {
+ parseFields();
+ ArrowBuf adjustedBody = adjustBodyForHeaderType();
+ ArrowMessage message =
+ new ArrowMessage(descriptor, header, appMetadata, adjustedBody, getMessageAllocator());
+ // Ownership transferred to ArrowMessage
+ appMetadata = null;
+ body = null;
+ return message;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ cleanup();
+ }
+ }
+
+ private ArrowBuf adjustBodyForHeaderType() {
+ if (header == null) {
+ return body;
+ }
+ switch (ArrowMessage.HeaderType.getHeader(header.headerType())) {
+ case SCHEMA:
+ if (body != null && body.capacity() == 0) {
+ body.close();
+ return null;
+ }
+ break;
+ case DICTIONARY_BATCH:
+ case RECORD_BATCH:
+ if (body == null) {
+ return allocator.getEmpty();
+ }
+ break;
+ case NONE:
+ case TENSOR:
+ default:
+ break;
+ }
+ return body;
+ }
+
+ private void parseFields() throws IOException {
+ while (hasRemaining()) {
+ int tag = readTag();
+ if (tag == -1) {
+ break;
+ }
+ int size = readLength();
+ switch (tag) {
+ case DESCRIPTOR_TAG:
+ {
+ byte[] bytes = readBytes(size);
+ descriptor = FlightDescriptor.parseFrom(bytes);
+ break;
+ }
+ case HEADER_TAG:
+ {
+ byte[] bytes = readBytes(size);
+ header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size);
+ break;
+ }
+ case APP_METADATA_TAG:
+ {
+ // Called before reading a new value to handle duplicate protobuf fields
+ // (last occurrence wins per spec) and prevent memory leaks.
+ closeAppMetadata();
+ appMetadata = readBuffer(size);
+ break;
+ }
+ case BODY_TAG:
+ {
+ // Called before reading a new value to handle duplicate protobuf fields
+ // (last occurrence wins per spec) and prevent memory leaks.
+ closeBody();
+ body = readBuffer(size);
+ break;
+ }
+ default:
+ // ignore unknown fields
+ }
+ }
+ }
+
+ /** Returns true if there is more data to read. */
+ protected abstract boolean hasRemaining() throws IOException;
+
+ /** Reads the next protobuf tag, or -1 if no more data. */
+ protected abstract int readTag() throws IOException;
+
+ /** Reads a varint-encoded length. */
+ protected abstract int readLength() throws IOException;
+
+ /** Reads the specified number of bytes into a new byte array. */
+ protected abstract byte[] readBytes(int size) throws IOException;
+
+ /** Reads the specified number of bytes into an ArrowBuf. */
+ protected abstract ArrowBuf readBuffer(int size) throws IOException;
+
+ /** Additional resources that should be transferred to the parsed ArrowMessage. */
+ protected BufferAllocator getMessageAllocator() {
+ return null;
+ }
+
+ /** Called in finally block to clean up resources. Subclasses can override to add cleanup. */
+ protected void cleanup() {
+ closeAppMetadata();
+ closeBody();
+ }
+
+ private void closeAppMetadata() {
+ if (appMetadata != null) {
+ appMetadata.close();
+ appMetadata = null;
+ }
+ }
+
+ private void closeBody() {
+ if (body != null) {
+ body.close();
+ body = null;
+ }
+ }
+ }
+
+ /** Parses FlightData from an InputStream, copying data into Arrow-managed buffers. */
+ static final class InputStreamReader extends FlightDataReader {
+ private final InputStream stream;
+
+ InputStreamReader(BufferAllocator allocator, InputStream stream) {
+ super(allocator);
+ this.stream = stream;
+ }
+
+ @Override
+ protected boolean hasRemaining() throws IOException {
+ return stream.available() > 0;
+ }
+
+ @Override
+ protected int readTag() throws IOException {
+ int tagFirstByte = stream.read();
+ if (tagFirstByte == -1) {
+ return -1;
+ }
+ return CodedInputStream.readRawVarint32(tagFirstByte, stream);
+ }
+
+ @Override
+ protected int readLength() throws IOException {
+ int firstByte = stream.read();
+ return CodedInputStream.readRawVarint32(firstByte, stream);
+ }
+
+ @Override
+ protected byte[] readBytes(int size) throws IOException {
+ byte[] bytes = new byte[size];
+ ByteStreams.readFully(stream, bytes);
+ return bytes;
+ }
+
+ @Override
+ protected ArrowBuf readBuffer(int size) throws IOException {
+ ArrowBuf buf = allocator.buffer(size);
+ byte[] heapBytes = new byte[size];
+ ByteStreams.readFully(stream, heapBytes);
+ buf.writeBytes(heapBytes);
+ buf.writerIndex(size);
+ return buf;
+ }
+ }
+
+ /** Parses FlightData from an ArrowBuf, using zero-copy slicing for large fields. */
+ static final class ArrowBufReader extends FlightDataReader {
+ private static final Logger LOG = LoggerFactory.getLogger(ArrowBufReader.class);
+
+ private final BufferAllocator messageAllocator;
+ private final ArrowBuf backingBuffer;
+ private final CodedInputStream codedInput;
+ private boolean transferred;
+
+ ArrowBufReader(
+ BufferAllocator allocator, BufferAllocator messageAllocator, ArrowBuf backingBuffer) {
+ super(allocator);
+ this.messageAllocator = messageAllocator;
+ this.backingBuffer = backingBuffer;
+ ByteBuffer buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity());
+ this.codedInput = CodedInputStream.newInstance(buffer);
+ }
+
+ static ArrowBufReader tryArrowBufReader(BufferAllocator allocator, InputStream stream) {
+ if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) {
+ return null;
+ }
+
+ HasByteBuffer hasByteBuffer = (HasByteBuffer) stream;
+ if (!hasByteBuffer.byteBufferSupported()) {
+ return null;
+ }
+
+ ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer();
+ if (peekBuffer == null || !peekBuffer.isDirect()) {
+ return null;
+ }
+
+ try {
+ int available = stream.available();
+ if (available > 0 && peekBuffer.remaining() < available) {
+ return null;
+ }
+ } catch (IOException ioe) {
+ return null;
+ }
+
+ InputStream detachedStream = ((Detachable) stream).detach();
+ ByteBuffer detachedBuffer = ((HasByteBuffer) detachedStream).getByteBuffer();
+
+ long bufferAddress = MemoryUtil.getByteBufferAddress(detachedBuffer);
+ int bufferSize = Objects.requireNonNull(detachedBuffer).remaining();
+
+ ForeignAllocation foreignAllocation =
+ new ForeignAllocation(bufferSize, bufferAddress + detachedBuffer.position()) {
+ @Override
+ protected void release0() {
+ closeQuietly(detachedStream);
+ }
+ };
+
+ BufferAllocator messageAllocator =
+ allocator.newChildAllocator(
+ // Keep detached transport memory scoped to this message until a downstream retain.
+ "arrow-msg-" + ALLOCATOR_ID.incrementAndGet(), 0, bufferSize);
+
+ try {
+ ArrowBuf backingBuffer = messageAllocator.wrapForeignAllocation(foreignAllocation);
+ return new ArrowBufReader(allocator, messageAllocator, backingBuffer);
+ } catch (Throwable t) {
+ closeQuietly(messageAllocator);
+ closeQuietly(detachedStream);
+ throw t;
+ }
+ }
+
+ private static void closeQuietly(InputStream stream) {
+ if (stream != null) {
+ try {
+ stream.close();
+ } catch (IOException e) {
+ LOG.debug("Error closing detached gRPC stream", e);
+ }
+ }
+ }
+
+ private static void closeQuietly(BufferAllocator allocator) {
+ if (allocator != null) {
+ try {
+ allocator.close();
+ } catch (Exception e) {
+ LOG.debug("Error closing message allocator", e);
+ }
+ }
+ }
+
+ @Override
+ protected void cleanup() {
+ super.cleanup();
+ backingBuffer.close();
+ if (!transferred) {
+ closeQuietly(messageAllocator);
+ }
+ }
+
+ @Override
+ protected boolean hasRemaining() throws IOException {
+ return !codedInput.isAtEnd();
+ }
+
+ @Override
+ protected int readTag() throws IOException {
+ int tag = codedInput.readTag();
+ return tag == 0 ? -1 : tag;
+ }
+
+ @Override
+ protected int readLength() throws IOException {
+ return codedInput.readRawVarint32();
+ }
+
+ @Override
+ protected byte[] readBytes(int size) throws IOException {
+ // Reads size bytes and creates a copy
+ return codedInput.readRawBytes(size);
+ }
+
+ @Override
+ protected ArrowBuf readBuffer(int size) throws IOException {
+ // CodedInputStream advances the shared ByteBuffer; use its read count for zero-copy slicing.
+ int offset = codedInput.getTotalBytesRead();
+ codedInput.skipRawBytes(size);
+ backingBuffer.getReferenceManager().retain();
+ return backingBuffer.slice(offset, size);
+ }
+
+ @Override
+ protected BufferAllocator getMessageAllocator() {
+ transferred = true;
+ return messageAllocator;
+ }
+ }
+}
diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java
index 15cfd6ba8..478d49766 100644
--- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java
+++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java
@@ -318,13 +318,18 @@ public boolean next() {
/** Update our metadata reference with a new one from this message. */
private void updateMetadata(ArrowMessage msg) {
- if (this.applicationMetadata != null) {
- this.applicationMetadata.close();
+ ArrowBuf retainedMetadata = null;
+ if (msg.getApplicationMetadata() != null) {
+ // Re-associate metadata with the stream allocator so it can outlive this message.
+ retainedMetadata =
+ msg.getApplicationMetadata()
+ .getReferenceManager()
+ .retain(msg.getApplicationMetadata(), allocator);
}
- this.applicationMetadata = msg.getApplicationMetadata();
if (this.applicationMetadata != null) {
- this.applicationMetadata.getReferenceManager().retain();
+ this.applicationMetadata.close();
}
+ this.applicationMetadata = retainedMetadata;
}
/** Ensure the Arrow metadata version doesn't change mid-stream. */
@@ -424,50 +429,49 @@ public void onNext(ArrowMessage msg) {
}
if (msg.getApplicationMetadata() != null) {
enqueue(msg);
+ } else {
+ AutoCloseables.closeNoChecked(msg);
}
break;
}
case SCHEMA:
{
- Schema schema = msg.asSchema();
-
- // if there is app metadata in the schema message, make sure
- // that we don't leak it.
- ArrowBuf meta = msg.getApplicationMetadata();
- if (meta != null) {
- meta.close();
- }
-
- final List This could be solved by BufferInputStream exposing Drainable.
- */
-public class GetReadableBuffer {
-
- private static final Field READABLE_BUFFER;
- private static final Class> BUFFER_INPUT_STREAM;
-
- static {
- Field tmpField = null;
- Class> tmpClazz = null;
- try {
- Class> clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream");
-
- Field f = clazz.getDeclaredField("buffer");
- f.setAccessible(true);
- // don't set until we've gotten past all exception cases.
- tmpField = f;
- tmpClazz = clazz;
- } catch (Exception e) {
- new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e)
- .printStackTrace();
- }
- READABLE_BUFFER = tmpField;
- BUFFER_INPUT_STREAM = tmpClazz;
- }
-
- /**
- * Extracts the ReadableBuffer for the given input stream.
- *
- * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null
- * will be returned.
- */
- public static ReadableBuffer getReadableBuffer(InputStream is) {
-
- if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) {
- return null;
- }
-
- try {
- return (ReadableBuffer) READABLE_BUFFER.get(is);
- } catch (Exception ex) {
- throw Throwables.propagate(ex);
- }
- }
-
- /**
- * Helper method to read a gRPC-provided InputStream into an ArrowBuf.
- *
- * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}.
- * @param buf The buffer to read into.
- * @param size The number of bytes to read.
- * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link
- * #BUFFER_INPUT_STREAM}).
- * @throws IOException if there is an error reading form the stream
- */
- public static void readIntoBuffer(
- final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath)
- throws IOException {
- ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null;
- byte[] heapBytes = new byte[size];
- if (readableBuffer != null) {
- readableBuffer.readBytes(heapBytes, 0, size);
- } else {
- ByteStreams.readFully(stream, heapBytes);
- }
- buf.writeBytes(heapBytes);
- buf.writerIndex(size);
- }
-}
diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java
new file mode 100644
index 000000000..b4b0669a2
--- /dev/null
+++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java
@@ -0,0 +1,526 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.flight;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+
+import com.google.common.collect.Iterables;
+import com.google.common.io.ByteStreams;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedOutputStream;
+import io.grpc.Detachable;
+import io.grpc.HasByteBuffer;
+import io.grpc.protobuf.ProtoUtils;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.Constructor;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.arrow.flight.FlightProducer.CallContext;
+import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
+import org.apache.arrow.flight.impl.Flight.FlightData;
+import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Tests FlightData parsing including duplicate field handling, well-formed messages, and zero-copy
+ * behavior. Covers both InputStream (with copying) and ArrowBuf (zero-copy) parsing paths. Verifies
+ * that duplicate protobuf fields use last-occurrence-wins semantics without memory leaks.
+ */
+public class TestArrowMessageParse {
+
+ private BufferAllocator allocator;
+
+ @BeforeEach
+ public void setUp() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ }
+
+ @AfterEach
+ public void tearDown() {
+ allocator.close();
+ }
+
+ /** Verifies duplicate app_metadata fields via InputStream path use last-occurrence-wins. */
+ @Test
+ public void testDuplicateAppMetadataInputStream() throws Exception {
+ byte[] firstAppMetadata = new byte[] {1, 2, 3};
+ byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8};
+
+ byte[] serialized =
+ buildFlightDataDescriptors(
+ List.of(
+ Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata),
+ Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata)));
+ InputStream stream = new ByteArrayInputStream(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ ArrowBuf appMetadata = message.getApplicationMetadata();
+ assertNotNull(appMetadata);
+ // Use readableBytes() instead of capacity() since allocator may round up
+ assertEquals(secondAppMetadata.length, appMetadata.readableBytes());
+
+ byte[] actual = new byte[secondAppMetadata.length];
+ appMetadata.getBytes(0, actual);
+ assertArrayEquals(secondAppMetadata, actual);
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /**
+ * Verifies duplicate app_metadata fields via zero-copy ArrowBuf path use last-occurrence-wins.
+ */
+ @Test
+ public void testDuplicateAppMetadataArrowBuf() throws Exception {
+ byte[] firstAppMetadata = new byte[] {1, 2, 3};
+ byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8};
+
+ // Verify clean start
+ assertEquals(0, allocator.getAllocatedMemory());
+
+ byte[] serialized =
+ buildFlightDataDescriptors(
+ List.of(
+ Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata),
+ Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata)));
+ InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ ArrowBuf appMetadata = message.getApplicationMetadata();
+ assertNotNull(appMetadata);
+ assertEquals(secondAppMetadata.length, appMetadata.readableBytes());
+
+ byte[] actual = new byte[secondAppMetadata.length];
+ appMetadata.getBytes(0, actual);
+ assertArrayEquals(secondAppMetadata, actual);
+
+ // Zero-copy: only the backing buffer (serialized message) should be allocated
+ assertEquals(serialized.length, allocator.getAllocatedMemory());
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /** Verifies duplicate body fields via InputStream path use last-occurrence-wins. */
+ @Test
+ public void testDuplicateBodyInputStream() throws Exception {
+ byte[] firstBody = new byte[] {10, 20, 30};
+ byte[] secondBody = new byte[] {40, 50, 60, 70};
+
+ byte[] serialized =
+ buildFlightDataDescriptors(
+ List.of(
+ Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody),
+ Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody)));
+ InputStream stream = new ByteArrayInputStream(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ ArrowBuf body = Iterables.getOnlyElement(message.getBufs());
+ assertNotNull(body);
+ assertEquals(secondBody.length, body.readableBytes());
+
+ byte[] actual = new byte[secondBody.length];
+ body.getBytes(0, actual);
+ assertArrayEquals(secondBody, actual);
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /** Verifies duplicate body fields via zero-copy ArrowBuf path use last-occurrence-wins. */
+ @Test
+ public void testDuplicateBodyArrowBuf() throws Exception {
+ byte[] firstBody = new byte[] {10, 20, 30};
+ byte[] secondBody = new byte[] {40, 50, 60, 70};
+
+ // Verify clean start
+ assertEquals(0, allocator.getAllocatedMemory());
+
+ byte[] serialized =
+ buildFlightDataDescriptors(
+ List.of(
+ Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody),
+ Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody)));
+ InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ ArrowBuf body = Iterables.getOnlyElement(message.getBufs());
+ assertNotNull(body);
+ assertEquals(secondBody.length, body.readableBytes());
+
+ byte[] actual = new byte[secondBody.length];
+ body.getBytes(0, actual);
+ assertArrayEquals(secondBody, actual);
+
+ // Zero-copy: only the backing buffer (serialized message) should be allocated
+ assertEquals(serialized.length, allocator.getAllocatedMemory());
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /** Verifies well-formed FlightData message parsing via InputStream path. */
+ @Test
+ public void testFieldsInputStream() throws Exception {
+ byte[] appMetadataBytes = new byte[] {100, 101, 102};
+ byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54};
+ FlightDescriptor expectedDescriptor = createTestDescriptor();
+
+ byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes);
+ InputStream stream = new ByteArrayInputStream(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ // Verify descriptor
+ assertEquals(expectedDescriptor, message.getDescriptor());
+
+ // Verify header is present (Schema message type)
+ assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType());
+
+ // Verify app metadata
+ ArrowBuf appMetadata = message.getApplicationMetadata();
+ assertNotNull(appMetadata);
+ assertEquals(appMetadataBytes.length, appMetadata.readableBytes());
+ byte[] actualAppMetadata = new byte[appMetadataBytes.length];
+ appMetadata.getBytes(0, actualAppMetadata);
+ assertArrayEquals(appMetadataBytes, actualAppMetadata);
+
+ // Verify body
+ ArrowBuf body = Iterables.getOnlyElement(message.getBufs());
+ assertNotNull(body);
+ assertEquals(bodyBytes.length, body.readableBytes());
+ byte[] actualBody = new byte[bodyBytes.length];
+ body.getBytes(0, actualBody);
+ assertArrayEquals(bodyBytes, actualBody);
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /** Verifies well-formed FlightData message parsing via zero-copy ArrowBuf path. */
+ @Test
+ public void testFieldsArrowBuf() throws Exception {
+ byte[] appMetadataBytes = new byte[] {100, 101, 102};
+ byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54};
+ FlightDescriptor expectedDescriptor = createTestDescriptor();
+
+ assertEquals(0, allocator.getAllocatedMemory());
+
+ byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes);
+ InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ // Verify descriptor
+ assertEquals(expectedDescriptor, message.getDescriptor());
+
+ // Verify header is present (Schema message type)
+ assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType());
+
+ // Verify app metadata
+ ArrowBuf appMetadata = message.getApplicationMetadata();
+ assertNotNull(appMetadata);
+ assertEquals(appMetadataBytes.length, appMetadata.readableBytes());
+ byte[] actualAppMetadata = new byte[appMetadataBytes.length];
+ appMetadata.getBytes(0, actualAppMetadata);
+ assertArrayEquals(appMetadataBytes, actualAppMetadata);
+
+ // Verify body
+ ArrowBuf body = Iterables.getOnlyElement(message.getBufs());
+ assertNotNull(body);
+ assertEquals(bodyBytes.length, body.readableBytes());
+ byte[] actualBody = new byte[bodyBytes.length];
+ body.getBytes(0, actualBody);
+ assertArrayEquals(bodyBytes, actualBody);
+
+ // Zero-copy: only the backing buffer (serialized message) should be allocated
+ assertEquals(serialized.length, allocator.getAllocatedMemory());
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ /** Verifies that heap buffers fall back to InputStream path without calling detach(). */
+ @Test
+ public void testHeapBufferFallbackDoesNotDetach() throws Exception {
+ byte[] appMetadataBytes = new byte[] {8, 9};
+ byte[] bodyBytes = new byte[] {10, 11, 12};
+
+ byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes);
+ MockGrpcInputStream stream = MockGrpcInputStream.ofHeapBuffer(serialized);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ assertNotNull(message.getDescriptor());
+ assertEquals(0, stream.getDetachCount());
+ }
+ }
+
+ /** Verifies fallback to InputStream path when getByteBuffer() returns null. */
+ @Test
+ public void testNullByteBufferFallbackToInputStream() throws Exception {
+ byte[] appMetadataBytes = new byte[] {20, 21, 22};
+ byte[] bodyBytes = new byte[] {30, 31, 32, 33};
+ FlightDescriptor expectedDescriptor = createTestDescriptor();
+
+ byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes);
+ MockGrpcInputStream stream = new MockGrpcInputStream(ByteBuffer.wrap(serialized), false);
+
+ try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) {
+ assertEquals(expectedDescriptor, message.getDescriptor());
+
+ ArrowBuf appMetadata = message.getApplicationMetadata();
+ assertNotNull(appMetadata);
+ byte[] actualAppMetadata = new byte[appMetadataBytes.length];
+ appMetadata.getBytes(0, actualAppMetadata);
+ assertArrayEquals(appMetadataBytes, actualAppMetadata);
+
+ ArrowBuf body = Iterables.getOnlyElement(message.getBufs());
+ assertNotNull(body);
+ byte[] actualBody = new byte[bodyBytes.length];
+ body.getBytes(0, actualBody);
+ assertArrayEquals(bodyBytes, actualBody);
+
+ assertEquals(0, stream.getDetachCount());
+ }
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ @Test
+ public void testRealFlightSmallBatchLifecycle() throws Exception {
+ try (BufferAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE);
+ FlightServer server =
+ FlightServer.builder(
+ rootAllocator,
+ Location.forGrpcInsecure("localhost", 0),
+ new NoOpFlightProducer() {
+ @Override
+ public void getStream(
+ CallContext context, Ticket ticket, ServerStreamListener listener) {
+ try (VectorSchemaRoot root =
+ VectorSchemaRoot.of(new BigIntVector("a", rootAllocator))) {
+ BigIntVector vector = (BigIntVector) root.getVector(0);
+ vector.allocateNew(8);
+ for (int i = 0; i < 8; i++) {
+ vector.set(i, i);
+ }
+ root.setRowCount(8);
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+ })
+ .build()
+ .start();
+ FlightClient client = FlightClient.builder(rootAllocator, server.getLocation()).build();
+ FlightStream stream = client.getStream(new Ticket(new byte[] {1}))) {
+ while (stream.next()) {
+ assertEquals(8, stream.getRoot().getRowCount());
+ }
+ }
+ }
+
+ @Test
+ public void testBufferInputStreamLargeRecordBatchLifecycle() throws Exception {
+ byte[] batchBytes;
+ try (BufferAllocator writerAllocator =
+ allocator.newChildAllocator("writer", 0, Long.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.of(new BigIntVector("a", writerAllocator))) {
+ BigIntVector vector = (BigIntVector) root.getVector(0);
+ vector.allocateNew(4095);
+ for (int i = 0; i < 4095; i++) {
+ vector.set(i, i);
+ }
+ root.setRowCount(4095);
+
+ try (ArrowRecordBatch batch = new VectorUnloader(root).getRecordBatch();
+ InputStream grpcStream =
+ ArrowMessage.createMarshaller(writerAllocator).stream(
+ new ArrowMessage(batch, null, false, IpcOption.DEFAULT))) {
+ batchBytes = ByteStreams.toByteArray(grpcStream);
+ }
+ }
+
+ try (BufferAllocator parseAllocator = allocator.newChildAllocator("parse", 0, Long.MAX_VALUE)) {
+ try (VectorSchemaRoot loadedRoot =
+ VectorSchemaRoot.of(new BigIntVector("a", parseAllocator))) {
+ ArrowMessage message =
+ ArrowMessage.createMarshaller(parseAllocator)
+ .parse(createGrpcBufferInputStream(batchBytes));
+ assertEquals(batchBytes.length, parseAllocator.getAllocatedMemory());
+ assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType());
+
+ try (ArrowRecordBatch batch = message.asRecordBatch()) {
+ new VectorLoader(loadedRoot).load(batch);
+ } finally {
+ message.close();
+ }
+
+ assertEquals(batchBytes.length, parseAllocator.getAllocatedMemory());
+ assertEquals(4095, loadedRoot.getRowCount());
+ assertEquals(4094L, ((BigIntVector) loadedRoot.getVector(0)).get(4095 - 1));
+ }
+ }
+
+ assertEquals(0, allocator.getAllocatedMemory());
+ }
+
+ // Helper methods to build complete FlightData messages
+
+ private InputStream createGrpcBufferInputStream(byte[] data) throws Exception {
+ ByteBuf byteBuf = Unpooled.directBuffer(data.length);
+ byteBuf.writeBytes(data);
+
+ Class> readableBufferClass = Class.forName("io.grpc.internal.ReadableBuffer");
+ Class> nettyReadableBufferClass = Class.forName("io.grpc.netty.NettyReadableBuffer");
+ Constructor> readableBufferCtor =
+ nettyReadableBufferClass.getDeclaredConstructor(ByteBuf.class);
+ readableBufferCtor.setAccessible(true);
+ Object readableBuffer = readableBufferCtor.newInstance(byteBuf);
+
+ Class> bufferInputStreamClass =
+ Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream");
+ Constructor> streamCtor = bufferInputStreamClass.getDeclaredConstructor(readableBufferClass);
+ streamCtor.setAccessible(true);
+ return (InputStream) streamCtor.newInstance(readableBuffer);
+ }
+
+ private FlightDescriptor createTestDescriptor() {
+ return FlightDescriptor.newBuilder()
+ .setType(FlightDescriptor.DescriptorType.PATH)
+ .addPath("test")
+ .addPath("path")
+ .build();
+ }
+
+ private byte[] createSchemaHeader() {
+ Schema schema =
+ new Schema(
+ Arrays.asList(
+ Field.nullable("id", new ArrowType.Int(32, true)),
+ Field.nullable("name", new ArrowType.Utf8())));
+ ByteBuffer headerBuffer = MessageSerializer.serializeMetadata(schema, IpcOption.DEFAULT);
+ byte[] headerBytes = new byte[headerBuffer.remaining()];
+ headerBuffer.get(headerBytes);
+ return headerBytes;
+ }
+
+ private byte[] buildFlightDataWithBothFields(byte[] appMetadata, byte[] body) throws IOException {
+ FlightData flightData =
+ FlightData.newBuilder()
+ .setFlightDescriptor(createTestDescriptor())
+ .setDataHeader(ByteString.copyFrom(createSchemaHeader()))
+ .setAppMetadata(ByteString.copyFrom(appMetadata))
+ .setDataBody(ByteString.copyFrom(body))
+ .build();
+ try (InputStream grpcStream =
+ ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) {
+ return ByteStreams.toByteArray(grpcStream);
+ }
+ }
+
+ // Helper methods to build FlightData messages with duplicate fields
+
+ private byte[] buildFlightDataDescriptors(List