Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,6 @@ object UnsafeProjection
InterpretedUnsafeProjection.createProjection(in)
}

protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
}

/**
* Returns an UnsafeProjection for given StructType.
*
Expand All @@ -153,7 +147,7 @@ object UnsafeProjection
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
createObject(toUnsafeExprs(exprs))
createObject(exprs)
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,20 @@ object CreateStruct extends FunctionBuilder {
}

/**
* Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
* Creates a struct with the given field names and values
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
trait CreateNamedStructLike extends Expression {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
examples = """
Examples:
> SELECT _FUNC_("a", 1, "b", 2, "c", 3);
{"a":1,"b":2,"c":3}
""")
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
lazy val (nameExprs, valExprs) = children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip
Expand Down Expand Up @@ -348,23 +359,6 @@ trait CreateNamedStructLike extends Expression {
override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
}

/**
* Creates a struct with the given field names and values
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
examples = """
Examples:
> SELECT _FUNC_("a", 1, "b", 2, "c", 3);
{"a":1,"b":2,"c":3}
""")
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
Expand Down Expand Up @@ -397,22 +391,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def prettyName: String = "named_struct"
}

/**
* Creates a struct with the given field names and values. This is a variant that returns
* UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
* this expression automatically at runtime.
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value)
}

override def prettyName: String = "named_struct_unsafe"
}

/**
* Creates a map after splitting the input text into key/value pairs using delimiters
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions.
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
createNamedStructLike.valExprs(ordinal)
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)

// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))

case CreateNamedStructUnsafe(children) =>
CreateNamedStructUnsafe(children.map(normalize))

case CreateArray(children) =>
CreateArray(children.map(normalize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStructLike]
&& !newList.head.isInstanceOf[CreateNamedStructLike]) {
&& !v.isInstanceOf[CreateNamedStruct]
&& !newList.head.isInstanceOf[CreateNamedStruct]) {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}

test("StringToMap") {
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Column(val expr: Expression) extends Logging {
UnresolvedAlias(a, Some(Column.generateAlias))

// Wait until the struct is resolved. This will generate a nicer looking alias.
case struct: CreateNamedStructLike => UnresolvedAlias(struct)
case struct: CreateNamedStruct => UnresolvedAlias(struct)

case expr: Expression => Alias(expr, toPrettySQL(expr))()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.ArrayType

/**
* A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map).
Expand Down Expand Up @@ -64,6 +68,24 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession {
val ds100_5 = Seq(S100_5()).toDS()
ds100_5.rdd.count
}

test("SPARK-29503 nest unsafe struct inside safe array") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is encountered only when whole stage codegen is disabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least yes for provided reproducer. For other case I'm not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For whole stage codegen, CreateNamedStruct is not converted to CreateNamedStructUnsafe, so the nested struct is not unsafe one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK got it. Thanks for explanation.

val df = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items")

// items: Seq[Int] => items.map { item => Seq(Struct(item)) }
val result = df.select(
new Column(MapObjects(
(item: Expression) => array(struct(new Column(item))).expr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, while fix seems fine to me too, was this only the reproducer? MapObjects is supposed to be internal purpose - it's under catalyst package.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't spent another time to try it (as it seems to be clean and simple reproducer). I'm not sure it's not going to be valid reproducer just due to pulling catalyst package. Catalyst could analyze the query and inject it if necessary in any way.

I indicated you'd like to revisit #25745 - that was WIP and it didn't have any number of performance gain. I'd rather choose "safeness" over "speed", and even we haven't figured out there's outstanding difference between twos. It was the only one case MapObjects could have unsafe struct, by allowing this, safe and unsafe are possibly mixed up leading to encounter corner case.

Copy link
Member

@HyukjinKwon HyukjinKwon Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I am not against this change. In that way, I think this fix is fine but wanted to know if this actually affects any user-facing surface.

Was also wondering if we can benefit from #25745 since some investigations look already made there to completely use unsafe one instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess of #25745 is, it was based on the assumption that it's safe to replace CreateNamedStruct with CreateNamedStructUnsafe as we already have one path to do this - and this observation broke the assumption. IMHO, once we found it's not safe to do this, the improvement has to prove safety before we take its benefits into account.

$"items".expr,
df.schema("items").dataType.asInstanceOf[ArrayType].elementType
)) as "items"
).collect()

assert(result.size === 1)
assert(result === Row(Seq(Seq(Row(1)), Seq(Row(2)), Seq(Row(3)))) :: Nil)
}
}
}

class S100(
Expand Down