diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index 623c2b81be..107cfa0c2f 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -19,6 +19,8 @@ import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; import io.netty.util.concurrent.DefaultThreadFactory; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; @@ -257,4 +259,35 @@ BufferAllocator getBufferAllocator() { public ArrowFlightMetaImpl getMeta() { return (ArrowFlightMetaImpl) this.meta; } + + @Override + public PreparedStatement prepareStatement(final String sql) throws SQLException { + checkOpen(); + return prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + } + + @Override + public PreparedStatement prepareStatement( + final String sql, final int resultSetType, final int resultSetConcurrency) + throws SQLException { + checkOpen(); + return prepareStatement(sql, resultSetType, resultSetConcurrency, getHoldability()); + } + + @Override + public PreparedStatement prepareStatement( + final String sql, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) + throws SQLException { + checkOpen(); + return ArrowFlightPreparedStatement.builder(this) + .withQuery(sql) + .withGeneratedHandle() + .withResultSetType(resultSetType) + .withResultSetConcurrency(resultSetConcurrency) + .withResultSetHoldability(resultSetHoldability) + .build(); + } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java deleted file mode 100644 index 37ee93722a..0000000000 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.driver.jdbc; - -import java.sql.SQLException; -import java.sql.Statement; -import org.apache.arrow.flight.FlightInfo; - -/** A {@link Statement} that deals with {@link FlightInfo}. */ -public interface ArrowFlightInfoStatement extends Statement { - - @Override - ArrowFlightConnection getConnection() throws SQLException; - - /** - * Executes the query this {@link Statement} is holding. - * - * @return the {@link FlightInfo} for the results. - * @throws SQLException on error. - */ - FlightInfo executeFlightInfoQuery() throws SQLException; -} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java index e1ccfc820f..202b491f5c 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java @@ -20,7 +20,6 @@ import java.sql.SQLException; import java.util.Properties; import java.util.TimeZone; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; import org.apache.arrow.memory.RootAllocator; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaFactory; @@ -79,20 +78,20 @@ public ArrowFlightPreparedStatement newPreparedStatement( final Meta.Signature signature, final int resultType, final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { + final int resultSetHoldability) { final ArrowFlightConnection flightConnection = (ArrowFlightConnection) connection; - ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = - flightConnection.getMeta().getPreparedStatement(statementHandle); + final AvaticaStatement existingStatement = + flightConnection.statementMap.get(statementHandle.id); + if (existingStatement instanceof ArrowFlightPreparedStatement) { + return (ArrowFlightPreparedStatement) existingStatement; + } + if (existingStatement != null) { + throw new IllegalStateException( + "Unexpected statement type found for prepared statement handle: " + statementHandle); + } - return ArrowFlightPreparedStatement.newPreparedStatement( - flightConnection, - preparedStatement, - statementHandle, - signature, - resultType, - resultSetConcurrency, - resultSetHoldability); + throw new IllegalStateException( + "PreparedStatement was not pre-created for handle: " + statementHandle); } @Override diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 2885f7895b..d383d239d1 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -67,7 +67,7 @@ public final class ArrowFlightJdbcFlightStreamResultSet throws SQLException { super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); this.connection = (ArrowFlightConnection) statement.connection; - this.flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + this.flightInfo = ((ArrowFlightMetaStatement) statement).executeFlightInfoQuery(); } /** Private constructor for fromFlightInfo. */ @@ -106,7 +106,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null); + final Meta.Signature signature = ArrowFlightMetaImpl.buildDefaultSignature(); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java index 49334951de..e084a6e9c3 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -73,7 +73,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null); + final Meta.Signature signature = ArrowFlightMetaImpl.buildDefaultSignature(); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index 64529b50c8..4da182ca63 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -17,20 +17,17 @@ package org.apache.arrow.driver.jdbc; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLTimeoutException; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; -import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.avatica.MetaImpl; import org.apache.calcite.avatica.NoSuchStatementException; @@ -39,8 +36,6 @@ /** Metadata handler for Arrow Flight. */ public class ArrowFlightMetaImpl extends MetaImpl { - private final Map statementHandlePreparedStatementMap; - /** * Constructs a {@link MetaImpl} object specific for Arrow Flight. * @@ -48,43 +43,12 @@ public class ArrowFlightMetaImpl extends MetaImpl { */ public ArrowFlightMetaImpl(final AvaticaConnection connection) { super(connection); - this.statementHandlePreparedStatementMap = new ConcurrentHashMap<>(); setDefaultConnectionProperties(); } - /** Construct a signature. */ - static Signature newSignature(final String sql, Schema resultSetSchema, Schema parameterSchema) { - List columnMetaData = - resultSetSchema == null - ? new ArrayList<>() - : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()); - List parameters = - parameterSchema == null - ? new ArrayList<>() - : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); - StatementType statementType = - resultSetSchema == null || resultSetSchema.getFields().isEmpty() - ? StatementType.IS_DML - : StatementType.SELECT; - return new Signature( - columnMetaData, - sql, - parameters, - Collections.emptyMap(), - null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor - statementType); - } - @Override public void closeStatement(final StatementHandle statementHandle) { - PreparedStatement preparedStatement = - statementHandlePreparedStatementMap.remove(new StatementHandleKey(statementHandle)); - // Testing if the prepared statement was created because the statement can be - // not created until - // this moment - if (preparedStatement != null) { - preparedStatement.close(); - } + getMetaStatement(statementHandle).closeStatement(); } @Override @@ -97,36 +61,7 @@ public ExecuteResult execute( final StatementHandle statementHandle, final List typedValues, final long maxRowCount) { - Preconditions.checkArgument( - connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); - PreparedStatement preparedStatement = getPreparedStatement(statementHandle); - - if (preparedStatement == null) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); - } - - new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()) - .bind(typedValues); - - if (statementHandle.signature == null - || statementHandle.signature.statementType == StatementType.IS_DML) { - // Update query - long updatedCount = preparedStatement.executeUpdate(); - return new ExecuteResult( - Collections.singletonList( - MetaResultSet.count(statementHandle.connectionId, statementHandle.id, updatedCount))); - } else { - // TODO Why is maxRowCount ignored? - return new ExecuteResult( - Collections.singletonList( - MetaResultSet.create( - statementHandle.connectionId, - statementHandle.id, - true, - statementHandle.signature, - null))); - } + return getMetaStatement(statementHandle).execute(statementHandle, typedValues, maxRowCount); } @Override @@ -141,24 +76,7 @@ public ExecuteResult execute( public ExecuteBatchResult executeBatch( final StatementHandle statementHandle, final List> parameterValuesList) throws IllegalStateException { - Preconditions.checkArgument( - connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); - PreparedStatement preparedStatement = getPreparedStatement(statementHandle); - - if (preparedStatement == null) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); - } - - final AvaticaParameterBinder binder = - new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()); - for (int i = 0; i < parameterValuesList.size(); i++) { - binder.bind(parameterValuesList.get(i), i); - } - - // Update query - long[] updatedCounts = {preparedStatement.executeUpdate()}; - return new ExecuteBatchResult(updatedCounts); + return getMetaStatement(statementHandle).executeBatch(statementHandle, parameterValuesList); } @Override @@ -173,22 +91,19 @@ public Frame fetch( String.format("%s does not use frames.", this), AvaticaConnection.HELPER.unsupported()); } - private PreparedStatement prepareForHandle(final String query, StatementHandle handle) { - final PreparedStatement preparedStatement = - ((ArrowFlightConnection) connection).getClientHandler().prepare(query); - handle.signature = - newSignature( - query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); - statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); - return preparedStatement; - } - @Override public StatementHandle prepare( final ConnectionHandle connectionHandle, final String query, final long maxRowCount) { - final StatementHandle handle = super.createStatement(connectionHandle); - prepareForHandle(query, handle); - return handle; + try { + // This is the Avatica entry point used by Connection.prepareStatement(String). + ArrowFlightPreparedStatement stmt = + (ArrowFlightPreparedStatement) + connection.prepareStatement( + query, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + return stmt.handle; + } catch (SQLException e) { + throw new RuntimeException(e); + } } @Override @@ -198,6 +113,7 @@ public ExecuteResult prepareAndExecute( final long maxRowCount, final PrepareCallback prepareCallback) throws NoSuchStatementException { + // This is the Avatica entry point used by Statement.execute(String). return prepareAndExecute( statementHandle, query, maxRowCount, -1 /* Not used */, prepareCallback); } @@ -211,19 +127,9 @@ public ExecuteResult prepareAndExecute( final PrepareCallback callback) throws NoSuchStatementException { try { - PreparedStatement preparedStatement = prepareForHandle(query, handle); - final StatementType statementType = preparedStatement.getType(); - - final long updateCount = - statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; - synchronized (callback.getMonitor()) { - callback.clear(); - callback.assign(handle.signature, null, updateCount); - } - callback.execute(); - final MetaResultSet metaResultSet = - MetaResultSet.create(handle.connectionId, handle.id, false, handle.signature, null); - return new ExecuteResult(Collections.singletonList(metaResultSet)); + // This is the Avatica entry point used by Statement.execute(String). + return getMetaStatement(handle) + .prepareAndExecute(query, maxRowCount, maxRowsInFirstFrame, callback); } catch (SQLTimeoutException e) { // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and // Runtime @@ -280,45 +186,51 @@ void setDefaultConnectionProperties() { .setTransactionIsolation(Connection.TRANSACTION_NONE); } - PreparedStatement getPreparedStatement(StatementHandle statementHandle) { - return statementHandlePreparedStatementMap.get(new StatementHandleKey(statementHandle)); - } - - // Helper used to look up prepared statement instances later. Avatica doesn't - // give us the - // signature in - // an UPDATE code path so we can't directly use StatementHandle as a map key. - private static final class StatementHandleKey { - public final String connectionId; - public final int id; - - StatementHandleKey(StatementHandle statementHandle) { - this.connectionId = statementHandle.connectionId; - this.id = statementHandle.id; + private ArrowFlightMetaStatement getMetaStatement(StatementHandle statementHandle) { + AvaticaStatement statement = connection.statementMap.get(statementHandle.id); + if (statement instanceof ArrowFlightMetaStatement) { + return (ArrowFlightMetaStatement) statement; } + throw new IllegalStateException("Statement not found: " + statementHandle); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } + public static Signature buildDefaultSignature() { + return buildSignature(null, StatementType.SELECT); + } - StatementHandleKey that = (StatementHandleKey) o; + public static Signature buildSignature(final String sql, final StatementType type) { + return buildSignature(sql, null, null, type); + } - if (id != that.id) { - return false; - } - return connectionId.equals(that.connectionId); - } + /** Builds an Avatica signature from Arrow result and parameter schemas. */ + public static Signature buildSignature( + final String sql, final Schema resultSetSchema, final Schema parameterSchema) { + StatementType statementType = + resultSetSchema == null || resultSetSchema.getFields().isEmpty() + ? StatementType.IS_DML + : StatementType.SELECT; + return buildSignature(sql, resultSetSchema, parameterSchema, statementType); + } - @Override - public int hashCode() { - int result = connectionId.hashCode(); - result = 31 * result + id; - return result; - } + private static Signature buildSignature( + final String sql, + final Schema resultSetSchema, + final Schema parameterSchema, + final StatementType statementType) { + List columnMetaData = + resultSetSchema == null + ? new ArrayList<>() + : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()); + List parameters = + parameterSchema == null + ? new ArrayList<>() + : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); + return new Signature( + columnMetaData, + sql, + parameters, + Collections.emptyMap(), + null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor + statementType); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java new file mode 100644 index 0000000000..415af19e8f --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import org.apache.arrow.flight.FlightInfo; +import org.apache.calcite.avatica.Meta.ExecuteBatchResult; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; +import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.remote.TypedValue; + +/** Statement capabilities used by {@link ArrowFlightMetaImpl}. */ +interface ArrowFlightMetaStatement extends Statement { + + @Override + ArrowFlightConnection getConnection() throws SQLException; + + FlightInfo executeFlightInfoQuery() throws SQLException; + + /** + * Avatica routes {@link Statement#execute(String)} through Meta.prepareAndExecute(...), so plain + * statements still need this hook even when they support direct executeQuery/executeUpdate paths. + */ + ExecuteResult prepareAndExecute( + String query, long maxRowCount, int maxRowsInFirstFrame, PrepareCallback callback) + throws SQLException; + + default ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default void closeStatement() {} +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index d7af6902f4..bd7ebbe0e4 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -16,51 +16,48 @@ */ package org.apache.arrow.driver.jdbc; -import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.Collections; +import java.util.List; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta.ExecuteBatchResult; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.MetaResultSet; +import org.apache.calcite.avatica.Meta.PrepareCallback; import org.apache.calcite.avatica.Meta.Signature; import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.Meta.StatementType; +import org.apache.calcite.avatica.remote.TypedValue; -/** Arrow Flight JBCS's implementation {@link PreparedStatement}. */ +/** Arrow Flight JDBC's implementation {@link java.sql.PreparedStatement}. */ public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement - implements ArrowFlightInfoStatement { + implements ArrowFlightMetaStatement { - private final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; + private ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; private ArrowFlightPreparedStatement( final ArrowFlightConnection connection, - final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, final StatementHandle handle, final Signature signature, + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) throws SQLException { super(connection, handle, signature, resultSetType, resultSetConcurrency, resultSetHoldability); this.preparedStatement = Preconditions.checkNotNull(preparedStatement); + this.handle.signature = signature; + setSignature(signature); } - static ArrowFlightPreparedStatement newPreparedStatement( - final ArrowFlightConnection connection, - final ArrowFlightSqlClientHandler.PreparedStatement preparedStmt, - final StatementHandle statementHandle, - final Signature signature, - final int resultSetType, - final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { - return new ArrowFlightPreparedStatement( - connection, - preparedStmt, - statementHandle, - signature, - resultSetType, - resultSetConcurrency, - resultSetHoldability); + static Builder builder(final ArrowFlightConnection connection) { + return new Builder(connection); } @Override @@ -68,14 +65,202 @@ public ArrowFlightConnection getConnection() throws SQLException { return (ArrowFlightConnection) super.getConnection(); } + ExecuteResult prepareAndExecute(final PrepareCallback callback) throws SQLException { + ensurePrepared(); + final StatementType statementType = preparedStatement.getType(); + final long updateCount = + statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; + synchronized (callback.getMonitor()) { + callback.clear(); + callback.assign(handle.signature, null, updateCount); + } + callback.execute(); + final MetaResultSet metaResultSet = + MetaResultSet.create(handle.connectionId, handle.id, false, handle.signature, null); + return new ExecuteResult(Collections.singletonList(metaResultSet)); + } + + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + + Schema getDataSetSchema() { + ensurePrepared(); + return preparedStatement.getDataSetSchema(); + } + @Override public synchronized void close() throws SQLException { - this.preparedStatement.close(); super.close(); } + void closePreparedResources() { + if (preparedStatement != null) { + preparedStatement.close(); + preparedStatement = null; + } + } + + ExecuteResult executeWithTypedValues( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + ensurePrepared(); + Preconditions.checkArgument( + connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); + new AvaticaParameterBinder( + preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()) + .bind(typedValues); + + if (statementHandle.signature == null + || statementHandle.signature.statementType == StatementType.IS_DML) { + long updatedCount = preparedStatement.executeUpdate(); + return new ExecuteResult( + Collections.singletonList( + MetaResultSet.count(statementHandle.connectionId, statementHandle.id, updatedCount))); + } + + // TODO Why is maxRowCount ignored? + return new ExecuteResult( + Collections.singletonList( + MetaResultSet.create( + statementHandle.connectionId, + statementHandle.id, + true, + statementHandle.signature, + null))); + } + + ExecuteBatchResult executeBatchWithTypedValues( + final StatementHandle statementHandle, final List> parameterValuesList) { + ensurePrepared(); + Preconditions.checkArgument( + connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); + final AvaticaParameterBinder binder = + new AvaticaParameterBinder( + preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()); + for (int i = 0; i < parameterValuesList.size(); i++) { + binder.bind(parameterValuesList.get(i), i); + } + + long[] updatedCounts = {preparedStatement.executeUpdate()}; + return new ExecuteBatchResult(updatedCounts); + } + + @Override + public ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + return executeWithTypedValues(statementHandle, typedValues, maxRowCount); + } + + @Override + public ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + return executeBatchWithTypedValues(statementHandle, parameterValuesList); + } + + @Override + public void closeStatement() { + closePreparedResources(); + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { + ensurePrepared(); return preparedStatement.executeQuery(); } + + private void ensurePrepared() { + if (preparedStatement == null) { + throw new IllegalStateException("PreparedStatement is already closed."); + } + } + + static final class Builder { + private final ArrowFlightConnection connection; + private StatementHandle handle; + private String query; + private Integer resultSetType; + private Integer resultSetConcurrency; + private Integer resultSetHoldability; + private boolean generateHandle; + + private Builder(final ArrowFlightConnection connection) { + this.connection = Preconditions.checkNotNull(connection); + } + + Builder withQuery(final String query) { + this.query = Preconditions.checkNotNull(query); + return this; + } + + Builder withGeneratedHandle() { + this.generateHandle = true; + this.handle = null; + return this; + } + + Builder withExistingStatement(final AvaticaStatement statement) throws SQLException { + Preconditions.checkNotNull(statement); + this.generateHandle = false; + this.handle = Preconditions.checkNotNull(statement.handle); + this.resultSetType = statement.getResultSetType(); + this.resultSetConcurrency = statement.getResultSetConcurrency(); + this.resultSetHoldability = statement.getResultSetHoldability(); + return this; + } + + Builder withResultSetType(final int resultSetType) { + this.resultSetType = resultSetType; + return this; + } + + Builder withResultSetConcurrency(final int resultSetConcurrency) { + this.resultSetConcurrency = resultSetConcurrency; + return this; + } + + Builder withResultSetHoldability(final int resultSetHoldability) { + this.resultSetHoldability = resultSetHoldability; + return this; + } + + ArrowFlightPreparedStatement build() throws SQLException { + Preconditions.checkNotNull(query); + Preconditions.checkNotNull(resultSetType); + Preconditions.checkNotNull(resultSetConcurrency); + Preconditions.checkNotNull(resultSetHoldability); + if (!generateHandle && handle == null) { + throw new IllegalStateException("PreparedStatement builder requires a handle."); + } + + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement = + connection.getClientHandler().prepare(query); + final Signature signature = + ArrowFlightMetaImpl.buildSignature( + query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema()); + + return new ArrowFlightPreparedStatement( + connection, + generateHandle ? null : handle, + signature, + preparedStatement, + resultSetType, + resultSetConcurrency, + resultSetHoldability); + } + } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java index 577aee3b4a..0df8f20d2a 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -16,17 +16,24 @@ */ package org.apache.arrow.driver.jdbc; +import java.sql.ResultSet; import java.sql.SQLException; -import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaResultSet; import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.Meta.StatementType; /** A SQL statement for querying data from an Arrow Flight server. */ -public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightInfoStatement { +public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightMetaStatement { ArrowFlightStatement( final ArrowFlightConnection connection, @@ -42,20 +49,140 @@ public ArrowFlightConnection getConnection() throws SQLException { return (ArrowFlightConnection) super.getConnection(); } + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + // Keep Avatica Statement.execute(String) behavior: Avatica calls Meta.prepareAndExecute, + // which resolves to this statement hook. + this.closeStatement(); + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + + @Override + public ResultSet executeQuery(final String sql) throws SQLException { + checkOpen(); + updateCount = -1; + switchToDirectStatementMode(); + try { + final Meta.Signature signature = + ArrowFlightMetaImpl.buildSignature(sql, StatementType.SELECT); + setSignature(signature); + return executeQueryInternal(signature, false); + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + + @Override + public long executeLargeUpdate(final String sql) throws SQLException { + checkOpen(); + clearOpenResultSet(); + updateCount = -1; + switchToDirectStatementMode(); + + try { + final long updatedCount = getConnection().getClientHandler().executeUpdate(sql); + setSignature(ArrowFlightMetaImpl.buildSignature(sql, StatementType.IS_DML)); + updateCount = updatedCount; + return updatedCount; + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { - final PreparedStatement preparedStatement = - getConnection().getMeta().getPreparedStatement(handle); + final ArrowFlightConnection connection = getConnection(); final Meta.Signature signature = getSignature(); if (signature == null) { return null; } - final Schema resultSetSchema = preparedStatement.getDataSetSchema(); - signature.columns.addAll( - ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); - setSignature(signature); + // A Statement handle can point to either this direct statement instance or a prepared + // statement instance created by Avatica Statement.execute(String) through + // Meta.prepareAndExecute. + final AvaticaStatement currentStatement = connection.statementMap.get(handle.id); + if (currentStatement instanceof ArrowFlightPreparedStatement) { + // Prepared path: reuse the current statement implementation associated with the handle. + final FlightInfo flightInfo = + ((ArrowFlightPreparedStatement) currentStatement).executeFlightInfoQuery(); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } - return preparedStatement.executeQuery(); + // Direct Statement.executeQuery(String) / executeUpdate(String) path. + final FlightInfo flightInfo = connection.getClientHandler().getInfo(signature.sql); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } + + private void updateSignatureColumnsFromFlightInfo( + final Meta.Signature signature, final FlightInfo flightInfo) { + final Schema resultSetSchema = flightInfo.getSchemaOptional().orElse(null); + if (resultSetSchema != null) { + signature.columns.addAll( + ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + setSignature(signature); + } + } + + private SQLException wrapStatementExecutionException(final String sql, final Exception exception) + throws SQLException { + if (!(exception instanceof SQLException)) { + return AvaticaConnection.HELPER.createException( + "Error while executing SQL \"" + sql + "\": " + exception.getMessage(), exception); + } + final SQLException sqlException = (SQLException) exception; + final String prefix = "Error while executing SQL \"" + sql + "\""; + final String message = sqlException.getMessage(); + if (message != null && message.startsWith(prefix)) { + return sqlException; + } + final Throwable cause = sqlException.getCause(); + if (cause instanceof FlightRuntimeException) { + final FlightStatusCode statusCode = ((FlightRuntimeException) cause).status().code(); + if (statusCode == FlightStatusCode.UNAVAILABLE) { + return sqlException; + } + } + return AvaticaConnection.HELPER.createException(prefix + ": " + message, sqlException); + } + + private void clearOpenResultSet() throws SQLException { + synchronized (this) { + if (openResultSet != null) { + final AvaticaResultSet resultSet = openResultSet; + openResultSet = null; + try { + resultSet.close(); + } catch (Exception exception) { + throw AvaticaConnection.HELPER.createException( + "Error while closing previous result set", exception); + } + } + } + } + + private void switchToDirectStatementMode() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final AvaticaStatement existingStatement = connection.statementMap.get(handle.id); + if (existingStatement == this) { + return; + } + if (existingStatement instanceof ArrowFlightPreparedStatement) { + // Release resources from previously attached statement implementation before switching back + // to direct statement mode for executeQuery/executeUpdate. + ((ArrowFlightPreparedStatement) existingStatement).closeStatement(); + } + connection.statementMap.put(handle.id, this); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index f0ea284239..08b2c5f93e 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -267,6 +267,16 @@ public FlightInfo getInfo(final String query) { return sqlClient.execute(query, getOptions()); } + /** + * Executes an update query directly, without creating a prepared statement first. + * + * @param query The update query. + * @return the number of rows affected. + */ + public long executeUpdate(final String query) { + return sqlClient.executeUpdate(query, getOptions()); + } + @Override public void close() throws SQLException { if (catalog.isPresent()) { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java new file mode 100644 index 0000000000..10ce0bd285 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.example; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +/** + * Minimal sample app for using the Arrow Flight SQL JDBC driver. + * + *

Defaults are configured for a local Dremio instance: + * + *

    + *
  • host: {@code localhost} + *
  • port: {@code 32010} + *
  • user: {@code dremio} + *
  • password: {@code dremio123} + *
+ * + *

Arguments are optional and positional: + * + *

+ *   [host] [port] [user] [password] [selectSql] [updateSql]
+ * 
+ * + *

If {@code updateSql} is omitted, only {@code Statement.executeQuery(...)} is executed. + */ +public final class ArrowFlightJdbcSampleApp { + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 32010; + private static final String DEFAULT_USER = "dremio"; + private static final String DEFAULT_PASSWORD = "dremio123"; + private static final String DEFAULT_SELECT_SQL = "SELECT 1 AS sample_value"; + + private ArrowFlightJdbcSampleApp() {} + + public static void main(final String[] args) throws Exception { + final String host = getArg(args, 0, DEFAULT_HOST); + final int port = Integer.parseInt(getArg(args, 1, Integer.toString(DEFAULT_PORT))); + final String user = getArg(args, 2, DEFAULT_USER); + final String password = getArg(args, 3, DEFAULT_PASSWORD); + final String selectSql = getArg(args, 4, DEFAULT_SELECT_SQL); + final String updateSql = getArg(args, 5, ""); + + final String url = String.format("jdbc:arrow-flight-sql://%s:%d", host, port); + final Properties properties = new Properties(); + properties.setProperty("user", user); + properties.setProperty("password", password); + properties.setProperty("useEncryption", "false"); + + System.out.println("Connecting to " + url); + try (Connection connection = DriverManager.getConnection(url, properties); + Statement statement = connection.createStatement()) { + runSelect(statement, selectSql); + + if (updateSql.isEmpty()) { + System.out.println( + "Skipping Statement.executeUpdate(...) because no updateSql argument was provided."); + } else { + runUpdate(statement, updateSql); + } + } + } + + private static void runSelect(final Statement statement, final String selectSql) + throws SQLException { + System.out.println("Running Statement.executeQuery: " + selectSql); + try (ResultSet resultSet = statement.executeQuery(selectSql)) { + final ResultSetMetaData metadata = resultSet.getMetaData(); + final int columnCount = metadata.getColumnCount(); + int rowCount = 0; + while (resultSet.next()) { + rowCount++; + final StringBuilder rowBuilder = new StringBuilder(); + for (int i = 1; i <= columnCount; i++) { + if (i > 1) { + rowBuilder.append(", "); + } + rowBuilder.append(metadata.getColumnLabel(i)).append('=').append(resultSet.getObject(i)); + } + System.out.println("row " + rowCount + ": " + rowBuilder); + } + System.out.println("Statement.executeQuery returned " + rowCount + " row(s)"); + } + } + + private static void runUpdate(final Statement statement, final String updateSql) + throws SQLException { + System.out.println("Running Statement.executeUpdate: " + updateSql); + final int updateCount = statement.executeUpdate(updateSql); + System.out.println("Statement.executeUpdate affected " + updateCount + " row(s)"); + } + + private static String getArg(final String[] args, final int index, final String defaultValue) { + if (index >= args.length) { + return defaultValue; + } + final String arg = args[index]; + return arg == null || arg.isEmpty() ? defaultValue : arg; + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index 0369c3a162..138a8e5b76 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -21,6 +21,8 @@ import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; @@ -98,6 +100,22 @@ public void testSimpleQueryNoParameterBindingWithExecute() throws SQLException { } } + @Test + public void testPrepareStatementRegistersCreatedStatementByGeneratedHandle() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + final ArrowFlightConnection flightConnection = (ArrowFlightConnection) connection; + + try (final PreparedStatement preparedStatement = connection.prepareStatement(query)) { + final ArrowFlightPreparedStatement arrowPreparedStatement = + (ArrowFlightPreparedStatement) preparedStatement; + + assertNotNull(flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); + assertSame( + arrowPreparedStatement, + flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); + } + } + @Test public void testQueryWithParameterBinding() throws SQLException { final String query = "Fake query with parameters"; diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java index 632cb0ba56..20e2059722 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -18,10 +18,13 @@ import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import java.sql.Connection; import java.sql.ResultSet; @@ -137,6 +140,34 @@ public void testExecuteShouldRunSelectQuery() throws SQLException { is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L)))); } + @Test + public void testExecuteReplacesStatementMapEntryWithPreparedStatement() throws SQLException { + final ArrowFlightStatement arrowStatement = (ArrowFlightStatement) statement; + final ArrowFlightConnection arrowConnection = (ArrowFlightConnection) connection; + + assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); + + final Object preparedStatement = arrowConnection.statementMap.get(arrowStatement.handle.id); + + assertNotNull(preparedStatement); + assertSame(preparedStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); + assertThat(preparedStatement, instanceOf(ArrowFlightPreparedStatement.class)); + } + + @Test + public void testExecuteQueryRestoresStatementMapEntryWithStatement() throws SQLException { + final ArrowFlightStatement arrowStatement = (ArrowFlightStatement) statement; + final ArrowFlightConnection arrowConnection = (ArrowFlightConnection) connection; + + assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); + + try (ResultSet resultSet = statement.executeQuery(SAMPLE_QUERY_CMD)) { + assertThat(resultSet.next(), is(true)); + } + + assertSame(arrowStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); + } + @Test public void testExecuteShouldRunUpdateQueryForSmallUpdate() throws SQLException { assertThat(statement.execute(SAMPLE_UPDATE_QUERY), is(false)); // Means this is an UPDATE query. diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java index f7c31c590c..05e85227f0 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -224,4 +225,13 @@ public void testShouldFailToPrepareStatementForBadStatement() { } assertThat(count, is(1)); } + + @Test + public void testExecuteLargeUpdateShouldWrapBadStatement() { + final String badQuery = "BAD INVALID UPDATE"; + final SQLException exception = + assertThrows(SQLException.class, () -> statement.executeLargeUpdate(badQuery)); + assertThat( + exception.getMessage(), startsWith(format("Error while executing SQL \"%s\"", badQuery))); + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java new file mode 100644 index 0000000000..c5c35a9173 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.Message; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.function.Consumer; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.FlightSqlUtils; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class ArrowFlightStatementProtocolTest { + private static final String SELECT_QUERY = "SELECT * FROM PROTOCOL_SELECT"; + private static final String UPDATE_QUERY = "UPDATE PROTOCOL_UPDATE"; + private static final Schema QUERY_SCHEMA = + new Schema(Collections.singletonList(Field.nullable("id", MinorType.INT.getType()))); + + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + @RegisterExtension + public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION = + FlightServerTestExtension.createStandardTestExtension(PRODUCER); + + private Connection connection; + + @BeforeAll + public static void setUpBeforeClass() { + PRODUCER.addSelectQuery( + SELECT_QUERY, + QUERY_SCHEMA, + Collections.singletonList( + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(QUERY_SCHEMA, allocator)) { + IntVector vector = (IntVector) root.getVector("id"); + vector.setSafe(0, 1); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + PRODUCER.addUpdateQuery(UPDATE_QUERY, 1); + + final Message commandGetDbSchemas = CommandGetDbSchemas.getDefaultInstance(); + final Consumer commandGetSchemasResultProducer = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = + VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + catalogName.setSafe(0, new Text("catalog_name #0")); + schemaName.setSafe(0, new Text("db_schema_name #0")); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + PRODUCER.addCatalogQuery(commandGetDbSchemas, commandGetSchemasResultProducer); + } + + @BeforeEach + public void setUp() throws SQLException { + PRODUCER.clearActionTypeCounter(); + PRODUCER.clearCommandTypeCounter(); + connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection); + } + + @AfterAll + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testStatementExecuteQueryUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteThenExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteUpdateThenExecuteQueryUsesStatementProtocol() + throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + try (ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteQueryUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY); + ResultSet resultSet = statement.executeQuery()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY)) { + assertThat(statement.execute(), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUpdateUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.executeUpdate(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.execute(), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testMetadataGetSchemasUsesJdbcApi() throws SQLException { + final DatabaseMetaData metaData = connection.getMetaData(); + try (ResultSet resultSet = metaData.getSchemas()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java index 45c2a96404..230c1346fb 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -103,6 +103,12 @@ public final class MockFlightSqlProducer implements FlightSqlProducer { private final Map>> expectedParameterValues = new HashMap<>(); private final Map actionTypeCounter = new HashMap<>(); + private final Map commandTypeCounter = new HashMap<>(); + + public static final String COMMAND_STATEMENT_QUERY = "statement_query"; + public static final String COMMAND_STATEMENT_UPDATE = "statement_update"; + public static final String COMMAND_PREPARED_STATEMENT_QUERY = "prepared_statement_query"; + public static final String COMMAND_PREPARED_STATEMENT_UPDATE = "prepared_statement_update"; private static FlightInfo getFlightInfoExportedAndImportedKeys( final Message message, final FlightDescriptor descriptor) { @@ -269,6 +275,7 @@ public FlightInfo getFlightInfoStatement( final CommandStatementQuery commandStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_STATEMENT_QUERY); final String query = commandStatementQuery.getQuery(); final Entry> queryInfo = Preconditions.checkNotNull( @@ -289,6 +296,7 @@ public FlightInfo getFlightInfoPreparedStatement( final CommandPreparedStatementQuery commandPreparedStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_QUERY); final ByteString preparedStatementHandle = commandPreparedStatementQuery.getPreparedStatementHandle(); @@ -356,6 +364,7 @@ public Runnable acceptPutStatement( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_STATEMENT_UPDATE); return () -> { final String query = commandStatementUpdate.getQuery(); final BiConsumer> resultProvider = @@ -429,6 +438,7 @@ public Runnable acceptPutPreparedStatementUpdate( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_UPDATE); final ByteString handle = commandPreparedStatementUpdate.getPreparedStatementHandle(); final String query = Preconditions.checkNotNull( @@ -651,10 +661,22 @@ public void clearActionTypeCounter() { actionTypeCounter.clear(); } + public void clearCommandTypeCounter() { + commandTypeCounter.clear(); + } + public Map getActionTypeCounter() { return actionTypeCounter; } + public Map getCommandTypeCounter() { + return commandTypeCounter; + } + + private void incrementCommandTypeCounter(String commandType) { + commandTypeCounter.put(commandType, commandTypeCounter.getOrDefault(commandType, 0) + 1); + } + private void getStreamCatalogFunctions( final Message ticket, final ServerStreamListener serverStreamListener) { Preconditions.checkNotNull(