From 5b07b57972dcced634bee41e2e111f8a726d83ab Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Sep 2019 00:19:36 -0700 Subject: [PATCH 1/3] Change CreateNamedStruct to use CreateNamedStructUnsafe implementation, then delete CreateNamedStructUnsafe --- .../sql/catalyst/expressions/Projection.scala | 8 +--- .../expressions/complexTypeCreator.scala | 41 +------------------ .../optimizer/NormalizeFloatingNumbers.scala | 5 +-- .../expressions/ComplexTypeSuite.scala | 1 - 4 files changed, 3 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index eaaf94baac21..300f075d3276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -127,12 +127,6 @@ object UnsafeProjection InterpretedUnsafeProjection.createProjection(in) } - protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { - exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - } - /** * Returns an UnsafeProjection for given StructType. * @@ -153,7 +147,7 @@ object UnsafeProjection * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - createObject(toUnsafeExprs(exprs)) + createObject(exprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 319a7fc87e59..347cea5571b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -366,51 +366,12 @@ trait CreateNamedStructLike extends Expression { // scalastyle:on line.size.limit case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rowClass = classOf[GenericInternalRow].getName - val values = ctx.freshName("values") - val valCodes = valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - s""" - |${eval.code} - |if (${eval.isNull}) { - | $values[$i] = null; - |} else { - | $values[$i] = ${eval.value}; - |} - """.stripMargin - } - val valuesCode = ctx.splitExpressionsWithCurrentInputs( - expressions = valCodes, - funcName = "createNamedStruct", - extraArguments = "Object[]" -> values :: Nil) - - ev.copy(code = - code""" - |Object[] $values = new Object[${valExprs.size}]; - |$valuesCode - |final InternalRow ${ev.value} = new $rowClass($values); - |$values = null; - """.stripMargin, isNull = FalseLiteral) - } - - override def prettyName: String = "named_struct" -} - -/** - * Creates a struct with the given field names and values. This is a variant that returns - * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } - override def prettyName: String = "named_struct_unsafe" + override def prettyName: String = "named_struct" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index b036092cf1fc..ea01d9e63eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) - case CreateNamedStructUnsafe(children) => - CreateNamedStructUnsafe(children.map(normalize)) - case CreateArray(children) => CreateArray(children.map(normalize)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c4438987cd2..9039cd645159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } test("StringToMap") { From 426d856dfcb895decec794b9501feb335935aba8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Sep 2019 00:24:13 -0700 Subject: [PATCH 2/3] Remove CreateNamedStructLike --- .../expressions/complexTypeCreator.scala | 33 ++++++++----------- .../sql/catalyst/optimizer/ComplexTypes.scala | 4 +-- .../sql/catalyst/optimizer/expressions.scala | 4 +-- .../scala/org/apache/spark/sql/Column.scala | 2 +- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 347cea5571b2..3726a333d7a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -295,9 +295,21 @@ object CreateStruct extends FunctionBuilder { } /** - * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) */ -trait CreateNamedStructLike extends Expression { +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", + examples = """ + Examples: + > SELECT _FUNC_("a", 1, "b", 2, "c", 3); + {"a":1,"b":2,"c":3} + """) +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -348,23 +360,6 @@ trait CreateNamedStructLike extends Expression { override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } -} - -/** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", - examples = """ - Examples: - > SELECT _FUNC_("a", 1, "b", 2, "c", 3); - {"a":1,"b":2,"c":3} - """) -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index db7d6d3254bd..d73ac3adac3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. - case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => - createNamedStructLike.valExprs(ordinal) + case GetStructField(createNameStruct: CreateNamedStruct, ordinal, _) => + createNameStruct.valExprs(ordinal) // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 39709529c00d..61798fba8617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] { if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] - && !newList.head.isInstanceOf[CreateNamedStructLike]) { + && !v.isInstanceOf[CreateNamedStruct] + && !newList.head.isInstanceOf[CreateNamedStruct]) { EqualTo(v, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b0de3c85aaef..ddf5d6720e51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -199,7 +199,7 @@ class Column(val expr: Expression) extends Logging { UnresolvedAlias(a, Some(Column.generateAlias)) // Wait until the struct is resolved. This will generate a nicer looking alias. - case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case struct: CreateNamedStruct => UnresolvedAlias(struct) case expr: Expression => Alias(expr, toPrettySQL(expr))() } From d3925bc02604fdb7223091c066018b2142c38753 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Sep 2019 01:17:35 -0700 Subject: [PATCH 3/3] Update ExpressionEvalHelper.checkResult to properly compare different InternalRow implementations --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index bc1f31b101c6..aef9686975d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -138,6 +138,10 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) + case (result: Seq[InternalRow], expected: Seq[InternalRow]) => + result.size == expected.size && result.zip(expected).forall { case (r, e) => + checkResult(r, e, exprDataType, exprNullable) + } case _ => result == expected }