Skip to content

Commit 8b4eff0

Browse files
authored
[GLUTEN-11088][VL] Spark 4.0: Fix ArrowEvalPythonExecSuite (#11288)
* [GLUTEN-11088][VL] Spark 4.0: Fix ArrowEvalPythonExecSuite
1 parent 30d1ab6 commit 8b4eff0

File tree

12 files changed

+147
-50
lines changed

12 files changed

+147
-50
lines changed

backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3636
import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, SparkVectorUtil}
3737
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
38-
import org.apache.spark.util.Utils
38+
import org.apache.spark.util.{SparkVersionUtil, Utils}
3939

4040
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
4141
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
@@ -49,12 +49,12 @@ import scala.collection.mutable.ArrayBuffer
4949
class ColumnarArrowPythonRunner(
5050
funcs: Seq[(ChainedPythonFunctions, Long)],
5151
evalType: Int,
52-
argOffsets: Array[Array[Int]],
52+
argMetas: Array[Array[(Int, Option[String])]],
5353
schema: StructType,
5454
timeZoneId: String,
5555
conf: Map[String, String],
5656
pythonMetrics: Map[String, SQLMetric])
57-
extends BasePythonRunnerShim(funcs, evalType, argOffsets, pythonMetrics) {
57+
extends BasePythonRunnerShim(funcs, evalType, argMetas, pythonMetrics) {
5858

5959
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
6060

@@ -149,7 +149,7 @@ class ColumnarArrowPythonRunner(
149149
PythonRDD.writeUTF(k, dataOut)
150150
PythonRDD.writeUTF(v, dataOut)
151151
}
152-
ColumnarArrowPythonRunner.this.writeUdf(dataOut, argOffsets)
152+
ColumnarArrowPythonRunner.this.writeUdf(dataOut, argMetas)
153153
}
154154

155155
// For Spark earlier than 4.0. It overrides the corresponding abstract method
@@ -165,6 +165,12 @@ class ColumnarArrowPythonRunner(
165165
}
166166

167167
def writeToStreamHelper(dataOut: DataOutputStream): Boolean = {
168+
if (!inputIterator.hasNext) {
169+
// See https://issues.apache.org/jira/browse/SPARK-44705:
170+
// Starting from Spark 4.0, we should return false once the iterator is drained out,
171+
// otherwise Spark won't stop calling this method repeatedly.
172+
return false
173+
}
168174
var numRows: Long = 0
169175
val arrowSchema = SparkSchemaUtil.toArrowSchema(schema, timeZoneId)
170176
val allocator = ArrowBufferAllocators.contextInstance()
@@ -264,7 +270,7 @@ case class ColumnarArrowEvalPythonExec(
264270

265271
protected def evaluateColumnar(
266272
funcs: Seq[(ChainedPythonFunctions, Long)],
267-
argOffsets: Array[Array[Int]],
273+
argMetas: Array[Array[(Int, Option[String])]],
268274
iter: Iterator[ColumnarBatch],
269275
schema: StructType,
270276
context: TaskContext): Iterator[ColumnarBatch] = {
@@ -274,7 +280,7 @@ case class ColumnarArrowEvalPythonExec(
274280
val columnarBatchIter = new ColumnarArrowPythonRunner(
275281
funcs,
276282
evalType,
277-
argOffsets,
283+
argMetas,
278284
schema,
279285
sessionLocalTimeZone,
280286
pythonRunnerConf,
@@ -306,22 +312,51 @@ case class ColumnarArrowEvalPythonExec(
306312
val allInputs = new ArrayBuffer[Expression]
307313
val dataTypes = new ArrayBuffer[DataType]
308314
val originalOffsets = new ArrayBuffer[Int]
309-
val argOffsets = inputs.map {
310-
input =>
311-
input.map {
312-
e =>
313-
if (allInputs.exists(_.semanticEquals(e))) {
314-
allInputs.indexWhere(_.semanticEquals(e))
315-
} else {
316-
val offset = child.output.indexWhere(
317-
_.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
318-
originalOffsets += offset
319-
allInputs += e
320-
dataTypes += e.dataType
321-
allInputs.length - 1
322-
}
323-
}.toArray
324-
}.toArray
315+
val argMetas: Array[Array[(Int, Option[String])]] = if (SparkVersionUtil.gteSpark40) {
316+
// Spark 4.0 requires ArgumentMetadata rather than trivial integer-based offset.
317+
// See https://issues.apache.org/jira/browse/SPARK-44918.
318+
inputs.map {
319+
input =>
320+
input.map {
321+
e =>
322+
val (key, value) = e match {
323+
case EvalPythonExecBase.NamedArgumentExpressionShim(key, value) =>
324+
(Some(key), value)
325+
case _ =>
326+
(None, e)
327+
}
328+
val pair: (Int, Option[String]) = if (allInputs.exists(_.semanticEquals(value))) {
329+
allInputs.indexWhere(_.semanticEquals(value)) -> key
330+
} else {
331+
val offset = child.output.indexWhere(
332+
_.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
333+
originalOffsets += offset
334+
allInputs += value
335+
dataTypes += value.dataType
336+
(allInputs.length - 1) -> key
337+
}
338+
pair
339+
}.toArray
340+
}.toArray
341+
} else {
342+
inputs.map {
343+
input =>
344+
input.map {
345+
e =>
346+
val pair: (Int, Option[String]) = if (allInputs.exists(_.semanticEquals(e))) {
347+
allInputs.indexWhere(_.semanticEquals(e)) -> None
348+
} else {
349+
val offset = child.output.indexWhere(
350+
_.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
351+
originalOffsets += offset
352+
allInputs += e
353+
dataTypes += e.dataType
354+
(allInputs.length - 1) -> None
355+
}
356+
pair
357+
}.toArray
358+
}.toArray
359+
}
325360
val schema = StructType(dataTypes.zipWithIndex.map {
326361
case (dt, i) =>
327362
StructField(s"_$i", dt)
@@ -339,15 +374,15 @@ case class ColumnarArrowEvalPythonExec(
339374
inputCbCache += inputCb
340375
numInputRows += inputCb.numRows
341376
// We only need to pass the referred cols data to python worker for evaluation.
342-
var colsForEval = new ArrayBuffer[ColumnVector]()
377+
val colsForEval = new ArrayBuffer[ColumnVector]()
343378
for (i <- originalOffsets) {
344379
colsForEval += inputCb.column(i)
345380
}
346381
new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
347382
}
348383

349384
val outputColumnarBatchIterator =
350-
evaluateColumnar(pyFuncs, argOffsets, inputBatchIter, schema, context)
385+
evaluateColumnar(pyFuncs, argMetas, inputBatchIter, schema, context)
351386
val res =
352387
outputColumnarBatchIterator.zipWithIndex.map {
353388
case (outputCb, batchId) =>

backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
3939
.set("spark.executor.cores", "1")
4040
}
4141

42-
// TODO: fix on spark-4.0
43-
testWithMaxSparkVersion("arrow_udf test: without projection", "3.5") {
42+
test("arrow_udf test: without projection") {
4443
lazy val base =
4544
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
4645
.toDF("a", "b")
@@ -60,8 +59,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
6059
checkAnswer(df2, expected)
6160
}
6261

63-
// TODO: fix on spark-4.0
64-
testWithMaxSparkVersion("arrow_udf test: with unrelated projection", "3.5") {
62+
test("arrow_udf test: with unrelated projection") {
6563
lazy val base =
6664
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))
6765
.toDF("a", "b")
@@ -81,7 +79,7 @@ class ArrowEvalPythonExecSuite extends WholeStageTransformerSuite {
8179
checkAnswer(df, expected)
8280
}
8381

84-
// TODO: fix on spark-4.0
82+
// A fix needed for Spark 4.0 change in https://github.com/apache/spark/pull/42864.
8583
testWithMaxSparkVersion("arrow_udf test: with preprojection", "3.5") {
8684
lazy val base =
8785
Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 1), ("3", 0))

shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ import java.net.Socket
2828
abstract class BasePythonRunnerShim(
2929
funcs: Seq[(ChainedPythonFunctions, Long)],
3030
evalType: Int,
31-
argOffsets: Array[Array[Int]],
31+
argMetas: Array[Array[(Int, Option[String])]],
3232
pythonMetrics: Map[String, SQLMetric])
33-
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) {
33+
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](
34+
funcs.map(_._1),
35+
evalType,
36+
argMetas.map(_.map(_._1))) {
3437
// The type aliases below provide consistent type names in child classes,
3538
// ensuring code compatibility with both Spark 4.0 and earlier versions.
3639
type Writer = WriterThread
@@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim(
4346
partitionIndex: Int,
4447
context: TaskContext): Writer
4548

46-
protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = {
47-
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets)
49+
protected def writeUdf(
50+
dataOut: DataOutputStream,
51+
argMetas: Array[Array[(Int, Option[String])]]): Unit = {
52+
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1)))
4853
}
4954

5055
override protected def newWriterThread(

shims/spark32/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919
import org.apache.spark.TaskContext
2020
import org.apache.spark.api.python.ChainedPythonFunctions
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.Expression
2223
import org.apache.spark.sql.types.StructType
2324

2425
abstract class EvalPythonExecBase extends EvalPythonExec {
@@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec {
3233
throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate")
3334
}
3435
}
36+
37+
object EvalPythonExecBase {
38+
object NamedArgumentExpressionShim {
39+
def unapply(expr: Expression): Option[(String, Expression)] = None
40+
}
41+
}

shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ import java.net.Socket
2828
abstract class BasePythonRunnerShim(
2929
funcs: Seq[(ChainedPythonFunctions, Long)],
3030
evalType: Int,
31-
argOffsets: Array[Array[Int]],
31+
argMetas: Array[Array[(Int, Option[String])]],
3232
pythonMetrics: Map[String, SQLMetric])
33-
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) {
33+
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](
34+
funcs.map(_._1),
35+
evalType,
36+
argMetas.map(_.map(_._1))) {
3437
// The type aliases below provide consistent type names in child classes,
3538
// ensuring code compatibility with both Spark 4.0 and earlier versions.
3639
type Writer = WriterThread
@@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim(
4346
partitionIndex: Int,
4447
context: TaskContext): Writer
4548

46-
protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = {
47-
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets)
49+
protected def writeUdf(
50+
dataOut: DataOutputStream,
51+
argMetas: Array[Array[(Int, Option[String])]]): Unit = {
52+
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1)))
4853
}
4954

5055
override protected def newWriterThread(

shims/spark33/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919
import org.apache.spark.TaskContext
2020
import org.apache.spark.api.python.ChainedPythonFunctions
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.Expression
2223
import org.apache.spark.sql.types.StructType
2324

2425
abstract class EvalPythonExecBase extends EvalPythonExec {
@@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec {
3233
throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate")
3334
}
3435
}
36+
37+
object EvalPythonExecBase {
38+
object NamedArgumentExpressionShim {
39+
def unapply(expr: Expression): Option[(String, Expression)] = None
40+
}
41+
}

shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ import java.net.Socket
2828
abstract class BasePythonRunnerShim(
2929
funcs: Seq[(ChainedPythonFunctions, Long)],
3030
evalType: Int,
31-
argOffsets: Array[Array[Int]],
31+
argMetas: Array[Array[(Int, Option[String])]],
3232
pythonMetrics: Map[String, SQLMetric])
33-
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs.map(_._1), evalType, argOffsets) {
33+
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](
34+
funcs.map(_._1),
35+
evalType,
36+
argMetas.map(_.map(_._1))) {
3437
// The type aliases below provide consistent type names in child classes,
3538
// ensuring code compatibility with both Spark 4.0 and earlier versions.
3639
type Writer = WriterThread
@@ -43,8 +46,10 @@ abstract class BasePythonRunnerShim(
4346
partitionIndex: Int,
4447
context: TaskContext): Writer
4548

46-
protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = {
47-
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets)
49+
protected def writeUdf(
50+
dataOut: DataOutputStream,
51+
argMetas: Array[Array[(Int, Option[String])]]): Unit = {
52+
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1)))
4853
}
4954

5055
override protected def newWriterThread(

shims/spark34/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919
import org.apache.spark.TaskContext
2020
import org.apache.spark.api.python.ChainedPythonFunctions
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.Expression
2223
import org.apache.spark.sql.types.StructType
2324

2425
abstract class EvalPythonExecBase extends EvalPythonExec {
@@ -32,3 +33,9 @@ abstract class EvalPythonExecBase extends EvalPythonExec {
3233
throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate")
3334
}
3435
}
36+
37+
object EvalPythonExecBase {
38+
object NamedArgumentExpressionShim {
39+
def unapply(expr: Expression): Option[(String, Expression)] = None
40+
}
41+
}

shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/BasePythonRunnerShim.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ import java.net.Socket
2828
abstract class BasePythonRunnerShim(
2929
funcs: Seq[(ChainedPythonFunctions, Long)],
3030
evalType: Int,
31-
argOffsets: Array[Array[Int]],
31+
argMetas: Array[Array[(Int, Option[String])]],
3232
pythonMetrics: Map[String, SQLMetric])
3333
extends BasePythonRunner[ColumnarBatch, ColumnarBatch](
3434
funcs.map(_._1),
3535
evalType,
36-
argOffsets,
36+
argMetas.map(_.map(_._1)),
3737
None) {
3838
// The type aliases below provide consistent type names in child classes,
3939
// ensuring code compatibility with both Spark 4.0 and earlier versions.
@@ -47,8 +47,10 @@ abstract class BasePythonRunnerShim(
4747
partitionIndex: Int,
4848
context: TaskContext): Writer
4949

50-
protected def writeUdf(dataOut: DataOutputStream, argOffsets: Array[Array[Int]]): Unit = {
51-
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argOffsets)
50+
protected def writeUdf(
51+
dataOut: DataOutputStream,
52+
argMetas: Array[Array[(Int, Option[String])]]): Unit = {
53+
PythonUDFRunner.writeUDFs(dataOut, funcs.map(_._1), argMetas.map(_.map(_._1)))
5254
}
5355

5456
override protected def newWriterThread(

shims/spark35/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecBase.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919
import org.apache.spark.TaskContext
2020
import org.apache.spark.api.python.ChainedPythonFunctions
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression}
2223
import org.apache.spark.sql.types.StructType
2324

2425
abstract class EvalPythonExecBase extends EvalPythonExec {
@@ -32,3 +33,12 @@ abstract class EvalPythonExecBase extends EvalPythonExec {
3233
throw new IllegalStateException("EvalPythonExecTransformer doesn't support evaluate")
3334
}
3435
}
36+
37+
object EvalPythonExecBase {
38+
object NamedArgumentExpressionShim {
39+
def unapply(expr: Expression): Option[(String, Expression)] = expr match {
40+
case NamedArgumentExpression(key, value) => Some((key, value))
41+
case _ => None
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)