diff --git a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala index 0c623b2b5fe1..bfa33a804c38 100644 --- a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala +++ b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, SparkVectorUtil} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkVersionUtil, Utils} import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} @@ -49,12 +49,12 @@ import scala.collection.mutable.ArrayBuffer class ColumnarArrowPythonRunner( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], schema: StructType, timeZoneId: String, conf: Map[String, String], pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunnerShim(funcs, evalType, argOffsets, pythonMetrics) { + extends BasePythonRunnerShim(funcs, evalType, argMetas, pythonMetrics) { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback @@ -149,7 +149,7 @@ class ColumnarArrowPythonRunner( PythonRDD.writeUTF(k, dataOut) PythonRDD.writeUTF(v, dataOut) } - ColumnarArrowPythonRunner.this.writeUdf(dataOut, argOffsets) + ColumnarArrowPythonRunner.this.writeUdf(dataOut, argMetas) } // For Spark earlier than 4.0. It overrides the corresponding abstract method @@ -165,6 +165,12 @@ class ColumnarArrowPythonRunner( } def writeToStreamHelper(dataOut: DataOutputStream): Boolean = { + if (!inputIterator.hasNext) { + // See https://issues.apache.org/jira/browse/SPARK-44705: + // Starting from Spark 4.0, we should return false once the iterator is drained out, + // otherwise Spark won't stop calling this method repeatedly. + return false + } var numRows: Long = 0 val arrowSchema = SparkSchemaUtil.toArrowSchema(schema, timeZoneId) val allocator = ArrowBufferAllocators.contextInstance() @@ -264,7 +270,7 @@ case class ColumnarArrowEvalPythonExec( protected def evaluateColumnar( funcs: Seq[(ChainedPythonFunctions, Long)], - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], iter: Iterator[ColumnarBatch], schema: StructType, context: TaskContext): Iterator[ColumnarBatch] = { @@ -274,7 +280,7 @@ case class ColumnarArrowEvalPythonExec( val columnarBatchIter = new ColumnarArrowPythonRunner( funcs, evalType, - argOffsets, + argMetas, schema, sessionLocalTimeZone, pythonRunnerConf, @@ -306,22 +312,51 @@ case class ColumnarArrowEvalPythonExec( val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] val originalOffsets = new ArrayBuffer[Int] - val argOffsets = inputs.map { - input => - input.map { - e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - val offset = child.output.indexWhere( - _.exprId.equals(e.asInstanceOf[AttributeReference].exprId)) - originalOffsets += offset - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } - }.toArray - }.toArray + val argMetas: Array[Array[(Int, Option[String])]] = if (SparkVersionUtil.gteSpark40) { + // Spark 4.0 requires ArgumentMetadata rather than trivial integer-based offset. + // See https://issues.apache.org/jira/browse/SPARK-44918. + inputs.map { + input => + input.map { + e => + val (key, value) = e match { + case EvalPythonExecBase.NamedArgumentExpressionShim(key, value) => + (Some(key), value) + case _ => + (None, e) + } + val pair: (Int, Option[String]) = if (allInputs.exists(_.semanticEquals(value))) { + allInputs.indexWhere(_.semanticEquals(value)) -> key + } else { + val offset = child.output.indexWhere( + _.exprId.equals(e.asInstanceOf[AttributeReference].exprId)) + originalOffsets += offset + allInputs += value + dataTypes += value.dataType + (allInputs.length - 1) -> key + } + pair + }.toArray + }.toArray + } else { + inputs.map { + input => + input.map { + e => + val pair: (Int, Option[String]) = if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) -> None + } else { + val offset = child.output.indexWhere( + _.exprId.equals(e.asInstanceOf[AttributeReference].exprId)) + originalOffsets += offset + allInputs += e + dataTypes += e.dataType + (allInputs.length - 1) -> None + } + pair + }.toArray + }.toArray + } val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) @@ -339,7 +374,7 @@ case class ColumnarArrowEvalPythonExec( inputCbCache += inputCb numInputRows += inputCb.numRows // We only need to pass the referred cols data to python worker for evaluation. - var colsForEval = new ArrayBuffer[ColumnVector]() + val colsForEval = new ArrayBuffer[ColumnVector]() for (i <- originalOffsets) { colsForEval += inputCb.column(i) } @@ -347,7 +382,7 @@ case class ColumnarArrowEvalPythonExec( } val outputColumnarBatchIterator = - evaluateColumnar(pyFuncs, argOffsets, inputBatchIter, schema, context) + evaluateColumnar(pyFuncs, argMetas, inputBatchIter, schema, context) val res = outputColumnarBatchIterator.zipWithIndex.map { case (outputCb, batchId) => diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala index 0ea34ec6ad63..52a17995f386 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala @@ -39,8 +39,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { .set("spark.executor.cores", "1") } - // TODO: fix on spark-4.0 - testWithMaxSparkVersion("arrow_udf test: without projection", "3.5") { + test("arrow_udf test: without projection") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) .toDF("a", "b") @@ -60,8 +59,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { checkAnswer(df2, expected) } - // TODO: fix on spark-4.0 - testWithMaxSparkVersion("arrow_udf test: with unrelated projection", "3.5") { + test("arrow_udf test: with unrelated projection") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) .toDF("a", "b") @@ -81,7 +79,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite { checkAnswer(df, expected) } - // TODO: fix on spark-4.0 + // A fix needed for Spark 4.0 change in https://github.com/apache/spark/pull/42864. testWithMaxSparkVersion("arrow_udf test: with preprojection", "3.5") { lazy val base = Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0)) diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala index fc0f82690509..82d971d5d6a4 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -28,9 +28,12 @@ import java.net.Socket abstract class BasePythonRunnerShim( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) { + extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( + funcs.map(_._1), + evalType, + argMetas.map(_.map(_._1))) { // The type aliases below provide consistent type names in child classes, // ensuring code compatibility with both Spark 4.0 and earlier versions. type Writer = WriterThread @@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim( partitionIndex: Int, context: TaskContext): Writer - protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets) + protected def writeUdf( + dataOut: DataOutputStream, + argMetas: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1))) } override protected def newWriterThread( diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala index 843d7e0169d5..7221e330a7dd 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType abstract class EvalPythonExecBase extends EvalPythonExec { @@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec { throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") } } + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = None + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala index fc0f82690509..82d971d5d6a4 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -28,9 +28,12 @@ import java.net.Socket abstract class BasePythonRunnerShim( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) { + extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( + funcs.map(_._1), + evalType, + argMetas.map(_.map(_._1))) { // The type aliases below provide consistent type names in child classes, // ensuring code compatibility with both Spark 4.0 and earlier versions. type Writer = WriterThread @@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim( partitionIndex: Int, context: TaskContext): Writer - protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets) + protected def writeUdf( + dataOut: DataOutputStream, + argMetas: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1))) } override protected def newWriterThread( diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala index 843d7e0169d5..7221e330a7dd 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType abstract class EvalPythonExecBase extends EvalPythonExec { @@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec { throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") } } + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = None + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala index fc0f82690509..82d971d5d6a4 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -28,9 +28,12 @@ import java.net.Socket abstract class BasePythonRunnerShim( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) { + extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( + funcs.map(_._1), + evalType, + argMetas.map(_.map(_._1))) { // The type aliases below provide consistent type names in child classes, // ensuring code compatibility with both Spark 4.0 and earlier versions. type Writer = WriterThread @@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim( partitionIndex: Int, context: TaskContext): Writer - protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets) + protected def writeUdf( + dataOut: DataOutputStream, + argMetas: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1))) } override protected def newWriterThread( diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala index 843d7e0169d5..7221e330a7dd 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType abstract class EvalPythonExecBase extends EvalPythonExec { @@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec { throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") } } + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = None + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala index 1a14622f87f0..ecabc9a93b7b 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -28,12 +28,12 @@ import java.net.Socket abstract class BasePythonRunnerShim( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( funcs.map(_._1), evalType, - argOffsets, + argMetas.map(_.map(_._1)), None) { // The type aliases below provide consistent type names in child classes, // ensuring code compatibility with both Spark 4.0 and earlier versions. @@ -47,8 +47,10 @@ abstract class BasePythonRunnerShim( partitionIndex: Int, context: TaskContext): Writer - protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets) + protected def writeUdf( + dataOut: DataOutputStream, + argMetas: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1))) } override protected def newWriterThread( diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala index 843d7e0169d5..3e74f5490114 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} import org.apache.spark.sql.types.StructType abstract class EvalPythonExecBase extends EvalPythonExec { @@ -32,3 +33,12 @@ abstract class EvalPythonExecBase extends EvalPythonExec { throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") } } + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = expr match { + case NamedArgumentExpression(key, value) => Some((key, value)) + case _ => None + } + } +} diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala index b84ff9b9f66f..127a0fc3cfc9 100644 --- a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala +++ b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala @@ -20,6 +20,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.TaskContext import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker} import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import java.io.DataOutputStream @@ -27,12 +28,12 @@ import java.io.DataOutputStream abstract class BasePythonRunnerShim( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, - argOffsets: Array[Array[Int]], + argMetas: Array[Array[(Int, Option[String])]], pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[ColumnarBatch, ColumnarBatch]( funcs.map(_._1), evalType, - argOffsets, + argMetas.map(_.map(_._1)), None, pythonMetrics) { @@ -43,8 +44,14 @@ abstract class BasePythonRunnerShim( partitionIndex: Int, context: TaskContext): Writer - protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None) + protected def writeUdf( + dataOut: DataOutputStream, + argOffsets: Array[Array[(Int, Option[String])]]): Unit = { + PythonUDFRunner.writeUDFs( + dataOut, + funcs, + argOffsets.map(_.map(pair => ArgumentMetadata(pair._1, pair._2))), + None) } override protected def newWriter( diff --git a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala index 99acc2644f58..7ad7ca6b09ee 100644 --- a/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala +++ b/shims/spark40/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.execution.python.EvalPythonEvaluatorFactory +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} abstract class EvalPythonExecBase extends EvalPythonExec { @@ -24,3 +24,12 @@ abstract class EvalPythonExecBase extends EvalPythonExec { throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate") } } + +object EvalPythonExecBase { + object NamedArgumentExpressionShim { + def unapply(expr: Expression): Option[(String, Expression)] = expr match { + case NamedArgumentExpression(key, value) => Some((key, value)) + case _ => None + } + } +}