From e7e78be65f6153e66db99bd818bf7dab94811917 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 13 Jun 2024 15:28:17 +0800 Subject: [PATCH 1/3] init --- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 4 +++- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 13 ++++++++++--- .../spark/sql/execution/WholeStageCodegenExec.scala | 4 +++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index bc0ca31dc635..c9c8fdb676b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -226,12 +226,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } } + @scala.annotation.nowarn("cat=deprecation") def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpression(e) case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case stream: LazyList[_] => stream.map(recursiveTransform).force + case stream: Stream[_] => stream.map(recursiveTransform).force + case lazyList: LazyList[_] => lazyList.map(recursiveTransform).force case seq: Iterable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 23d26854a767..f90ea92dac95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID +import scala.annotation.nowarn import scala.collection.{mutable, Map} import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -356,6 +357,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ + @nowarn("cat=deprecation") protected final def legacyWithNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false @@ -381,9 +383,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: LazyList[_] => - // LazyList is lazy so we need to force materialization + case s: Stream[_] => + // Stream is lazy so we need to force materialization s.map(mapChild).force + case l: LazyList[_] => + // LazyList is lazy so we need to force materialization + l.map(mapChild).force case s: Seq[_] => s.map(mapChild) case m: Map[_, _] => @@ -781,6 +786,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] } } + @nowarn("cat=deprecation") override def clone(): BaseType = { def mapChild(child: Any): Any = child match { case arg: TreeNode[_] if containsChild(arg) => @@ -813,7 +819,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case (_, other) => other } case d: DataType => d // Avoid unpacking Structs - case args: LazyList[_] => args.map(mapChild).force // Force materialization on stream + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: LazyList[_] => args.map(mapChild).force // Force materialization on LazyList case args: Iterable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 382f8cf8861a..4651a8097ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -150,6 +150,7 @@ trait CodegenSupport extends SparkPlan { * * Note that `outputVars` and `row` can't both be null. */ + @scala.annotation.nowarn("cat=deprecation") final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVarsCandidate = if (outputVars != null) { @@ -166,7 +167,8 @@ trait CodegenSupport extends SparkPlan { } val inputVars = inputVarsCandidate match { - case stream: LazyList[ExprCode] => stream.force + case stream: Stream[ExprCode] => stream.force + case lazyList: LazyList[ExprCode] => lazyList.force case other => other } From 1ef80a905ddc0409883bff5314d893ab02c8eaae Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 13 Jun 2024 15:49:52 +0800 Subject: [PATCH 2/3] move nowarn --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 4 ++-- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f90ea92dac95..6683f2dbfb39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -357,7 +357,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - @nowarn("cat=deprecation") protected final def legacyWithNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false @@ -380,6 +379,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case nonChild: AnyRef => nonChild case null => null } + @nowarn("cat=deprecation") val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. @@ -786,7 +786,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] } } - @nowarn("cat=deprecation") override def clone(): BaseType = { def mapChild(child: Any): Any = child match { case arg: TreeNode[_] if containsChild(arg) => @@ -807,6 +806,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case other => other } + @nowarn("cat=deprecation") val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => arg.asInstanceOf[BaseType].clone() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 4651a8097ed0..6ec0836f704c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -150,7 +150,6 @@ trait CodegenSupport extends SparkPlan { * * Note that `outputVars` and `row` can't both be null. */ - @scala.annotation.nowarn("cat=deprecation") final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVarsCandidate = if (outputVars != null) { @@ -166,6 +165,7 @@ trait CodegenSupport extends SparkPlan { } } + @scala.annotation.nowarn("cat=deprecation") val inputVars = inputVarsCandidate match { case stream: Stream[ExprCode] => stream.force case lazyList: LazyList[ExprCode] => lazyList.force From dbe04498805b24d0c779bca017e16a9cc2f74a28 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 13 Jun 2024 16:48:44 +0800 Subject: [PATCH 3/3] test case --- .../sql/catalyst/plans/LogicalPlanSuite.scala | 22 +++++++++++++++ .../sql/catalyst/trees/TreeNodeSuite.scala | 27 +++++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 8 ++++++ .../execution/WholeStageCodegenSuite.scala | 10 +++++++ 4 files changed, 67 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 31f7e07143c5..f783083d0a44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import scala.annotation.nowarn + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -83,6 +85,26 @@ class LogicalPlanSuite extends SparkFunSuite { } test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + @nowarn("cat=deprecation") + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + @nowarn("cat=deprecation") + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } + + test("SPARK-45685: transformExpressions works with a LazyList") { val id1 = NamedExpression.newExprId val id2 = NamedExpression.newExprId val plan = Project(LazyList( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4dbadef93a07..21542d43eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees import java.math.BigInteger import java.util.UUID +import scala.annotation.nowarn import scala.collection.mutable.ArrayBuffer import org.json4s.JsonAST._ @@ -693,6 +694,22 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("transform works on stream of children") { + @nowarn("cat=deprecation") + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + @nowarn("cat=deprecation") + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("SPARK-45685: transform works on LazyList of children") { val before = Coalesce(LazyList(Literal(1), Literal(2))) // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the // situation in which the TreeNode.mapChildren function's change detection is not triggered. A @@ -707,6 +724,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("withNewChildren on stream of children") { + @nowarn("cat=deprecation") + val before = Coalesce(Stream(Literal(1), Literal(2))) + @nowarn("cat=deprecation") + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + @nowarn("cat=deprecation") + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("SPARK-45685: withNewChildren on LazyList of children") { val before = Coalesce(LazyList(Literal(1), Literal(2))) val result = before.withNewChildren(LazyList(Literal(1), Literal(3))) val expected = Coalesce(LazyList(Literal(1), Literal(3))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 15de4c5cc5b2..1400ee25f431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -744,6 +744,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("SPARK-24500: create union with stream of children") { + @scala.annotation.nowarn("cat=deprecation") + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-45685: create union with LazyList of children") { val df = Union(LazyList( Range(1, 1, 1, 1), Range(1, 2, 1, 1))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 3aaf61ffba46..4d2d46582892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -785,6 +785,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test("SPARK-26680: Stream in groupBy does not cause StackOverflowError") { + @scala.annotation.nowarn("cat=deprecation") + val groupByCols = Stream(col("key")) + val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value") + .groupBy(groupByCols: _*) + .max("value") + + checkAnswer(df, Seq(Row(1, 3), Row(2, 3))) + } + + test("SPARK-45685: LazyList in groupBy does not cause StackOverflowError") { val groupByCols = LazyList(col("key")) val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value") .groupBy(groupByCols: _*)