Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(true)

val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled")
.category(CATEGORY_SHUFFLE)
.doc(
"When enabled, the native shuffle writer will directly execute the child native plan " +
"instead of reading intermediate batches via JNI. This optimization avoids the " +
"JNI round-trip for native plans whose inputs are all native scans " +
"(CometNativeScanExec, CometIcebergNativeScanExec). Supports single and multi-source " +
"plans (e.g., joins over native scans). " +
"This is an experimental feature and is disabled by default.")
.internal()
.booleanConf
.createWithDefault(false)

val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directRead.enabled")
.category(CATEGORY_SHUFFLE)
Expand Down
4 changes: 1 addition & 3 deletions native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType

import org.apache.comet.serde.OperatorOuterClass.Operator

/**
* A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle.
*/
Expand All @@ -49,7 +51,11 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val outputAttributes: Seq[Attribute] = Seq.empty,
val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty,
val numParts: Int = 0,
val rangePartitionBounds: Option[Seq[InternalRow]] = None)
val rangePartitionBounds: Option[Seq[InternalRow]] = None,
// For direct native execution: the child's native plan to compose with ShuffleWriter
val childNativePlan: Option[Operator] = None,
val commonByKey: Map[String, Array[Byte]] = Map.empty,
val perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty)
extends ShuffleDependency[K, V, C](
_rdd,
partitioner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Exp
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder}
import org.apache.spark.sql.comet.{CometIcebergNativeScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
Expand All @@ -52,6 +53,7 @@ import org.apache.comet.CometConf
import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE}
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo}
import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.operator.CometSink
import org.apache.comet.shims.ShimCometShuffleExchangeExec

Expand Down Expand Up @@ -89,9 +91,113 @@ case class CometShuffleExchangeExec(
private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

/**
* Information about direct native execution optimization. When the child is a single-source
* native plan with a fully native scan (CometNativeScanExec), we can pass the child's native
* plan to the shuffle writer and execute: Scan -> Filter -> Project -> ShuffleWriter all in
* native code, avoiding the JNI round-trip for intermediate batches.
*
* Currently only supports CometNativeScanExec (fully native scans that read files directly via
* DataFusion). JVM scan wrappers (CometScanExec, CometBatchScanExec) still require JNI input
* and are not optimized.
*/
@transient private lazy val directNativeExecutionInfo: Option[DirectNativeExecutionInfo] = {
if (!CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.get()) {
None
} else if (shuffleType != CometNativeShuffle) {
None
} else {
// Check if direct native execution is possible
outputPartitioning match {
case _: RangePartitioning =>
// RangePartitioning requires sampling the data to compute bounds,
// which requires executing the child plan. Fall back to current behavior.
None
case _ =>
child match {
case nativeChild: CometNativeExec =>
// Find input sources using foreachUntilCometInput
val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan]
nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _)

// Optimize when all input sources are native scans
// (CometNativeScanExec, CometIcebergNativeScanExec).
// JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input,
// so we don't optimize those.
// Check if the plan contains subqueries (e.g., bloom filters with might_contain).
// Subqueries are registered with the parent execution context ID, but direct
// native shuffle creates a new execution context, so subquery lookup would fail.
val containsSubquery = nativeChild.exists { p =>
p.expressions.exists(_.exists(_.isInstanceOf[ScalarSubquery]))
}
if (containsSubquery) {
// Fall back to avoid subquery lookup failures
None
} else {
// Check that ALL input sources are native scans (file-reading, no JNI)
val allNativeScans = inputSources.nonEmpty && inputSources.forall {
case _: CometNativeScanExec => true
case _: CometIcebergNativeScanExec => true
case _ => false
}
if (allNativeScans) {
// Collect per-partition plan data from all native scans
val (commonByKey, perPartitionByKey) =
nativeChild.findAllPlanData(nativeChild)
// All scans must have the same partition count
val partitionCounts = perPartitionByKey.values.map(_.length).toSet
if (partitionCounts.size <= 1) {
val numPartitions = partitionCounts.headOption.getOrElse(0)
if (numPartitions == 0) {
// Empty table (no data files) - fall back to normal execution
None
} else {
Some(
DirectNativeExecutionInfo(
nativeChild.nativeOp,
numPartitions,
commonByKey,
perPartitionByKey))
}
} else {
None // Partition count mismatch across scans
}
} else {
None
}
}
case _ =>
None
}
}
}
}

/**
* Returns true if direct native execution optimization is being used for this shuffle. This is
* primarily intended for testing to verify the optimization is applied correctly.
*/
def isDirectNativeExecution: Boolean = directNativeExecutionInfo.isDefined

/**
* Creates an RDD that provides empty iterators for each partition. Used when direct native
* execution is enabled - the shuffle writer will execute the full native plan which reads data
* directly (no JNI input needed).
*/
private def createEmptyPartitionRDD(numPartitions: Int): RDD[ColumnarBatch] = {
sparkContext.parallelize(Seq.empty[ColumnarBatch], numPartitions)
}

@transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) {
// CometNativeShuffle assumes that the input plan is Comet plan.
child.executeColumnar()
directNativeExecutionInfo match {
case Some(info) =>
// Direct native execution: create an RDD with empty partitions.
// The shuffle writer will execute the full native plan which reads data directly.
createEmptyPartitionRDD(info.numPartitions)
case None =>
// Fall back to current behavior: execute child and pass intermediate batches
child.executeColumnar()
}
} else if (shuffleType == CometColumnarShuffle) {
// CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans,
// rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec
Expand Down Expand Up @@ -142,7 +248,10 @@ case class CometShuffleExchangeExec(
child.output,
outputPartitioning,
serializer,
metrics)
metrics,
directNativeExecutionInfo.map(_.childNativePlan),
directNativeExecutionInfo.map(_.commonByKey).getOrElse(Map.empty),
directNativeExecutionInfo.map(_.perPartitionByKey).getOrElse(Map.empty))
metrics("numPartitions").set(dep.partitioner.numPartitions)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
Expand Down Expand Up @@ -586,7 +695,11 @@ object CometShuffleExchangeExec
outputAttributes: Seq[Attribute],
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
metrics: Map[String, SQLMetric],
childNativePlan: Option[Operator] = None,
commonByKey: Map[String, Array[Byte]] = Map.empty,
perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty)
: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions

// The code block below is mostly brought over from
Expand Down Expand Up @@ -653,7 +766,10 @@ object CometShuffleExchangeExec
outputAttributes = outputAttributes,
shuffleWriteMetrics = metrics,
numParts = numParts,
rangePartitionBounds = rangePartitionBounds)
rangePartitionBounds = rangePartitionBounds,
childNativePlan = childNativePlan,
commonByKey = commonByKey,
perPartitionByKey = perPartitionByKey)
dependency
}

Expand Down Expand Up @@ -858,3 +974,17 @@ object CometShuffleExchangeExec
dependency
}
}

/**
* Information needed for direct native execution optimization.
*
* @param childNativePlan
* The child's native operator plan to compose with ShuffleWriter
* @param numPartitions
* The number of partitions (from the underlying scan)
*/
private[shuffle] case class DirectNativeExecutionInfo(
childNativePlan: Operator,
numPartitions: Int,
commonByKey: Map[String, Array[Byte]],
perPartitionByKey: Map[String, Array[Array[Byte]]])
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
mapId,
context,
metrics,
dep.rangePartitionBounds)
dep.rangePartitionBounds,
dep.childNativePlan,
dep.commonByKey,
dep.perPartitionByKey)
case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new CometBypassMergeSortShuffleWriter(
env.blockManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ abstract class CometNativeExec extends CometExec {
* @return
* (commonByKey, perPartitionByKey) - common data is shared, per-partition varies
*/
private def findAllPlanData(
private[comet] def findAllPlanData(
plan: SparkPlan): (Map[String, Array[Byte]], Map[String, Array[Array[Byte]]]) = {
plan match {
// Found an Iceberg scan with planning data
Expand Down
Loading
Loading