diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index bc0ca31dc635..c9c8fdb676b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -226,12 +226,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } } + @scala.annotation.nowarn("cat=deprecation") def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpression(e) case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case stream: LazyList[_] => stream.map(recursiveTransform).force + case stream: Stream[_] => stream.map(recursiveTransform).force + case lazyList: LazyList[_] => lazyList.map(recursiveTransform).force case seq: Iterable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 23d26854a767..6683f2dbfb39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID +import scala.annotation.nowarn import scala.collection.{mutable, Map} import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -378,12 +379,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case nonChild: AnyRef => nonChild case null => null } + @nowarn("cat=deprecation") val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: LazyList[_] => - // LazyList is lazy so we need to force materialization + case s: Stream[_] => + // Stream is lazy so we need to force materialization s.map(mapChild).force + case l: LazyList[_] => + // LazyList is lazy so we need to force materialization + l.map(mapChild).force case s: Seq[_] => s.map(mapChild) case m: Map[_, _] => @@ -801,6 +806,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case other => other } + @nowarn("cat=deprecation") val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => arg.asInstanceOf[BaseType].clone() @@ -813,7 +819,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case (_, other) => other } case d: DataType => d // Avoid unpacking Structs - case args: LazyList[_] => args.map(mapChild).force // Force materialization on stream + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: LazyList[_] => args.map(mapChild).force // Force materialization on LazyList case args: Iterable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 31f7e07143c5..f783083d0a44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import scala.annotation.nowarn + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -83,6 +85,26 @@ class LogicalPlanSuite extends SparkFunSuite { } test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + @nowarn("cat=deprecation") + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + @nowarn("cat=deprecation") + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } + + test("SPARK-45685: transformExpressions works with a LazyList") { val id1 = NamedExpression.newExprId val id2 = NamedExpression.newExprId val plan = Project(LazyList( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4dbadef93a07..21542d43eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees import java.math.BigInteger import java.util.UUID +import scala.annotation.nowarn import scala.collection.mutable.ArrayBuffer import org.json4s.JsonAST._ @@ -693,6 +694,22 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("transform works on stream of children") { + @nowarn("cat=deprecation") + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + @nowarn("cat=deprecation") + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("SPARK-45685: transform works on LazyList of children") { val before = Coalesce(LazyList(Literal(1), Literal(2))) // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the // situation in which the TreeNode.mapChildren function's change detection is not triggered. A @@ -707,6 +724,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("withNewChildren on stream of children") { + @nowarn("cat=deprecation") + val before = Coalesce(Stream(Literal(1), Literal(2))) + @nowarn("cat=deprecation") + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + @nowarn("cat=deprecation") + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("SPARK-45685: withNewChildren on LazyList of children") { val before = Coalesce(LazyList(Literal(1), Literal(2))) val result = before.withNewChildren(LazyList(Literal(1), Literal(3))) val expected = Coalesce(LazyList(Literal(1), Literal(3))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 382f8cf8861a..6ec0836f704c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -165,8 +165,10 @@ trait CodegenSupport extends SparkPlan { } } + @scala.annotation.nowarn("cat=deprecation") val inputVars = inputVarsCandidate match { - case stream: LazyList[ExprCode] => stream.force + case stream: Stream[ExprCode] => stream.force + case lazyList: LazyList[ExprCode] => lazyList.force case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 15de4c5cc5b2..1400ee25f431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -744,6 +744,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("SPARK-24500: create union with stream of children") { + @scala.annotation.nowarn("cat=deprecation") + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-45685: create union with LazyList of children") { val df = Union(LazyList( Range(1, 1, 1, 1), Range(1, 2, 1, 1))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 3aaf61ffba46..4d2d46582892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -785,6 +785,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test("SPARK-26680: Stream in groupBy does not cause StackOverflowError") { + @scala.annotation.nowarn("cat=deprecation") + val groupByCols = Stream(col("key")) + val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value") + .groupBy(groupByCols: _*) + .max("value") + + checkAnswer(df, Seq(Row(1, 3), Row(2, 3))) + } + + test("SPARK-45685: LazyList in groupBy does not cause StackOverflowError") { val groupByCols = LazyList(col("key")) val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value") .groupBy(groupByCols: _*)