Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,6 @@ object UnsafeProjection
InterpretedUnsafeProjection.createProjection(in)
}

protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
}

/**
* Returns an UnsafeProjection for given StructType.
*
Expand All @@ -153,7 +147,7 @@ object UnsafeProjection
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
createObject(toUnsafeExprs(exprs))
createObject(exprs)
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,21 @@ object CreateStruct extends FunctionBuilder {
}

/**
* Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
* Creates a struct with the given field names and values
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
trait CreateNamedStructLike extends Expression {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
examples = """
Examples:
> SELECT _FUNC_("a", 1, "b", 2, "c", 3);
{"a":1,"b":2,"c":3}
""")
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {

lazy val (nameExprs, valExprs) = children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip
Expand Down Expand Up @@ -348,69 +360,13 @@ trait CreateNamedStructLike extends Expression {
override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
}

/**
* Creates a struct with the given field names and values
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
examples = """
Examples:
> SELECT _FUNC_("a", 1, "b", 2, "c", 3);
{"a":1,"b":2,"c":3}
""")
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
val valCodes = valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.genCode(ctx)
s"""
|${eval.code}
|if (${eval.isNull}) {
| $values[$i] = null;
|} else {
| $values[$i] = ${eval.value};
|}
""".stripMargin
}
val valuesCode = ctx.splitExpressionsWithCurrentInputs(
expressions = valCodes,
funcName = "createNamedStruct",
extraArguments = "Object[]" -> values :: Nil)

ev.copy(code =
code"""
|Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
|$values = null;
""".stripMargin, isNull = FalseLiteral)
}

override def prettyName: String = "named_struct"
}

/**
* Creates a struct with the given field names and values. This is a variant that returns
* UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
* this expression automatically at runtime.
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any types GenerateUnsafeProjection doesn't support? From GenerateUnsafeProjection.canSupport, looks no.

ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value)
}

override def prettyName: String = "named_struct_unsafe"
override def prettyName: String = "named_struct"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
createNamedStructLike.valExprs(ordinal)
case GetStructField(createNameStruct: CreateNamedStruct, ordinal, _) =>
createNameStruct.valExprs(ordinal)

// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))

case CreateNamedStructUnsafe(children) =>
CreateNamedStructUnsafe(children.map(normalize))

case CreateArray(children) =>
CreateArray(children.map(normalize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStructLike]
&& !newList.head.isInstanceOf[CreateNamedStructLike]) {
&& !v.isInstanceOf[CreateNamedStruct]
&& !newList.head.isInstanceOf[CreateNamedStruct]) {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}

test("StringToMap") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case (result: Seq[InternalRow], expected: Seq[InternalRow]) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was needed because we can't do direct equals() comparison between UnsafeRow and other row classes. After this PR's changes, the "SPARK-14793: split wide struct creation into blocks due to JVM code size limit" test case in CodeGenerationSuite was failing because the new code was producing UnsafeRow but the test code was comparing against GenericInternalRow. In the old code, this comparison between sequences of rows was happening in the default case _ => below, but that case doesn't work when the InternalRow implementations are mismatched.

I'm not sure whether this change-of-internal-row-format will have adverse consequences in other parts of the code.

result.size == expected.size && result.zip(expected).forall { case (r, e) =>
checkResult(r, e, exprDataType, exprNullable)
}
case _ =>
result == expected
}
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class Column(val expr: Expression) extends Logging {
UnresolvedAlias(a, Some(Column.generateAlias))

// Wait until the struct is resolved. This will generate a nicer looking alias.
case struct: CreateNamedStructLike => UnresolvedAlias(struct)
case struct: CreateNamedStruct => UnresolvedAlias(struct)

case expr: Expression => Alias(expr, toPrettySQL(expr))()
}
Expand Down